Skip to content

Commit 699dc2e

Browse files
feat: cache tokenizers (#3558)
We now load all tokenizers at the import of the llm_endpoint module, and any initialisation of the LLMEndpoint class will use the cached tokenizers Closes ENT-402
1 parent 8a1e5f2 commit 699dc2e

File tree

2 files changed

+95
-52
lines changed

2 files changed

+95
-52
lines changed

core/quivr_core/llm/llm_endpoint.py

+66-23
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,14 @@
1717
logger = logging.getLogger("quivr_core")
1818

1919

20-
class LLMEndpoint:
21-
_cache: dict[int, "LLMEndpoint"] = {}
20+
class LLMTokenizer:
21+
_cache: dict[int, "LLMTokenizer"] = {}
2222

23-
def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel):
24-
self._config = llm_config
25-
self._llm = llm
26-
self._supports_func_calling = model_supports_function_calling(
27-
self._config.model
28-
)
23+
def __init__(self, tokenizer_hub: str | None, fallback_tokenizer: str):
24+
self.tokenizer_hub = tokenizer_hub
25+
self.fallback_tokenizer = fallback_tokenizer
2926

30-
if llm_config.tokenizer_hub:
27+
if self.tokenizer_hub:
3128
# To prevent the warning
3229
# huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
3330
os.environ["TOKENIZERS_PARALLELISM"] = (
@@ -36,34 +33,81 @@ def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel):
3633
else os.environ["TOKENIZERS_PARALLELISM"]
3734
)
3835
try:
39-
from transformers import AutoTokenizer
36+
if "text-embedding-ada-002" in self.tokenizer_hub:
37+
from transformers import GPT2TokenizerFast
4038

41-
self.tokenizer = AutoTokenizer.from_pretrained(llm_config.tokenizer_hub)
39+
self.tokenizer = GPT2TokenizerFast.from_pretrained(
40+
self.tokenizer_hub
41+
)
42+
else:
43+
from transformers import AutoTokenizer
44+
45+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_hub)
4246
except OSError: # if we don't manage to connect to huggingface and/or no cached models are present
4347
logger.warning(
44-
f"Cannot acces the configured tokenizer from {llm_config.tokenizer_hub}, using the default tokenizer {llm_config.fallback_tokenizer}"
48+
f"Cannot acces the configured tokenizer from {self.tokenizer_hub}, using the default tokenizer {self.fallback_tokenizer}"
4549
)
46-
self.tokenizer = tiktoken.get_encoding(llm_config.fallback_tokenizer)
50+
self.tokenizer = tiktoken.get_encoding(self.fallback_tokenizer)
4751
else:
48-
self.tokenizer = tiktoken.get_encoding(llm_config.fallback_tokenizer)
52+
self.tokenizer = tiktoken.get_encoding(self.fallback_tokenizer)
53+
54+
@classmethod
55+
def load(cls, tokenizer_hub: str, fallback_tokenizer: str):
56+
cache_key = hash(str(tokenizer_hub))
57+
if cache_key in cls._cache:
58+
return cls._cache[cache_key]
59+
instance = cls(tokenizer_hub, fallback_tokenizer)
60+
cls._cache[cache_key] = instance
61+
return instance
62+
63+
@classmethod
64+
def preload_tokenizers(cls):
65+
"""Preload all available tokenizers from the models configuration into cache."""
66+
from quivr_core.rag.entities.config import LLMModelConfig
67+
68+
unique_tokenizer_hubs = set()
69+
70+
# Collect all unique tokenizer hubs
71+
for supplier_models in LLMModelConfig._model_defaults.values():
72+
for config in supplier_models.values():
73+
if config.tokenizer_hub:
74+
unique_tokenizer_hubs.add(config.tokenizer_hub)
75+
76+
# Load each unique tokenizer
77+
for hub in unique_tokenizer_hubs:
78+
try:
79+
cls.load(hub, LLMEndpointConfig._FALLBACK_TOKENIZER)
80+
logger.info(f"Successfully preloaded tokenizer: {hub}")
81+
except Exception as e:
82+
logger.warning(f"Failed to preload tokenizer {hub}: {str(e)}")
83+
84+
85+
# Preload tokenizers when module is imported
86+
LLMTokenizer.preload_tokenizers()
87+
88+
89+
class LLMEndpoint:
90+
def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel):
91+
self._config = llm_config
92+
self._llm = llm
93+
self._supports_func_calling = model_supports_function_calling(
94+
self._config.model
95+
)
96+
97+
self.llm_tokenizer = LLMTokenizer.load(
98+
llm_config.tokenizer_hub, llm_config.fallback_tokenizer
99+
)
49100

50101
def count_tokens(self, text: str) -> int:
51102
# Tokenize the input text and return the token count
52-
encoding = self.tokenizer.encode(text)
103+
encoding = self.llm_tokenizer.tokenizer.encode(text)
53104
return len(encoding)
54105

55106
def get_config(self):
56107
return self._config
57108

58109
@classmethod
59110
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
60-
# Create a cache key from the config
61-
cache_key = hash(str(config.model_dump()))
62-
63-
# Return cached instance if it exists
64-
if cache_key in cls._cache:
65-
return cls._cache[cache_key]
66-
67111
_llm: Union[AzureChatOpenAI, ChatOpenAI, ChatAnthropic, ChatMistralAI]
68112
try:
69113
if config.supplier == DefaultModelSuppliers.AZURE:
@@ -122,7 +166,6 @@ def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
122166
temperature=config.temperature,
123167
)
124168
instance = cls(llm=_llm, llm_config=config)
125-
cls._cache[cache_key] = instance
126169
return instance
127170

128171
except ImportError as e:

core/quivr_core/rag/entities/config.py

+29-29
Original file line numberDiff line numberDiff line change
@@ -86,73 +86,73 @@ class LLMModelConfig:
8686
"gpt-4o": LLMConfig(
8787
max_context_tokens=128000,
8888
max_output_tokens=16384,
89-
tokenizer_hub="Xenova/gpt-4o",
89+
tokenizer_hub="Quivr/gpt-4o",
9090
),
9191
"gpt-4o-mini": LLMConfig(
9292
max_context_tokens=128000,
9393
max_output_tokens=16384,
94-
tokenizer_hub="Xenova/gpt-4o",
94+
tokenizer_hub="Quivr/gpt-4o",
9595
),
9696
"gpt-4-turbo": LLMConfig(
9797
max_context_tokens=128000,
9898
max_output_tokens=4096,
99-
tokenizer_hub="Xenova/gpt-4",
99+
tokenizer_hub="Quivr/gpt-4",
100100
),
101101
"gpt-4": LLMConfig(
102102
max_context_tokens=8192,
103103
max_output_tokens=8192,
104-
tokenizer_hub="Xenova/gpt-4",
104+
tokenizer_hub="Quivr/gpt-4",
105105
),
106106
"gpt-3.5-turbo": LLMConfig(
107107
max_context_tokens=16385,
108108
max_output_tokens=4096,
109-
tokenizer_hub="Xenova/gpt-3.5-turbo",
109+
tokenizer_hub="Quivr/gpt-3.5-turbo",
110110
),
111111
"text-embedding-3-large": LLMConfig(
112-
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
112+
max_context_tokens=8191, tokenizer_hub="Quivr/text-embedding-ada-002"
113113
),
114114
"text-embedding-3-small": LLMConfig(
115-
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
115+
max_context_tokens=8191, tokenizer_hub="Quivr/text-embedding-ada-002"
116116
),
117117
"text-embedding-ada-002": LLMConfig(
118-
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
118+
max_context_tokens=8191, tokenizer_hub="Quivr/text-embedding-ada-002"
119119
),
120120
},
121121
DefaultModelSuppliers.ANTHROPIC: {
122122
"claude-3-5-sonnet": LLMConfig(
123123
max_context_tokens=200000,
124124
max_output_tokens=8192,
125-
tokenizer_hub="Xenova/claude-tokenizer",
125+
tokenizer_hub="Quivr/claude-tokenizer",
126126
),
127127
"claude-3-opus": LLMConfig(
128128
max_context_tokens=200000,
129129
max_output_tokens=4096,
130-
tokenizer_hub="Xenova/claude-tokenizer",
130+
tokenizer_hub="Quivr/claude-tokenizer",
131131
),
132132
"claude-3-sonnet": LLMConfig(
133133
max_context_tokens=200000,
134134
max_output_tokens=4096,
135-
tokenizer_hub="Xenova/claude-tokenizer",
135+
tokenizer_hub="Quivr/claude-tokenizer",
136136
),
137137
"claude-3-haiku": LLMConfig(
138138
max_context_tokens=200000,
139139
max_output_tokens=4096,
140-
tokenizer_hub="Xenova/claude-tokenizer",
140+
tokenizer_hub="Quivr/claude-tokenizer",
141141
),
142142
"claude-2-1": LLMConfig(
143143
max_context_tokens=200000,
144144
max_output_tokens=4096,
145-
tokenizer_hub="Xenova/claude-tokenizer",
145+
tokenizer_hub="Quivr/claude-tokenizer",
146146
),
147147
"claude-2-0": LLMConfig(
148148
max_context_tokens=100000,
149149
max_output_tokens=4096,
150-
tokenizer_hub="Xenova/claude-tokenizer",
150+
tokenizer_hub="Quivr/claude-tokenizer",
151151
),
152152
"claude-instant-1-2": LLMConfig(
153153
max_context_tokens=100000,
154154
max_output_tokens=4096,
155-
tokenizer_hub="Xenova/claude-tokenizer",
155+
tokenizer_hub="Quivr/claude-tokenizer",
156156
),
157157
},
158158
# Unclear for LLAMA models...
@@ -161,53 +161,53 @@ class LLMModelConfig:
161161
"llama-3.1": LLMConfig(
162162
max_context_tokens=128000,
163163
max_output_tokens=4096,
164-
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
164+
tokenizer_hub="Quivr/Meta-Llama-3.1-Tokenizer",
165165
),
166166
"llama-3": LLMConfig(
167167
max_context_tokens=8192,
168168
max_output_tokens=2048,
169-
tokenizer_hub="Xenova/llama3-tokenizer-new",
169+
tokenizer_hub="Quivr/llama3-tokenizer-new",
170170
),
171171
"code-llama": LLMConfig(
172-
max_context_tokens=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
172+
max_context_tokens=16384, tokenizer_hub="Quivr/llama-code-tokenizer"
173173
),
174174
},
175175
DefaultModelSuppliers.GROQ: {
176176
"llama-3.3-70b": LLMConfig(
177177
max_context_tokens=128000,
178178
max_output_tokens=32768,
179-
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
179+
tokenizer_hub="Quivr/Meta-Llama-3.1-Tokenizer",
180180
),
181181
"llama-3.1-70b": LLMConfig(
182182
max_context_tokens=128000,
183183
max_output_tokens=32768,
184-
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
184+
tokenizer_hub="Quivr/Meta-Llama-3.1-Tokenizer",
185185
),
186186
"llama-3": LLMConfig(
187-
max_context_tokens=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
187+
max_context_tokens=8192, tokenizer_hub="Quivr/llama3-tokenizer-new"
188188
),
189189
"code-llama": LLMConfig(
190-
max_context_tokens=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
190+
max_context_tokens=16384, tokenizer_hub="Quivr/llama-code-tokenizer"
191191
),
192192
},
193193
DefaultModelSuppliers.MISTRAL: {
194194
"mistral-large": LLMConfig(
195195
max_context_tokens=128000,
196196
max_output_tokens=4096,
197-
tokenizer_hub="Xenova/mistral-tokenizer-v3",
197+
tokenizer_hub="Quivr/mistral-tokenizer-v3",
198198
),
199199
"mistral-small": LLMConfig(
200200
max_context_tokens=128000,
201201
max_output_tokens=4096,
202-
tokenizer_hub="Xenova/mistral-tokenizer-v3",
202+
tokenizer_hub="Quivr/mistral-tokenizer-v3",
203203
),
204204
"mistral-nemo": LLMConfig(
205205
max_context_tokens=128000,
206206
max_output_tokens=4096,
207-
tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer",
207+
tokenizer_hub="Quivr/Mistral-Nemo-Instruct-Tokenizer",
208208
),
209209
"codestral": LLMConfig(
210-
max_context_tokens=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
210+
max_context_tokens=32000, tokenizer_hub="Quivr/mistral-tokenizer-v3"
211211
),
212212
},
213213
}
@@ -247,9 +247,9 @@ class LLMEndpointConfig(QuivrBaseConfig):
247247
llm_base_url: str | None = None
248248
env_variable_name: str | None = None
249249
llm_api_key: str | None = None
250-
max_context_tokens: int = 10000
251-
max_output_tokens: int = 4000
252-
temperature: float = 0.7
250+
max_context_tokens: int = 20000
251+
max_output_tokens: int = 4096
252+
temperature: float = 0.3
253253
streaming: bool = True
254254
prompt: CustomPromptsModel | None = None
255255

0 commit comments

Comments
 (0)