|
4 | 4 | from quart import Quart
|
5 | 5 | from starlette.applications import Starlette
|
6 | 6 | from starlette.responses import PlainTextResponse
|
| 7 | +from typing_extensions import Literal |
7 | 8 |
|
8 | 9 | from mangum import Mangum
|
9 | 10 | from mangum.exceptions import LifespanFailure
|
| 11 | +from mangum.types import Receive, Scope, Send |
10 | 12 |
|
11 | 13 |
|
12 | 14 | @pytest.mark.parametrize(
|
@@ -209,6 +211,54 @@ async def app(scope, receive, send):
|
209 | 211 | handler(mock_aws_api_gateway_event, {})
|
210 | 212 |
|
211 | 213 |
|
| 214 | +@pytest.mark.parametrize( |
| 215 | + "mock_aws_api_gateway_event,lifespan", |
| 216 | + [(["GET", None, None], "auto"), (["GET", None, None], "on")], |
| 217 | + indirect=["mock_aws_api_gateway_event"], |
| 218 | +) |
| 219 | +def test_lifespan_state(mock_aws_api_gateway_event, lifespan: Literal["on", "auto"]) -> None: |
| 220 | + startup_complete = False |
| 221 | + shutdown_complete = False |
| 222 | + |
| 223 | + async def app(scope: Scope, receive: Receive, send: Send): |
| 224 | + nonlocal startup_complete, shutdown_complete |
| 225 | + |
| 226 | + if scope["type"] == "lifespan": |
| 227 | + while True: |
| 228 | + message = await receive() |
| 229 | + if message["type"] == "lifespan.startup": |
| 230 | + scope["state"].update({"test_key": b"Hello, world!"}) |
| 231 | + await send({"type": "lifespan.startup.complete"}) |
| 232 | + startup_complete = True |
| 233 | + elif message["type"] == "lifespan.shutdown": |
| 234 | + await send({"type": "lifespan.shutdown.complete"}) |
| 235 | + shutdown_complete = True |
| 236 | + return |
| 237 | + |
| 238 | + if scope["type"] == "http": |
| 239 | + await send( |
| 240 | + { |
| 241 | + "type": "http.response.start", |
| 242 | + "status": 200, |
| 243 | + "headers": [[b"content-type", b"text/plain; charset=utf-8"]], |
| 244 | + } |
| 245 | + ) |
| 246 | + await send({"type": "http.response.body", "body": scope["state"]["test_key"]}) |
| 247 | + |
| 248 | + handler = Mangum(app, lifespan=lifespan) |
| 249 | + response = handler(mock_aws_api_gateway_event, {}) |
| 250 | + |
| 251 | + assert startup_complete |
| 252 | + assert shutdown_complete |
| 253 | + assert response == { |
| 254 | + "statusCode": 200, |
| 255 | + "isBase64Encoded": False, |
| 256 | + "headers": {"content-type": "text/plain; charset=utf-8"}, |
| 257 | + "multiValueHeaders": {}, |
| 258 | + "body": "Hello, world!", |
| 259 | + } |
| 260 | + |
| 261 | + |
212 | 262 | @pytest.mark.parametrize("mock_aws_api_gateway_event", [["GET", None, None]], indirect=True)
|
213 | 263 | def test_starlette_lifespan(mock_aws_api_gateway_event) -> None:
|
214 | 264 | startup_complete = False
|
|
0 commit comments