Skip to content

Commit bf8d9e8

Browse files
committed
Final fix for llamacpp muxing.
1 parent 6405e64 commit bf8d9e8

File tree

2 files changed

+80
-57
lines changed

2 files changed

+80
-57
lines changed

src/codegate/muxing/router.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastapi import APIRouter, HTTPException, Request
55
from fastapi.responses import StreamingResponse
66

7+
import codegate.providers.llamacpp.completion_handler as llamacpp
78
from codegate.clients.detector import DetectClient
89
from codegate.db.models import ProviderType
910
from codegate.muxing import models as mux_models
@@ -148,9 +149,14 @@ async def route_to_dest_provider(
148149
from_openai = anthropic_from_openai
149150
to_openai = anthropic_to_openai
150151
case ProviderType.llamacpp:
151-
completion_function = provider._completion_handler.execute_completion
152-
from_openai = identity
153-
to_openai = identity
152+
if is_fim_request:
153+
completion_function = llamacpp.complete
154+
from_openai = identity
155+
to_openai = identity
156+
else:
157+
completion_function = llamacpp.chat
158+
from_openai = identity
159+
to_openai = identity
154160
case ProviderType.ollama:
155161
if is_fim_request:
156162
completion_function = ollama.generate_streaming

src/codegate/providers/llamacpp/completion_handler.py

+71-54
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,75 @@ async def chat_to_async_iterator(
4949
yield StreamingChatCompletion(**item)
5050

5151

52+
ENGINE = LlamaCppInferenceEngine()
53+
54+
55+
async def complete(request, api_key, model_path):
56+
stream = request.get_stream()
57+
full_path = f"{model_path}/{request.get_model()}.gguf"
58+
request_dict = request.dict(
59+
exclude={
60+
"best_of",
61+
"frequency_penalty",
62+
"n",
63+
"stream_options",
64+
"user",
65+
}
66+
)
67+
68+
response = await ENGINE.complete(
69+
full_path,
70+
Config.get_config().chat_model_n_ctx,
71+
Config.get_config().chat_model_n_gpu_layers,
72+
**request_dict,
73+
)
74+
75+
if stream:
76+
return completion_to_async_iterator(response)
77+
# TODO fix this code path is broken
78+
return LegacyCompletion(**response)
79+
80+
81+
async def chat(request, api_key, model_path):
82+
stream = request.get_stream()
83+
full_path = f"{model_path}/{request.get_model()}.gguf"
84+
request_dict = request.dict(
85+
exclude={
86+
"audio",
87+
"frequency_penalty",
88+
"include_reasoning",
89+
"metadata",
90+
"max_completion_tokens",
91+
"modalities",
92+
"n",
93+
"parallel_tool_calls",
94+
"prediction",
95+
"prompt",
96+
"reasoning_effort",
97+
"service_tier",
98+
"store",
99+
"stream_options",
100+
"user",
101+
}
102+
)
103+
104+
response = await ENGINE.chat(
105+
full_path,
106+
Config.get_config().chat_model_n_ctx,
107+
Config.get_config().chat_model_n_gpu_layers,
108+
**request_dict,
109+
)
110+
111+
if stream:
112+
return chat_to_async_iterator(response)
113+
else:
114+
# TODO fix this code path is broken
115+
return StreamingChatCompletion(**response)
116+
117+
52118
class LlamaCppCompletionHandler(BaseCompletionHandler):
53119
def __init__(self, base_url):
54-
self.inference_engine = LlamaCppInferenceEngine()
120+
self.inference_engine = ENGINE
55121
self.base_url = base_url
56122

57123
async def execute_completion(
@@ -65,64 +131,15 @@ async def execute_completion(
65131
"""
66132
Execute the completion request with inference engine API
67133
"""
68-
model_path = f"{self.base_url}/{request.get_model()}.gguf"
69-
70134
# Create a copy of the request dict and remove stream_options
71135
# Reason - Request error as JSON:
72136
# {'error': "Llama.create_completion() got an unexpected keyword argument 'stream_options'"}
73137
if is_fim_request:
74-
request_dict = request.dict(
75-
exclude={
76-
"best_of",
77-
"frequency_penalty",
78-
"n",
79-
"stream_options",
80-
"user",
81-
}
82-
)
83-
84-
response = await self.inference_engine.complete(
85-
model_path,
86-
Config.get_config().chat_model_n_ctx,
87-
Config.get_config().chat_model_n_gpu_layers,
88-
**request_dict,
89-
)
90-
91-
if stream:
92-
return completion_to_async_iterator(response)
93-
return LegacyCompletion(**response)
138+
# base_url == model_path in this case
139+
return await complete(request, api_key, self.base_url)
94140
else:
95-
request_dict = request.dict(
96-
exclude={
97-
"audio",
98-
"frequency_penalty",
99-
"include_reasoning",
100-
"metadata",
101-
"max_completion_tokens",
102-
"modalities",
103-
"n",
104-
"parallel_tool_calls",
105-
"prediction",
106-
"prompt",
107-
"reasoning_effort",
108-
"service_tier",
109-
"store",
110-
"stream_options",
111-
"user",
112-
}
113-
)
114-
115-
response = await self.inference_engine.chat(
116-
model_path,
117-
Config.get_config().chat_model_n_ctx,
118-
Config.get_config().chat_model_n_gpu_layers,
119-
**request_dict,
120-
)
121-
122-
if stream:
123-
return chat_to_async_iterator(response)
124-
else:
125-
return StreamingChatCompletion(**response)
141+
# base_url == model_path in this case
142+
return await chat(request, api_key, self.base_url)
126143

127144
def _create_streaming_response(
128145
self,

0 commit comments

Comments
 (0)