Skip to content

Commit 3bd347b

Browse files
authored
Support async Authorizers (#1373)
1 parent 3e08300 commit 3bd347b

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

jupyter_server/auth/authorizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Distributed under the terms of the Modified BSD License.
1010
from __future__ import annotations
1111

12-
from typing import TYPE_CHECKING
12+
from typing import TYPE_CHECKING, Awaitable
1313

1414
from traitlets import Instance
1515
from traitlets.config import LoggingConfigurable
@@ -44,7 +44,7 @@ class Authorizer(LoggingConfigurable):
4444

4545
def is_authorized(
4646
self, handler: JupyterHandler, user: User, action: str, resource: str
47-
) -> bool:
47+
) -> Awaitable[bool] | bool:
4848
"""A method to determine if ``user`` is authorized to perform ``action``
4949
(read, write, or execute) on the ``resource`` type.
5050

jupyter_server/auth/decorator.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
"""
33
# Copyright (c) Jupyter Development Team.
44
# Distributed under the terms of the Modified BSD License.
5+
import asyncio
56
from functools import wraps
67
from typing import Any, Callable, Optional, TypeVar, Union, cast
78

9+
from jupyter_core.utils import ensure_async
810
from tornado.log import app_log
911
from tornado.web import HTTPError
1012

@@ -42,7 +44,7 @@ def authorized(
4244

4345
def wrapper(method):
4446
@wraps(method)
45-
def inner(self, *args, **kwargs):
47+
async def inner(self, *args, **kwargs):
4648
# default values for action, resource
4749
nonlocal action
4850
nonlocal resource
@@ -61,8 +63,15 @@ def inner(self, *args, **kwargs):
6163
raise HTTPError(status_code=403, log_message=message)
6264
# If the user is allowed to do this action,
6365
# call the method.
64-
if self.authorizer.is_authorized(self, user, action, resource):
65-
return method(self, *args, **kwargs)
66+
authorized = await ensure_async(
67+
self.authorizer.is_authorized(self, user, action, resource)
68+
)
69+
if authorized:
70+
out = method(self, *args, **kwargs)
71+
# If the method is a coroutine, await it
72+
if asyncio.iscoroutine(out):
73+
return await out
74+
return out
6675
# else raise an exception.
6776
else:
6877
raise HTTPError(status_code=403, log_message=message)

tests/auth/test_authorizer.py

+48
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
"""Tests for authorization"""
2+
import asyncio
23
import json
34
import os
5+
from typing import Awaitable
46

57
import pytest
68
from jupyter_client.kernelspec import NATIVE_KERNEL_NAME
79
from nbformat import writes
810
from nbformat.v4 import new_notebook
11+
from traitlets import Bool
912

13+
from jupyter_server.auth.authorizer import Authorizer
14+
from jupyter_server.auth.identity import User
15+
from jupyter_server.base.handlers import JupyterHandler
1016
from jupyter_server.services.security import csp_report_uri
1117

1218

@@ -217,3 +223,45 @@ async def test_authorized_requests(
217223

218224
code = await send_request(url, body=body, method=method)
219225
assert code in expected_codes
226+
227+
228+
class AsyncAuthorizerTest(Authorizer):
229+
"""Test that an asynchronous authorizer would still work."""
230+
231+
called = Bool(False)
232+
233+
async def mock_async_fetch(self) -> True:
234+
"""Mock an async fetch"""
235+
# Mock a hang for a half a second.
236+
await asyncio.sleep(0.5)
237+
return True
238+
239+
async def is_authorized(
240+
self, handler: JupyterHandler, user: User, action: str, resource: str
241+
) -> Awaitable[bool]:
242+
response = await self.mock_async_fetch()
243+
self.called = True
244+
return response
245+
246+
247+
@pytest.mark.parametrize(
248+
"jp_server_config,",
249+
[
250+
{
251+
"ServerApp": {"authorizer_class": AsyncAuthorizerTest},
252+
"jpserver_extensions": {"jupyter_server_terminals": True},
253+
}
254+
],
255+
)
256+
async def test_async_authorizer(
257+
request,
258+
io_loop,
259+
send_request,
260+
tmp_path,
261+
jp_serverapp,
262+
):
263+
code = await send_request("/api/status", method="GET")
264+
assert code == 200
265+
# Ensure that the authorizor method finished its request.
266+
assert hasattr(jp_serverapp.authorizer, "called")
267+
assert jp_serverapp.authorizer.called is True

0 commit comments

Comments
 (0)