Skip to content

Commit ffd6f6c

Browse files
committedJun 29, 2024
Initial version
1 parent d3e95bc commit ffd6f6c

8 files changed

+882
-0
lines changed
 

‎.dockerignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.vscode/
2+
__pycache__/

‎.vscode/settings.json

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"python.analysis.typeCheckingMode": "basic"
3+
}

‎Dockerfile

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# The Dockerfile is used to construct an image that can be directly used
2+
# to run the OpenAI compatible Triton Inference Server server.
3+
4+
# prepare basic build environment
5+
# https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
6+
FROM nvcr.io/nvidia/tritonserver:24.05-trtllm-python-py3 AS build
7+
WORKDIR /opt/tritonserver/openai
8+
9+
# To build TensorRT-LLM engines, see https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama for one such example
10+
# Use `python3 -c "import tensorrt_llm"` to find TensorRT-LLM version used by Triton Inference Server
11+
ARG TENSORRT_LLM_VERSION="v0.9.0"
12+
ARG VENV="pyinstaller"
13+
14+
# install build and runtime dependencies
15+
# pyinstaller bundles `app.py` into a single fat executable included with all necessary dependencies
16+
COPY *.py .
17+
COPY requirements.txt requirements.txt
18+
19+
ENV DEBIAN_FRONTEND=noninteractive
20+
RUN apt-get update -y \
21+
&& apt-get install -y python3.10-venv \
22+
&& python3 -m venv ${VENV} \
23+
&& ${VENV}/bin/python3 -m pip install --upgrade --requirement requirements.txt \
24+
&& ${VENV}/bin/python3 -m pip install --upgrade pyinstaller \
25+
&& ${VENV}/bin/pyinstaller --onefile --paths=. --clean app.py \
26+
&& git clone --depth 1 --branch ${TENSORRT_LLM_VERSION} https://github.com/NVIDIA/TensorRT-LLM.git /opt/tritonserver/third-party-src/TensorRT-LLM \
27+
&& git clone --depth 1 --branch ${TENSORRT_LLM_VERSION} https://github.com/triton-inference-server/tensorrtllm_backend.git /opt/tritonserver/third-party-src/tensorrtllm_backend
28+
29+
FROM nvcr.io/nvidia/tritonserver:24.05-trtllm-python-py3
30+
COPY --from=build --chown=triton-server:triton-server /opt/tritonserver/openai/dist/app /opt/tritonserver/bin/tritonopenaiserver
31+
COPY --from=build --chown=triton-server:triton-server /opt/tritonserver/third-party-src/ /opt/tritonserver/third-party-src/
32+
33+
EXPOSE 11434/tcp

‎README.md

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Triton Inference Server OpenAI compatible API proxy
2+
[This project](https://github.com/visitsb/triton-inference-server-openai-api) provides an OpenAI API compatible proxy for NVIDIA [Triton Inference Server](https://www.nvidia.com/en-us/ai-data-science/products/triton-inference-server/). More specifically, LLMs on NVIDIA GPUs can benefit from high performance inference with [TensorRT-LLM](https://developer.nvidia.com/tensorrt#inference) backend running on [Triton Inference Server compared to using llama.cpp](https://jan.ai/post/benchmarking-nvidia-tensorrt-llm#key-findings).
3+
4+
Triton Inference Server supports [HTTP/REST and GRPC inference protocols](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/inference_protocols.md) based on the community developed [KServe protocol](https://github.com/kserve/kserve/tree/master/docs/predict-api/v2), but that is not useable with existing OpenAI API clients.
5+
6+
This proxy bridges that gap and it currently API supports **text** generation [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) endpoints only which are suitable for use with [Open WebUI](https://docs.openwebui.com/) or similar OpenAI clients-
7+
```text
8+
GET|POST /v1/models (or /models)
9+
GET /v1/models/{model} (or /models/{model})
10+
POST /v1/chat/completions (or /v1/completions) streaming supported
11+
```
12+
13+
## Usage
14+
**Recommended** Use a pre-published [Docker image](https://hub.docker.com/repository/docker/visitsb/tritonserver)
15+
```bash
16+
docker image pull visitsb/tritonserver:24.05-trtllm-python-py3
17+
```
18+
19+
Alternatively, use the `Dockerfile` to build a local image. The proxy is built on top of existing [Triton Inference Server](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver) docker image which precludes the TensorRT-LLM backend.
20+
21+
```bash
22+
# Pull upstream NVIDIA docker image
23+
docker image pull nvcr.io/nvidia/tritonserver:24.05-trtllm-python-py3
24+
# Clone this repository
25+
git clone <this repository>
26+
cd triton-inference-server-openai-api
27+
# Build your custom docker image with proxy bundled
28+
docker buildx build --no-cache --tag myimages/tritonserver:24.05-trtllm-python-py3 .
29+
```
30+
31+
Once your image is pulled (or built locally) you can run it directly using Docker-
32+
```bash
33+
# Run Triton Inference Server alongwith proxy as shoen in `sh -c` command
34+
docker run --rm --tty --interactive \
35+
--gpus all --shm-size 4g --memory 32g \
36+
--cpuset-cpus 0-3 --publish 11434:11434/tcp \
37+
--volume <your Triton models folder>:/models:rw \
38+
--name triton \
39+
visitsb/tritonserver:24.05-trtllm-python-py3 \
40+
sh -c '/opt/tritonserver/bin/tritonserver \
41+
--model-store /models/mymodel/model \
42+
& /opt/tritonserver/bin/tritonopenaiserver \
43+
--tokenizer_dir /models/mymodel/tokenizer'
44+
```
45+
46+
Alternatively using `docker-compose.yml`-
47+
```yaml
48+
triton:
49+
image: visitsb/tritonserver:24.05-trtllm-python-py3
50+
command: >
51+
sh -c '/opt/tritonserver/bin/tritonserver --model-store /models/mymodel/model & /opt/tritonserver/bin/tritonopenaiserver --tokenizer_dir /models/mymodel/tokenizer'
52+
ports:
53+
- "11434:11434/tcp" # OpenAI API Proxy
54+
- "8000:8000/tcp" # HTTP
55+
- "8001:8001/tcp" # GRPC
56+
- "8080:8080/tcp" # Sagemaker, Vertex
57+
- "8002:8002/tcp" # Prometheus metrics
58+
volumes:
59+
- <your Triton models folder>:/models:rw
60+
shm_size: "4G"
61+
deploy:
62+
resources:
63+
limits:
64+
memory: 32G
65+
reservations:
66+
memory: 8G
67+
devices:
68+
- driver: nvidia
69+
count: all
70+
capabilities: [compute,video,utility]
71+
ulimits:
72+
stack: 67108864
73+
memlock:
74+
soft: -1
75+
hard: -1
76+
```
77+
78+
## Performance
79+
Using [GenAI-Perf](https://github.com/triton-inference-server/client/tree/main/src/c%2B%2B/perf_analyzer/genai-perf) to measure performance for [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) on a [NVIDIA RTX 4090 GPU](https://www.nvidia.com/en-us/geforce/graphics-cards/40-series/rtx-4090/) the following was observed-
80+
81+
Test: [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) evaluated using NVIDIA [GenAI-Perf](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/client/src/c%2B%2B/perf_analyzer/genai-perf/docs/tutorial.html#openai-chat-completions-api). For llama.cpp evaluation [QuantFactory/Meta-Llama-3-8B-GGUF](https://huggingface.co/QuantFactory/Meta-Llama-3-8B-GGUF) - `Meta-Llama-3-8B.Q8_0.gguf` was used.
82+
83+
```text
84+
Backend Loaded model size GPU Util Tokens/sec
85+
------- ----------------- -------- ----------
86+
TensorRT (gRPC) 15879MiB / 24564MiB 91% 97.04
87+
TensorRT (HTTP) 15879MiB / 24564MiB 91% 56.73
88+
llama.cpp 9491MiB / 24564MiB 74% 70.23
89+
```
90+
91+
In summary, TensorRT (gRPC) inference is better than llama.cpp, but using TensorRT (HTTP) gave similar performance to llama.cpp.
92+
93+
The raw performance numbers are as below-
94+
#### TensorRT (gRPC)
95+
```text
96+
[INFO] genai_perf.wrapper:135 - Running Perf Analyzer : 'perf_analyzer -m llama3 --async --service-kind triton -u triton:8001 --measurement-interval 4000 --stability-percentage 999 -i grpc --streaming --shape max_tokens:1 --shape text_input:1 --concurrency-range 1'
97+
LLM Metrics
98+
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
99+
┃ Statistic ┃ avg ┃ min ┃ max ┃ p99 ┃ p90 ┃ p75 ┃
100+
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
101+
│ Request latency (ns) │ 1,081… │ 1,048… │ 1,311,… │ 1,284… │ 1,083,… │ 1,064… │
102+
│ Num output token │ 105 │ 100 │ 110 │ 110 │ 109 │ 107 │
103+
│ Num input token │ 200 │ 200 │ 200 │ 200 │ 200 │ 200 │
104+
└──────────────────────┴────────┴────────┴─────────┴────────┴─────────┴────────┘
105+
Output token throughput (per sec): 97.04
106+
Request throughput (per sec): 0.92
107+
```
108+
109+
#### TensorRT (HTTP) via this OpenAI API Proxy
110+
```text
111+
[INFO] genai_perf.wrapper:135 - Running Perf Analyzer : 'perf_analyzer -m llama3 --async --endpoint v1/chat/completions --service-kind openai -u triton:11434 --measurement-interval 4000 --stability-percentage 999 -i http --concurrency-range 1'
112+
LLM Metrics
113+
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
114+
┃ Statistic ┃ avg ┃ min ┃ max ┃ p99 ┃ p90 ┃ p75 ┃
115+
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
116+
│ Request latency (ns) │ 2,033… │ 1,732… │ 3,856,… │ 3,723… │ 2,525,… │ 1,802… │
117+
│ Num output token │ 115 │ 110 │ 121 │ 121 │ 120 │ 119 │
118+
│ Num input token │ 200 │ 200 │ 200 │ 200 │ 200 │ 200 │
119+
└──────────────────────┴────────┴────────┴─────────┴────────┴─────────┴────────┘
120+
Output token throughput (per sec): 56.73
121+
Request throughput (per sec): 0.49
122+
```
123+
124+
#### llama.cpp
125+
```text
126+
[INFO] genai_perf.wrapper:135 - Running Perf Analyzer : 'perf_analyzer -m llama3 --async --endpoint v1/chat/completions --service-kind openai -u llama:11434 --measurement-interval 4000 --stability-percentage 999 -i http --concurrency-range 1'
127+
LLM Metrics
128+
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
129+
┃ Statistic ┃ avg ┃ min ┃ max ┃ p99 ┃ p90 ┃ p75 ┃
130+
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
131+
│ Request latency (ns) │ 1,656… │ 1,596… │ 1,822,… │ 1,810… │ 1,701,… │ 1,649… │
132+
│ Num output token │ 116 │ 104 │ 149 │ 147 │ 132 │ 118 │
133+
│ Num input token │ 200 │ 200 │ 200 │ 200 │ 200 │ 200 │
134+
└──────────────────────┴────────┴────────┴─────────┴────────┴─────────┴────────┘
135+
Output token throughput (per sec): 70.23
136+
Request throughput (per sec): 0.60
137+
```
138+
139+
**Note** This proxy uses TensorRT (HTTP) currently, so above performance numbers should be considered relative. Performance will vary for TensorRT-LLM models based on [build and deployment options](https://github.com/triton-inference-server/tensorrtllm_backend?tab=readme-ov-file#using-the-tensorrt-llm-backend) used.
140+
141+
Additional optimizations like speculative sampling and FP8 quantization can further improve throughput. For more on the throughput levels that are possible with TensorRT-LLM for different combinations of model, hardware, and workload, see the [official benchmarks](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/performance/perf-overview.md).
142+
143+
## Build and deploy your own models
144+
The image includes [TensorRT-LLM toolbox](https://github.com/NVIDIA/TensorRT-LLM.git) and [backend](https://github.com/triton-inference-server/tensorrtllm_backend.git) for building your own TensorRT-LLM models. Both can be found under `/opt/tritonserver/third-party-src/` inside your Docker image.
145+
146+
The basic steps to build a TensorRT model are outlined [here](https://github.com/triton-inference-server/tensorrtllm_backend?tab=readme-ov-file#using-the-tensorrt-llm-backend) which essentially involves
147+
1. Downloading a [Hugging Face model](https://huggingface.co/models) of your choice,
148+
2. Converting it to a TensorRT format, and
149+
3. Lastly building a compiled model that can be deployed on Triton Inference Server.
150+
151+
Additionally, you can also use the steps mentioned [here](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html#retrieve-the-model-weights) to build your TensorRT model. Once your model is built, you can [deploy](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html#deploy-with-triton-inference-server) and use it through the OpenAI API proxy.
152+
153+
## Further references-
154+
- [Benchmarking NVIDIA TensorRT-LLM](https://jan.ai/post/benchmarking-nvidia-tensorrt-llm) - TensorRT-LLM was 30-70% faster than [llama.cpp](https://github.com/ggerganov/llama.cpp) on the same hardware, consumes less memory on consecutive runs with marginally more GPU VRAM utilization than llama.cpp and models are 20%+ smaller compiled model sizes than llama.cpp.
155+
- [Use Llama 3 with NVIDIA TensorRT-LLM and Triton Inference Server](https://docs.lxp.lu/howto/llama3-triton/) - 30-minute tutorial to show how to use TensorRT-LLM to build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs using Llama3 model as an example.
156+
- Similar guide can be on [Serverless TensorRT-LLM (LLaMA 3 8B)](https://modal.com/docs/examples/trtllm_llama) - how to use the TensorRT-LLM framework to serve Meta’s LLaMA 3 8B model at a total throughput of roughly 4,500 output tokens per second on a single NVIDIA A100 40GB GPU.

‎__init__.py

Whitespace-only changes.

‎app.py

+471
Large diffs are not rendered by default.

‎protocol.py

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Adapted from
2+
# https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py
3+
# +StreamOptions
4+
# ChatCompletionRequest stream_options: Optional[StreamOptions] = None
5+
# ChatCompletionStreamResponse usage: Optional[UsageInfo]
6+
from typing import Literal, Optional, List, Dict, Any, Union
7+
8+
import time
9+
10+
import shortuuid
11+
from pydantic import BaseModel, Field
12+
13+
14+
class ErrorResponse(BaseModel):
15+
object: str = "error"
16+
message: str
17+
code: int
18+
19+
20+
class ModelPermission(BaseModel):
21+
id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}")
22+
object: str = "model_permission"
23+
created: int = Field(default_factory=lambda: int(time.time()))
24+
allow_create_engine: bool = False
25+
allow_sampling: bool = True
26+
allow_logprobs: bool = True
27+
allow_search_indices: bool = True
28+
allow_view: bool = True
29+
allow_fine_tuning: bool = False
30+
organization: str = "*"
31+
group: Optional[str] = None
32+
is_blocking: str = False
33+
34+
35+
class ModelCard(BaseModel):
36+
id: str
37+
object: str = "model"
38+
created: int = Field(default_factory=lambda: int(time.time()))
39+
owned_by: str = "fastchat"
40+
root: Optional[str] = None
41+
parent: Optional[str] = None
42+
permission: List[ModelPermission] = []
43+
44+
45+
class ModelList(BaseModel):
46+
object: str = "list"
47+
data: List[ModelCard] = []
48+
49+
50+
class UsageInfo(BaseModel):
51+
prompt_tokens: int = 0
52+
total_tokens: int = 0
53+
completion_tokens: Optional[int] = 0
54+
55+
56+
class LogProbs(BaseModel):
57+
text_offset: List[int] = Field(default_factory=list)
58+
token_logprobs: List[Optional[float]] = Field(default_factory=list)
59+
tokens: List[str] = Field(default_factory=list)
60+
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
61+
62+
63+
class StreamOptions(BaseModel):
64+
include_usage: Optional[bool] = False
65+
66+
67+
class ChatCompletionRequest(BaseModel):
68+
model: str
69+
messages: Union[
70+
str,
71+
List[Dict[str, str]],
72+
List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]],
73+
]
74+
temperature: Optional[float] = 0.7
75+
top_p: Optional[float] = 1.0
76+
top_k: Optional[int] = -1
77+
n: Optional[int] = 1
78+
max_tokens: Optional[int] = None
79+
stop: Optional[Union[str, List[str]]] = None
80+
stream: Optional[bool] = False
81+
presence_penalty: Optional[float] = 0.0
82+
frequency_penalty: Optional[float] = 0.0
83+
user: Optional[str] = None
84+
stream_options: Optional[StreamOptions] = None
85+
86+
87+
class ChatMessage(BaseModel):
88+
role: str
89+
content: str
90+
91+
92+
class ChatCompletionResponseChoice(BaseModel):
93+
index: int
94+
message: ChatMessage
95+
finish_reason: Optional[Literal["stop", "length"]] = None
96+
97+
98+
class ChatCompletionResponse(BaseModel):
99+
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
100+
object: str = "chat.completion"
101+
created: int = Field(default_factory=lambda: int(time.time()))
102+
model: str
103+
choices: List[ChatCompletionResponseChoice]
104+
usage: UsageInfo
105+
106+
107+
class DeltaMessage(BaseModel):
108+
role: Optional[str] = None
109+
content: Optional[str] = None
110+
111+
112+
class ChatCompletionResponseStreamChoice(BaseModel):
113+
index: int
114+
delta: DeltaMessage
115+
finish_reason: Optional[Literal["stop", "length"]] = None
116+
117+
118+
class ChatCompletionStreamResponse(BaseModel):
119+
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
120+
object: str = "chat.completion.chunk"
121+
created: int = Field(default_factory=lambda: int(time.time()))
122+
model: str
123+
choices: List[ChatCompletionResponseStreamChoice]
124+
usage: Optional[UsageInfo] = None
125+
126+
127+
class TokenCheckRequestItem(BaseModel):
128+
model: str
129+
prompt: str
130+
max_tokens: int
131+
132+
133+
class TokenCheckRequest(BaseModel):
134+
prompts: List[TokenCheckRequestItem]
135+
136+
137+
class TokenCheckResponseItem(BaseModel):
138+
fits: bool
139+
tokenCount: int
140+
contextLength: int
141+
142+
143+
class TokenCheckResponse(BaseModel):
144+
prompts: List[TokenCheckResponseItem]
145+
146+
147+
class EmbeddingsRequest(BaseModel):
148+
model: Optional[str] = None
149+
engine: Optional[str] = None
150+
input: Union[str, List[Any]]
151+
user: Optional[str] = None
152+
encoding_format: Optional[str] = None
153+
154+
155+
class EmbeddingsResponse(BaseModel):
156+
object: str = "list"
157+
data: List[Dict[str, Any]]
158+
model: str
159+
usage: UsageInfo
160+
161+
162+
class CompletionRequest(BaseModel):
163+
model: str
164+
prompt: Union[str, List[Any]]
165+
suffix: Optional[str] = None
166+
temperature: Optional[float] = 0.7
167+
n: Optional[int] = 1
168+
max_tokens: Optional[int] = 16
169+
stop: Optional[Union[str, List[str]]] = None
170+
stream: Optional[bool] = False
171+
top_p: Optional[float] = 1.0
172+
top_k: Optional[int] = -1
173+
logprobs: Optional[int] = None
174+
echo: Optional[bool] = False
175+
presence_penalty: Optional[float] = 0.0
176+
frequency_penalty: Optional[float] = 0.0
177+
user: Optional[str] = None
178+
use_beam_search: Optional[bool] = False
179+
best_of: Optional[int] = None
180+
181+
182+
class CompletionResponseChoice(BaseModel):
183+
index: int
184+
text: str
185+
logprobs: Optional[LogProbs] = None
186+
finish_reason: Optional[Literal["stop", "length"]] = None
187+
188+
189+
class CompletionResponse(BaseModel):
190+
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
191+
object: str = "text_completion"
192+
created: int = Field(default_factory=lambda: int(time.time()))
193+
model: str
194+
choices: List[CompletionResponseChoice]
195+
usage: UsageInfo
196+
197+
198+
class CompletionResponseStreamChoice(BaseModel):
199+
index: int
200+
text: str
201+
logprobs: Optional[LogProbs] = None
202+
finish_reason: Optional[Literal["stop", "length"]] = None
203+
204+
205+
class CompletionStreamResponse(BaseModel):
206+
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
207+
object: str = "text_completion"
208+
created: int = Field(default_factory=lambda: int(time.time()))
209+
model: str
210+
choices: List[CompletionResponseStreamChoice]

‎requirements.txt

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
fastapi
2+
uvicorn
3+
requests
4+
shortuuid
5+
transformers
6+
tokenizers
7+
semantic-text-splitter

0 commit comments

Comments
 (0)
Please sign in to comment.