Skip to content

Commit e2a3bcb

Browse files
feat: limit tokenizers cache size (#3577)
We limit the number of preloaded tokenizers to 3 and the maximum size to 50 MB
1 parent 05e212a commit e2a3bcb

File tree

4 files changed

+101
-12
lines changed

4 files changed

+101
-12
lines changed

core/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"langchain-mistralai>=0.2.3",
2626
"fasttext-langdetect>=1.0.5",
2727
"langfuse>=2.57.0",
28+
"pympler>=1.1",
2829
]
2930
readme = "README.md"
3031
requires-python = ">= 3.11"

core/quivr_core/llm/llm_endpoint.py

+96-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import os
3-
from typing import Union
3+
from typing import Union, Any
44
from urllib.parse import parse_qs, urlparse
55

66
import tiktoken
@@ -9,6 +9,8 @@
99
from langchain_core.language_models.chat_models import BaseChatModel
1010
from langchain_openai import AzureChatOpenAI, ChatOpenAI
1111
from pydantic import SecretStr
12+
import time
13+
from pympler import asizeof
1214

1315
from quivr_core.brain.info import LLMInfo
1416
from quivr_core.rag.entities.config import DefaultModelSuppliers, LLMEndpointConfig
@@ -17,8 +19,17 @@
1719
logger = logging.getLogger("quivr_core")
1820

1921

22+
def get_size(obj: Any, seen: set | None = None) -> int:
23+
return asizeof.asizeof(obj)
24+
25+
2026
class LLMTokenizer:
21-
_cache: dict[int, "LLMTokenizer"] = {}
27+
_cache: dict[
28+
int, tuple["LLMTokenizer", int, float]
29+
] = {} # {hash: (tokenizer, size_bytes, last_access_time)}
30+
_max_cache_size_mb: int = 50
31+
_max_cache_count: int = 3 # Default maximum number of cached tokenizers
32+
_current_cache_size: int = 0
2233

2334
def __init__(self, tokenizer_hub: str | None, fallback_tokenizer: str):
2435
self.tokenizer_hub = tokenizer_hub
@@ -51,33 +62,106 @@ def __init__(self, tokenizer_hub: str | None, fallback_tokenizer: str):
5162
else:
5263
self.tokenizer = tiktoken.get_encoding(self.fallback_tokenizer)
5364

65+
# More accurate size estimation
66+
self._size_bytes = get_size(self.tokenizer)
67+
5468
@classmethod
5569
def load(cls, tokenizer_hub: str, fallback_tokenizer: str):
5670
cache_key = hash(str(tokenizer_hub))
71+
72+
# If in cache, update last access time and return
5773
if cache_key in cls._cache:
58-
return cls._cache[cache_key]
74+
tokenizer, size, _ = cls._cache[cache_key]
75+
cls._cache[cache_key] = (tokenizer, size, time.time())
76+
return tokenizer
77+
78+
# Create new instance
5979
instance = cls(tokenizer_hub, fallback_tokenizer)
60-
cls._cache[cache_key] = instance
80+
81+
# Check if adding this would exceed either cache limit
82+
while (
83+
cls._current_cache_size + instance._size_bytes
84+
> cls._max_cache_size_mb * 1024 * 1024
85+
or len(cls._cache) >= cls._max_cache_count
86+
):
87+
# Find least recently used item
88+
oldest_key = min(
89+
cls._cache.keys(),
90+
key=lambda k: cls._cache[k][2], # last_access_time
91+
)
92+
# Remove it
93+
_, removed_size, _ = cls._cache.pop(oldest_key)
94+
cls._current_cache_size -= removed_size
95+
96+
# Add new instance to cache with current timestamp
97+
cls._cache[cache_key] = (instance, instance._size_bytes, time.time())
98+
cls._current_cache_size += instance._size_bytes
6199
return instance
62100

63101
@classmethod
64-
def preload_tokenizers(cls):
65-
"""Preload all available tokenizers from the models configuration into cache."""
102+
def set_max_cache_size_mb(cls, size_mb: int):
103+
"""Set the maximum cache size in megabytes."""
104+
cls._max_cache_size_mb = size_mb
105+
cls._cleanup_cache()
106+
107+
@classmethod
108+
def set_max_cache_count(cls, count: int):
109+
"""Set the maximum number of tokenizers to cache."""
110+
cls._max_cache_count = count
111+
cls._cleanup_cache()
112+
113+
@classmethod
114+
def _cleanup_cache(cls):
115+
"""Clean up cache when limits are exceeded."""
116+
while (
117+
cls._current_cache_size > cls._max_cache_size_mb * 1024 * 1024
118+
or len(cls._cache) > cls._max_cache_count
119+
):
120+
oldest_key = min(cls._cache.keys(), key=lambda k: cls._cache[k][2])
121+
_, removed_size, _ = cls._cache.pop(oldest_key)
122+
cls._current_cache_size -= removed_size
123+
124+
@classmethod
125+
def preload_tokenizers(cls, models: list[str] | None = None):
126+
"""Preload tokenizers into cache.
127+
128+
Args:
129+
models: Optional list of model names (e.g. 'gpt-4o', 'claude-3-5-sonnet').
130+
If None, preloads all available tokenizers.
131+
"""
66132
from quivr_core.rag.entities.config import LLMModelConfig
67133

68134
unique_tokenizer_hubs = set()
69135

70-
# Collect all unique tokenizer hubs
71-
for supplier_models in LLMModelConfig._model_defaults.values():
72-
for config in supplier_models.values():
73-
if config.tokenizer_hub:
74-
unique_tokenizer_hubs.add(config.tokenizer_hub)
136+
# Collect tokenizer hubs based on provided models or all available
137+
if models:
138+
for model_name in models:
139+
# Find matching model configurations
140+
for supplier_models in LLMModelConfig._model_defaults.values():
141+
for base_model_name, config in supplier_models.items():
142+
# Check if the model name matches or starts with the base model name
143+
if (
144+
model_name.startswith(base_model_name)
145+
and config.tokenizer_hub
146+
):
147+
unique_tokenizer_hubs.add(config.tokenizer_hub)
148+
break
149+
else:
150+
# Original behavior - collect all unique tokenizer hubs
151+
for supplier_models in LLMModelConfig._model_defaults.values():
152+
for config in supplier_models.values():
153+
if config.tokenizer_hub:
154+
unique_tokenizer_hubs.add(config.tokenizer_hub)
75155

76156
# Load each unique tokenizer
77157
for hub in unique_tokenizer_hubs:
78158
try:
79159
cls.load(hub, LLMEndpointConfig._FALLBACK_TOKENIZER)
80-
logger.info(f"Successfully preloaded tokenizer: {hub}")
160+
logger.info(
161+
f"Successfully preloaded tokenizer: {hub}. "
162+
f"Total cache size: {cls._current_cache_size / (1024 * 1024):.2f} MB. "
163+
f"Cache count: {len(cls._cache)}"
164+
)
81165
except Exception as e:
82166
logger.warning(f"Failed to preload tokenizer {hub}: {str(e)}")
83167

core/requirements-dev.lock

+2
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ pyflakes==3.2.0
296296
pygments==2.18.0
297297
# via ipython
298298
# via rich
299+
pympler==1.1
300+
# via quivr-core
299301
pytest==8.3.3
300302
# via pytest-asyncio
301303
# via pytest-benchmark

core/requirements.lock

+2
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ pydantic-settings==2.6.1
214214
# via langchain-community
215215
pygments==2.18.0
216216
# via rich
217+
pympler==1.1
218+
# via quivr-core
217219
python-dateutil==2.8.2
218220
# via pandas
219221
python-dotenv==1.0.1

0 commit comments

Comments
 (0)