1
- from unittest .mock import MagicMock
1
+ from typing import Dict , List
2
+ from unittest .mock import AsyncMock , MagicMock , patch
2
3
3
4
import pytest
4
5
@@ -151,6 +152,100 @@ async def test_request_file_matcher(
151
152
)
152
153
153
154
155
+ # We mock PersonaManager because it's tested in /tests/persona/test_manager.py
156
+ MOCK_PERSONA_MANAGER = AsyncMock ()
157
+ MOCK_PERSONA_MANAGER .check_persona_match .return_value = True
158
+
159
+
160
+ @pytest .mark .asyncio
161
+ @pytest .mark .parametrize (
162
+ "body, expected_queries" ,
163
+ [
164
+ ({"messages" : [{"role" : "system" , "content" : "Youre helpful" }]}, []),
165
+ ({"messages" : [{"role" : "user" , "content" : "hello" }]}, ["hello" ]),
166
+ (
167
+ {"messages" : [{"role" : "user" , "content" : [{"type" : "text" , "text" : "hello_dict" }]}]},
168
+ ["hello_dict" ],
169
+ ),
170
+ ],
171
+ )
172
+ async def test_user_msgs_persona_desc_matcher (body : Dict , expected_queries : List [str ]):
173
+ mux_rule = mux_models .MuxRule (
174
+ provider_id = "1" ,
175
+ model = "fake-gpt" ,
176
+ matcher_type = "user_messages_persona_desc" ,
177
+ matcher = "foo_persona" ,
178
+ )
179
+ muxing_rule_matcher = rulematcher .UserMsgsPersonaDescMuxMatcher (mocked_route_openai , mux_rule )
180
+
181
+ mocked_thing_to_match = mux_models .ThingToMatchMux (
182
+ body = body ,
183
+ url_request_path = "/chat/completions" ,
184
+ is_fim_request = False ,
185
+ client_type = "generic" ,
186
+ )
187
+
188
+ resulting_queries = muxing_rule_matcher ._get_queries_for_persona_match (body )
189
+ assert set (resulting_queries ) == set (expected_queries )
190
+
191
+ with patch ("codegate.muxing.rulematcher.PersonaManager" , return_value = MOCK_PERSONA_MANAGER ):
192
+ result = await muxing_rule_matcher .match (mocked_thing_to_match )
193
+
194
+ if expected_queries :
195
+ assert result is True
196
+ else :
197
+ assert result is False
198
+
199
+
200
+ @pytest .mark .asyncio
201
+ @pytest .mark .parametrize (
202
+ "body, expected_queries" ,
203
+ [
204
+ ({"messages" : [{"role" : "system" , "content" : "Youre helpful" }]}, ["Youre helpful" ]),
205
+ ({"messages" : [{"role" : "user" , "content" : "hello" }]}, []),
206
+ (
207
+ {
208
+ "messages" : [
209
+ {"role" : "system" , "content" : "Youre helpful" },
210
+ {"role" : "user" , "content" : "hello" },
211
+ ]
212
+ },
213
+ ["Youre helpful" ],
214
+ ),
215
+ (
216
+ {"messages" : [{"role" : "user" , "content" : "hello" }], "system" : "Anthropic system" },
217
+ ["Anthropic system" ],
218
+ ),
219
+ ],
220
+ )
221
+ async def test_sys_prompt_persona_desc_matcher (body : Dict , expected_queries : List [str ]):
222
+ mux_rule = mux_models .MuxRule (
223
+ provider_id = "1" ,
224
+ model = "fake-gpt" ,
225
+ matcher_type = "sys_prompt_persona_desc" ,
226
+ matcher = "foo_persona" ,
227
+ )
228
+ muxing_rule_matcher = rulematcher .SysPromptPersonaDescMuxMatcher (mocked_route_openai , mux_rule )
229
+
230
+ mocked_thing_to_match = mux_models .ThingToMatchMux (
231
+ body = body ,
232
+ url_request_path = "/chat/completions" ,
233
+ is_fim_request = False ,
234
+ client_type = "generic" ,
235
+ )
236
+
237
+ resulting_queries = muxing_rule_matcher ._get_queries_for_persona_match (body )
238
+ assert set (resulting_queries ) == set (expected_queries )
239
+
240
+ with patch ("codegate.muxing.rulematcher.PersonaManager" , return_value = MOCK_PERSONA_MANAGER ):
241
+ result = await muxing_rule_matcher .match (mocked_thing_to_match )
242
+
243
+ if expected_queries :
244
+ assert result is True
245
+ else :
246
+ assert result is False
247
+
248
+
154
249
@pytest .mark .parametrize (
155
250
"matcher_type, expected_class" ,
156
251
[
@@ -159,8 +254,12 @@ async def test_request_file_matcher(
159
254
(mux_models .MuxMatcherType .fim_filename , rulematcher .RequestTypeAndFileMuxingRuleMatcher ),
160
255
(mux_models .MuxMatcherType .chat_filename , rulematcher .RequestTypeAndFileMuxingRuleMatcher ),
161
256
(
162
- mux_models .MuxMatcherType .persona_description ,
163
- rulematcher .PersonaDescriptionMuxingRuleMatcher ,
257
+ mux_models .MuxMatcherType .user_messages_persona_desc ,
258
+ rulematcher .UserMsgsPersonaDescMuxMatcher ,
259
+ ),
260
+ (
261
+ mux_models .MuxMatcherType .sys_prompt_persona_desc ,
262
+ rulematcher .SysPromptPersonaDescMuxMatcher ,
164
263
),
165
264
("invalid_matcher" , None ),
166
265
],
0 commit comments