Skip to content

Commit 96e7cb5

Browse files
authored
feat: add generate support to sagemaker_server (#8047)
1 parent 205f13c commit 96e7cb5

File tree

9 files changed

+332
-11
lines changed

9 files changed

+332
-11
lines changed

docker/sagemaker/serve

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -140,6 +140,20 @@ if [ "${is_mme_mode}" = false ] && [ -f "${SAGEMAKER_MODEL_REPO}/config.pbtxt" ]
140140
exit 1
141141
fi
142142

143+
# Validate SAGEMAKER_TRITON_INFERENCE_TYPE if set
144+
if [ -n "$SAGEMAKER_TRITON_INFERENCE_TYPE" ]; then
145+
case "$SAGEMAKER_TRITON_INFERENCE_TYPE" in
146+
"infer"|"generate"|"generate_stream")
147+
# Valid value, continue
148+
;;
149+
*)
150+
echo "ERROR: Invalid SAGEMAKER_TRITON_INFERENCE_TYPE '${SAGEMAKER_TRITON_INFERENCE_TYPE}'"
151+
echo " Must be one of: infer, generate, generate_stream"
152+
exit 1
153+
;;
154+
esac
155+
fi
156+
143157
if [ "${is_mme_mode}" = false ] && [ -n "$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then
144158
if [ -d "${SAGEMAKER_MODEL_REPO}/$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then
145159
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --load-model=${SAGEMAKER_TRITON_DEFAULT_MODEL_NAME}"

qa/L0_http/test.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -649,7 +649,7 @@ wait $SERVER_PID
649649
# https://github.com/mpetazzoni/sseclient
650650
pip install sseclient-py
651651

652-
SERVER_ARGS="--model-repository=`pwd`/generate_models"
652+
SERVER_ARGS="--model-repository=`pwd`/../python_models/generate_models"
653653
SERVER_LOG="./inference_server_generate_endpoint_test.log"
654654
CLIENT_LOG="./generate_endpoint_test.log"
655655
run_server
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#!/usr/bin/python
2+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
import sys
29+
30+
sys.path.append("../common")
31+
32+
import json
33+
import os
34+
import sys
35+
import unittest
36+
37+
import requests
38+
import sseclient
39+
import test_util as tu
40+
41+
42+
class SageMakerGenerateStreamTest(tu.TestResultCollector):
43+
def setUp(self):
44+
SAGEMAKER_BIND_TO_PORT = os.getenv("SAGEMAKER_BIND_TO_PORT", "8080")
45+
self.url_ = "http://localhost:{}/invocations".format(SAGEMAKER_BIND_TO_PORT)
46+
47+
def generate_stream(self, inputs, stream=False):
48+
headers = {"Accept": "text/event-stream"}
49+
# stream=True used to indicate response can be iterated over, which
50+
# should be the common setting for generate_stream.
51+
# For correctness test case, stream=False so that we can re-examine
52+
# the response content.
53+
return requests.post(
54+
self.url_,
55+
data=inputs if isinstance(inputs, str) else json.dumps(inputs),
56+
headers=headers,
57+
stream=stream,
58+
)
59+
60+
def generate_stream_expect_success(self, inputs, expected_output, rep_count):
61+
r = self.generate_stream(inputs)
62+
r.raise_for_status()
63+
self.check_sse_responses(r, [{"TEXT": expected_output}] * rep_count)
64+
65+
def check_sse_responses(self, res, expected_res):
66+
# Validate SSE format
67+
self.assertIn("Content-Type", res.headers)
68+
self.assertEqual(
69+
"text/event-stream; charset=utf-8", res.headers["Content-Type"]
70+
)
71+
72+
# SSE format (data: []) is hard to parse, use helper library for simplicity
73+
client = sseclient.SSEClient(res)
74+
res_count = 0
75+
for event in client.events():
76+
# Parse event data, join events into a single response
77+
data = json.loads(event.data)
78+
for key, value in expected_res[res_count].items():
79+
self.assertIn(key, data)
80+
self.assertEqual(value, data[key])
81+
res_count += 1
82+
self.assertEqual(len(expected_res), res_count)
83+
# Make sure there is no message in the wrong form
84+
for remaining in client._read():
85+
self.assertTrue(
86+
remaining.startswith(b"data:"),
87+
f"SSE response not formed properly, got: {remaining}",
88+
)
89+
self.assertTrue(
90+
remaining.endswith(b"\n\n"),
91+
f"SSE response not formed properly, got: {remaining}",
92+
)
93+
94+
def test_generate_stream(self):
95+
# Setup text-based input
96+
text = "hello world"
97+
rep_count = 3
98+
inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": rep_count}
99+
self.generate_stream_expect_success(inputs, text, rep_count)
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()
+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/usr/bin/python
2+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
import sys
29+
30+
sys.path.append("../common")
31+
32+
import json
33+
import os
34+
import sys
35+
import unittest
36+
37+
import requests
38+
import test_util as tu
39+
40+
41+
class SageMakerGenerateTest(tu.TestResultCollector):
42+
def setUp(self):
43+
SAGEMAKER_BIND_TO_PORT = os.getenv("SAGEMAKER_BIND_TO_PORT", "8080")
44+
self.url_ = "http://localhost:{}/invocations".format(SAGEMAKER_BIND_TO_PORT)
45+
46+
def generate(self, inputs):
47+
return requests.post(
48+
self.url_, data=inputs if isinstance(inputs, str) else json.dumps(inputs)
49+
)
50+
51+
def test_generate(self):
52+
# Setup text-based input
53+
text = "hello world"
54+
inputs = {"PROMPT": text, "STREAM": False}
55+
56+
r = self.generate(inputs)
57+
r.raise_for_status()
58+
59+
self.assertIn("Content-Type", r.headers)
60+
self.assertEqual(r.headers["Content-Type"], "application/json")
61+
62+
data = r.json()
63+
self.assertIn("TEXT", data)
64+
self.assertEqual(text, data["TEXT"])
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()

qa/L0_sagemaker/test.sh

+114-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
2+
# Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -56,8 +56,12 @@ rm -f *.out
5656

5757
SAGEMAKER_TEST=sagemaker_test.py
5858
SAGEMAKER_MULTI_MODEL_TEST=sagemaker_multi_model_test.py
59+
SAGEMAKER_GENERATE_TEST=sagemaker_generate_test.py
60+
SAGEMAKER_GENERATE_STREAM_TEST=sagemaker_generate_stream_test.py
5961
MULTI_MODEL_UNIT_TEST_COUNT=7
6062
UNIT_TEST_COUNT=9
63+
GENERATE_UNIT_TEST_COUNT=1
64+
GENERATE_STREAM_UNIT_TEST_COUNT=1
6165
CLIENT_LOG="./client.log"
6266

6367
DATADIR=/data/inferenceserver/${REPO_VERSION}
@@ -74,6 +78,10 @@ mkdir models && \
7478
rm -r models/sm_model/2 && rm -r models/sm_model/3 && \
7579
sed -i "s/onnx_int32_int32_int32/sm_model/" models/sm_model/config.pbtxt
7680

81+
mkdir -p models/mock_llm/1 && \
82+
cp ../python_models/generate_models/mock_llm/1/model.py models/mock_llm/1 && \
83+
cp ../python_models/generate_models/mock_llm/config.pbtxt models/mock_llm
84+
7785
# Use SageMaker's ping endpoint to check server status
7886
# Wait until server health endpoint shows ready. Sets WAIT_RET to 0 on
7987
# success, 1 on failure
@@ -259,12 +267,115 @@ else
259267
fi
260268
set -e
261269

262-
unset SAGEMAKER_SAFE_PORT_RANGE
270+
kill $SERVER_PID
271+
wait $SERVE_PID
272+
273+
# Start server with LLM and set inference type to generate
274+
export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME=mock_llm
275+
export SAGEMAKER_TRITON_INFERENCE_TYPE=generate
276+
serve > $SERVER_LOG 2>&1 &
277+
SERVE_PID=$!
278+
# Obtain Triton PID in such way as $! will return the script PID
279+
sleep 1
280+
SERVER_PID=`ps | grep tritonserver | awk '{ printf $1 }'`
281+
sagemaker_wait_for_server_ready $SERVER_PID 10
282+
if [ "$WAIT_RET" != "0" ]; then
283+
echo -e "\n***\n*** Failed to start $SERVER\n***"
284+
kill $SERVER_PID || true
285+
cat $SERVER_LOG
286+
exit 1
287+
fi
288+
289+
# Inference with generate inference type
290+
set +e
291+
python $SAGEMAKER_GENERATE_TEST SageMakerGenerateTest >>$CLIENT_LOG 2>&1
292+
if [ $? -ne 0 ]; then
293+
echo -e "\n***\n*** Test Failed\n***"
294+
cat $CLIENT_LOG
295+
RET=1
296+
else
297+
check_test_results $TEST_RESULT_FILE $GENERATE_UNIT_TEST_COUNT
298+
if [ $? -ne 0 ]; then
299+
cat $CLIENT_LOG
300+
echo -e "\n***\n*** Test Result Verification Failed\n***"
301+
RET=1
302+
fi
303+
fi
304+
set -e
305+
306+
unset SAGEMAKER_TRITON_DEFAULT_MODEL_NAME
307+
unset SAGEMAKER_TRITON_INFERENCE_TYPE
308+
309+
kill $SERVER_PID
310+
wait $SERVE_PID
311+
312+
# Start server with LLM and set inference type to generate_stream
313+
export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME=mock_llm
314+
export SAGEMAKER_TRITON_INFERENCE_TYPE=generate_stream
315+
serve > $SERVER_LOG 2>&1 &
316+
SERVE_PID=$!
317+
# Obtain Triton PID in such way as $! will return the script PID
318+
sleep 1
319+
SERVER_PID=`ps | grep tritonserver | awk '{ printf $1 }'`
320+
sagemaker_wait_for_server_ready $SERVER_PID 10
321+
if [ "$WAIT_RET" != "0" ]; then
322+
echo -e "\n***\n*** Failed to start $SERVER\n***"
323+
kill $SERVER_PID || true
324+
cat $SERVER_LOG
325+
exit 1
326+
fi
327+
328+
# Helper library to parse SSE events
329+
# https://github.com/mpetazzoni/sseclient
330+
pip install sseclient-py
331+
332+
# Inference with generate_stream inference type
333+
set +e
334+
python $SAGEMAKER_GENERATE_STREAM_TEST SageMakerGenerateStreamTest >>$CLIENT_LOG 2>&1
335+
if [ $? -ne 0 ]; then
336+
echo -e "\n***\n*** Test Failed\n***"
337+
cat $CLIENT_LOG
338+
RET=1
339+
else
340+
check_test_results $TEST_RESULT_FILE $GENERATE_STREAM_UNIT_TEST_COUNT
341+
if [ $? -ne 0 ]; then
342+
cat $CLIENT_LOG
343+
echo -e "\n***\n*** Test Result Verification Failed\n***"
344+
RET=1
345+
fi
346+
fi
347+
set -e
348+
263349
unset SAGEMAKER_TRITON_DEFAULT_MODEL_NAME
350+
unset SAGEMAKER_TRITON_INFERENCE_TYPE
264351

265352
kill $SERVER_PID
266353
wait $SERVE_PID
267354

355+
# Test serve with incorrect inference type
356+
export SAGEMAKER_TRITON_INFERENCE_TYPE=incorrect_inference_type
357+
serve > $SERVER_LOG 2>&1 &
358+
SERVE_PID=$!
359+
# Obtain Triton PID in such way as $! will return the script PID
360+
sleep 1
361+
SERVER_PID=`ps | grep tritonserver | awk '{ printf $1 }'`
362+
if [ -n "$SERVER_PID" ]; then
363+
echo -e "\n***\n*** Expect failed to start $SERVER\n***"
364+
kill $SERVER_PID || true
365+
cat $SERVER_LOG
366+
RET=1
367+
else
368+
grep "ERROR: Invalid SAGEMAKER_TRITON_INFERENCE_TYPE" $SERVER_LOG
369+
if [ $? -ne 0 ]; then
370+
echo -e "\n***\n*** Failed. Expected error on incorrect inference type\n***"
371+
RET=1
372+
fi
373+
fi
374+
unset SAGEMAKER_TRITON_INFERENCE_TYPE
375+
376+
unset SAGEMAKER_SAFE_PORT_RANGE
377+
unset SAGEMAKER_TRITON_DEFAULT_MODEL_NAME
378+
268379
# Test serve with incorrect model name
269380
export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME=incorrect_model_name
270381
serve > $SERVER_LOG 2>&1 &
@@ -288,6 +399,7 @@ fi
288399
unset SAGEMAKER_TRITON_DEFAULT_MODEL_NAME
289400

290401
# Test serve with SAGEMAKER_TRITON_DEFAULT_MODEL_NAME unset, but containing single model directory
402+
rm -rf models/mock_llm
291403
serve > $SERVER_LOG 2>&1 &
292404
SERVE_PID=$!
293405
# Obtain Triton PID in such way as $! will return the script PID

qa/L0_http/generate_models/mock_llm/1/model.py qa/python_models/generate_models/mock_llm/1/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Redistribution and use in source and binary forms, with or without
44
# modification, are permitted provided that the following conditions

qa/L0_http/generate_models/mock_llm/config.pbtxt qa/python_models/generate_models/mock_llm/config.pbtxt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Redistribution and use in source and binary forms, with or without
44
# modification, are permitted provided that the following conditions

0 commit comments

Comments
 (0)