From 2b4b13ca9d08f8e234fa4cd09e71ecb05d6feb40 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 6 Mar 2025 18:05:35 +0200 Subject: [PATCH 1/5] Wire in Persona matching to muxing Closes: #1220 This PR: - Changes the way the queries are checked to determine if they match a persona description. Before we were using SQLite now we use Numpy - Adds a new MuxMatcherType called `persona_description` that checks if the user messages in a request match a persona description --- ...126-e4c05d7591a8_add_installation_table.py | 2 - src/codegate/cli.py | 2 +- src/codegate/config.py | 3 + src/codegate/db/connection.py | 41 ++++--- src/codegate/muxing/models.py | 2 + src/codegate/muxing/persona.py | 103 +++++++++++++++--- src/codegate/muxing/rulematcher.py | 44 ++++++++ tests/muxing/test_persona.py | 62 ++++++++++- 8 files changed, 213 insertions(+), 46 deletions(-) diff --git a/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py index 775e3967..9e2b6c13 100644 --- a/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py +++ b/migrations/versions/2025_03_05_2126-e4c05d7591a8_add_installation_table.py @@ -9,8 +9,6 @@ from typing import Sequence, Union from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. revision: str = "e4c05d7591a8" diff --git a/src/codegate/cli.py b/src/codegate/cli.py index 1ae3f9c2..5c08821c 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -16,8 +16,8 @@ from codegate.config import Config, ConfigurationError from codegate.db.connection import ( init_db_sync, - init_session_if_not_exists, init_instance, + init_session_if_not_exists, ) from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.sensitive_data.manager import SensitiveDataManager diff --git a/src/codegate/config.py b/src/codegate/config.py index 179ec4d3..d7c5eb63 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -65,6 +65,9 @@ class Config: # The value 0.3 was found through experimentation. See /tests/muxing/test_semantic_router.py # It's the threshold value to determine if a persona description is similar to existing personas persona_diff_desc_threshold = 0.3 + # Weight factor for distances in the persona description similarity calculation. Check + # the function _weight_distances for more details. Range is [0, 1]. + distances_weight_factor = 0.8 # Provider URLs with defaults provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy()) diff --git a/src/codegate/db/connection.py b/src/codegate/db/connection.py index 3f439aea..0ddbf4bb 100644 --- a/src/codegate/db/connection.py +++ b/src/codegate/db/connection.py @@ -617,7 +617,7 @@ async def init_instance(self) -> None: await self._execute_with_no_return(sql, instance.model_dump()) except IntegrityError as e: logger.debug(f"Exception type: {type(e)}") - raise AlreadyExistsError(f"Instance already initialized.") + raise AlreadyExistsError("Instance already initialized.") class DbReader(DbCodeGate): @@ -1059,6 +1059,24 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]: ) return personas[0] if personas else None + async def get_persona_embed_by_name(self, persona_name: str) -> Optional[PersonaEmbedding]: + """ + Get a persona by name. + """ + sql = text( + """ + SELECT + id, name, description, description_embedding + FROM personas + WHERE name = :name + """ + ) + conditions = {"name": persona_name} + personas = await self._exec_select_conditions_to_pydantic( + PersonaEmbedding, sql, conditions, should_raise=True + ) + return personas[0] if personas else None + async def get_distance_to_existing_personas( self, query_embedding: np.ndarray, exclude_id: Optional[str] ) -> List[PersonaDistance]: @@ -1086,27 +1104,6 @@ async def get_distance_to_existing_personas( ) return persona_distances - async def get_distance_to_persona( - self, persona_id: str, query_embedding: np.ndarray - ) -> PersonaDistance: - """ - Get the distance between a persona and a query embedding. - """ - sql = """ - SELECT - id, - name, - description, - vec_distance_cosine(description_embedding, :query_embedding) as distance - FROM personas - WHERE id = :id - """ - conditions = {"id": persona_id, "query_embedding": query_embedding} - persona_distance = await self._exec_vec_db_query_to_pydantic( - sql, conditions, PersonaDistance - ) - return persona_distance[0] - async def get_all_personas(self) -> List[Persona]: """ Get all the personas. diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py index 5637c5b8..2f070265 100644 --- a/src/codegate/muxing/models.py +++ b/src/codegate/muxing/models.py @@ -32,6 +32,8 @@ class MuxMatcherType(str, Enum): fim_filename = "fim_filename" # Match based on chat request type. It will match if the request type is chat chat_filename = "chat_filename" + # Match the request content to the persona description + persona_description = "persona_description" class MuxRule(pydantic.BaseModel): diff --git a/src/codegate/muxing/persona.py b/src/codegate/muxing/persona.py index ac21205c..4620fb45 100644 --- a/src/codegate/muxing/persona.py +++ b/src/codegate/muxing/persona.py @@ -43,6 +43,7 @@ def __init__(self): self._n_gpu = conf.chat_model_n_gpu_layers self._persona_threshold = conf.persona_threshold self._persona_diff_desc_threshold = conf.persona_diff_desc_threshold + self._distances_weight_factor = conf.distances_weight_factor self._db_recorder = DbRecorder() self._db_reader = DbReader() @@ -99,18 +100,17 @@ def _clean_text_for_embedding(self, text: str) -> str: return text - async def _embed_text(self, text: str) -> np.ndarray: + async def _embed_texts(self, texts: List[str]) -> np.ndarray: """ Helper function to embed text using the inference engine. """ - cleaned_text = self._clean_text_for_embedding(text) + cleaned_texts = [self._clean_text_for_embedding(text) for text in texts] # .embed returns a list of embeddings embed_list = await self._inference_engine.embed( - self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu + self._embeddings_model, cleaned_texts, n_gpu_layers=self._n_gpu ) - # Use only the first entry in the list and make sure we have the appropriate type - logger.debug("Text embedded in semantic routing", text=cleaned_text[:50]) - return np.array(embed_list[0], dtype=np.float32) + logger.debug("Text embedded in semantic routing", num_texts=len(texts)) + return np.array(embed_list, dtype=np.float32) async def _is_persona_description_diff( self, emb_persona_desc: np.ndarray, exclude_id: Optional[str] @@ -142,7 +142,8 @@ async def _validate_persona_description( Validate the persona description by embedding the text and checking if it is different enough from existing personas. """ - emb_persona_desc = await self._embed_text(persona_desc) + emb_persona_desc_list = await self._embed_texts([persona_desc]) + emb_persona_desc = emb_persona_desc_list[0] if not await self._is_persona_description_diff(emb_persona_desc, exclude_id): raise PersonaSimilarDescriptionError( "The persona description is too similar to existing personas." @@ -217,21 +218,87 @@ async def delete_persona(self, persona_name: str) -> None: await self._db_recorder.delete_persona(persona.id) logger.info(f"Deleted persona {persona_name} from the database.") - async def check_persona_match(self, persona_name: str, query: str) -> bool: + async def _get_cosine_distance(self, emb_queries: np.ndarray, emb_persona: np.ndarray) -> float: """ - Check if the query matches the persona description. A vector similarity - search is performed between the query and the persona description. + Calculate the cosine distance between the queries embeddings and persona embedding. + Persona embedding is a single vector of length M + Queries embeddings is a matrix of shape (N, M) + N is the number of queries. User messages in this case. + M is the number of dimensions in the embedding + + Defintion of cosine distance: 1 - cosine similarity + [Cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) + + NOTE: Experimented by individually querying SQLite for each query, but as the number + of queries increases, the performance is better with NumPy. If the number of queries + is small the performance is onpar. Hence the decision to use NumPy. + """ + # Handle the case where we have a single query (single user message) + if emb_queries.ndim == 1: + emb_queries = emb_queries.reshape(1, -1) + + emb_queries_norm = np.linalg.norm(emb_queries, axis=1) + persona_embed_norm = np.linalg.norm(emb_persona) + cosine_similarities = np.dot(emb_queries, emb_persona.T) / ( + emb_queries_norm * persona_embed_norm + ) + # We could also use directly cosine_similarities but we get the distance to match + # the behavior of SQLite function vec_distance_cosine + cosine_distances = 1 - cosine_similarities + return cosine_distances + + async def _weight_distances(self, distances: np.ndarray) -> np.ndarray: + """ + Weights the received distances, with later positions being more important and the + last position unchanged. The reasoning is that the distances correspond to user + messages, with the last message being the most recent and therefore the most + important. + + Args: + distances: NumPy array of float values between 0 and 2 + weight_factor: Factor that determines how quickly weights increase (0-1) + Lower values create a steeper importance curve. 1 makes + all weights equal. + + Returns: + Weighted distances as a NumPy array + """ + # Get array length + n = len(distances) + + # Create positions array in reverse order (n-1, n-2, ..., 1, 0) + # This makes the last element have position 0 + positions = np.arange(n - 1, -1, -1) + + # Create weights - now the last element (position 0) gets weight 1 + weights = self._distances_weight_factor**positions + + # Apply weights by dividing distances + # Smaller weight -> larger effective distance + weighted_distances = distances / weights + return weighted_distances + + async def check_persona_match(self, persona_name: str, queries: List[str]) -> bool: + """ + Check if the queries match the persona description. A vector similarity + search is performed between the queries and the persona description. 0 means the vectors are identical, 2 means they are orthogonal. - See - [sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine) + + The vectors are compared using cosine similarity implemented in _get_cosine_distance. """ - persona = await self._db_reader.get_persona_by_name(persona_name) - if not persona: + persona_embed = await self._db_reader.get_persona_embed_by_name(persona_name) + if not persona_embed: raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.") - emb_query = await self._embed_text(query) - persona_distance = await self._db_reader.get_distance_to_persona(persona.id, emb_query) - logger.info(f"Persona distance to {persona_name}", distance=persona_distance.distance) - if persona_distance.distance < self._persona_threshold: + emb_queries = await self._embed_texts(queries) + cosine_distances = await self._get_cosine_distance( + emb_queries, persona_embed.description_embedding + ) + logger.debug("Cosine distances calculated", cosine_distances=cosine_distances) + + weighted_distances = await self._weight_distances(cosine_distances) + logger.info("Weighted distances to persona", weighted_distances=weighted_distances) + + if np.any(weighted_distances < self._persona_threshold): return True return False diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index d41eb2ce..0787553f 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -11,6 +11,7 @@ from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError from codegate.extract_snippets.factory import BodyCodeExtractorFactory from codegate.muxing import models as mux_models +from codegate.muxing.persona import PersonaManager logger = structlog.get_logger("codegate") @@ -82,6 +83,7 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher, mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher, mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher, + mux_models.MuxMatcherType.persona_description: PersonaDescriptionMuxingRuleMatcher, } try: @@ -171,6 +173,48 @@ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: return is_rule_matched +class PersonaDescriptionMuxingRuleMatcher(MuxingRuleMatcher): + """Muxing rule to match the request content to a persona description.""" + + def _get_user_messages_from_body(self, body: Dict) -> List[str]: + """ + Get the user messages from the body to use as queries. + """ + user_messages = [] + for msg in body.get("messages", []): + if msg.get("role", "") == "user": + msgs_content = msg.get("content") + if not msgs_content: + continue + if isinstance(msgs_content, list): + for msg_content in msgs_content: + if msg_content.get("type", "") == "text": + user_messages.append(msg_content.get("text", "")) + elif isinstance(msgs_content, str): + user_messages.append(msgs_content) + return user_messages + + def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + """ + Return True if the matcher is the persona description matched with the + user messages. + """ + user_messages = self._get_user_messages_from_body(thing_to_match.body) + if not user_messages: + return False + + persona_manager = PersonaManager() + is_persona_matched = persona_manager.check_persona_match( + persona_name=self._mux_rule.matcher, queries=user_messages + ) + logger.info( + "Persona rule matched", + matcher=self._mux_rule.matcher, + is_persona_matched=is_persona_matched, + ) + return is_persona_matched + + class MuxingRulesinWorkspaces: """A thread safe dictionary to store the muxing rules in workspaces.""" diff --git a/tests/muxing/test_persona.py b/tests/muxing/test_persona.py index fd0003c9..2cfbbe3c 100644 --- a/tests/muxing/test_persona.py +++ b/tests/muxing/test_persona.py @@ -90,7 +90,7 @@ async def test_persona_not_exist_match(semantic_router_mocked_db: PersonaManager persona_name = "test_persona" query = "test_query" with pytest.raises(PersonaDoesNotExistError): - await semantic_router_mocked_db.check_persona_match(persona_name, query) + await semantic_router_mocked_db.check_persona_match(persona_name, [query]) class PersonaMatchTest(BaseModel): @@ -333,11 +333,39 @@ async def test_check_persona_pass_match( # Check for the queries that should pass for query in persona_match_test.pass_queries: match = await semantic_router_mocked_db.check_persona_match( - persona_match_test.persona_name, query + persona_match_test.persona_name, [query] ) assert match is True +@pytest.mark.asyncio +@pytest.mark.parametrize( + "persona_match_test", + [ + simple_persona, + architect, + coder, + devops_sre, + ], +) +async def test_check_persona_pass_match_vector( + semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest +): + """Test checking persona match.""" + await semantic_router_mocked_db.add_persona( + persona_match_test.persona_name, persona_match_test.persona_desc + ) + + # We disable the weighting between distances since these are no user messages that + # need to be weighted differently, they all are weighted the same. + semantic_router_mocked_db._distances_weight_factor = 1.0 + # Check for match passing the entire list + match = await semantic_router_mocked_db.check_persona_match( + persona_match_test.persona_name, persona_match_test.pass_queries + ) + assert match is True + + @pytest.mark.asyncio @pytest.mark.parametrize( "persona_match_test", @@ -359,11 +387,39 @@ async def test_check_persona_fail_match( # Check for the queries that should fail for query in persona_match_test.fail_queries: match = await semantic_router_mocked_db.check_persona_match( - persona_match_test.persona_name, query + persona_match_test.persona_name, [query] ) assert match is False +@pytest.mark.asyncio +@pytest.mark.parametrize( + "persona_match_test", + [ + simple_persona, + architect, + coder, + devops_sre, + ], +) +async def test_check_persona_fail_match_vector( + semantic_router_mocked_db: PersonaManager, persona_match_test: PersonaMatchTest +): + """Test checking persona match.""" + await semantic_router_mocked_db.add_persona( + persona_match_test.persona_name, persona_match_test.persona_desc + ) + + # We disable the weighting between distances since these are no user messages that + # need to be weighted differently, they all are weighted the same. + semantic_router_mocked_db._distances_weight_factor = 1.0 + # Check for match passing the entire list + match = await semantic_router_mocked_db.check_persona_match( + persona_match_test.persona_name, persona_match_test.fail_queries + ) + assert match is False + + @pytest.mark.asyncio @pytest.mark.parametrize( "personas", From 6f6fe6201fc222c2d81c84c5c1ab8b41cd39c412 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Fri, 7 Mar 2025 10:57:09 +0200 Subject: [PATCH 2/5] convert muxing rule match to coroutines --- src/codegate/muxing/rulematcher.py | 21 +++++++++------------ tests/muxing/test_rulematcher.py | 19 +++++++++++++------ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index 0787553f..adacf68e 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -61,7 +61,7 @@ def __init__(self, route: ModelRoute, mux_rule: mux_models.MuxRule): self._mux_rule = mux_rule @abstractmethod - def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: """Return True if the rule matches the thing_to_match.""" pass @@ -97,7 +97,7 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch class CatchAllMuxingRuleMatcher(MuxingRuleMatcher): """A catch all muxing rule matcher.""" - def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: logger.info("Catch all rule matched") return True @@ -132,7 +132,7 @@ def _is_matcher_in_filenames(self, detected_client: ClientType, data: dict) -> b ) return is_filename_match - def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: """ Return True if the matcher is in one of the request filenames. """ @@ -156,7 +156,7 @@ def _is_request_type_match(self, is_fim_request: bool) -> bool: return True return False - def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: """ Return True if the matcher is in one of the request filenames and if the request type matches the MuxMatcherType. @@ -194,7 +194,7 @@ def _get_user_messages_from_body(self, body: Dict) -> List[str]: user_messages.append(msgs_content) return user_messages - def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: """ Return True if the matcher is the persona description matched with the user messages. @@ -204,14 +204,11 @@ def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: return False persona_manager = PersonaManager() - is_persona_matched = persona_manager.check_persona_match( + is_persona_matched = await persona_manager.check_persona_match( persona_name=self._mux_rule.matcher, queries=user_messages ) - logger.info( - "Persona rule matched", - matcher=self._mux_rule.matcher, - is_persona_matched=is_persona_matched, - ) + if is_persona_matched: + logger.info("Persona rule matched", persona=self._mux_rule.matcher) return is_persona_matched @@ -258,7 +255,7 @@ async def get_match_for_active_workspace( try: rules = await self.get_ws_rules(self._active_workspace) for rule in rules: - if rule.match(thing_to_match): + if await rule.match(thing_to_match): return rule.destination() return None except KeyError: diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py index 7e551525..ad4c420a 100644 --- a/tests/muxing/test_rulematcher.py +++ b/tests/muxing/test_rulematcher.py @@ -24,6 +24,7 @@ ) +@pytest.mark.asyncio @pytest.mark.parametrize( "matcher_blob, thing_to_match", [ @@ -40,12 +41,13 @@ ), ], ) -def test_catch_all(matcher_blob, thing_to_match): +async def test_catch_all(matcher_blob, thing_to_match): muxing_rule_matcher = rulematcher.CatchAllMuxingRuleMatcher(mocked_route_openai, matcher_blob) # It should always match - assert muxing_rule_matcher.match(thing_to_match) is True + assert await muxing_rule_matcher.match(thing_to_match) is True +@pytest.mark.asyncio @pytest.mark.parametrize( "matcher, filenames_to_match, expected_bool", [ @@ -60,7 +62,7 @@ def test_catch_all(matcher_blob, thing_to_match): ("*.ts", ["main.tsx", "test.tsx"], False), # Extension no match ], ) -def test_file_matcher( +async def test_file_matcher( matcher, filenames_to_match, expected_bool, @@ -81,9 +83,10 @@ def test_file_matcher( is_fim_request=False, client_type="generic", ) - assert muxing_rule_matcher.match(mocked_thing_to_match) is expected_bool + assert await muxing_rule_matcher.match(mocked_thing_to_match) is expected_bool +@pytest.mark.asyncio @pytest.mark.parametrize( "matcher, filenames_to_match, expected_bool_filenames", [ @@ -107,7 +110,7 @@ def test_file_matcher( (True, "chat_filename", False), # No match ], ) -def test_request_file_matcher( +async def test_request_file_matcher( matcher, filenames_to_match, expected_bool_filenames, @@ -143,7 +146,7 @@ def test_request_file_matcher( ) is expected_bool_filenames ) - assert muxing_rule_matcher.match(mocked_thing_to_match) is ( + assert await muxing_rule_matcher.match(mocked_thing_to_match) is ( expected_bool_request and expected_bool_filenames ) @@ -155,6 +158,10 @@ def test_request_file_matcher( (mux_models.MuxMatcherType.filename_match, rulematcher.FileMuxingRuleMatcher), (mux_models.MuxMatcherType.fim_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher), (mux_models.MuxMatcherType.chat_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher), + ( + mux_models.MuxMatcherType.persona_description, + rulematcher.PersonaDescriptionMuxingRuleMatcher, + ), ("invalid_matcher", None), ], ) From 6c46ef74d0af1a5668bed6619d19aa7d78e2e91c Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Fri, 7 Mar 2025 11:00:31 +0200 Subject: [PATCH 3/5] use named parameters in numpy arange function --- src/codegate/muxing/persona.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegate/muxing/persona.py b/src/codegate/muxing/persona.py index 4620fb45..46d88288 100644 --- a/src/codegate/muxing/persona.py +++ b/src/codegate/muxing/persona.py @@ -268,7 +268,7 @@ async def _weight_distances(self, distances: np.ndarray) -> np.ndarray: # Create positions array in reverse order (n-1, n-2, ..., 1, 0) # This makes the last element have position 0 - positions = np.arange(n - 1, -1, -1) + positions = np.arange(start=n - 1, stop=-1, step=-1, dtype=np.float32) # Create weights - now the last element (position 0) gets weight 1 weights = self._distances_weight_factor**positions From 714a63066da82cf761d1eb1ed9862b2f844fbb05 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Fri, 7 Mar 2025 11:10:20 +0200 Subject: [PATCH 4/5] fixed Pydantic warning --- src/codegate/db/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 07c4c8ed..04171b88 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -276,7 +276,7 @@ def nd_array_custom_serializer(x): NdArray = Annotated[ np.ndarray, BeforeValidator(nd_array_custom_before_validator), - PlainSerializer(nd_array_custom_serializer, return_type=str), + PlainSerializer(nd_array_custom_serializer, return_type=np.ndarray), ] VALID_PERSONA_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_ -]+$") From 4eb38f68b300a2ceb636fbe1f48d40baa104f263 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Fri, 7 Mar 2025 12:11:43 +0200 Subject: [PATCH 5/5] add a persona matcher to the system prompt and unit tests --- src/codegate/muxing/models.py | 4 +- src/codegate/muxing/rulematcher.py | 77 ++++++++++++++++----- tests/muxing/test_rulematcher.py | 103 ++++++++++++++++++++++++++++- 3 files changed, 164 insertions(+), 20 deletions(-) diff --git a/src/codegate/muxing/models.py b/src/codegate/muxing/models.py index 2f070265..d923986c 100644 --- a/src/codegate/muxing/models.py +++ b/src/codegate/muxing/models.py @@ -32,8 +32,10 @@ class MuxMatcherType(str, Enum): fim_filename = "fim_filename" # Match based on chat request type. It will match if the request type is chat chat_filename = "chat_filename" - # Match the request content to the persona description + # Match the user messages to the persona description persona_description = "persona_description" + # Match the system prompt to the persona description + sys_prompt_persona_desc = "sys_prompt_persona_desc" class MuxRule(pydantic.BaseModel): diff --git a/src/codegate/muxing/rulematcher.py b/src/codegate/muxing/rulematcher.py index adacf68e..80e7daa4 100644 --- a/src/codegate/muxing/rulematcher.py +++ b/src/codegate/muxing/rulematcher.py @@ -83,7 +83,8 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher, mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher, mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher, - mux_models.MuxMatcherType.persona_description: PersonaDescriptionMuxingRuleMatcher, + mux_models.MuxMatcherType.persona_description: UserMsgsPersonaDescMuxMatcher, + mux_models.MuxMatcherType.sys_prompt_persona_desc: SysPromptPersonaDescMuxMatcher, } try: @@ -173,12 +174,42 @@ async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: return is_rule_matched -class PersonaDescriptionMuxingRuleMatcher(MuxingRuleMatcher): +class PersonaDescMuxMatcher(MuxingRuleMatcher): """Muxing rule to match the request content to a persona description.""" - def _get_user_messages_from_body(self, body: Dict) -> List[str]: + @abstractmethod + def _get_queries_for_persona_match(self, body: Dict) -> List[str]: + """ + Get the queries to use for persona matching. + """ + pass + + async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + """ + Return True if the matcher is the persona description matched with the queries. + + The queries are extracted from the body and will depend on the type of matcher. + 1. UserMessagesPersonaDescMuxMatcher: Extracts queries from the user messages in the body. + 2. SysPromptPersonaDescMuxMatcher: Extracts queries from the system messages in the body. + """ + queries = self._get_queries_for_persona_match(thing_to_match.body) + if not queries: + return False + + persona_manager = PersonaManager() + is_persona_matched = await persona_manager.check_persona_match( + persona_name=self._mux_rule.matcher, queries=queries + ) + if is_persona_matched: + logger.info("Persona rule matched", persona=self._mux_rule.matcher) + return is_persona_matched + + +class UserMsgsPersonaDescMuxMatcher(PersonaDescMuxMatcher): + + def _get_queries_for_persona_match(self, body: Dict) -> List[str]: """ - Get the user messages from the body to use as queries. + Get the queries from the user messages in the body. """ user_messages = [] for msg in body.get("messages", []): @@ -194,22 +225,34 @@ def _get_user_messages_from_body(self, body: Dict) -> List[str]: user_messages.append(msgs_content) return user_messages - async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool: + +class SysPromptPersonaDescMuxMatcher(PersonaDescMuxMatcher): + + def _get_queries_for_persona_match(self, body: Dict) -> List[str]: """ - Return True if the matcher is the persona description matched with the - user messages. + Get the queries from the system messages in the body. """ - user_messages = self._get_user_messages_from_body(thing_to_match.body) - if not user_messages: - return False + system_messages = [] + for msg in body.get("messages", []): + if msg.get("role", "") in ["system", "developer"]: + msgs_content = msg.get("content") + if not msgs_content: + continue + if isinstance(msgs_content, list): + for msg_content in msgs_content: + if msg_content.get("type", "") == "text": + system_messages.append(msg_content.get("text", "")) + elif isinstance(msgs_content, str): + system_messages.append(msgs_content) - persona_manager = PersonaManager() - is_persona_matched = await persona_manager.check_persona_match( - persona_name=self._mux_rule.matcher, queries=user_messages - ) - if is_persona_matched: - logger.info("Persona rule matched", persona=self._mux_rule.matcher) - return is_persona_matched + # Handling the anthropic system prompt + anthropic_sys_prompt = body.get("system") + if anthropic_sys_prompt: + system_messages.append(anthropic_sys_prompt) + + # In an ideal world, the length of system_messages should be 1. Returnin the list + # to handle any edge cases and to not break parent function's signature. + return system_messages class MuxingRulesinWorkspaces: diff --git a/tests/muxing/test_rulematcher.py b/tests/muxing/test_rulematcher.py index ad4c420a..981e3885 100644 --- a/tests/muxing/test_rulematcher.py +++ b/tests/muxing/test_rulematcher.py @@ -1,4 +1,5 @@ -from unittest.mock import MagicMock +from typing import Dict, List +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -151,6 +152,100 @@ async def test_request_file_matcher( ) +# We mock PersonaManager because it's tested in /tests/persona/test_manager.py +MOCK_PERSONA_MANAGER = AsyncMock() +MOCK_PERSONA_MANAGER.check_persona_match.return_value = True + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "body, expected_queries", + [ + ({"messages": [{"role": "system", "content": "Youre helpful"}]}, []), + ({"messages": [{"role": "user", "content": "hello"}]}, ["hello"]), + ( + {"messages": [{"role": "user", "content": [{"type": "text", "text": "hello_dict"}]}]}, + ["hello_dict"], + ), + ], +) +async def test_user_msgs_persona_desc_matcher(body: Dict, expected_queries: List[str]): + mux_rule = mux_models.MuxRule( + provider_id="1", + model="fake-gpt", + matcher_type="persona_description", + matcher="foo_persona", + ) + muxing_rule_matcher = rulematcher.UserMsgsPersonaDescMuxMatcher(mocked_route_openai, mux_rule) + + mocked_thing_to_match = mux_models.ThingToMatchMux( + body=body, + url_request_path="/chat/completions", + is_fim_request=False, + client_type="generic", + ) + + resulting_queries = muxing_rule_matcher._get_queries_for_persona_match(body) + assert set(resulting_queries) == set(expected_queries) + + with patch("codegate.muxing.rulematcher.PersonaManager", return_value=MOCK_PERSONA_MANAGER): + result = await muxing_rule_matcher.match(mocked_thing_to_match) + + if expected_queries: + assert result is True + else: + assert result is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "body, expected_queries", + [ + ({"messages": [{"role": "system", "content": "Youre helpful"}]}, ["Youre helpful"]), + ({"messages": [{"role": "user", "content": "hello"}]}, []), + ( + { + "messages": [ + {"role": "system", "content": "Youre helpful"}, + {"role": "user", "content": "hello"}, + ] + }, + ["Youre helpful"], + ), + ( + {"messages": [{"role": "user", "content": "hello"}], "system": "Anthropic system"}, + ["Anthropic system"], + ), + ], +) +async def test_sys_prompt_persona_desc_matcher(body: Dict, expected_queries: List[str]): + mux_rule = mux_models.MuxRule( + provider_id="1", + model="fake-gpt", + matcher_type="sys_prompt_persona_desc", + matcher="foo_persona", + ) + muxing_rule_matcher = rulematcher.SysPromptPersonaDescMuxMatcher(mocked_route_openai, mux_rule) + + mocked_thing_to_match = mux_models.ThingToMatchMux( + body=body, + url_request_path="/chat/completions", + is_fim_request=False, + client_type="generic", + ) + + resulting_queries = muxing_rule_matcher._get_queries_for_persona_match(body) + assert set(resulting_queries) == set(expected_queries) + + with patch("codegate.muxing.rulematcher.PersonaManager", return_value=MOCK_PERSONA_MANAGER): + result = await muxing_rule_matcher.match(mocked_thing_to_match) + + if expected_queries: + assert result is True + else: + assert result is False + + @pytest.mark.parametrize( "matcher_type, expected_class", [ @@ -160,7 +255,11 @@ async def test_request_file_matcher( (mux_models.MuxMatcherType.chat_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher), ( mux_models.MuxMatcherType.persona_description, - rulematcher.PersonaDescriptionMuxingRuleMatcher, + rulematcher.UserMsgsPersonaDescMuxMatcher, + ), + ( + mux_models.MuxMatcherType.sys_prompt_persona_desc, + rulematcher.SysPromptPersonaDescMuxMatcher, ), ("invalid_matcher", None), ],