@@ -43,6 +43,7 @@ def __init__(self):
43
43
self ._n_gpu = conf .chat_model_n_gpu_layers
44
44
self ._persona_threshold = conf .persona_threshold
45
45
self ._persona_diff_desc_threshold = conf .persona_diff_desc_threshold
46
+ self ._distances_weight_factor = conf .distances_weight_factor
46
47
self ._db_recorder = DbRecorder ()
47
48
self ._db_reader = DbReader ()
48
49
@@ -99,18 +100,17 @@ def _clean_text_for_embedding(self, text: str) -> str:
99
100
100
101
return text
101
102
102
- async def _embed_text (self , text : str ) -> np .ndarray :
103
+ async def _embed_texts (self , texts : List [ str ] ) -> np .ndarray :
103
104
"""
104
105
Helper function to embed text using the inference engine.
105
106
"""
106
- cleaned_text = self ._clean_text_for_embedding (text )
107
+ cleaned_texts = [ self ._clean_text_for_embedding (text ) for text in texts ]
107
108
# .embed returns a list of embeddings
108
109
embed_list = await self ._inference_engine .embed (
109
- self ._embeddings_model , [ cleaned_text ] , n_gpu_layers = self ._n_gpu
110
+ self ._embeddings_model , cleaned_texts , n_gpu_layers = self ._n_gpu
110
111
)
111
- # Use only the first entry in the list and make sure we have the appropriate type
112
- logger .debug ("Text embedded in semantic routing" , text = cleaned_text [:50 ])
113
- return np .array (embed_list [0 ], dtype = np .float32 )
112
+ logger .debug ("Text embedded in semantic routing" , num_texts = len (texts ))
113
+ return np .array (embed_list , dtype = np .float32 )
114
114
115
115
async def _is_persona_description_diff (
116
116
self , emb_persona_desc : np .ndarray , exclude_id : Optional [str ]
@@ -142,7 +142,8 @@ async def _validate_persona_description(
142
142
Validate the persona description by embedding the text and checking if it is
143
143
different enough from existing personas.
144
144
"""
145
- emb_persona_desc = await self ._embed_text (persona_desc )
145
+ emb_persona_desc_list = await self ._embed_texts ([persona_desc ])
146
+ emb_persona_desc = emb_persona_desc_list [0 ]
146
147
if not await self ._is_persona_description_diff (emb_persona_desc , exclude_id ):
147
148
raise PersonaSimilarDescriptionError (
148
149
"The persona description is too similar to existing personas."
@@ -217,21 +218,87 @@ async def delete_persona(self, persona_name: str) -> None:
217
218
await self ._db_recorder .delete_persona (persona .id )
218
219
logger .info (f"Deleted persona { persona_name } from the database." )
219
220
220
- async def check_persona_match (self , persona_name : str , query : str ) -> bool :
221
+ async def _get_cosine_distance (self , emb_queries : np . ndarray , emb_persona : np . ndarray ) -> float :
221
222
"""
222
- Check if the query matches the persona description. A vector similarity
223
- search is performed between the query and the persona description.
223
+ Calculate the cosine distance between the queries embeddings and persona embedding.
224
+ Persona embedding is a single vector of length M
225
+ Queries embeddings is a matrix of shape (N, M)
226
+ N is the number of queries. User messages in this case.
227
+ M is the number of dimensions in the embedding
228
+
229
+ Defintion of cosine distance: 1 - cosine similarity
230
+ [Cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
231
+
232
+ NOTE: Experimented by individually querying SQLite for each query, but as the number
233
+ of queries increases, the performance is better with NumPy. If the number of queries
234
+ is small the performance is onpar. Hence the decision to use NumPy.
235
+ """
236
+ # Handle the case where we have a single query (single user message)
237
+ if emb_queries .ndim == 1 :
238
+ emb_queries = emb_queries .reshape (1 , - 1 )
239
+
240
+ emb_queries_norm = np .linalg .norm (emb_queries , axis = 1 )
241
+ persona_embed_norm = np .linalg .norm (emb_persona )
242
+ cosine_similarities = np .dot (emb_queries , emb_persona .T ) / (
243
+ emb_queries_norm * persona_embed_norm
244
+ )
245
+ # We could also use directly cosine_similarities but we get the distance to match
246
+ # the behavior of SQLite function vec_distance_cosine
247
+ cosine_distances = 1 - cosine_similarities
248
+ return cosine_distances
249
+
250
+ async def _weight_distances (self , distances : np .ndarray ) -> np .ndarray :
251
+ """
252
+ Weights the received distances, with later positions being more important and the
253
+ last position unchanged. The reasoning is that the distances correspond to user
254
+ messages, with the last message being the most recent and therefore the most
255
+ important.
256
+
257
+ Args:
258
+ distances: NumPy array of float values between 0 and 2
259
+ weight_factor: Factor that determines how quickly weights increase (0-1)
260
+ Lower values create a steeper importance curve. 1 makes
261
+ all weights equal.
262
+
263
+ Returns:
264
+ Weighted distances as a NumPy array
265
+ """
266
+ # Get array length
267
+ n = len (distances )
268
+
269
+ # Create positions array in reverse order (n-1, n-2, ..., 1, 0)
270
+ # This makes the last element have position 0
271
+ positions = np .arange (n - 1 , - 1 , - 1 )
272
+
273
+ # Create weights - now the last element (position 0) gets weight 1
274
+ weights = self ._distances_weight_factor ** positions
275
+
276
+ # Apply weights by dividing distances
277
+ # Smaller weight -> larger effective distance
278
+ weighted_distances = distances / weights
279
+ return weighted_distances
280
+
281
+ async def check_persona_match (self , persona_name : str , queries : List [str ]) -> bool :
282
+ """
283
+ Check if the queries match the persona description. A vector similarity
284
+ search is performed between the queries and the persona description.
224
285
0 means the vectors are identical, 2 means they are orthogonal.
225
- See
226
- [sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)
286
+
287
+ The vectors are compared using cosine similarity implemented in _get_cosine_distance.
227
288
"""
228
- persona = await self ._db_reader .get_persona_by_name (persona_name )
229
- if not persona :
289
+ persona_embed = await self ._db_reader .get_persona_embed_by_name (persona_name )
290
+ if not persona_embed :
230
291
raise PersonaDoesNotExistError (f"Persona { persona_name } does not exist." )
231
292
232
- emb_query = await self ._embed_text (query )
233
- persona_distance = await self ._db_reader .get_distance_to_persona (persona .id , emb_query )
234
- logger .info (f"Persona distance to { persona_name } " , distance = persona_distance .distance )
235
- if persona_distance .distance < self ._persona_threshold :
293
+ emb_queries = await self ._embed_texts (queries )
294
+ cosine_distances = await self ._get_cosine_distance (
295
+ emb_queries , persona_embed .description_embedding
296
+ )
297
+ logger .debug ("Cosine distances calculated" , cosine_distances = cosine_distances )
298
+
299
+ weighted_distances = await self ._weight_distances (cosine_distances )
300
+ logger .info ("Weighted distances to persona" , weighted_distances = weighted_distances )
301
+
302
+ if np .any (weighted_distances < self ._persona_threshold ):
236
303
return True
237
304
return False
0 commit comments