Skip to content

Commit 05b134a

Browse files
blktjhrozekJAORMX
authored
Replace litellm with native API implementations. (#1252)
* Replace `litellm` with native API implementations. Refactors client architecture to use native implementations instead of `litellm` dependency. Adds support for OpenAPI, Ollama, OpenRouter, and fixes multiple issues with Anthropic and Copilot providers. Improves message handling and streaming responses. Commit message brought you by Anthropic Claude 3.7. Co-Authored-By: Jakub Hrozek <[email protected]> * Handle API key for ollama servers (#1257) This was missed. Signed-off-by: Juan Antonio Osorio <[email protected]> * Ran `make format`. * Restricted scope of exception handling. This change aims to make it simpler to track down in which step of the pipeline a particulare exception occurred. * Linting/formatting. * Fix bandit. * fix integration tests * trying to fix llamacpp muxing * Final fix for llamacpp muxing. * Minor enhancement to integration test routine. --------- Signed-off-by: Juan Antonio Osorio <[email protected]> Co-authored-by: Jakub Hrozek <[email protected]> Co-authored-by: Juan Antonio Osorio <[email protected]>
1 parent d435949 commit 05b134a

File tree

101 files changed

+6809
-3239
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+6809
-3239
lines changed

prompts/default.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ pii_redacted: |
4646
The context files contain redacted personally identifiable information (PII) that is represented by a UUID encased within <>. For example:
4747
- <123e4567-e89b-12d3-a456-426614174000>
4848
- <2d040296-98e9-4350-84be-fda4336057eb>
49-
If you encounter any PII redacted with a UUID, DO NOT WARN the user about it. Simplt respond to the user request and keep the PII redacted and intact, using the same UUID.
49+
If you encounter any PII redacted with a UUID, DO NOT WARN the user about it. Simply respond to the user request and keep the PII redacted and intact, using the same UUID.
5050
# Security-focused prompts
5151
security_audit: "You are a security expert conducting a thorough code review. Identify potential security vulnerabilities, suggest improvements, and explain security best practices."
5252

@@ -56,6 +56,6 @@ red_team: "You are a red team member conducting a security assessment. Identify
5656
# BlueTeam prompts
5757
blue_team: "You are a blue team member conducting a security assessment. Identify security controls, misconfigurations, and potential vulnerabilities."
5858

59-
# Per client prompts
59+
# Per client prompts
6060
client_prompts:
6161
kodu: "If malicious packages or leaked secrets are found, please end the task, sending the problems found embedded in <attempt_completion><result> tags"

src/codegate/config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
# Default provider URLs
1818
DEFAULT_PROVIDER_URLS = {
19-
"openai": "https://api.openai.com/v1",
20-
"openrouter": "https://openrouter.ai/api/v1",
21-
"anthropic": "https://api.anthropic.com/v1",
19+
"openai": "https://api.openai.com",
20+
"openrouter": "https://openrouter.ai/api",
21+
"anthropic": "https://api.anthropic.com",
2222
"vllm": "http://localhost:8000", # Base URL without /v1 path
2323
"ollama": "http://localhost:11434", # Default Ollama server URL
2424
"lm_studio": "http://localhost:1234",

src/codegate/db/connection.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,17 @@ def does_db_exist(self):
123123
return self._db_path.is_file()
124124

125125

126+
def row_from_model(model: BaseModel) -> dict:
127+
return dict(
128+
id=model.id,
129+
timestamp=model.timestamp,
130+
provider=model.provider,
131+
request=model.request.json(exclude_defaults=True, exclude_unset=True),
132+
type=model.type,
133+
workspace_id=model.workspace_id,
134+
)
135+
136+
126137
class DbRecorder(DbCodeGate):
127138
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
128139
super().__init__(sqlite_path, *args, **kwargs)
@@ -133,7 +144,10 @@ async def _execute_update_pydantic_model(
133144
"""Execute an update or insert command for a Pydantic model."""
134145
try:
135146
async with self._async_db_engine.begin() as conn:
136-
result = await conn.execute(sql_command, model.model_dump())
147+
row = model
148+
if isinstance(model, BaseModel):
149+
row = model.model_dump()
150+
result = await conn.execute(sql_command, row)
137151
row = result.first()
138152
if row is None:
139153
return None
@@ -175,7 +189,8 @@ async def record_request(self, prompt_params: Optional[Prompt] = None) -> Option
175189
RETURNING *
176190
"""
177191
)
178-
recorded_request = await self._execute_update_pydantic_model(prompt_params, sql)
192+
row = row_from_model(prompt_params)
193+
recorded_request = await self._execute_update_pydantic_model(row, sql)
179194
# Uncomment to debug the recorded request
180195
# logger.debug(f"Recorded request: {recorded_request}")
181196
return recorded_request # type: ignore
@@ -194,7 +209,8 @@ async def update_request(
194209
RETURNING *
195210
"""
196211
)
197-
updated_request = await self._execute_update_pydantic_model(prompt_params, sql)
212+
row = row_from_model(prompt_params)
213+
updated_request = await self._execute_update_pydantic_model(row, sql)
198214
# Uncomment to debug the recorded request
199215
# logger.debug(f"Recorded request: {recorded_request}")
200216
return updated_request # type: ignore
@@ -217,7 +233,7 @@ async def record_outputs(
217233
output=first_output.output,
218234
)
219235
full_outputs = []
220-
# Just store the model respnses in the list of JSON objects.
236+
# Just store the model responses in the list of JSON objects.
221237
for output in outputs:
222238
full_outputs.append(output.output)
223239

@@ -341,7 +357,7 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
341357
f"Alerts: {len(context.alerts_raised)}."
342358
)
343359
except Exception as e:
344-
logger.error(f"Failed to record context: {context}.", error=str(e))
360+
logger.error(f"Failed to record context: {context}.", error=str(e), exc_info=e)
345361

346362
async def add_workspace(self, workspace_name: str) -> WorkspaceRow:
347363
"""Add a new workspace to the DB.

src/codegate/db/fim_cache.py

+12
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ def __init__(self):
3333

3434
def _extract_message_from_fim_request(self, request: str) -> Optional[str]:
3535
"""Extract the user message from the FIM request"""
36+
### NEW CODE PATH ###
37+
if not isinstance(request, str):
38+
content_message = None
39+
for message in request.get_messages():
40+
for content in message.get_content():
41+
if content_message is None:
42+
content_message = content.get_text()
43+
else:
44+
logger.warning("Expected one user message, found multiple.")
45+
return None
46+
return content_message
47+
3648
try:
3749
parsed_request = json.loads(request)
3850
except Exception as e:

src/codegate/extract_snippets/body_extractor.py

+29-47
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
KoduCodeSnippetExtractor,
1010
OpenInterpreterCodeSnippetExtractor,
1111
)
12+
from codegate.types.common import MessageTypeFilter
1213

1314

1415
class BodyCodeSnippetExtractorError(Exception):
@@ -32,25 +33,22 @@ def _extract_from_user_messages(self, data: dict) -> set[str]:
3233
raise BodyCodeSnippetExtractorError("Code Extractor not set.")
3334

3435
filenames: List[str] = []
35-
for msg in data.get("messages", []):
36-
if msg.get("role", "") == "user":
36+
for msg in data.get_messages(filters=[MessageTypeFilter.USER]):
37+
for content in msg.get_content():
3738
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
38-
msg.get("content")
39+
content.get_text(),
3940
)
4041
filenames.extend(extracted_snippets.keys())
4142
return set(filenames)
4243

4344
def _extract_from_list_user_messages(self, data: dict) -> set[str]:
4445
filenames: List[str] = []
45-
for msg in data.get("messages", []):
46-
if msg.get("role", "") == "user":
47-
msgs_content = msg.get("content", [])
48-
for msg_content in msgs_content:
49-
if msg_content.get("type", "") == "text":
50-
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
51-
msg_content.get("text")
52-
)
53-
filenames.extend(extracted_snippets.keys())
46+
for msg in data.get_messages(filters=[MessageTypeFilter.USER]):
47+
for content in msg.get_content():
48+
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
49+
content.get_text(),
50+
)
51+
filenames.extend(extracted_snippets.keys())
5452
return set(filenames)
5553

5654
@abstractmethod
@@ -93,43 +91,27 @@ class OpenInterpreterBodySnippetExtractor(BodyCodeSnippetExtractor):
9391
def __init__(self):
9492
self._snippet_extractor = OpenInterpreterCodeSnippetExtractor()
9593

96-
def _is_msg_tool_call(self, msg: dict) -> bool:
97-
return msg.get("role", "") == "assistant" and msg.get("tool_calls", [])
98-
99-
def _is_msg_tool_result(self, msg: dict) -> bool:
100-
return msg.get("role", "") == "tool" and msg.get("content", "")
101-
102-
def _extract_args_from_tool_call(self, msg: dict) -> str:
103-
"""
104-
Extract the arguments from the tool call message.
105-
"""
106-
tool_calls = msg.get("tool_calls", [])
107-
if not tool_calls:
108-
return ""
109-
return tool_calls[0].get("function", {}).get("arguments", "")
110-
111-
def _extract_result_from_tool_result(self, msg: dict) -> str:
112-
"""
113-
Extract the result from the tool result message.
114-
"""
115-
return msg.get("content", "")
116-
11794
def extract_unique_filenames(self, data: dict) -> set[str]:
118-
messages = data.get("messages", [])
119-
if not messages:
120-
return set()
121-
12295
filenames: List[str] = []
123-
for i_msg in range(len(messages) - 1):
124-
msg = messages[i_msg]
125-
next_msg = messages[i_msg + 1]
126-
if self._is_msg_tool_call(msg) and self._is_msg_tool_result(next_msg):
127-
tool_args = self._extract_args_from_tool_call(msg)
128-
tool_response = self._extract_result_from_tool_result(next_msg)
129-
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
130-
f"{tool_args}\n{tool_response}"
131-
)
132-
filenames.extend(extracted_snippets.keys())
96+
# Note: the previous version of this code used to analyze
97+
# tool-call and tool-results pairs to ensure that the regex
98+
# matched.
99+
#
100+
# Given it was not a business or functional requirement, but
101+
# rather an technical decision to avoid adding more regexes,
102+
# we decided to analysis contents on a per-message basis, to
103+
# avoid creating more dependency on the behaviour of the
104+
# coding assistant.
105+
#
106+
# We still filter only tool-calls and tool-results.
107+
filters = [MessageTypeFilter.ASSISTANT, MessageTypeFilter.TOOL]
108+
for msg in data.get_messages(filters=filters):
109+
for content in msg.get_content():
110+
if content.get_text() is not None:
111+
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
112+
f"{content.get_text()}\n\nbackwards compatibility"
113+
)
114+
filenames.extend(extracted_snippets.keys())
133115
return set(filenames)
134116

135117

src/codegate/extract_snippets/message_extractor.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,16 @@ def extract_snippets(self, message: str, require_filepath: bool = False) -> List
279279
"""
280280
regexes = self._choose_regex(require_filepath)
281281
# Find all code block matches
282+
if isinstance(message, str):
283+
return [
284+
self._get_snippet_for_match(match)
285+
for regex in regexes
286+
for match in regex.finditer(message)
287+
]
282288
return [
283289
self._get_snippet_for_match(match)
284290
for regex in regexes
285-
for match in regex.finditer(message)
291+
for match in regex.finditer(message.get_text())
286292
]
287293

288294
def extract_unique_snippets(self, message: str) -> Dict[str, CodeSnippet]:

src/codegate/llm_utils/__init__.py

-3
This file was deleted.

0 commit comments

Comments
 (0)