Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wire in Persona matching to muxing #1244

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = "e4c05d7591a8"
Expand Down
2 changes: 1 addition & 1 deletion src/codegate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from codegate.config import Config, ConfigurationError
from codegate.db.connection import (
init_db_sync,
init_session_if_not_exists,
init_instance,
init_session_if_not_exists,
)
from codegate.pipeline.factory import PipelineFactory
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class Config:
# The value 0.3 was found through experimentation. See /tests/muxing/test_semantic_router.py
# It's the threshold value to determine if a persona description is similar to existing personas
persona_diff_desc_threshold = 0.3
# Weight factor for distances in the persona description similarity calculation. Check
# the function _weight_distances for more details. Range is [0, 1].
distances_weight_factor = 0.8

# Provider URLs with defaults
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
Expand Down
41 changes: 19 additions & 22 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ async def init_instance(self) -> None:
await self._execute_with_no_return(sql, instance.model_dump())
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Instance already initialized.")
raise AlreadyExistsError("Instance already initialized.")


class DbReader(DbCodeGate):
Expand Down Expand Up @@ -1059,6 +1059,24 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
)
return personas[0] if personas else None

async def get_persona_embed_by_name(self, persona_name: str) -> Optional[PersonaEmbedding]:
"""
Get a persona by name.
"""
sql = text(
"""
SELECT
id, name, description, description_embedding
FROM personas
WHERE name = :name
"""
)
conditions = {"name": persona_name}
personas = await self._exec_select_conditions_to_pydantic(
PersonaEmbedding, sql, conditions, should_raise=True
)
return personas[0] if personas else None

async def get_distance_to_existing_personas(
self, query_embedding: np.ndarray, exclude_id: Optional[str]
) -> List[PersonaDistance]:
Expand Down Expand Up @@ -1086,27 +1104,6 @@ async def get_distance_to_existing_personas(
)
return persona_distances

async def get_distance_to_persona(
self, persona_id: str, query_embedding: np.ndarray
) -> PersonaDistance:
"""
Get the distance between a persona and a query embedding.
"""
sql = """
SELECT
id,
name,
description,
vec_distance_cosine(description_embedding, :query_embedding) as distance
FROM personas
WHERE id = :id
"""
conditions = {"id": persona_id, "query_embedding": query_embedding}
persona_distance = await self._exec_vec_db_query_to_pydantic(
sql, conditions, PersonaDistance
)
return persona_distance[0]

async def get_all_personas(self) -> List[Persona]:
"""
Get all the personas.
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/muxing/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class MuxMatcherType(str, Enum):
fim_filename = "fim_filename"
# Match based on chat request type. It will match if the request type is chat
chat_filename = "chat_filename"
# Match the request content to the persona description
persona_description = "persona_description"


class MuxRule(pydantic.BaseModel):
Expand Down
103 changes: 85 additions & 18 deletions src/codegate/muxing/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self):
self._n_gpu = conf.chat_model_n_gpu_layers
self._persona_threshold = conf.persona_threshold
self._persona_diff_desc_threshold = conf.persona_diff_desc_threshold
self._distances_weight_factor = conf.distances_weight_factor
self._db_recorder = DbRecorder()
self._db_reader = DbReader()

Expand Down Expand Up @@ -99,18 +100,17 @@ def _clean_text_for_embedding(self, text: str) -> str:

return text

async def _embed_text(self, text: str) -> np.ndarray:
async def _embed_texts(self, texts: List[str]) -> np.ndarray:
"""
Helper function to embed text using the inference engine.
"""
cleaned_text = self._clean_text_for_embedding(text)
cleaned_texts = [self._clean_text_for_embedding(text) for text in texts]
# .embed returns a list of embeddings
embed_list = await self._inference_engine.embed(
self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu
self._embeddings_model, cleaned_texts, n_gpu_layers=self._n_gpu
)
# Use only the first entry in the list and make sure we have the appropriate type
logger.debug("Text embedded in semantic routing", text=cleaned_text[:50])
return np.array(embed_list[0], dtype=np.float32)
logger.debug("Text embedded in semantic routing", num_texts=len(texts))
return np.array(embed_list, dtype=np.float32)

async def _is_persona_description_diff(
self, emb_persona_desc: np.ndarray, exclude_id: Optional[str]
Expand Down Expand Up @@ -142,7 +142,8 @@ async def _validate_persona_description(
Validate the persona description by embedding the text and checking if it is
different enough from existing personas.
"""
emb_persona_desc = await self._embed_text(persona_desc)
emb_persona_desc_list = await self._embed_texts([persona_desc])
emb_persona_desc = emb_persona_desc_list[0]
if not await self._is_persona_description_diff(emb_persona_desc, exclude_id):
raise PersonaSimilarDescriptionError(
"The persona description is too similar to existing personas."
Expand Down Expand Up @@ -217,21 +218,87 @@ async def delete_persona(self, persona_name: str) -> None:
await self._db_recorder.delete_persona(persona.id)
logger.info(f"Deleted persona {persona_name} from the database.")

async def check_persona_match(self, persona_name: str, query: str) -> bool:
async def _get_cosine_distance(self, emb_queries: np.ndarray, emb_persona: np.ndarray) -> float:
"""
Check if the query matches the persona description. A vector similarity
search is performed between the query and the persona description.
Calculate the cosine distance between the queries embeddings and persona embedding.
Persona embedding is a single vector of length M
Queries embeddings is a matrix of shape (N, M)
N is the number of queries. User messages in this case.
M is the number of dimensions in the embedding

Defintion of cosine distance: 1 - cosine similarity
[Cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)

NOTE: Experimented by individually querying SQLite for each query, but as the number
of queries increases, the performance is better with NumPy. If the number of queries
is small the performance is onpar. Hence the decision to use NumPy.
"""
# Handle the case where we have a single query (single user message)
if emb_queries.ndim == 1:
emb_queries = emb_queries.reshape(1, -1)

emb_queries_norm = np.linalg.norm(emb_queries, axis=1)
persona_embed_norm = np.linalg.norm(emb_persona)
cosine_similarities = np.dot(emb_queries, emb_persona.T) / (
emb_queries_norm * persona_embed_norm
)
# We could also use directly cosine_similarities but we get the distance to match
# the behavior of SQLite function vec_distance_cosine
cosine_distances = 1 - cosine_similarities
return cosine_distances

async def _weight_distances(self, distances: np.ndarray) -> np.ndarray:
"""
Weights the received distances, with later positions being more important and the
last position unchanged. The reasoning is that the distances correspond to user
messages, with the last message being the most recent and therefore the most
important.

Args:
distances: NumPy array of float values between 0 and 2
weight_factor: Factor that determines how quickly weights increase (0-1)
Lower values create a steeper importance curve. 1 makes
all weights equal.

Returns:
Weighted distances as a NumPy array
"""
# Get array length
n = len(distances)

# Create positions array in reverse order (n-1, n-2, ..., 1, 0)
# This makes the last element have position 0
positions = np.arange(n - 1, -1, -1)

# Create weights - now the last element (position 0) gets weight 1
weights = self._distances_weight_factor**positions

# Apply weights by dividing distances
# Smaller weight -> larger effective distance
weighted_distances = distances / weights
return weighted_distances

async def check_persona_match(self, persona_name: str, queries: List[str]) -> bool:
"""
Check if the queries match the persona description. A vector similarity
search is performed between the queries and the persona description.
0 means the vectors are identical, 2 means they are orthogonal.
See
[sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)

The vectors are compared using cosine similarity implemented in _get_cosine_distance.
"""
persona = await self._db_reader.get_persona_by_name(persona_name)
if not persona:
persona_embed = await self._db_reader.get_persona_embed_by_name(persona_name)
if not persona_embed:
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")

emb_query = await self._embed_text(query)
persona_distance = await self._db_reader.get_distance_to_persona(persona.id, emb_query)
logger.info(f"Persona distance to {persona_name}", distance=persona_distance.distance)
if persona_distance.distance < self._persona_threshold:
emb_queries = await self._embed_texts(queries)
cosine_distances = await self._get_cosine_distance(
emb_queries, persona_embed.description_embedding
)
logger.debug("Cosine distances calculated", cosine_distances=cosine_distances)

weighted_distances = await self._weight_distances(cosine_distances)
logger.info("Weighted distances to persona", weighted_distances=weighted_distances)

if np.any(weighted_distances < self._persona_threshold):
return True
return False
44 changes: 44 additions & 0 deletions src/codegate/muxing/rulematcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
from codegate.muxing import models as mux_models
from codegate.muxing.persona import PersonaManager

logger = structlog.get_logger("codegate")

Expand Down Expand Up @@ -82,6 +83,7 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher,
mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher,
mux_models.MuxMatcherType.persona_description: PersonaDescriptionMuxingRuleMatcher,
}

try:
Expand Down Expand Up @@ -171,6 +173,48 @@ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
return is_rule_matched


class PersonaDescriptionMuxingRuleMatcher(MuxingRuleMatcher):
"""Muxing rule to match the request content to a persona description."""

def _get_user_messages_from_body(self, body: Dict) -> List[str]:
"""
Get the user messages from the body to use as queries.
"""
user_messages = []
for msg in body.get("messages", []):
if msg.get("role", "") == "user":
msgs_content = msg.get("content")
if not msgs_content:
continue
if isinstance(msgs_content, list):
for msg_content in msgs_content:
if msg_content.get("type", "") == "text":
user_messages.append(msg_content.get("text", ""))
elif isinstance(msgs_content, str):
user_messages.append(msgs_content)
return user_messages

def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
"""
Return True if the matcher is the persona description matched with the
user messages.
"""
user_messages = self._get_user_messages_from_body(thing_to_match.body)
if not user_messages:
return False

persona_manager = PersonaManager()
is_persona_matched = persona_manager.check_persona_match(
persona_name=self._mux_rule.matcher, queries=user_messages
)
logger.info(
"Persona rule matched",
matcher=self._mux_rule.matcher,
is_persona_matched=is_persona_matched,
)
return is_persona_matched


class MuxingRulesinWorkspaces:
"""A thread safe dictionary to store the muxing rules in workspaces."""

Expand Down
62 changes: 59 additions & 3 deletions tests/muxing/test_persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager
persona_name = "test_persona"
query = "test_query"
with pytest.raises(PersonaDoesNotExistError):
await semantic_router_mocked_db.check_persona_match(persona_name, query)
await semantic_router_mocked_db.check_persona_match(persona_name, [query])


class PersonaMatchTest(BaseModel):
Expand Down Expand Up @@ -333,11 +333,39 @@ async def test_check_persona_pass_match(
# Check for the queries that should pass
for query in persona_match_test.pass_queries:
match = await semantic_router_mocked_db.check_persona_match(
persona_match_test.persona_name, query
persona_match_test.persona_name, [query]
)
assert match is True


@pytest.mark.asyncio
@pytest.mark.parametrize(
"persona_match_test",
[
simple_persona,
architect,
coder,
devops_sre,
],
)
async def test_check_persona_pass_match_vector(
semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest
):
"""Test checking persona match."""
await semantic_router_mocked_db.add_persona(
persona_match_test.persona_name, persona_match_test.persona_desc
)

# We disable the weighting between distances since these are no user messages that
# need to be weighted differently, they all are weighted the same.
semantic_router_mocked_db._distances_weight_factor = 1.0
# Check for match passing the entire list
match = await semantic_router_mocked_db.check_persona_match(
persona_match_test.persona_name, persona_match_test.pass_queries
)
assert match is True


@pytest.mark.asyncio
@pytest.mark.parametrize(
"persona_match_test",
Expand All @@ -359,11 +387,39 @@ async def test_check_persona_fail_match(
# Check for the queries that should fail
for query in persona_match_test.fail_queries:
match = await semantic_router_mocked_db.check_persona_match(
persona_match_test.persona_name, query
persona_match_test.persona_name, [query]
)
assert match is False


@pytest.mark.asyncio
@pytest.mark.parametrize(
"persona_match_test",
[
simple_persona,
architect,
coder,
devops_sre,
],
)
async def test_check_persona_fail_match_vector(
semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest
):
"""Test checking persona match."""
await semantic_router_mocked_db.add_persona(
persona_match_test.persona_name, persona_match_test.persona_desc
)

# We disable the weighting between distances since these are no user messages that
# need to be weighted differently, they all are weighted the same.
semantic_router_mocked_db._distances_weight_factor = 1.0
# Check for match passing the entire list
match = await semantic_router_mocked_db.check_persona_match(
persona_match_test.persona_name, persona_match_test.fail_queries
)
assert match is False


@pytest.mark.asyncio
@pytest.mark.parametrize(
"personas",
Expand Down
Loading