diff --git a/runpod/endpoint/asyncio/asyncio_runner.py b/runpod/endpoint/asyncio/asyncio_runner.py index 8b87d7ce..e5c458f5 100644 --- a/runpod/endpoint/asyncio/asyncio_runner.py +++ b/runpod/endpoint/asyncio/asyncio_runner.py @@ -1,4 +1,4 @@ -""" Module for running endpoints asynchronously. """ +"""Module for running endpoints asynchronously.""" # pylint: disable=too-few-public-methods,R0801 @@ -89,9 +89,14 @@ async def stream(self) -> Any: while True: await asyncio.sleep(1) stream_partial = await self._fetch_job(source="stream") - if stream_partial["status"] not in FINAL_STATES: + if ( + stream_partial["status"] not in FINAL_STATES + or len(stream_partial.get("stream", [])) > 0 + ): for chunk in stream_partial.get("stream", []): yield chunk["output"] + elif stream_partial["status"] in FINAL_STATES: + break async def cancel(self) -> dict: """Cancels current job diff --git a/tests/test_endpoint/test_asyncio_runner.py b/tests/test_endpoint/test_asyncio_runner.py index 088a549f..0ec77caf 100644 --- a/tests/test_endpoint/test_asyncio_runner.py +++ b/tests/test_endpoint/test_asyncio_runner.py @@ -1,4 +1,4 @@ -""" Unit tests for the asyncio_runner module. """ +"""Unit tests for the asyncio_runner module.""" # pylint: disable=too-few-public-methods @@ -114,8 +114,6 @@ async def json_side_effect(): outputs = [] async for stream_output in job.stream(): outputs.append(stream_output) - if not responses: # Break the loop when responses are exhausted - break assert outputs == ["OUTPUT1", "OUTPUT2"]