Skip to content

Commit a6fd02d

Browse files
authored
Merge pull request #409 from litui/feature/azure-deepseek-example
Feature: additional example provider for DeepSeek R1 API on Azure.
2 parents db29eb2 + 7a8fa1a commit a6fd02d

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import List, Union, Generator, Iterator
2+
from pydantic import BaseModel
3+
import requests
4+
import os
5+
6+
7+
class Pipeline:
8+
class Valves(BaseModel):
9+
# You can add your custom valves here.
10+
AZURE_DEEPSEEKR1_API_KEY: str
11+
AZURE_DEEPSEEKR1_ENDPOINT: str
12+
AZURE_DEEPSEEKR1_API_VERSION: str
13+
14+
def __init__(self):
15+
self.type = "manifold"
16+
self.name = "Azure "
17+
self.valves = self.Valves(
18+
**{
19+
"AZURE_DEEPSEEKR1_API_KEY": os.getenv("AZURE_DEEPSEEKR1_API_KEY", "your-azure-deepseek-r1-api-key-here"),
20+
"AZURE_DEEPSEEKR1_ENDPOINT": os.getenv("AZURE_DEEPSEEKR1_ENDPOINT", "your-azure-deepseek-r1-endpoint-here"),
21+
"AZURE_DEEPSEEKR1_API_VERSION": os.getenv("AZURE_DEEPSEEKR1_API_VERSION", "2024-05-01-preview"),
22+
}
23+
)
24+
self.set_pipelines()
25+
pass
26+
27+
def set_pipelines(self):
28+
models = ['DeepSeek-R1']
29+
model_names = ['DeepSeek-R1']
30+
self.pipelines = [
31+
{"id": model, "name": name} for model, name in zip(models, model_names)
32+
]
33+
print(f"azure_deepseek_r1_pipeline - models: {self.pipelines}")
34+
pass
35+
36+
async def on_valves_updated(self):
37+
self.set_pipelines()
38+
39+
async def on_startup(self):
40+
# This function is called when the server is started.
41+
print(f"on_startup:{__name__}")
42+
pass
43+
44+
async def on_shutdown(self):
45+
# This function is called when the server is stopped.
46+
print(f"on_shutdown:{__name__}")
47+
pass
48+
49+
def pipe(
50+
self, user_message: str, model_id: str, messages: List[dict], body: dict
51+
) -> Union[str, Generator, Iterator]:
52+
# This is where you can add your custom pipelines like RAG.
53+
print(f"pipe:{__name__}")
54+
55+
print(messages)
56+
print(user_message)
57+
58+
headers = {
59+
"api-key": self.valves.AZURE_DEEPSEEKR1_API_KEY,
60+
"Content-Type": "application/json",
61+
}
62+
63+
url = f"{self.valves.AZURE_DEEPSEEKR1_ENDPOINT}/models/chat/completions?api-version={self.valves.AZURE_DEEPSEEKR1_API_VERSION}"
64+
65+
print(body)
66+
67+
allowed_params = {'messages', 'temperature', 'role', 'content', 'contentPart', 'contentPartImage',
68+
'enhancements', 'dataSources', 'n', 'stream', 'stop', 'max_tokens', 'presence_penalty',
69+
'frequency_penalty', 'logit_bias', 'user', 'function_call', 'funcions', 'tools',
70+
'tool_choice', 'top_p', 'log_probs', 'top_logprobs', 'response_format', 'seed', 'model'}
71+
# remap user field
72+
if "user" in body and not isinstance(body["user"], str):
73+
body["user"] = body["user"]["id"] if "id" in body["user"] else str(body["user"])
74+
# Fill in model field as per Azure's api requirements
75+
body["model"] = model_id
76+
filtered_body = {k: v for k, v in body.items() if k in allowed_params}
77+
# log fields that were filtered out as a single line
78+
if len(body) != len(filtered_body):
79+
print(f"Dropped params: {', '.join(set(body.keys()) - set(filtered_body.keys()))}")
80+
81+
try:
82+
r = requests.post(
83+
url=url,
84+
json=filtered_body,
85+
headers=headers,
86+
stream=True,
87+
)
88+
89+
r.raise_for_status()
90+
if body["stream"]:
91+
return r.iter_lines()
92+
else:
93+
return r.json()
94+
except Exception as e:
95+
if r:
96+
text = r.text
97+
return f"Error: {e} ({text})"
98+
else:
99+
return f"Error: {e}"

0 commit comments

Comments
 (0)