Skip to content

Commit 631b930

Browse files
KludexFurqanHabibi
andauthoredSep 26, 2024
Add support for lifespan state (#337)
* Add support for lifespan state * Add lifespan state to http connection * Add test * Revert import changes * Fix imports * Fix imports * Improve a bit the test * Use typing_extensions.Literal --------- Co-authored-by: Muhammad Furqan Habibi <[email protected]>
1 parent 5a7121d commit 631b930

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed
 

‎mangum/adapter.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@ def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:
6060

6161
def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict[str, Any]:
6262
handler = self.infer(event, context)
63+
scope = handler.scope
6364
with ExitStack() as stack:
6465
if self.lifespan in ("auto", "on"):
6566
lifespan_cycle = LifespanCycle(self.app, self.lifespan)
6667
stack.enter_context(lifespan_cycle)
68+
scope.update({"state": lifespan_cycle.lifespan_state.copy()})
6769

68-
http_cycle = HTTPCycle(handler.scope, handler.body)
70+
http_cycle = HTTPCycle(scope, handler.body)
6971
http_response = http_cycle(self.app)
7072

7173
return handler(http_response)

‎mangum/protocols/lifespan.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import enum
55
import logging
66
from types import TracebackType
7+
from typing import Any
78

89
from mangum.exceptions import LifespanFailure, LifespanUnsupported, UnexpectedMessage
910
from mangum.types import ASGI, LifespanMode, Message
@@ -63,6 +64,7 @@ def __init__(self, app: ASGI, lifespan: LifespanMode) -> None:
6364
self.startup_event: asyncio.Event = asyncio.Event()
6465
self.shutdown_event: asyncio.Event = asyncio.Event()
6566
self.logger = logging.getLogger("mangum.lifespan")
67+
self.lifespan_state: dict[str, Any] = {}
6668

6769
def __enter__(self) -> None:
6870
"""Runs the event loop for application startup."""
@@ -82,7 +84,7 @@ async def run(self) -> None:
8284
"""Calls the application with the `lifespan` connection scope."""
8385
try:
8486
await self.app(
85-
{"type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}},
87+
{"type": "lifespan", "asgi": {"spec_version": "2.0", "version": "3.0"}, "state": self.lifespan_state},
8688
self.receive,
8789
self.send,
8890
)
@@ -101,7 +103,7 @@ async def receive(self) -> Message:
101103
if self.state is LifespanCycleState.CONNECTING:
102104
# Connection established. The next event returned by the queue will be
103105
# `lifespan.startup` to inform the application that the connection is
104-
# ready to receive lfiespan messages.
106+
# ready to receive lifespan messages.
105107
self.state = LifespanCycleState.STARTUP
106108

107109
elif self.state is LifespanCycleState.STARTUP:

‎tests/test_lifespan.py

+50
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from quart import Quart
55
from starlette.applications import Starlette
66
from starlette.responses import PlainTextResponse
7+
from typing_extensions import Literal
78

89
from mangum import Mangum
910
from mangum.exceptions import LifespanFailure
11+
from mangum.types import Receive, Scope, Send
1012

1113

1214
@pytest.mark.parametrize(
@@ -209,6 +211,54 @@ async def app(scope, receive, send):
209211
handler(mock_aws_api_gateway_event, {})
210212

211213

214+
@pytest.mark.parametrize(
215+
"mock_aws_api_gateway_event,lifespan",
216+
[(["GET", None, None], "auto"), (["GET", None, None], "on")],
217+
indirect=["mock_aws_api_gateway_event"],
218+
)
219+
def test_lifespan_state(mock_aws_api_gateway_event, lifespan: Literal["on", "auto"]) -> None:
220+
startup_complete = False
221+
shutdown_complete = False
222+
223+
async def app(scope: Scope, receive: Receive, send: Send):
224+
nonlocal startup_complete, shutdown_complete
225+
226+
if scope["type"] == "lifespan":
227+
while True:
228+
message = await receive()
229+
if message["type"] == "lifespan.startup":
230+
scope["state"].update({"test_key": b"Hello, world!"})
231+
await send({"type": "lifespan.startup.complete"})
232+
startup_complete = True
233+
elif message["type"] == "lifespan.shutdown":
234+
await send({"type": "lifespan.shutdown.complete"})
235+
shutdown_complete = True
236+
return
237+
238+
if scope["type"] == "http":
239+
await send(
240+
{
241+
"type": "http.response.start",
242+
"status": 200,
243+
"headers": [[b"content-type", b"text/plain; charset=utf-8"]],
244+
}
245+
)
246+
await send({"type": "http.response.body", "body": scope["state"]["test_key"]})
247+
248+
handler = Mangum(app, lifespan=lifespan)
249+
response = handler(mock_aws_api_gateway_event, {})
250+
251+
assert startup_complete
252+
assert shutdown_complete
253+
assert response == {
254+
"statusCode": 200,
255+
"isBase64Encoded": False,
256+
"headers": {"content-type": "text/plain; charset=utf-8"},
257+
"multiValueHeaders": {},
258+
"body": "Hello, world!",
259+
}
260+
261+
212262
@pytest.mark.parametrize("mock_aws_api_gateway_event", [["GET", None, None]], indirect=True)
213263
def test_starlette_lifespan(mock_aws_api_gateway_event) -> None:
214264
startup_complete = False

0 commit comments

Comments
 (0)