1
1
import logging
2
2
3
- import polars as pl
4
3
from dotenv import load_dotenv
5
4
from fastapi import FastAPI , HTTPException
6
5
from fastapi .middleware .cors import CORSMiddleware
7
- from pydantic import BaseModel
8
6
from sentence_transformers import SentenceTransformer
9
7
from slowapi import Limiter , _rate_limit_exceeded_handler
10
8
from slowapi .errors import RateLimitExceeded
11
9
from slowapi .util import get_remote_address
12
10
from starlette .requests import Request
13
11
14
12
from pypi_scout .api .data_loader import ApiDataLoader
13
+ from pypi_scout .api .models import QueryModel , SearchResponse
15
14
from pypi_scout .config import Config
16
15
from pypi_scout .embeddings .simple_vector_database import SimpleVectorDatabase
17
16
from pypi_scout .utils .logging import setup_logging
18
17
from pypi_scout .utils .score_calculator import calculate_score
19
18
20
- # Setup logging
21
19
setup_logging ()
22
20
logging .info ("Initializing backend..." )
23
21
24
- # Initialize limiter
25
22
limiter = Limiter (key_func = get_remote_address )
26
23
app = FastAPI ()
27
24
app .state .limiter = limiter
28
25
app .add_exception_handler (RateLimitExceeded , _rate_limit_exceeded_handler )
29
26
30
- # Load environment variables and configuration
31
27
load_dotenv ()
32
28
config = Config ()
33
29
34
- # Add CORS middleware
35
30
app .add_middleware (
36
31
CORSMiddleware ,
37
- allow_origins = ["*" ], # Temporary wildcard for testing
32
+ allow_origins = ["*" ],
38
33
allow_credentials = True ,
39
34
allow_methods = ["*" ],
40
35
allow_headers = ["*" ],
44
39
df_packages , df_embeddings = data_loader .load_dataset ()
45
40
46
41
model = SentenceTransformer (config .EMBEDDINGS_MODEL_NAME )
47
-
48
42
vector_database = SimpleVectorDatabase (embeddings_model = model , df_embeddings = df_embeddings )
49
43
50
44
51
- class QueryModel (BaseModel ):
52
- query : str
53
- top_k : int = config .N_RESULTS_TO_RETURN
54
-
55
-
56
- class Match (BaseModel ):
57
- name : str
58
- summary : str
59
- similarity : float
60
- weekly_downloads : int
61
-
62
-
63
- class SearchResponse (BaseModel ):
64
- matches : list [Match ]
65
- warning : bool = False
66
- warning_message : str = None
67
-
68
-
69
45
@app .post ("/api/search" , response_model = SearchResponse )
70
46
@limiter .limit ("4/minute" )
71
47
async def search (query : QueryModel , request : Request ):
@@ -75,7 +51,7 @@ async def search(query: QueryModel, request: Request):
75
51
The top_k packages with the highest score are returned.
76
52
"""
77
53
78
- if query .top_k > 60 :
54
+ if query .top_k > 100 :
79
55
raise HTTPException (status_code = 400 , detail = "top_k cannot be larger than 100." )
80
56
81
57
logging .info (f"Searching for similar projects. Query: '{ query .query } '" )
@@ -85,18 +61,6 @@ async def search(query: QueryModel, request: Request):
85
61
f"Fetched the { len (df_matches )} most similar projects. Calculating the weighted scores and filtering..."
86
62
)
87
63
88
- warning = False
89
- warning_message = ""
90
- matches_missing_in_local_dataset = df_matches .filter (pl .col ("weekly_downloads" ).is_null ())["name" ].to_list ()
91
- if matches_missing_in_local_dataset :
92
- warning = True
93
- warning_message = (
94
- f"The following entries have 'None' for 'weekly_downloads': { matches_missing_in_local_dataset } . "
95
- "These entries were found in the vector database but not in the local dataset and have been excluded from the results."
96
- )
97
- logging .error (warning_message )
98
- df_matches = df_matches .filter (~ pl .col ("name" ).is_in (matches_missing_in_local_dataset ))
99
-
100
64
df_matches = calculate_score (
101
65
df_matches , weight_similarity = config .WEIGHT_SIMILARITY , weight_weekly_downloads = config .WEIGHT_WEEKLY_DOWNLOADS
102
66
)
@@ -107,4 +71,4 @@ async def search(query: QueryModel, request: Request):
107
71
108
72
logging .info (f"Returning the { len (df_matches )} best matches." )
109
73
df_matches = df_matches .select (["name" , "similarity" , "summary" , "weekly_downloads" ])
110
- return SearchResponse (matches = df_matches .to_dicts (), warning = warning , warning_message = warning_message )
74
+ return SearchResponse (matches = df_matches .to_dicts ())
0 commit comments