Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add generate support to sagemaker_server #8047

Merged
merged 4 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docker/sagemaker/serve
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2021-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -140,6 +140,20 @@ if [ "${is_mme_mode}" = false ] && [ -f "${SAGEMAKER_MODEL_REPO}/config.pbtxt" ]
exit 1
fi

# Validate SAGEMAKER_TRITON_INFERENCE_TYPE if set
if [ -n "$SAGEMAKER_TRITON_INFERENCE_TYPE" ]; then
case "$SAGEMAKER_TRITON_INFERENCE_TYPE" in
"infer"|"generate"|"generate_stream")
# Valid value, continue
;;
*)
echo "ERROR: Invalid SAGEMAKER_TRITON_INFERENCE_TYPE '${SAGEMAKER_TRITON_INFERENCE_TYPE}'"
echo " Must be one of: infer, generate, generate_stream"
exit 1
;;
esac
fi

if [ "${is_mme_mode}" = false ] && [ -n "$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then
if [ -d "${SAGEMAKER_MODEL_REPO}/$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then
SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --load-model=${SAGEMAKER_TRITON_DEFAULT_MODEL_NAME}"
Expand Down
4 changes: 2 additions & 2 deletions qa/L0_http/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -649,7 +649,7 @@ wait $SERVER_PID
# https://github.com/mpetazzoni/sseclient
pip install sseclient-py

SERVER_ARGS="--model-repository=`pwd`/generate_models"
SERVER_ARGS="--model-repository=`pwd`/../python_models/generate_models"
SERVER_LOG="./inference_server_generate_endpoint_test.log"
CLIENT_LOG="./generate_endpoint_test.log"
run_server
Expand Down
103 changes: 103 additions & 0 deletions qa/L0_sagemaker/sagemaker_generate_stream_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/usr/bin/python
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import sys

sys.path.append("../common")

import json
import os
import sys
import unittest

import requests
import sseclient
import test_util as tu


class SageMakerGenerateStreamTest(tu.TestResultCollector):
def setUp(self):
SAGEMAKER_BIND_TO_PORT = os.getenv("SAGEMAKER_BIND_TO_PORT", "8080")
self.url_ = "http://localhost:{}/invocations".format(SAGEMAKER_BIND_TO_PORT)

def generate_stream(self, inputs, stream=False):
headers = {"Accept": "text/event-stream"}
# stream=True used to indicate response can be iterated over, which
# should be the common setting for generate_stream.
# For correctness test case, stream=False so that we can re-examine
# the response content.
return requests.post(
self.url_,
data=inputs if isinstance(inputs, str) else json.dumps(inputs),
headers=headers,
stream=stream,
)

def generate_stream_expect_success(self, inputs, expected_output, rep_count):
r = self.generate_stream(inputs)
r.raise_for_status()
self.check_sse_responses(r, [{"TEXT": expected_output}] * rep_count)

def check_sse_responses(self, res, expected_res):
# Validate SSE format
self.assertIn("Content-Type", res.headers)
self.assertEqual(
"text/event-stream; charset=utf-8", res.headers["Content-Type"]
)

# SSE format (data: []) is hard to parse, use helper library for simplicity
client = sseclient.SSEClient(res)
res_count = 0
for event in client.events():
# Parse event data, join events into a single response
data = json.loads(event.data)
for key, value in expected_res[res_count].items():
self.assertIn(key, data)
self.assertEqual(value, data[key])
res_count += 1
self.assertEqual(len(expected_res), res_count)
# Make sure there is no message in the wrong form
for remaining in client._read():
self.assertTrue(
remaining.startswith(b"data:"),
f"SSE response not formed properly, got: {remaining}",
)
self.assertTrue(
remaining.endswith(b"\n\n"),
f"SSE response not formed properly, got: {remaining}",
)

def test_generate_stream(self):
# Setup text-based input
text = "hello world"
rep_count = 3
inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": rep_count}
self.generate_stream_expect_success(inputs, text, rep_count)


if __name__ == "__main__":
unittest.main()
68 changes: 68 additions & 0 deletions qa/L0_sagemaker/sagemaker_generate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/python
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import sys

sys.path.append("../common")

import json
import os
import sys
import unittest

import requests
import test_util as tu


class SageMakerGenerateTest(tu.TestResultCollector):
def setUp(self):
SAGEMAKER_BIND_TO_PORT = os.getenv("SAGEMAKER_BIND_TO_PORT", "8080")
self.url_ = "http://localhost:{}/invocations".format(SAGEMAKER_BIND_TO_PORT)

def generate(self, inputs):
return requests.post(
self.url_, data=inputs if isinstance(inputs, str) else json.dumps(inputs)
)

def test_generate(self):
# Setup text-based input
text = "hello world"
inputs = {"PROMPT": text, "STREAM": False}

r = self.generate(inputs)
r.raise_for_status()

self.assertIn("Content-Type", r.headers)
self.assertEqual(r.headers["Content-Type"], "application/json")

data = r.json()
self.assertIn("TEXT", data)
self.assertEqual(text, data["TEXT"])


if __name__ == "__main__":
unittest.main()
116 changes: 114 additions & 2 deletions qa/L0_sagemaker/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2025, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -56,8 +56,12 @@ rm -f *.out

SAGEMAKER_TEST=sagemaker_test.py
SAGEMAKER_MULTI_MODEL_TEST=sagemaker_multi_model_test.py
SAGEMAKER_GENERATE_TEST=sagemaker_generate_test.py
SAGEMAKER_GENERATE_STREAM_TEST=sagemaker_generate_stream_test.py
MULTI_MODEL_UNIT_TEST_COUNT=7
UNIT_TEST_COUNT=9
GENERATE_UNIT_TEST_COUNT=1
GENERATE_STREAM_UNIT_TEST_COUNT=1
CLIENT_LOG="./client.log"

DATADIR=/data/inferenceserver/${REPO_VERSION}
Expand All @@ -74,6 +78,10 @@ mkdir models && \
rm -r models/sm_model/2 && rm -r models/sm_model/3 && \
sed -i "s/onnx_int32_int32_int32/sm_model/" models/sm_model/config.pbtxt

mkdir -p models/mock_llm/1 && \
cp ../python_models/generate_models/mock_llm/1/model.py models/mock_llm/1 && \
cp ../python_models/generate_models/mock_llm/config.pbtxt models/mock_llm

# Use SageMaker's ping endpoint to check server status
# Wait until server health endpoint shows ready. Sets WAIT_RET to 0 on
# success, 1 on failure
Expand Down Expand Up @@ -259,12 +267,115 @@ else
fi
set -e

unset SAGEMAKER_SAFE_PORT_RANGE
kill $SERVER_PID
wait $SERVE_PID

# Start server with LLM and set inference type to generate
export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME=mock_llm
export SAGEMAKER_TRITON_INFERENCE_TYPE=generate
serve > $SERVER_LOG 2>&1 &
SERVE_PID=$!
# Obtain Triton PID in such way as $! will return the script PID
sleep 1
SERVER_PID=`ps | grep tritonserver | awk '{ printf $1 }'`
sagemaker_wait_for_server_ready $SERVER_PID 10
if [ "$WAIT_RET" != "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
kill $SERVER_PID || true
cat $SERVER_LOG
exit 1
fi

# Inference with generate inference type
set +e
python $SAGEMAKER_GENERATE_TEST SageMakerGenerateTest >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test Failed\n***"
cat $CLIENT_LOG
RET=1
else
check_test_results $TEST_RESULT_FILE $GENERATE_UNIT_TEST_COUNT
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e

unset SAGEMAKER_TRITON_DEFAULT_MODEL_NAME
unset SAGEMAKER_TRITON_INFERENCE_TYPE

kill $SERVER_PID
wait $SERVE_PID

# Start server with LLM and set inference type to generate_stream
export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME=mock_llm
export SAGEMAKER_TRITON_INFERENCE_TYPE=generate_stream
serve > $SERVER_LOG 2>&1 &
SERVE_PID=$!
# Obtain Triton PID in such way as $! will return the script PID
sleep 1
SERVER_PID=`ps | grep tritonserver | awk '{ printf $1 }'`
sagemaker_wait_for_server_ready $SERVER_PID 10
if [ "$WAIT_RET" != "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
kill $SERVER_PID || true
cat $SERVER_LOG
exit 1
fi

# Helper library to parse SSE events
# https://github.com/mpetazzoni/sseclient
pip install sseclient-py

# Inference with generate_stream inference type
set +e
python $SAGEMAKER_GENERATE_STREAM_TEST SageMakerGenerateStreamTest >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test Failed\n***"
cat $CLIENT_LOG
RET=1
else
check_test_results $TEST_RESULT_FILE $GENERATE_STREAM_UNIT_TEST_COUNT
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e

unset SAGEMAKER_TRITON_DEFAULT_MODEL_NAME
unset SAGEMAKER_TRITON_INFERENCE_TYPE

kill $SERVER_PID
wait $SERVE_PID

# Test serve with incorrect inference type
export SAGEMAKER_TRITON_INFERENCE_TYPE=incorrect_inference_type
serve > $SERVER_LOG 2>&1 &
SERVE_PID=$!
# Obtain Triton PID in such way as $! will return the script PID
sleep 1
SERVER_PID=`ps | grep tritonserver | awk '{ printf $1 }'`
if [ -n "$SERVER_PID" ]; then
echo -e "\n***\n*** Expect failed to start $SERVER\n***"
kill $SERVER_PID || true
cat $SERVER_LOG
RET=1
else
grep "ERROR: Invalid SAGEMAKER_TRITON_INFERENCE_TYPE" $SERVER_LOG
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Failed. Expected error on incorrect inference type\n***"
RET=1
fi
fi
unset SAGEMAKER_TRITON_INFERENCE_TYPE

unset SAGEMAKER_SAFE_PORT_RANGE
unset SAGEMAKER_TRITON_DEFAULT_MODEL_NAME

# Test serve with incorrect model name
export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME=incorrect_model_name
serve > $SERVER_LOG 2>&1 &
Expand All @@ -288,6 +399,7 @@ fi
unset SAGEMAKER_TRITON_DEFAULT_MODEL_NAME

# Test serve with SAGEMAKER_TRITON_DEFAULT_MODEL_NAME unset, but containing single model directory
rm -rf models/mock_llm
serve > $SERVER_LOG 2>&1 &
SERVE_PID=$!
# Obtain Triton PID in such way as $! will return the script PID
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the same model taken from L0_http generate tests - can you make it so there's only one copy of the model and both tests use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is. Let me refactor L0_http to use the new location then

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down
Loading
Loading