Skip to content

Commit 4bb0c2c

Browse files
committed
bot: Start implementing new FastAPI-based server
1 parent 664221f commit 4bb0c2c

24 files changed

+1132
-53
lines changed

.vscode/settings.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
"--disable=too-many-return-statements",
1313
"--disable=too-many-branches"
1414
],
15-
"editor.formatOnSave": true
15+
"editor.formatOnSave": true,
16+
"editor.defaultFormatter": "charliermarsh.ruff"
1617
}

Makefile

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ install:
1010
dev-frontend: config.yml blanco.db
1111
poetry run python -m bot.dev_server
1212

13+
dev-backend: config.yml blanco.db
14+
poetry run python -m bot.api.main
15+
1316
dev: config.yml blanco.db
1417
poetry run python -m bot.main
1518

bot/api/depends/database.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import TYPE_CHECKING
2+
3+
from fastapi import HTTPException, Request
4+
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
5+
6+
if TYPE_CHECKING:
7+
from bot.database import Database
8+
9+
10+
def database_dependency(request: Request) -> 'Database':
11+
"""
12+
FastAPI dependency to get the database object.
13+
14+
Args:
15+
request (web.Request): The request.
16+
17+
Returns:
18+
Database: The database object.
19+
"""
20+
21+
state = request.app.state
22+
if not hasattr(state, 'database'):
23+
raise HTTPException(
24+
status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail='No database connection'
25+
)
26+
27+
database: 'Database' = state.database
28+
if database is None:
29+
raise HTTPException(
30+
status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail='No database connection'
31+
)
32+
33+
return database

bot/api/depends/session.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import TYPE_CHECKING, Optional
2+
3+
from fastapi import Depends, HTTPException, Request
4+
from starlette.status import HTTP_401_UNAUTHORIZED
5+
6+
from .database import database_dependency
7+
8+
if TYPE_CHECKING:
9+
from bot.api.utils.session import SessionManager
10+
from bot.database import Database
11+
from bot.models.oauth import OAuth
12+
13+
14+
EXPECTED_AUTH_SCHEME = 'Bearer'
15+
EXPECTED_AUTH_PARTS = 2
16+
17+
18+
def session_dependency(
19+
request: Request, db: 'Database' = Depends(database_dependency)
20+
) -> 'OAuth':
21+
"""
22+
FastAPI dependency to get the requesting user's info.
23+
24+
Args:
25+
request (web.Request): The request.
26+
27+
Returns:
28+
OAuth: The info for the current Discord user.
29+
"""
30+
31+
authorization = request.headers.get('Authorization')
32+
if authorization is None:
33+
raise HTTPException(
34+
status_code=HTTP_401_UNAUTHORIZED, detail='No authorization header'
35+
)
36+
37+
parts = authorization.split()
38+
if len(parts) != EXPECTED_AUTH_PARTS:
39+
raise HTTPException(
40+
status_code=HTTP_401_UNAUTHORIZED, detail='Invalid authorization header'
41+
)
42+
43+
scheme, token = parts
44+
if scheme != EXPECTED_AUTH_SCHEME:
45+
raise HTTPException(
46+
status_code=HTTP_401_UNAUTHORIZED, detail='Invalid authorization scheme'
47+
)
48+
49+
session_manager: 'SessionManager' = request.app.state.session_manager
50+
session = session_manager.decode_session(token)
51+
if session is None:
52+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail='Invalid session')
53+
54+
user: Optional['OAuth'] = db.get_oauth('discord', session.user_id)
55+
if user is None:
56+
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail='User not found')
57+
58+
return user

bot/api/extension.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Nextcord extension that runs the API server for the bot
3+
"""
4+
5+
from typing import TYPE_CHECKING
6+
7+
from .main import run_app
8+
9+
if TYPE_CHECKING:
10+
from bot.utils.blanco import BlancoBot
11+
12+
13+
def setup(bot: 'BlancoBot'):
14+
"""
15+
Run the API server within the bot's existing event loop.
16+
"""
17+
run_app(bot.loop, bot.database)

bot/api/main.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
Main module for the API server.
3+
"""
4+
5+
from asyncio import set_event_loop
6+
from contextlib import asynccontextmanager
7+
from logging import INFO
8+
from typing import TYPE_CHECKING, Any, Optional
9+
10+
from fastapi import FastAPI
11+
from uvicorn import Config, Server, run
12+
from uvicorn.config import LOGGING_CONFIG
13+
14+
from bot.database import Database
15+
from bot.utils.config import config as bot_config
16+
from bot.utils.logger import DATE_FMT_STR, LOG_FMT_COLOR, create_logger
17+
18+
from .routes.account import account_router
19+
from .routes.oauth import oauth_router
20+
from .utils.session import SessionManager
21+
22+
if TYPE_CHECKING:
23+
from asyncio import AbstractEventLoop
24+
25+
26+
_database: Optional[Database] = None
27+
28+
29+
@asynccontextmanager
30+
async def lifespan(app: FastAPI):
31+
logger = create_logger('api.lifespan')
32+
33+
if _database is None:
34+
logger.warn('Manually creating database connection')
35+
database = Database(bot_config.db_file)
36+
else:
37+
logger.info('Connecting to database from FastAPI')
38+
database = _database
39+
40+
app.state.database = database
41+
app.state.session_manager = SessionManager(database)
42+
yield
43+
44+
45+
app = FastAPI(lifespan=lifespan)
46+
app.include_router(account_router)
47+
app.include_router(oauth_router)
48+
49+
50+
@app.get('/')
51+
async def health_check():
52+
return {'status': 'ok'}
53+
54+
55+
def _get_log_config() -> dict[str, Any]:
56+
log_config = LOGGING_CONFIG
57+
log_config['formatters']['default']['fmt'] = LOG_FMT_COLOR[INFO]
58+
log_config['formatters']['default']['datefmt'] = DATE_FMT_STR
59+
log_config['formatters']['access']['fmt'] = LOG_FMT_COLOR[INFO]
60+
61+
return log_config
62+
63+
64+
def run_app(loop: 'AbstractEventLoop', db: Database):
65+
"""
66+
Run the API server in the bot's event loop.
67+
"""
68+
global _database # noqa: PLW0603
69+
_database = db
70+
71+
set_event_loop(loop)
72+
73+
config = Config(
74+
app=app,
75+
loop=loop, # type: ignore
76+
host='0.0.0.0',
77+
port=bot_config.server_port,
78+
log_config=_get_log_config(),
79+
)
80+
server = Server(config)
81+
82+
loop.create_task(server.serve())
83+
84+
85+
if __name__ == '__main__':
86+
run(
87+
app='bot.api.main:app',
88+
host='127.0.0.1',
89+
port=bot_config.server_port,
90+
reload=True,
91+
reload_dirs=['bot/api'],
92+
log_config=_get_log_config(),
93+
)

bot/api/models/account.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class AccountResponse(BaseModel):
7+
username: str = Field(description="The user's username.")
8+
spotify_logged_in: bool = Field(
9+
description='Whether the user is logged in to Spotify.'
10+
)
11+
spotify_username: Optional[str] = Field(
12+
default=None, description="The user's Spotify username, if logged in."
13+
)
14+
lastfm_logged_in: bool = Field(
15+
description='Whether the user is logged in to Last.fm.'
16+
)
17+
lastfm_username: Optional[str] = Field(
18+
default=None, description="The user's Last.fm username, if logged in."
19+
)

bot/api/models/oauth.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class OAuthResponse(BaseModel):
7+
session_id: str = Field(description='The session ID for the user.')
8+
jwt: str = Field(description='The JSON Web Token for the user.')
9+
10+
11+
class DiscordUser(BaseModel):
12+
id: int = Field(description='The user ID.')
13+
username: str = Field(description='The username.')
14+
discriminator: str = Field(description='The discriminator.')
15+
avatar: Optional[str] = Field(
16+
default=None, description='The avatar hash, if the user has one.'
17+
)

bot/api/models/session.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic import BaseModel
2+
3+
4+
class Session(BaseModel):
5+
user_id: int
6+
session_id: str
7+
expiration_time: int

bot/api/routes/account/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from fastapi import APIRouter
2+
3+
from .login import get_login_url as route_login
4+
from .me import get_logged_in_user as route_me
5+
6+
account_router = APIRouter(prefix='/account', tags=['account'])
7+
account_router.add_api_route('/login', route_login, methods=['GET'])
8+
account_router.add_api_route('/me', route_me, methods=['GET'])

bot/api/routes/account/login.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from secrets import token_urlsafe
2+
3+
from fastapi import HTTPException
4+
from fastapi.responses import RedirectResponse
5+
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
6+
from yarl import URL
7+
8+
from bot.utils.config import config as bot_config
9+
10+
11+
async def get_login_url() -> RedirectResponse:
12+
oauth_id = bot_config.discord_oauth_id
13+
base_url = bot_config.base_url
14+
15+
if oauth_id is None or base_url is None:
16+
raise HTTPException(
17+
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
18+
detail='Missing Discord OAuth ID or base URL',
19+
)
20+
21+
state = token_urlsafe(16)
22+
23+
url = URL.build(
24+
scheme='https',
25+
host='discord.com',
26+
path='/api/oauth2/authorize',
27+
query={
28+
'client_id': oauth_id,
29+
'response_type': 'code',
30+
'scope': 'identify guilds email',
31+
'redirect_uri': f'{base_url}/oauth/discord',
32+
'state': state,
33+
'prompt': 'none',
34+
},
35+
)
36+
37+
response = RedirectResponse(url=str(url))
38+
response.set_cookie('state', state, httponly=True, samesite='lax')
39+
return response

bot/api/routes/account/me.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Route for getting the current user's account information.
3+
"""
4+
5+
from typing import TYPE_CHECKING, Optional
6+
7+
from fastapi import Depends
8+
9+
from bot.api.depends.database import database_dependency
10+
from bot.api.depends.session import session_dependency
11+
from bot.api.models.account import AccountResponse
12+
13+
if TYPE_CHECKING:
14+
from bot.database import Database
15+
from bot.models.oauth import LastfmAuth, OAuth
16+
17+
18+
async def get_logged_in_user(
19+
user: 'OAuth' = Depends(session_dependency),
20+
db: 'Database' = Depends(database_dependency),
21+
) -> AccountResponse:
22+
spotify_username = None
23+
spotify: Optional['OAuth'] = db.get_oauth('spotify', user.user_id)
24+
if spotify is not None:
25+
spotify_username = spotify.username
26+
27+
lastfm_username = None
28+
lastfm: Optional['LastfmAuth'] = db.get_lastfm_credentials(user.user_id)
29+
if lastfm is not None:
30+
lastfm_username = lastfm.username
31+
32+
return AccountResponse(
33+
username=user.username,
34+
spotify_logged_in=spotify is not None,
35+
spotify_username=spotify_username,
36+
lastfm_logged_in=lastfm is not None,
37+
lastfm_username=lastfm_username,
38+
)

bot/api/routes/oauth/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from fastapi import APIRouter
2+
3+
from .discord import discord_oauth as route_discord
4+
5+
oauth_router = APIRouter(prefix='/oauth', tags=['oauth'])
6+
oauth_router.add_api_route('/discord', route_discord, methods=['GET'])

0 commit comments

Comments
 (0)