@@ -42,7 +42,7 @@ def _validate_response(self, response_text: str, response_model: BaseModel) -> B
42
42
response_text = self .prefill + response_text
43
43
return response_model .model_validate_json (response_text )
44
44
except ValidationError as e :
45
- log .warning ("Response validation failed" , exc_info = e )
45
+ log .warning ("[-] Response validation failed\n " , exc_info = e )
46
46
raise LLMError ("Validation failed" ) from e
47
47
# try:
48
48
# response_clean_attempt = response_text.split('{', 1)[1]
@@ -75,10 +75,10 @@ def chat(self, user_prompt: str, response_model: BaseModel = None, max_tokens: i
75
75
return response_text
76
76
77
77
class Claude (LLM ):
78
- def __init__ (self , model : str , system_prompt : str = "" ) -> None :
78
+ def __init__ (self , model : str , base_url : str , system_prompt : str = "" ) -> None :
79
79
super ().__init__ (system_prompt )
80
80
# API key is retrieved from an environment variable by default
81
- self .client = anthropic .Anthropic (max_retries = 3 , base_url = os . getenv ( "ANTHROPIC_BASE_URL" , "https://api.anthropic.com" ) )
81
+ self .client = anthropic .Anthropic (max_retries = 3 , base_url = base_url )
82
82
self .model = model
83
83
84
84
def create_messages (self , user_prompt : str ) -> List [Dict [str , str ]]:
@@ -111,32 +111,31 @@ def get_response(self, response: Dict[str, Any]) -> str:
111
111
112
112
113
113
class ChatGPT (LLM ):
114
- def __init__ (self , model : str , system_prompt : str = "" ) -> None :
114
+ def __init__ (self , model : str , base_url : str , system_prompt : str = "" ) -> None :
115
115
super ().__init__ (system_prompt )
116
- # Retrieves API key and API Endpoint if specified from an environment variable
117
- self .client = openai .OpenAI (api_key = os .getenv ("OPENAI_API_KEY" ), base_url = os .getenv ("OPENAI_BASE_URL" , f"https://api.openai.com/v1" ))
116
+ self .client = openai .OpenAI (api_key = os .getenv ("OPENAI_API_KEY" ), base_url = base_url )
118
117
self .model = model
119
118
120
119
def create_messages (self , user_prompt : str ) -> List [Dict [str , str ]]:
121
- messages = [{"role" : "system" , "content" : self .system_prompt }, {"role" : "user" , "content" : user_prompt }]
120
+ messages = [{"role" : "system" , "content" : self .system_prompt },
121
+ {"role" : "user" , "content" : user_prompt }]
122
122
return messages
123
123
124
- def send_message (self , messages : List [Dict [str , str ]], max_tokens : int , response_model ) -> Dict [str , Any ]:
124
+ def send_message (self , messages : List [Dict [str , str ]], max_tokens : int , response_model = None ) -> Dict [str , Any ]:
125
125
try :
126
- # For analyzing files and context code, use the beta endpoint and parse so we can feed it the pydantic model
126
+ params = {
127
+ "model" : self .model ,
128
+ "messages" : messages ,
129
+ "max_tokens" : max_tokens ,
130
+ }
131
+
132
+ # Add response format configuration if a model is provided
127
133
if response_model :
128
- return self .client .beta .chat .completions .parse (
129
- model = self .model ,
130
- messages = messages ,
131
- max_tokens = max_tokens ,
132
- response_format = response_model
133
- )
134
- else :
135
- return self .client .chat .completions .create (
136
- model = self .model ,
137
- messages = messages ,
138
- max_tokens = max_tokens ,
139
- )
134
+ params ["response_format" ] = {
135
+ "type" : "json_object"
136
+ }
137
+
138
+ return self .client .chat .completions .create (** params )
140
139
except openai .APIConnectionError as e :
141
140
raise APIConnectionError ("The server could not be reached" ) from e
142
141
except openai .RateLimitError as e :
@@ -146,37 +145,24 @@ def send_message(self, messages: List[Dict[str, str]], max_tokens: int, response
146
145
except Exception as e :
147
146
raise LLMError (f"An unexpected error occurred: { str (e )} " ) from e
148
147
149
- def _clean_response (self , response : str ) -> str :
150
- # Step 1: Remove markdown code block wrappers
151
- cleaned_text = response .strip ('```json\n ' ).strip ('```' )
152
- # Step 2: Correctly handle newlines and escaped characters
153
- cleaned_text = cleaned_text .replace ('\n ' , '' ).replace ('\\ \' ' , '\' ' )
154
- # Step 3: Replace escaped double quotes with regular double quotes
155
- cleaned_text = cleaned_text .replace ('\\ "' , '"' )
156
-
157
- return cleaned_text .replace ('\n ' , '' )
158
-
159
148
def get_response (self , response : Dict [str , Any ]) -> str :
160
149
response = response .choices [0 ].message .content
161
- cleaned_response = self ._clean_response (response )
162
- return cleaned_response
150
+ return response
163
151
164
152
165
153
class Ollama (LLM ):
166
- def __init__ (self , model : str , system_prompt : str = "" ) -> None :
154
+ def __init__ (self , model : str , base_url : str , system_prompt : str = "" ) -> None :
167
155
super ().__init__ (system_prompt )
168
- self .api_url = "http://localhost:11434/api/chat"
156
+ self .api_url = base_url
169
157
self .model = model
170
158
171
- def create_messages (self , user_prompt : str ) -> List [Dict [str , str ]]:
172
- messages = [{"role" : "system" , "content" : self .system_prompt }, {"role" : "user" , "content" : user_prompt }]
173
- return messages
159
+ def create_messages (self , user_prompt : str ) -> str :
160
+ return user_prompt
174
161
175
162
def send_message (self , user_prompt : str , max_tokens : int , response_model : BaseModel ) -> Dict [str , Any ]:
176
163
payload = {
177
164
"model" : self .model ,
178
- #"messages": messages,
179
- "messages" : "hello" ,
165
+ "prompt" : user_prompt ,
180
166
"options" : {
181
167
"temperature" : 1 ,
182
168
"system" : self .system_prompt ,
@@ -195,12 +181,10 @@ def send_message(self, user_prompt: str, max_tokens: int, response_model: BaseMo
195
181
else :
196
182
raise APIStatusError (e .response .status_code , e .response .json ()) from e
197
183
198
-
199
184
def get_response (self , response : Dict [str , Any ]) -> str :
200
- response = response .json ()['message' ][ 'content ' ]
185
+ response = response .json ()['response ' ]
201
186
return response
202
187
203
-
204
188
def _log_response (self , response : Dict [str , Any ]) -> None :
205
189
log .debug ("Received chat response" , extra = {"usage" : "Ollama" })
206
190
0 commit comments