Skip to content

Commit c593ad4

Browse files
authored
cleaned up a bit (#10)
1 parent d1f5818 commit c593ad4

File tree

5 files changed

+39
-45
lines changed

5 files changed

+39
-45
lines changed

frontend/app/utils/search.ts

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ export const handleSearch = async (
3030
`${apiUrl}/search`,
3131
{
3232
query: query,
33+
top_k: 40,
3334
},
3435
{
3536
headers: {

pypi_scout/api/data_loader.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def load_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]:
1919
else:
2020
raise ValueError(f"Unexpected value found for STORAGE_BACKEND: {self.config.STORAGE_BACKEND}") # noqa: TRY003
2121

22+
df_embeddings = self._drop_rows_from_embeddings_that_do_not_appear_in_packages(df_embeddings, df_packages)
2223
return df_packages, df_embeddings
2324

2425
def _load_local_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]:
@@ -56,10 +57,22 @@ def _load_blob_dataset(self) -> Tuple[pl.DataFrame, pl.DataFrame]:
5657

5758
return df_packages, df_embeddings
5859

59-
def _log_packages_dataset_info(self, df_packages: pl.DataFrame) -> None:
60+
@staticmethod
61+
def _log_packages_dataset_info(df_packages: pl.DataFrame) -> None:
6062
logging.info(f"Finished loading the `packages` dataset. Number of rows in dataset: {len(df_packages):,}")
6163
logging.info(df_packages.describe())
6264

63-
def _log_embeddings_dataset_info(self, df_embeddings: pl.DataFrame) -> None:
65+
@staticmethod
66+
def _log_embeddings_dataset_info(df_embeddings: pl.DataFrame) -> None:
6467
logging.info(f"Finished loading the `embeddings` dataset. Number of rows in dataset: {len(df_embeddings):,}")
6568
logging.info(df_embeddings.describe())
69+
70+
@staticmethod
71+
def _drop_rows_from_embeddings_that_do_not_appear_in_packages(df_embeddings, df_packages):
72+
# We only keep the packages in the vector dataset that also occur in the packages dataset.
73+
# In theory, this should never drop something. But still good to keep as a fail-safe to prevent issues in the API.
74+
logging.info("Dropping packages in the `embeddings` dataset that do not occur in the `packages` dataset...")
75+
logging.info(f"Number of rows before dropping: {len(df_embeddings):,}...")
76+
df_embeddings = df_embeddings.join(df_packages, on="name", how="semi")
77+
logging.info(f"Number of rows after dropping: {len(df_embeddings):,}...")
78+
return df_embeddings

pypi_scout/api/main.py

+4-40
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,35 @@
11
import logging
22

3-
import polars as pl
43
from dotenv import load_dotenv
54
from fastapi import FastAPI, HTTPException
65
from fastapi.middleware.cors import CORSMiddleware
7-
from pydantic import BaseModel
86
from sentence_transformers import SentenceTransformer
97
from slowapi import Limiter, _rate_limit_exceeded_handler
108
from slowapi.errors import RateLimitExceeded
119
from slowapi.util import get_remote_address
1210
from starlette.requests import Request
1311

1412
from pypi_scout.api.data_loader import ApiDataLoader
13+
from pypi_scout.api.models import QueryModel, SearchResponse
1514
from pypi_scout.config import Config
1615
from pypi_scout.embeddings.simple_vector_database import SimpleVectorDatabase
1716
from pypi_scout.utils.logging import setup_logging
1817
from pypi_scout.utils.score_calculator import calculate_score
1918

20-
# Setup logging
2119
setup_logging()
2220
logging.info("Initializing backend...")
2321

24-
# Initialize limiter
2522
limiter = Limiter(key_func=get_remote_address)
2623
app = FastAPI()
2724
app.state.limiter = limiter
2825
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
2926

30-
# Load environment variables and configuration
3127
load_dotenv()
3228
config = Config()
3329

34-
# Add CORS middleware
3530
app.add_middleware(
3631
CORSMiddleware,
37-
allow_origins=["*"], # Temporary wildcard for testing
32+
allow_origins=["*"],
3833
allow_credentials=True,
3934
allow_methods=["*"],
4035
allow_headers=["*"],
@@ -44,28 +39,9 @@
4439
df_packages, df_embeddings = data_loader.load_dataset()
4540

4641
model = SentenceTransformer(config.EMBEDDINGS_MODEL_NAME)
47-
4842
vector_database = SimpleVectorDatabase(embeddings_model=model, df_embeddings=df_embeddings)
4943

5044

51-
class QueryModel(BaseModel):
52-
query: str
53-
top_k: int = config.N_RESULTS_TO_RETURN
54-
55-
56-
class Match(BaseModel):
57-
name: str
58-
summary: str
59-
similarity: float
60-
weekly_downloads: int
61-
62-
63-
class SearchResponse(BaseModel):
64-
matches: list[Match]
65-
warning: bool = False
66-
warning_message: str = None
67-
68-
6945
@app.post("/api/search", response_model=SearchResponse)
7046
@limiter.limit("4/minute")
7147
async def search(query: QueryModel, request: Request):
@@ -75,7 +51,7 @@ async def search(query: QueryModel, request: Request):
7551
The top_k packages with the highest score are returned.
7652
"""
7753

78-
if query.top_k > 60:
54+
if query.top_k > 100:
7955
raise HTTPException(status_code=400, detail="top_k cannot be larger than 100.")
8056

8157
logging.info(f"Searching for similar projects. Query: '{query.query}'")
@@ -85,18 +61,6 @@ async def search(query: QueryModel, request: Request):
8561
f"Fetched the {len(df_matches)} most similar projects. Calculating the weighted scores and filtering..."
8662
)
8763

88-
warning = False
89-
warning_message = ""
90-
matches_missing_in_local_dataset = df_matches.filter(pl.col("weekly_downloads").is_null())["name"].to_list()
91-
if matches_missing_in_local_dataset:
92-
warning = True
93-
warning_message = (
94-
f"The following entries have 'None' for 'weekly_downloads': {matches_missing_in_local_dataset}. "
95-
"These entries were found in the vector database but not in the local dataset and have been excluded from the results."
96-
)
97-
logging.error(warning_message)
98-
df_matches = df_matches.filter(~pl.col("name").is_in(matches_missing_in_local_dataset))
99-
10064
df_matches = calculate_score(
10165
df_matches, weight_similarity=config.WEIGHT_SIMILARITY, weight_weekly_downloads=config.WEIGHT_WEEKLY_DOWNLOADS
10266
)
@@ -107,4 +71,4 @@ async def search(query: QueryModel, request: Request):
10771

10872
logging.info(f"Returning the {len(df_matches)} best matches.")
10973
df_matches = df_matches.select(["name", "similarity", "summary", "weekly_downloads"])
110-
return SearchResponse(matches=df_matches.to_dicts(), warning=warning, warning_message=warning_message)
74+
return SearchResponse(matches=df_matches.to_dicts())

pypi_scout/api/models.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from pydantic import BaseModel
2+
3+
4+
class QueryModel(BaseModel):
5+
query: str
6+
top_k: int
7+
8+
9+
class Match(BaseModel):
10+
name: str
11+
summary: str
12+
similarity: float
13+
weekly_downloads: int
14+
15+
16+
class SearchResponse(BaseModel):
17+
matches: list[Match]
18+
warning: bool = False
19+
warning_message: str = None

pypi_scout/config.py

-3
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@ class Config:
3838
# Google Drive file ID for downloading the raw dataset.
3939
GOOGLE_FILE_ID = "1IDJvCsq1gz0yUSXgff13pMl3nUk7zJzb"
4040

41-
# Number of top results to return for a query.
42-
N_RESULTS_TO_RETURN = 40
43-
4441
# Fraction of the dataset to include in the vector database. This value determines the portion of top packages
4542
# (sorted by weekly downloads) to include. Increase this value to include a larger portion of the dataset, up to 1.0 (100%).
4643
# For reference, a value of 0.25 corresponds to including all PyPI packages with at least approximately 650 weekly downloads

0 commit comments

Comments
 (0)