Skip to content

Commit ffd932c

Browse files
add a persona matcher to the system prompt and unit tests
1 parent 714a630 commit ffd932c

File tree

3 files changed

+165
-21
lines changed

3 files changed

+165
-21
lines changed

src/codegate/muxing/models.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ class MuxMatcherType(str, Enum):
3232
fim_filename = "fim_filename"
3333
# Match based on chat request type. It will match if the request type is chat
3434
chat_filename = "chat_filename"
35-
# Match the request content to the persona description
35+
# Match the user messages to the persona description
3636
persona_description = "persona_description"
37+
# Match the system prompt to the persona description
38+
sys_prompt_persona_desc = "sys_prompt_persona_desc"
3739

3840

3941
class MuxRule(pydantic.BaseModel):

src/codegate/muxing/rulematcher.py

+60-17
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def create(db_mux_rule: db_models.MuxRule, route: ModelRoute) -> MuxingRuleMatch
8383
mux_models.MuxMatcherType.filename_match: FileMuxingRuleMatcher,
8484
mux_models.MuxMatcherType.fim_filename: RequestTypeAndFileMuxingRuleMatcher,
8585
mux_models.MuxMatcherType.chat_filename: RequestTypeAndFileMuxingRuleMatcher,
86-
mux_models.MuxMatcherType.persona_description: PersonaDescriptionMuxingRuleMatcher,
86+
mux_models.MuxMatcherType.persona_description: UserMsgsPersonaDescMuxMatcher,
87+
mux_models.MuxMatcherType.sys_prompt_persona_desc: SysPromptPersonaDescMuxMatcher,
8788
}
8889

8990
try:
@@ -173,12 +174,42 @@ async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
173174
return is_rule_matched
174175

175176

176-
class PersonaDescriptionMuxingRuleMatcher(MuxingRuleMatcher):
177+
class PersonaDescMuxMatcher(MuxingRuleMatcher):
177178
"""Muxing rule to match the request content to a persona description."""
178179

179-
def _get_user_messages_from_body(self, body: Dict) -> List[str]:
180+
@abstractmethod
181+
def _get_queries_for_persona_match(self, body: Dict) -> List[str]:
182+
"""
183+
Get the queries to use for persona matching.
184+
"""
185+
pass
186+
187+
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
188+
"""
189+
Return True if the matcher is the persona description matched with the queries.
190+
191+
The queries are extracted from the body and will depend on the type of matcher.
192+
1. UserMessagesPersonaDescMuxMatcher: Extracts queries from the user messages in the body.
193+
2. SysPromptPersonaDescMuxMatcher: Extracts queries from the system messages in the body.
194+
"""
195+
queries = self._get_queries_for_persona_match(thing_to_match.body)
196+
if not queries:
197+
return False
198+
199+
persona_manager = PersonaManager()
200+
is_persona_matched = await persona_manager.check_persona_match(
201+
persona_name=self._mux_rule.matcher, queries=queries
202+
)
203+
if is_persona_matched:
204+
logger.info("Persona rule matched", persona=self._mux_rule.matcher)
205+
return is_persona_matched
206+
207+
208+
class UserMsgsPersonaDescMuxMatcher(PersonaDescMuxMatcher):
209+
210+
def _get_queries_for_persona_match(self, body: Dict) -> List[str]:
180211
"""
181-
Get the user messages from the body to use as queries.
212+
Get the queries from the user messages in the body.
182213
"""
183214
user_messages = []
184215
for msg in body.get("messages", []):
@@ -194,22 +225,34 @@ def _get_user_messages_from_body(self, body: Dict) -> List[str]:
194225
user_messages.append(msgs_content)
195226
return user_messages
196227

197-
async def match(self, thing_to_match: mux_models.ThingToMatchMux) -> bool:
228+
229+
class SysPromptPersonaDescMuxMatcher(PersonaDescMuxMatcher):
230+
231+
def _get_queries_for_persona_match(self, body: Dict) -> List[str]:
198232
"""
199-
Return True if the matcher is the persona description matched with the
200-
user messages.
233+
Get the queries from the system messages in the body.
201234
"""
202-
user_messages = self._get_user_messages_from_body(thing_to_match.body)
203-
if not user_messages:
204-
return False
235+
system_messages = []
236+
for msg in body.get("messages", []):
237+
if msg.get("role", "") in ["system", "developer"]:
238+
msgs_content = msg.get("content")
239+
if not msgs_content:
240+
continue
241+
if isinstance(msgs_content, list):
242+
for msg_content in msgs_content:
243+
if msg_content.get("type", "") == "text":
244+
system_messages.append(msg_content.get("text", ""))
245+
elif isinstance(msgs_content, str):
246+
system_messages.append(msgs_content)
205247

206-
persona_manager = PersonaManager()
207-
is_persona_matched = await persona_manager.check_persona_match(
208-
persona_name=self._mux_rule.matcher, queries=user_messages
209-
)
210-
if is_persona_matched:
211-
logger.info("Persona rule matched", persona=self._mux_rule.matcher)
212-
return is_persona_matched
248+
# Handling the anthropic system prompt
249+
anthropic_sys_prompt = body.get("system")
250+
if anthropic_sys_prompt:
251+
system_messages.append(anthropic_sys_prompt)
252+
253+
# In an ideal world, the length of system_messages should be 1. Returnin the list
254+
# to handle any edge cases and to not break parent function's signature.
255+
return system_messages
213256

214257

215258
class MuxingRulesinWorkspaces:

tests/muxing/test_rulematcher.py

+102-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from unittest.mock import MagicMock
1+
from typing import Dict, List
2+
from unittest.mock import AsyncMock, MagicMock, patch
23

34
import pytest
45

@@ -151,6 +152,100 @@ async def test_request_file_matcher(
151152
)
152153

153154

155+
# We mock PersonaManager because it's tested in /tests/persona/test_manager.py
156+
MOCK_PERSONA_MANAGER = AsyncMock()
157+
MOCK_PERSONA_MANAGER.check_persona_match.return_value = True
158+
159+
160+
@pytest.mark.asyncio
161+
@pytest.mark.parametrize(
162+
"body, expected_queries",
163+
[
164+
({"messages": [{"role": "system", "content": "Youre helpful"}]}, []),
165+
({"messages": [{"role": "user", "content": "hello"}]}, ["hello"]),
166+
(
167+
{"messages": [{"role": "user", "content": [{"type": "text", "text": "hello_dict"}]}]},
168+
["hello_dict"],
169+
),
170+
],
171+
)
172+
async def test_user_msgs_persona_desc_matcher(body: Dict, expected_queries: List[str]):
173+
mux_rule = mux_models.MuxRule(
174+
provider_id="1",
175+
model="fake-gpt",
176+
matcher_type="user_messages_persona_desc",
177+
matcher="foo_persona",
178+
)
179+
muxing_rule_matcher = rulematcher.UserMsgsPersonaDescMuxMatcher(mocked_route_openai, mux_rule)
180+
181+
mocked_thing_to_match = mux_models.ThingToMatchMux(
182+
body=body,
183+
url_request_path="/chat/completions",
184+
is_fim_request=False,
185+
client_type="generic",
186+
)
187+
188+
resulting_queries = muxing_rule_matcher._get_queries_for_persona_match(body)
189+
assert set(resulting_queries) == set(expected_queries)
190+
191+
with patch("codegate.muxing.rulematcher.PersonaManager", return_value=MOCK_PERSONA_MANAGER):
192+
result = await muxing_rule_matcher.match(mocked_thing_to_match)
193+
194+
if expected_queries:
195+
assert result is True
196+
else:
197+
assert result is False
198+
199+
200+
@pytest.mark.asyncio
201+
@pytest.mark.parametrize(
202+
"body, expected_queries",
203+
[
204+
({"messages": [{"role": "system", "content": "Youre helpful"}]}, ["Youre helpful"]),
205+
({"messages": [{"role": "user", "content": "hello"}]}, []),
206+
(
207+
{
208+
"messages": [
209+
{"role": "system", "content": "Youre helpful"},
210+
{"role": "user", "content": "hello"},
211+
]
212+
},
213+
["Youre helpful"],
214+
),
215+
(
216+
{"messages": [{"role": "user", "content": "hello"}], "system": "Anthropic system"},
217+
["Anthropic system"],
218+
),
219+
],
220+
)
221+
async def test_sys_prompt_persona_desc_matcher(body: Dict, expected_queries: List[str]):
222+
mux_rule = mux_models.MuxRule(
223+
provider_id="1",
224+
model="fake-gpt",
225+
matcher_type="sys_prompt_persona_desc",
226+
matcher="foo_persona",
227+
)
228+
muxing_rule_matcher = rulematcher.SysPromptPersonaDescMuxMatcher(mocked_route_openai, mux_rule)
229+
230+
mocked_thing_to_match = mux_models.ThingToMatchMux(
231+
body=body,
232+
url_request_path="/chat/completions",
233+
is_fim_request=False,
234+
client_type="generic",
235+
)
236+
237+
resulting_queries = muxing_rule_matcher._get_queries_for_persona_match(body)
238+
assert set(resulting_queries) == set(expected_queries)
239+
240+
with patch("codegate.muxing.rulematcher.PersonaManager", return_value=MOCK_PERSONA_MANAGER):
241+
result = await muxing_rule_matcher.match(mocked_thing_to_match)
242+
243+
if expected_queries:
244+
assert result is True
245+
else:
246+
assert result is False
247+
248+
154249
@pytest.mark.parametrize(
155250
"matcher_type, expected_class",
156251
[
@@ -159,8 +254,12 @@ async def test_request_file_matcher(
159254
(mux_models.MuxMatcherType.fim_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher),
160255
(mux_models.MuxMatcherType.chat_filename, rulematcher.RequestTypeAndFileMuxingRuleMatcher),
161256
(
162-
mux_models.MuxMatcherType.persona_description,
163-
rulematcher.PersonaDescriptionMuxingRuleMatcher,
257+
mux_models.MuxMatcherType.user_messages_persona_desc,
258+
rulematcher.UserMsgsPersonaDescMuxMatcher,
259+
),
260+
(
261+
mux_models.MuxMatcherType.sys_prompt_persona_desc,
262+
rulematcher.SysPromptPersonaDescMuxMatcher,
164263
),
165264
("invalid_matcher", None),
166265
],

0 commit comments

Comments
 (0)