Skip to content

Commit 876aef6

Browse files
authored
Support middlewares for subscriptions (#221)
1 parent a5a2a65 commit 876aef6

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

src/graphql/execution/execute.py

+2
Original file line numberDiff line numberDiff line change
@@ -2043,6 +2043,7 @@ def subscribe(
20432043
type_resolver: GraphQLTypeResolver | None = None,
20442044
subscribe_field_resolver: GraphQLFieldResolver | None = None,
20452045
execution_context_class: type[ExecutionContext] | None = None,
2046+
middleware: MiddlewareManager | None = None,
20462047
) -> AwaitableOrValue[AsyncIterator[ExecutionResult] | ExecutionResult]:
20472048
"""Create a GraphQL subscription.
20482049
@@ -2082,6 +2083,7 @@ def subscribe(
20822083
field_resolver,
20832084
type_resolver,
20842085
subscribe_field_resolver,
2086+
middleware=middleware,
20852087
)
20862088

20872089
# Return early errors if execution context failed.

tests/execution/test_middleware.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import inspect
12
from typing import Awaitable, cast
23

34
import pytest
4-
from graphql.execution import Middleware, MiddlewareManager, execute
5+
from graphql.execution import Middleware, MiddlewareManager, execute, subscribe
56
from graphql.language.parser import parse
67
from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString
78

@@ -236,6 +237,45 @@ async def resolve(self, next_, *args, **kwargs):
236237
result = await awaitable_result
237238
assert result.data == {"field": "devloseR"}
238239

240+
@pytest.mark.asyncio()
241+
async def subscription_simple():
242+
async def bar_resolve(_obj, _info):
243+
yield "bar"
244+
yield "oof"
245+
246+
test_type = GraphQLObjectType(
247+
"Subscription",
248+
{
249+
"bar": GraphQLField(
250+
GraphQLString,
251+
resolve=lambda message, _info: message,
252+
subscribe=bar_resolve,
253+
),
254+
},
255+
)
256+
doc = parse("subscription { bar }")
257+
258+
async def reverse_middleware(next_, value, info, **kwargs):
259+
awaitable_maybe = next_(value, info, **kwargs)
260+
return awaitable_maybe[::-1]
261+
262+
noop_type = GraphQLObjectType(
263+
"Noop",
264+
{"noop": GraphQLField(GraphQLString)},
265+
)
266+
schema = GraphQLSchema(query=noop_type, subscription=test_type)
267+
268+
agen = subscribe(
269+
schema,
270+
doc,
271+
middleware=MiddlewareManager(reverse_middleware),
272+
)
273+
assert inspect.isasyncgen(agen)
274+
data = (await agen.__anext__()).data
275+
assert data == {"bar": "rab"}
276+
data = (await agen.__anext__()).data
277+
assert data == {"bar": "foo"}
278+
239279
def describe_without_manager():
240280
def no_middleware():
241281
doc = parse("{ field }")

0 commit comments

Comments
 (0)