Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wire in Persona matching to muxing #1244

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "e4c05d7591a8"
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from codegate.config import Config, ConfigurationError
from codegate.db.connection import (
init_db_sync,
init_session_if_not_exists,
init_instance,
init_session_if_not_exists,
)
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class Config:
# The value 0.3 was found through experimentation. See /tests/muxing/test_semantic_router.py
# It's the threshold value to determine if a persona description is similar to existing personas
persona_diff_desc_threshold = 0.3
# Weight factor for distances in the persona description similarity calculation. Check
# the function _weight_distances for more details. Range is [0, 1].
distances_weight_factor = 0.8

# Provider URLs with defaults
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
Expand Down
41 changes: 19 additions & 22 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ async def init_instance(self) -> None:
await self._execute_with_no_return(sql, instance.model_dump())
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Instance already initialized.")
raise AlreadyExistsError("Instance already initialized.")


class DbReader(DbCodeGate):
Expand Down Expand Up @@ -1059,6 +1059,24 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
)
return personas[0] if personas else None

async def get_persona_embed_by_name(self, persona_name: str) -> Optional[PersonaEmbedding]:
"""
Get a persona by name.
"""
sql = text(
"""
SELECT
id, name, description, description_embedding
FROM personas
WHERE name = :name
"""
)
conditions = {"name": persona_name}
personas = await self._exec_select_conditions_to_pydantic(
PersonaEmbedding, sql, conditions, should_raise=True
)
return personas[0] if personas else None

async def get_distance_to_existing_personas(
self, query_embedding: np.ndarray, exclude_id: Optional[str]
) -> List[PersonaDistance]:
Expand Down Expand Up @@ -1086,27 +1104,6 @@ async def get_distance_to_existing_personas(
)
return persona_distances

async def get_distance_to_persona(
self, persona_id: str, query_embedding: np.ndarray
) -> PersonaDistance:
"""
Get the distance between a persona and a query embedding.
"""
sql = """
SELECT
id,
name,
description,
vec_distance_cosine(description_embedding, :query_embedding) as distance
FROM personas
WHERE id = :id
"""
conditions = {"id": persona_id, "query_embedding": query_embedding}
persona_distance = await self._exec_vec_db_query_to_pydantic(
sql, conditions, PersonaDistance
)
return persona_distance[0]

async def get_all_personas(self) -> List[Persona]:
"""
Get all the personas.
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def nd_array_custom_serializer(x):
NdArray = Annotated[
np.ndarray,
BeforeValidator(nd_array_custom_before_validator),
PlainSerializer(nd_array_custom_serializer, return_type=str),
PlainSerializer(nd_array_custom_serializer, return_type=np.ndarray),
]

VALID_PERSONA_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_ -]+$")
Expand Down
4 changes: 4 additions & 0 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class MuxMatcherType(str, Enum):
fim_filename = "fim_filename"
# Match based on chat request type. It will match if the request type is chat
chat_filename = "chat_filename"
# Match the user messages to the persona description
persona_description = "persona_description"
# Match the system prompt to the persona description
sys_prompt_persona_desc = "sys_prompt_persona_desc"


class MuxRule(pydantic.BaseModel):
Expand Down
103 changes: 85 additions & 18 deletions src/codegate/muxing/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self):
self._n_gpu = conf.chat_model_n_gpu_layers
self._persona_threshold = conf.persona_threshold
self._persona_diff_desc_threshold = conf.persona_diff_desc_threshold
self._distances_weight_factor = conf.distances_weight_factor
self._db_recorder = DbRecorder()
self._db_reader = DbReader()

Expand Down Expand Up @@ -99,18 +100,17 @@ def _clean_text_for_embedding(self, text: str) -> str:

return text

async def _embed_text(self, text: str) -> np.ndarray:
async def _embed_texts(self, texts: List[str]) -> np.ndarray:
"""
Helper function to embed text using the inference engine.
"""
cleaned_text = self._clean_text_for_embedding(text)
cleaned_texts = [self._clean_text_for_embedding(text) for text in texts]
# .embed returns a list of embeddings
embed_list = await self._inference_engine.embed(
self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu
self._embeddings_model, cleaned_texts, n_gpu_layers=self._n_gpu
)
# Use only the first entry in the list and make sure we have the appropriate type
logger.debug("Text embedded in semantic routing", text=cleaned_text[:50])
return np.array(embed_list[0], dtype=np.float32)
logger.debug("Text embedded in semantic routing", num_texts=len(texts))
return np.array(embed_list, dtype=np.float32)

async def _is_persona_description_diff(
self, emb_persona_desc: np.ndarray, exclude_id: Optional[str]
Expand Down Expand Up @@ -142,7 +142,8 @@ async def _validate_persona_description(
Validate the persona description by embedding the text and checking if it is
different enough from existing personas.
"""
emb_persona_desc = await self._embed_text(persona_desc)
emb_persona_desc_list = await self._embed_texts([persona_desc])
emb_persona_desc = emb_persona_desc_list[0]
if not await self._is_persona_description_diff(emb_persona_desc, exclude_id):
raise PersonaSimilarDescriptionError(
"The persona description is too similar to existing personas."
Expand Down Expand Up @@ -217,21 +218,87 @@ async def delete_persona(self, persona_name: str) -> None:
await self._db_recorder.delete_persona(persona.id)
logger.info(f"Deleted persona {persona_name} from the database.")

async def check_persona_match(self, persona_name: str, query: str) -> bool:
async def _get_cosine_distance(self, emb_queries: np.ndarray, emb_persona: np.ndarray) -> float:
"""
Check if the query matches the persona description. A vector similarity
search is performed between the query and the persona description.
Calculate the cosine distance between the queries embeddings and persona embedding.
Persona embedding is a single vector of length M
Queries embeddings is a matrix of shape (N, M)
N is the number of queries. User messages in this case.
M is the number of dimensions in the embedding

Defintion of cosine distance: 1 - cosine similarity
[Cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)

NOTE: Experimented by individually querying SQLite for each query, but as the number
of queries increases, the performance is better with NumPy. If the number of queries
is small the performance is onpar. Hence the decision to use NumPy.
"""
# Handle the case where we have a single query (single user message)
if emb_queries.ndim == 1:
emb_queries = emb_queries.reshape(1, -1)

emb_queries_norm = np.linalg.norm(emb_queries, axis=1)
persona_embed_norm = np.linalg.norm(emb_persona)
cosine_similarities = np.dot(emb_queries, emb_persona.T) / (
emb_queries_norm * persona_embed_norm
)
# We could also use directly cosine_similarities but we get the distance to match
# the behavior of SQLite function vec_distance_cosine
cosine_distances = 1 - cosine_similarities
return cosine_distances

async def _weight_distances(self, distances: np.ndarray) -> np.ndarray:
"""
Weights the received distances, with later positions being more important and the
last position unchanged. The reasoning is that the distances correspond to user
messages, with the last message being the most recent and therefore the most
important.

Args:
distances: NumPy array of float values between 0 and 2
weight_factor: Factor that determines how quickly weights increase (0-1)
Lower values create a steeper importance curve. 1 makes
all weights equal.

Returns:
Weighted distances as a NumPy array
"""
# Get array length
n = len(distances)

# Create positions array in reverse order (n-1, n-2, ..., 1, 0)
# This makes the last element have position 0
positions = np.arange(start=n - 1, stop=-1, step=-1, dtype=np.float32)

# Create weights - now the last element (position 0) gets weight 1
weights = self._distances_weight_factor**positions

# Apply weights by dividing distances
# Smaller weight -> larger effective distance
weighted_distances = distances / weights
return weighted_distances

async def check_persona_match(self, persona_name: str, queries: List[str]) -> bool:
"""
Check if the queries match the persona description. A vector similarity
search is performed between the queries and the persona description.
0 means the vectors are identical, 2 means they are orthogonal.
See
[sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)

The vectors are compared using cosine similarity implemented in _get_cosine_distance.
"""
persona = await self._db_reader.get_persona_by_name(persona_name)
if not persona:
persona_embed = await self._db_reader.get_persona_embed_by_name(persona_name)
if not persona_embed:
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")

emb_query = await self._embed_text(query)
persona_distance = await self._db_reader.get_distance_to_persona(persona.id, emb_query)
logger.info(f"Persona distance to {persona_name}", distance=persona_distance.distance)
if persona_distance.distance < self._persona_threshold:
emb_queries = await self._embed_texts(queries)
cosine_distances = await self._get_cosine_distance(
emb_queries, persona_embed.description_embedding
)
logger.debug("Cosine distances calculated", cosine_distances=cosine_distances)

weighted_distances = await self._weight_distances(cosine_distances)
logger.info("Weighted distances to persona", weighted_distances=weighted_distances)

if np.any(weighted_distances < self._persona_threshold):
return True
return False
Loading
Loading