Skip to content

Commit 2b4b13c

Browse files
Wire in Persona matching to muxing
Closes: #1220 This PR: - Changes the way the queries are checked to determine if they match a persona description. Before we were using SQLite now we use Numpy - Adds a new MuxMatcherType called `persona_description` that checks if the user messages in a request match a persona description
1 parent 96aa48d commit 2b4b13c

File tree

8 files changed

+213
-46
lines changed

8 files changed

+213
-46
lines changed

migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py

-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from typing import Sequence, Union
1010

1111
from alembic import op
12-
import sqlalchemy as sa
13-
1412

1513
# revision identifiers, used by Alembic.
1614
revision: str = "e4c05d7591a8"

src/codegate/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from codegate.config import Config, ConfigurationError
1717
from codegate.db.connection import (
1818
init_db_sync,
19-
init_session_if_not_exists,
2019
init_instance,
20+
init_session_if_not_exists,
2121
)
2222
from codegate.pipeline.factory import PipelineFactory
2323
from codegate.pipeline.sensitive_data.manager import SensitiveDataManager

src/codegate/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class Config:
6565
# The value 0.3 was found through experimentation. See /tests/muxing/test_semantic_router.py
6666
# It's the threshold value to determine if a persona description is similar to existing personas
6767
persona_diff_desc_threshold = 0.3
68+
# Weight factor for distances in the persona description similarity calculation. Check
69+
# the function _weight_distances for more details. Range is [0, 1].
70+
distances_weight_factor = 0.8
6871

6972
# Provider URLs with defaults
7073
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())

src/codegate/db/connection.py

+19-22
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ async def init_instance(self) -> None:
617617
await self._execute_with_no_return(sql, instance.model_dump())
618618
except IntegrityError as e:
619619
logger.debug(f"Exception type: {type(e)}")
620-
raise AlreadyExistsError(f"Instance already initialized.")
620+
raise AlreadyExistsError("Instance already initialized.")
621621

622622

623623
class DbReader(DbCodeGate):
@@ -1059,6 +1059,24 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
10591059
)
10601060
return personas[0] if personas else None
10611061

1062+
async def get_persona_embed_by_name(self, persona_name: str) -> Optional[PersonaEmbedding]:
1063+
"""
1064+
Get a persona by name.
1065+
"""
1066+
sql = text(
1067+
"""
1068+
SELECT
1069+
id, name, description, description_embedding
1070+
FROM personas
1071+
WHERE name = :name
1072+
"""
1073+
)
1074+
conditions = {"name": persona_name}
1075+
personas = await self._exec_select_conditions_to_pydantic(
1076+
PersonaEmbedding, sql, conditions, should_raise=True
1077+
)
1078+
return personas[0] if personas else None
1079+
10621080
async def get_distance_to_existing_personas(
10631081
self, query_embedding: np.ndarray, exclude_id: Optional[str]
10641082
) -> List[PersonaDistance]:
@@ -1086,27 +1104,6 @@ async def get_distance_to_existing_personas(
10861104
)
10871105
return persona_distances
10881106

1089-
async def get_distance_to_persona(
1090-
self, persona_id: str, query_embedding: np.ndarray
1091-
) -> PersonaDistance:
1092-
"""
1093-
Get the distance between a persona and a query embedding.
1094-
"""
1095-
sql = """
1096-
SELECT
1097-
id,
1098-
name,
1099-
description,
1100-
vec_distance_cosine(description_embedding, :query_embedding) as distance
1101-
FROM personas
1102-
WHERE id = :id
1103-
"""
1104-
conditions = {"id": persona_id, "query_embedding": query_embedding}
1105-
persona_distance = await self._exec_vec_db_query_to_pydantic(
1106-
sql, conditions, PersonaDistance
1107-
)
1108-
return persona_distance[0]
1109-
11101107
async def get_all_personas(self) -> List[Persona]:
11111108
"""
11121109
Get all the personas.

src/codegate/muxing/models.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class MuxMatcherType(str, Enum):
3232
fim_filename = "fim_filename"
3333
# Match based on chat request type. It will match if the request type is chat
3434
chat_filename = "chat_filename"
35+
# Match the request content to the persona description
36+
persona_description = "persona_description"
3537

3638

3739
class MuxRule(pydantic.BaseModel):

src/codegate/muxing/persona.py

+85-18
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self):
4343
self._n_gpu = conf.chat_model_n_gpu_layers
4444
self._persona_threshold = conf.persona_threshold
4545
self._persona_diff_desc_threshold = conf.persona_diff_desc_threshold
46+
self._distances_weight_factor = conf.distances_weight_factor
4647
self._db_recorder = DbRecorder()
4748
self._db_reader = DbReader()
4849

@@ -99,18 +100,17 @@ def _clean_text_for_embedding(self, text: str) -> str:
99100

100101
return text
101102

102-
async def _embed_text(self, text: str) -> np.ndarray:
103+
async def _embed_texts(self, texts: List[str]) -> np.ndarray:
103104
"""
104105
Helper function to embed text using the inference engine.
105106
"""
106-
cleaned_text = self._clean_text_for_embedding(text)
107+
cleaned_texts = [self._clean_text_for_embedding(text) for text in texts]
107108
# .embed returns a list of embeddings
108109
embed_list = await self._inference_engine.embed(
109-
self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu
110+
self._embeddings_model, cleaned_texts, n_gpu_layers=self._n_gpu
110111
)
111-
# Use only the first entry in the list and make sure we have the appropriate type
112-
logger.debug("Text embedded in semantic routing", text=cleaned_text[:50])
113-
return np.array(embed_list[0], dtype=np.float32)
112+
logger.debug("Text embedded in semantic routing", num_texts=len(texts))
113+
return np.array(embed_list, dtype=np.float32)
114114

115115
async def _is_persona_description_diff(
116116
self, emb_persona_desc: np.ndarray, exclude_id: Optional[str]
@@ -142,7 +142,8 @@ async def _validate_persona_description(
142142
Validate the persona description by embedding the text and checking if it is
143143
different enough from existing personas.
144144
"""
145-
emb_persona_desc = await self._embed_text(persona_desc)
145+
emb_persona_desc_list = await self._embed_texts([persona_desc])
146+
emb_persona_desc = emb_persona_desc_list[0]
146147
if not await self._is_persona_description_diff(emb_persona_desc, exclude_id):
147148
raise PersonaSimilarDescriptionError(
148149
"The persona description is too similar to existing personas."
@@ -217,21 +218,87 @@ async def delete_persona(self, persona_name: str) -> None:
217218
await self._db_recorder.delete_persona(persona.id)
218219
logger.info(f"Deleted persona {persona_name} from the database.")
219220

220-
async def check_persona_match(self, persona_name: str, query: str) -> bool:
221+
async def _get_cosine_distance(self, emb_queries: np.ndarray, emb_persona: np.ndarray) -> float:
221222
"""
222-
Check if the query matches the persona description. A vector similarity
223-
search is performed between the query and the persona description.
223+
Calculate the cosine distance between the queries embeddings and persona embedding.
224+
Persona embedding is a single vector of length M
225+
Queries embeddings is a matrix of shape (N, M)
226+
N is the number of queries. User messages in this case.
227+
M is the number of dimensions in the embedding
228+
229+
Defintion of cosine distance: 1 - cosine similarity
230+
[Cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
231+
232+
NOTE: Experimented by individually querying SQLite for each query, but as the number
233+
of queries increases, the performance is better with NumPy. If the number of queries
234+
is small the performance is onpar. Hence the decision to use NumPy.
235+
"""
236+
# Handle the case where we have a single query (single user message)
237+
if emb_queries.ndim == 1:
238+
emb_queries = emb_queries.reshape(1, -1)
239+
240+
emb_queries_norm = np.linalg.norm(emb_queries, axis=1)
241+
persona_embed_norm = np.linalg.norm(emb_persona)
242+
cosine_similarities = np.dot(emb_queries, emb_persona.T) / (
243+
emb_queries_norm * persona_embed_norm
244+
)
245+
# We could also use directly cosine_similarities but we get the distance to match
246+
# the behavior of SQLite function vec_distance_cosine
247+
cosine_distances = 1 - cosine_similarities
248+
return cosine_distances
249+
250+
async def _weight_distances(self, distances: np.ndarray) -> np.ndarray:
251+
"""
252+
Weights the received distances, with later positions being more important and the
253+
last position unchanged. The reasoning is that the distances correspond to user
254+
messages, with the last message being the most recent and therefore the most
255+
important.
256+
257+
Args:
258+
distances: NumPy array of float values between 0 and 2
259+
weight_factor: Factor that determines how quickly weights increase (0-1)
260+
Lower values create a steeper importance curve. 1 makes
261+
all weights equal.
262+
263+
Returns:
264+
Weighted distances as a NumPy array
265+
"""
266+
# Get array length
267+
n = len(distances)
268+
269+
# Create positions array in reverse order (n-1, n-2, ..., 1, 0)
270+
# This makes the last element have position 0
271+
positions = np.arange(n - 1, -1, -1)
272+
273+
# Create weights - now the last element (position 0) gets weight 1
274+
weights = self._distances_weight_factor**positions
275+
276+
# Apply weights by dividing distances
277+
# Smaller weight -> larger effective distance
278+
weighted_distances = distances / weights
279+
return weighted_distances
280+
281+
async def check_persona_match(self, persona_name: str, queries: List[str]) -> bool:
282+
"""
283+
Check if the queries match the persona description. A vector similarity
284+
search is performed between the queries and the persona description.
224285
0 means the vectors are identical, 2 means they are orthogonal.
225-
See
226-
[sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)
286+
287+
The vectors are compared using cosine similarity implemented in _get_cosine_distance.
227288
"""
228-
persona = await self._db_reader.get_persona_by_name(persona_name)
229-
if not persona:
289+
persona_embed = await self._db_reader.get_persona_embed_by_name(persona_name)
290+
if not persona_embed:
230291
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")
231292

232-
emb_query = await self._embed_text(query)
233-
persona_distance = await self._db_reader.get_distance_to_persona(persona.id, emb_query)
234-
logger.info(f"Persona distance to {persona_name}", distance=persona_distance.distance)
235-
if persona_distance.distance < self._persona_threshold:
293+
emb_queries = await self._embed_texts(queries)
294+
cosine_distances = await self._get_cosine_distance(
295+
emb_queries, persona_embed.description_embedding
296+
)
297+
logger.debug("Cosine distances calculated", cosine_distances=cosine_distances)
298+
299+
weighted_distances = await self._weight_distances(cosine_distances)
300+
logger.info("Weighted distances to persona", weighted_distances=weighted_distances)
301+
302+
if np.any(weighted_distances < self._persona_threshold):
236303
return True
237304
return False

src/codegate/muxing/rulematcher.py

+44
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError
1212
from codegate.extract_snippets.factory import BodyCodeExtractorFactory
1313
from codegate.muxing import models as mux_models
14+
from codegate.muxing.persona import PersonaManager
1415

1516
logger = structlog.get_logger("codegate")
1617

@@ -82,6 +83,7 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch
8283
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
8384
mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher,
8485
mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher,
86+
mux_models.MuxMatcherType.persona_description: PersonaDescriptionMuxingRuleMatcher,
8587
}
8688

8789
try:
@@ -171,6 +173,48 @@ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
171173
return is_rule_matched
172174

173175

176+
class PersonaDescriptionMuxingRuleMatcher(MuxingRuleMatcher):
177+
"""Muxing rule to match the request content to a persona description."""
178+
179+
def _get_user_messages_from_body(self, body: Dict) -> List[str]:
180+
"""
181+
Get the user messages from the body to use as queries.
182+
"""
183+
user_messages = []
184+
for msg in body.get("messages", []):
185+
if msg.get("role", "") == "user":
186+
msgs_content = msg.get("content")
187+
if not msgs_content:
188+
continue
189+
if isinstance(msgs_content, list):
190+
for msg_content in msgs_content:
191+
if msg_content.get("type", "") == "text":
192+
user_messages.append(msg_content.get("text", ""))
193+
elif isinstance(msgs_content, str):
194+
user_messages.append(msgs_content)
195+
return user_messages
196+
197+
def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
198+
"""
199+
Return True if the matcher is the persona description matched with the
200+
user messages.
201+
"""
202+
user_messages = self._get_user_messages_from_body(thing_to_match.body)
203+
if not user_messages:
204+
return False
205+
206+
persona_manager = PersonaManager()
207+
is_persona_matched = persona_manager.check_persona_match(
208+
persona_name=self._mux_rule.matcher, queries=user_messages
209+
)
210+
logger.info(
211+
"Persona rule matched",
212+
matcher=self._mux_rule.matcher,
213+
is_persona_matched=is_persona_matched,
214+
)
215+
return is_persona_matched
216+
217+
174218
class MuxingRulesinWorkspaces:
175219
"""A thread safe dictionary to store the muxing rules in workspaces."""
176220

tests/muxing/test_persona.py

+59-3
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager
9090
persona_name = "test_persona"
9191
query = "test_query"
9292
with pytest.raises(PersonaDoesNotExistError):
93-
await semantic_router_mocked_db.check_persona_match(persona_name, query)
93+
await semantic_router_mocked_db.check_persona_match(persona_name, [query])
9494

9595

9696
class PersonaMatchTest(BaseModel):
@@ -333,11 +333,39 @@ async def test_check_persona_pass_match(
333333
# Check for the queries that should pass
334334
for query in persona_match_test.pass_queries:
335335
match = await semantic_router_mocked_db.check_persona_match(
336-
persona_match_test.persona_name, query
336+
persona_match_test.persona_name, [query]
337337
)
338338
assert match is True
339339

340340

341+
@pytest.mark.asyncio
342+
@pytest.mark.parametrize(
343+
"persona_match_test",
344+
[
345+
simple_persona,
346+
architect,
347+
coder,
348+
devops_sre,
349+
],
350+
)
351+
async def test_check_persona_pass_match_vector(
352+
semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest
353+
):
354+
"""Test checking persona match."""
355+
await semantic_router_mocked_db.add_persona(
356+
persona_match_test.persona_name, persona_match_test.persona_desc
357+
)
358+
359+
# We disable the weighting between distances since these are no user messages that
360+
# need to be weighted differently, they all are weighted the same.
361+
semantic_router_mocked_db._distances_weight_factor = 1.0
362+
# Check for match passing the entire list
363+
match = await semantic_router_mocked_db.check_persona_match(
364+
persona_match_test.persona_name, persona_match_test.pass_queries
365+
)
366+
assert match is True
367+
368+
341369
@pytest.mark.asyncio
342370
@pytest.mark.parametrize(
343371
"persona_match_test",
@@ -359,11 +387,39 @@ async def test_check_persona_fail_match(
359387
# Check for the queries that should fail
360388
for query in persona_match_test.fail_queries:
361389
match = await semantic_router_mocked_db.check_persona_match(
362-
persona_match_test.persona_name, query
390+
persona_match_test.persona_name, [query]
363391
)
364392
assert match is False
365393

366394

395+
@pytest.mark.asyncio
396+
@pytest.mark.parametrize(
397+
"persona_match_test",
398+
[
399+
simple_persona,
400+
architect,
401+
coder,
402+
devops_sre,
403+
],
404+
)
405+
async def test_check_persona_fail_match_vector(
406+
semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest
407+
):
408+
"""Test checking persona match."""
409+
await semantic_router_mocked_db.add_persona(
410+
persona_match_test.persona_name, persona_match_test.persona_desc
411+
)
412+
413+
# We disable the weighting between distances since these are no user messages that
414+
# need to be weighted differently, they all are weighted the same.
415+
semantic_router_mocked_db._distances_weight_factor = 1.0
416+
# Check for match passing the entire list
417+
match = await semantic_router_mocked_db.check_persona_match(
418+
persona_match_test.persona_name, persona_match_test.fail_queries
419+
)
420+
assert match is False
421+
422+
367423
@pytest.mark.asyncio
368424
@pytest.mark.parametrize(
369425
"personas",

0 commit comments

Comments
 (0)