Skip to content

Commit 8d7d0f6

Browse files
committed
Change incorrect subscribe return type to a GraphQLError rather than systems error
Replicates graphql/graphql-js@ea1894a
1 parent 47ecdb3 commit 8d7d0f6

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

src/graphql/execution/subscribe.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,8 @@ async def create_source_event_stream(
145145
return ExecutionResult(data=None, errors=context)
146146

147147
try:
148-
event_stream = await execute_subscription(context)
149-
150-
# Assert field returned an event stream, otherwise yield an error.
151-
if not isinstance(event_stream, AsyncIterable):
152-
raise TypeError(
153-
"Subscription field must return AsyncIterable."
154-
f" Received: {inspect(event_stream)}."
155-
)
156-
return event_stream
157-
148+
return await execute_subscription(context)
158149
except GraphQLError as error:
159-
# Report it as an ExecutionResult, containing only errors and no data.
160150
return ExecutionResult(data=None, errors=[error])
161151

162152

@@ -207,6 +197,13 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
207197
if isinstance(event_stream, Exception):
208198
raise event_stream
209199

200+
# Assert field returned an event stream, otherwise yield an error.
201+
if not isinstance(event_stream, AsyncIterable):
202+
raise GraphQLError(
203+
"Subscription field must return AsyncIterable."
204+
f" Received: {inspect(event_stream)}."
205+
)
206+
210207
return event_stream
211208
except Exception as error:
212209
raise located_error(error, field_nodes, path.as_list())

tests/execution/test_subscribe.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,16 @@ async def should_pass_through_unexpected_errors_thrown_in_subscribe():
354354
@mark.asyncio
355355
@mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning")
356356
async def throws_an_error_if_subscribe_does_not_return_an_iterator():
357-
with raises(TypeError) as exc_info:
358-
await subscribe_with_bad_fn(lambda _obj, _info: "test")
359-
360-
assert str(exc_info.value) == (
361-
"Subscription field must return AsyncIterable. Received: 'test'."
357+
assert await subscribe_with_bad_fn(lambda _obj, _info: "test") == (
358+
None,
359+
[
360+
{
361+
"message": "Subscription field must return AsyncIterable."
362+
" Received: 'test'.",
363+
"locations": [(1, 16)],
364+
"path": ["foo"],
365+
}
366+
],
362367
)
363368

364369
@mark.asyncio

0 commit comments

Comments
 (0)