Skip to content

Commit d6e0ed4

Browse files
feat: ensuring that max_context_tokens is never larger than what supported by models (#3519)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
1 parent e384a0a commit d6e0ed4

File tree

1 file changed

+100
-33
lines changed

1 file changed

+100
-33
lines changed

core/quivr_core/rag/entities/config.py

+100-33
Original file line numberDiff line numberDiff line change
@@ -75,89 +75,139 @@ class DefaultModelSuppliers(str, Enum):
7575

7676

7777
class LLMConfig(QuivrBaseConfig):
78-
context: int | None = None
78+
max_context_tokens: int | None = None
79+
max_output_tokens: int | None = None
7980
tokenizer_hub: str | None = None
8081

8182

8283
class LLMModelConfig:
8384
_model_defaults: Dict[DefaultModelSuppliers, Dict[str, LLMConfig]] = {
8485
DefaultModelSuppliers.OPENAI: {
85-
"gpt-4o": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
86-
"gpt-4o-mini": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4o"),
87-
"gpt-4-turbo": LLMConfig(context=128000, tokenizer_hub="Xenova/gpt-4"),
88-
"gpt-4": LLMConfig(context=8192, tokenizer_hub="Xenova/gpt-4"),
86+
"gpt-4o": LLMConfig(
87+
max_context_tokens=128000,
88+
max_output_tokens=16384,
89+
tokenizer_hub="Xenova/gpt-4o",
90+
),
91+
"gpt-4o-mini": LLMConfig(
92+
max_context_tokens=128000,
93+
max_output_tokens=16384,
94+
tokenizer_hub="Xenova/gpt-4o",
95+
),
96+
"gpt-4-turbo": LLMConfig(
97+
max_context_tokens=128000,
98+
max_output_tokens=4096,
99+
tokenizer_hub="Xenova/gpt-4",
100+
),
101+
"gpt-4": LLMConfig(
102+
max_context_tokens=8192,
103+
max_output_tokens=8192,
104+
tokenizer_hub="Xenova/gpt-4",
105+
),
89106
"gpt-3.5-turbo": LLMConfig(
90-
context=16385, tokenizer_hub="Xenova/gpt-3.5-turbo"
107+
max_context_tokens=16385,
108+
max_output_tokens=4096,
109+
tokenizer_hub="Xenova/gpt-3.5-turbo",
91110
),
92111
"text-embedding-3-large": LLMConfig(
93-
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
112+
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
94113
),
95114
"text-embedding-3-small": LLMConfig(
96-
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
115+
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
97116
),
98117
"text-embedding-ada-002": LLMConfig(
99-
context=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
118+
max_context_tokens=8191, tokenizer_hub="Xenova/text-embedding-ada-002"
100119
),
101120
},
102121
DefaultModelSuppliers.ANTHROPIC: {
103122
"claude-3-5-sonnet": LLMConfig(
104-
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
123+
max_context_tokens=200000,
124+
max_output_tokens=8192,
125+
tokenizer_hub="Xenova/claude-tokenizer",
105126
),
106127
"claude-3-opus": LLMConfig(
107-
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
128+
max_context_tokens=200000,
129+
max_output_tokens=4096,
130+
tokenizer_hub="Xenova/claude-tokenizer",
108131
),
109132
"claude-3-sonnet": LLMConfig(
110-
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
133+
max_context_tokens=200000,
134+
max_output_tokens=4096,
135+
tokenizer_hub="Xenova/claude-tokenizer",
111136
),
112137
"claude-3-haiku": LLMConfig(
113-
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
138+
max_context_tokens=200000,
139+
max_output_tokens=4096,
140+
tokenizer_hub="Xenova/claude-tokenizer",
114141
),
115142
"claude-2-1": LLMConfig(
116-
context=200000, tokenizer_hub="Xenova/claude-tokenizer"
143+
max_context_tokens=200000,
144+
max_output_tokens=4096,
145+
tokenizer_hub="Xenova/claude-tokenizer",
117146
),
118147
"claude-2-0": LLMConfig(
119-
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
148+
max_context_tokens=100000,
149+
max_output_tokens=4096,
150+
tokenizer_hub="Xenova/claude-tokenizer",
120151
),
121152
"claude-instant-1-2": LLMConfig(
122-
context=100000, tokenizer_hub="Xenova/claude-tokenizer"
153+
max_context_tokens=100000,
154+
max_output_tokens=4096,
155+
tokenizer_hub="Xenova/claude-tokenizer",
123156
),
124157
},
158+
# Unclear for LLAMA models...
159+
# see https://huggingface.co/meta-llama/Llama-3.1-405B-Instruct/discussions/6
125160
DefaultModelSuppliers.META: {
126161
"llama-3.1": LLMConfig(
127-
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
162+
max_context_tokens=128000,
163+
max_output_tokens=4096,
164+
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
128165
),
129166
"llama-3": LLMConfig(
130-
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
167+
max_context_tokens=8192,
168+
max_output_tokens=2048,
169+
tokenizer_hub="Xenova/llama3-tokenizer-new",
131170
),
132-
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
133171
"code-llama": LLMConfig(
134-
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
172+
max_context_tokens=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
135173
),
136174
},
137175
DefaultModelSuppliers.GROQ: {
138-
"llama-3.1": LLMConfig(
139-
context=128000, tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer"
176+
"llama-3.3-70b": LLMConfig(
177+
max_context_tokens=128000,
178+
max_output_tokens=32768,
179+
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
180+
),
181+
"llama-3.1-70b": LLMConfig(
182+
max_context_tokens=128000,
183+
max_output_tokens=32768,
184+
tokenizer_hub="Xenova/Meta-Llama-3.1-Tokenizer",
140185
),
141186
"llama-3": LLMConfig(
142-
context=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
187+
max_context_tokens=8192, tokenizer_hub="Xenova/llama3-tokenizer-new"
143188
),
144-
"llama-2": LLMConfig(context=4096, tokenizer_hub="Xenova/llama2-tokenizer"),
145189
"code-llama": LLMConfig(
146-
context=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
190+
max_context_tokens=16384, tokenizer_hub="Xenova/llama-code-tokenizer"
147191
),
148192
},
149193
DefaultModelSuppliers.MISTRAL: {
150194
"mistral-large": LLMConfig(
151-
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
195+
max_context_tokens=128000,
196+
max_output_tokens=4096,
197+
tokenizer_hub="Xenova/mistral-tokenizer-v3",
152198
),
153199
"mistral-small": LLMConfig(
154-
context=128000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
200+
max_context_tokens=128000,
201+
max_output_tokens=4096,
202+
tokenizer_hub="Xenova/mistral-tokenizer-v3",
155203
),
156204
"mistral-nemo": LLMConfig(
157-
context=128000, tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer"
205+
max_context_tokens=128000,
206+
max_output_tokens=4096,
207+
tokenizer_hub="Xenova/Mistral-Nemo-Instruct-Tokenizer",
158208
),
159209
"codestral": LLMConfig(
160-
context=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
210+
max_context_tokens=32000, tokenizer_hub="Xenova/mistral-tokenizer-v3"
161211
),
162212
},
163213
}
@@ -193,13 +243,12 @@ def get_llm_model_config(
193243
class LLMEndpointConfig(QuivrBaseConfig):
194244
supplier: DefaultModelSuppliers = DefaultModelSuppliers.OPENAI
195245
model: str = "gpt-4o"
196-
context_length: int | None = None
197246
tokenizer_hub: str | None = None
198247
llm_base_url: str | None = None
199248
env_variable_name: str | None = None
200249
llm_api_key: str | None = None
201-
max_context_tokens: int = 2000
202-
max_output_tokens: int = 2000
250+
max_context_tokens: int = 10000
251+
max_output_tokens: int = 4000
203252
temperature: float = 0.7
204253
streaming: bool = True
205254
prompt: CustomPromptsModel | None = None
@@ -240,7 +289,25 @@ def set_llm_model_config(self):
240289
self.supplier, self.model
241290
)
242291
if llm_model_config:
243-
self.context_length = llm_model_config.context
292+
if llm_model_config.max_context_tokens:
293+
_max_context_tokens = (
294+
llm_model_config.max_context_tokens
295+
- llm_model_config.max_output_tokens
296+
if llm_model_config.max_output_tokens
297+
else llm_model_config.max_context_tokens
298+
)
299+
if self.max_context_tokens > _max_context_tokens:
300+
logger.warning(
301+
f"Lowering max_context_tokens from {self.max_context_tokens} to {_max_context_tokens}"
302+
)
303+
self.max_context_tokens = _max_context_tokens
304+
if llm_model_config.max_output_tokens:
305+
if self.max_output_tokens > llm_model_config.max_output_tokens:
306+
logger.warning(
307+
f"Lowering max_output_tokens from {self.max_output_tokens} to {llm_model_config.max_output_tokens}"
308+
)
309+
self.max_output_tokens = llm_model_config.max_output_tokens
310+
244311
self.tokenizer_hub = llm_model_config.tokenizer_hub
245312

246313
def set_llm_model(self, model: str):

0 commit comments

Comments
 (0)