Skip to content

Commit 6219ffb

Browse files
committed
fixed gpt, added ollama
1 parent 14ec1cb commit 6219ffb

File tree

3 files changed

+37
-49
lines changed

3 files changed

+37
-49
lines changed

Diff for: .env.example

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#ANTHROPIC_API_KEY=sk-1234
66
ANTHROPIC_BASE_URL=https://api.anthropic.com
7-
ANTHROPIC_MODEL=claude-3-5-sonnet-20240620
7+
ANTHROPIC_MODEL=claude-3-5-sonnet-latest
88

99

1010
# For usage with OpenAI (-l gpt)
@@ -13,7 +13,7 @@ ANTHROPIC_MODEL=claude-3-5-sonnet-20240620
1313

1414
#OPENAI_API_KEY=sk-1234
1515
OPENAI_BASE_URL=https://api.openai.com/v1
16-
OPENAI_MODEL=gpt-4o-2024-08-06
16+
OPENAI_MODEL=chatgpt-4o-latest
1717

1818

1919
# For usage with Ollama (-l ollama)

Diff for: vulnhuntr/LLMs.py

+27-43
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _validate_response(self, response_text: str, response_model: BaseModel) -> B
4242
response_text = self.prefill + response_text
4343
return response_model.model_validate_json(response_text)
4444
except ValidationError as e:
45-
log.warning("Response validation failed", exc_info=e)
45+
log.warning("[-] Response validation failed\n", exc_info=e)
4646
raise LLMError("Validation failed") from e
4747
# try:
4848
# response_clean_attempt = response_text.split('{', 1)[1]
@@ -75,10 +75,10 @@ def chat(self, user_prompt: str, response_model: BaseModel = None, max_tokens: i
7575
return response_text
7676

7777
class Claude(LLM):
78-
def __init__(self, model: str, system_prompt: str = "") -> None:
78+
def __init__(self, model: str, base_url: str, system_prompt: str = "") -> None:
7979
super().__init__(system_prompt)
8080
# API key is retrieved from an environment variable by default
81-
self.client = anthropic.Anthropic(max_retries=3, base_url=os.getenv("ANTHROPIC_BASE_URL", "https://api.anthropic.com"))
81+
self.client = anthropic.Anthropic(max_retries=3, base_url=base_url)
8282
self.model = model
8383

8484
def create_messages(self, user_prompt: str) -> List[Dict[str, str]]:
@@ -111,32 +111,31 @@ def get_response(self, response: Dict[str, Any]) -> str:
111111

112112

113113
class ChatGPT(LLM):
114-
def __init__(self, model: str, system_prompt: str = "") -> None:
114+
def __init__(self, model: str, base_url: str, system_prompt: str = "") -> None:
115115
super().__init__(system_prompt)
116-
# Retrieves API key and API Endpoint if specified from an environment variable
117-
self.client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL", f"https://api.openai.com/v1"))
116+
self.client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=base_url)
118117
self.model = model
119118

120119
def create_messages(self, user_prompt: str) -> List[Dict[str, str]]:
121-
messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}]
120+
messages = [{"role": "system", "content": self.system_prompt},
121+
{"role": "user", "content": user_prompt}]
122122
return messages
123123

124-
def send_message(self, messages: List[Dict[str, str]], max_tokens: int, response_model) -> Dict[str, Any]:
124+
def send_message(self, messages: List[Dict[str, str]], max_tokens: int, response_model=None) -> Dict[str, Any]:
125125
try:
126-
# For analyzing files and context code, use the beta endpoint and parse so we can feed it the pydantic model
126+
params = {
127+
"model": self.model,
128+
"messages": messages,
129+
"max_tokens": max_tokens,
130+
}
131+
132+
# Add response format configuration if a model is provided
127133
if response_model:
128-
return self.client.beta.chat.completions.parse(
129-
model=self.model,
130-
messages=messages,
131-
max_tokens=max_tokens,
132-
response_format=response_model
133-
)
134-
else:
135-
return self.client.chat.completions.create(
136-
model=self.model,
137-
messages=messages,
138-
max_tokens=max_tokens,
139-
)
134+
params["response_format"] = {
135+
"type": "json_object"
136+
}
137+
138+
return self.client.chat.completions.create(**params)
140139
except openai.APIConnectionError as e:
141140
raise APIConnectionError("The server could not be reached") from e
142141
except openai.RateLimitError as e:
@@ -146,37 +145,24 @@ def send_message(self, messages: List[Dict[str, str]], max_tokens: int, response
146145
except Exception as e:
147146
raise LLMError(f"An unexpected error occurred: {str(e)}") from e
148147

149-
def _clean_response(self, response: str) -> str:
150-
# Step 1: Remove markdown code block wrappers
151-
cleaned_text = response.strip('```json\n').strip('```')
152-
# Step 2: Correctly handle newlines and escaped characters
153-
cleaned_text = cleaned_text.replace('\n', '').replace('\\\'', '\'')
154-
# Step 3: Replace escaped double quotes with regular double quotes
155-
cleaned_text = cleaned_text.replace('\\"', '"')
156-
157-
return cleaned_text.replace('\n', '')
158-
159148
def get_response(self, response: Dict[str, Any]) -> str:
160149
response = response.choices[0].message.content
161-
cleaned_response = self._clean_response(response)
162-
return cleaned_response
150+
return response
163151

164152

165153
class Ollama(LLM):
166-
def __init__(self, model: str, system_prompt: str = "") -> None:
154+
def __init__(self, model: str, base_url: str, system_prompt: str = "") -> None:
167155
super().__init__(system_prompt)
168-
self.api_url = "http://localhost:11434/api/chat"
156+
self.api_url = base_url
169157
self.model = model
170158

171-
def create_messages(self, user_prompt: str) -> List[Dict[str, str]]:
172-
messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}]
173-
return messages
159+
def create_messages(self, user_prompt: str) -> str:
160+
return user_prompt
174161

175162
def send_message(self, user_prompt: str, max_tokens: int, response_model: BaseModel) -> Dict[str, Any]:
176163
payload = {
177164
"model": self.model,
178-
#"messages": messages,
179-
"messages": "hello",
165+
"prompt": user_prompt,
180166
"options": {
181167
"temperature": 1,
182168
"system": self.system_prompt,
@@ -195,12 +181,10 @@ def send_message(self, user_prompt: str, max_tokens: int, response_model: BaseMo
195181
else:
196182
raise APIStatusError(e.response.status_code, e.response.json()) from e
197183

198-
199184
def get_response(self, response: Dict[str, Any]) -> str:
200-
response = response.json()['message']['content']
185+
response = response.json()['response']
201186
return response
202187

203-
204188
def _log_response(self, response: Dict[str, Any]) -> None:
205189
log.debug("Received chat response", extra={"usage": "Ollama"})
206190

Diff for: vulnhuntr/__main__.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -282,17 +282,21 @@ def extract_between_tags(tag: str, string: str, strip: bool = False) -> list[str
282282
return ext_list
283283

284284
def initialize_llm(llm_arg: str, system_prompt: str = "") -> Claude | ChatGPT | Ollama:
285+
llm_arg = llm_arg.lower()
285286
if llm_arg == 'claude':
286287
anth_model = os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20240620")
287-
llm = Claude(anth_model, system_prompt)
288+
anth_base_url = os.getenv("ANTHROPIC_BASE_URL", "https://api.anthropic.com")
289+
llm = Claude(anth_model, anth_base_url, system_prompt)
288290
elif llm_arg == 'gpt':
289291
openai_model = os.getenv("OPENAI_MODEL", "gpt-4o-2024-08-06")
290-
llm = ChatGPT(openai_model, system_prompt)
292+
openai_base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
293+
llm = ChatGPT(openai_model, openai_base_url, system_prompt)
291294
elif llm_arg == 'ollama':
292295
ollama_model = os.getenv("OLLAMA_MODEL", "llama3")
293-
llm = Ollama(ollama_model, system_prompt)
296+
ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://127.0.0.1:11434/api/generate")
297+
llm = Ollama(ollama_model, ollama_base_url, system_prompt)
294298
else:
295-
raise ValueError(f"Invalid LLM argument: {llm_arg}")
299+
raise ValueError(f"Invalid LLM argument: {llm_arg}\nValid options are: claude, gpt, ollama")
296300
return llm
297301

298302
def print_readable(report: Response) -> None:

0 commit comments

Comments
 (0)