Skip to content

Commit db15a91

Browse files
committed
Add docstrings and mypy annotations
1 parent db2f1a1 commit db15a91

10 files changed

+153
-142
lines changed

.github/workflows/main.yml

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ jobs:
2727
- name: Lint with black
2828
run: |
2929
poetry run black --check .
30+
- name: Check with mypy
31+
run: |
32+
poetry run mypy .
3033
- name: Run tests
3134
run: |
3235
poetry run pytest tests

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
config.json
55
*.egg-info
66
__pycache__
7+
.mypy_cache

.pre-commit-config.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@ repos:
1616
rev: 3.7.9
1717
hooks:
1818
- id: flake8
19+
- repo: https://github.com/pre-commit/mirrors-mypy
20+
rev: v0.761
21+
hooks:
22+
- id: mypy

mypy.ini

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[mypy]
2+
python_version = 3.8
3+
4+
[mypy-discord]
5+
ignore_missing_imports = true
6+
7+
[mypy-pylru]
8+
ignore_missing_imports = true
9+
10+
[mypy-pytest]
11+
ignore_missing_imports = true

oumodulesbot/backend.py

+37-5
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import json
33
import logging
44
import re
5+
from typing import Mapping, Optional, Tuple
56

67
from .ou_utils import MODULE_CODE_RE_TEMPLATE, get_module_url
78

89
logger = logging.getLogger(__name__)
910

11+
CacheItem = Tuple[str, Optional[str]] # title, url
12+
1013

1114
class OUModulesBackend:
1215

@@ -17,9 +20,18 @@ class OUModulesBackend:
1720

1821
def __init__(self):
1922
with open("cache.json", "r") as f:
20-
self.cache = json.load(f)
23+
cache_json = json.load(f)
24+
self.cache: Mapping[str, CacheItem] = {
25+
k: tuple(v) for k, v in cache_json.items()
26+
}
27+
28+
async def get_module_title(self, code: str) -> Optional[str]:
29+
"""
30+
Returns a module title for given code, if available.
2131
22-
async def get_module_title(self, code):
32+
Tries a lookup in cache, and if it fails then it attempts to query
33+
the Open University Digital Archive.
34+
"""
2335
code = code.upper()
2436

2537
# 1. Try cached title:
@@ -36,17 +48,37 @@ async def get_module_title(self, code):
3648
response = await client.get(url_template.format(code))
3749
html = response.content.decode("utf-8")
3850
except Exception:
39-
return
51+
return None
4052
title = self.MODULE_TITLE_RE.findall(html)
4153
return title[0].replace("!", "") if title else None
4254

43-
async def _is_module_url(self, url, code):
55+
async def _is_module_url(self, url: str, code: str) -> bool:
56+
"""
57+
Check if given URL looks like a valid URL for a given module code,
58+
i.e. resolves to 200, and doesn't redirect away to a different page.
59+
60+
OU redirects to places like /courses/ sometimes which is a masked way
61+
of saying '404'. However 301 redirects don't always indicate that
62+
modules aren't available, because they sometimes point to a different
63+
page for the same module.
64+
65+
Thus a compromise is used here by allowing redirects, but only if the
66+
destination page URL includes the module code.
67+
"""
4468
async with httpx.AsyncClient() as client:
4569
response = await client.head(url, allow_redirects=True)
4670
correct_redirect = code.lower() in str(response.url).lower()
4771
return correct_redirect and response.status_code == 200
4872

49-
async def get_module_url(self, code):
73+
async def get_module_url(self, code: str) -> Optional[str]:
74+
"""
75+
Return module's URL for given module code, if available.
76+
77+
Tries to lookup in cache, and if it fails then it tries to provide
78+
the URL with a well-known template (see get_module_url from ou_utils).
79+
80+
The template-based URL is returned only if it passes a HTTP check.
81+
"""
5082
code = code.upper()
5183

5284
# 1. Try cached URL

oumodulesbot/ou_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
MODULE_CODE_RE_TEMPLATE = r"[a-zA-Z]{1,6}[0-9]{1,3}(?:-[a-zA-Z]{1,5})?"
22

33

4-
def get_module_level(module_code):
4+
def get_module_level(module_code: str) -> int:
55
for c in module_code:
66
if c.isdigit():
77
return int(c)
8+
raise ValueError(f"Invalid module code: {module_code}")
89

910

10-
def get_module_url(module_code):
11+
def get_module_url(module_code: str) -> str:
1112
if get_module_level(module_code) == 0:
1213
template = "http://www.open.ac.uk/courses/short-courses/{}"
1314
elif get_module_level(module_code) == 8:

oumodulesbot/oumodulesbot.py

+45-13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import namedtuple
2+
from typing import Iterable, List, Sequence
13
import json
24
import logging
35
import os
@@ -12,6 +14,7 @@
1214
logger = logging.getLogger(__name__)
1315

1416
replies_cache = pylru.lrucache(1000)
17+
Module = namedtuple("Module", "code,title")
1518

1619

1720
class OUModulesBot(discord.Client):
@@ -23,8 +26,12 @@ def __init__(self, *args, **kwargs):
2326
super().__init__(*args, **kwargs)
2427
self.backend = OUModulesBackend()
2528

26-
async def do_mentions(self, message):
27-
modules = []
29+
async def process_mentions(self, message: discord.Message) -> None:
30+
"""
31+
Process module code mentions from given `message`, and reply with
32+
thieir names/URLs if any were found.
33+
"""
34+
modules: List[Module] = []
2835
any_found = False
2936
for module in self.MENTION_RE.findall(message.content)[
3037
: self.MODULES_COUNT_LIMIT
@@ -33,14 +40,23 @@ async def do_mentions(self, message):
3340
title = await self.backend.get_module_title(module_code)
3441
if title:
3542
any_found = True
36-
modules.append((module_code, title))
43+
modules.append(Module(module_code, title))
3744
else:
38-
modules.append((module_code, "not found"))
45+
modules.append(Module(module_code, "not found"))
3946
if any_found:
4047
# don't spam unless we're sure we at least found some modules
4148
await self.post_modules(message, modules)
4249

43-
async def format_course(self, code, title, for_embed=False):
50+
async def format_course(
51+
self, code: str, title: str, for_embed: bool = False
52+
) -> str:
53+
"""
54+
Return a string describing a module ready for posting to Discord,
55+
for given module `code` and `title`. Appends URL if available.
56+
57+
Uses more compact formatting if `for_embed` is True, which should
58+
be used if multiple modules are presented as part of an embed.
59+
"""
4460
fmt = " * {} " if for_embed else "{}"
4561
fmt_link = " * [{}]({}) " if for_embed else "{} ({})"
4662
url = await self.backend.get_module_url(code)
@@ -53,24 +69,38 @@ async def format_course(self, code, title, for_embed=False):
5369
else:
5470
return "{}: {}".format(code, result)
5571

56-
async def _embed_modules(self, embed, modules):
72+
async def embed_modules(
73+
self, embed: discord.Embed, modules: Iterable[Module]
74+
) -> None:
75+
"""
76+
Adds `embed` fields for each provided module.
77+
"""
5778
for (code, title) in modules:
5879
embed.add_field(
5980
name=code,
6081
value=await self.format_course(code, title, for_embed=True),
6182
inline=True,
6283
)
6384

64-
async def post_modules(self, message, modules):
85+
async def post_modules(
86+
self, message: discord.Message, modules: Sequence[Module]
87+
) -> None:
88+
"""
89+
Create or update a bot message for given users's input message,
90+
and a list of modules.
91+
92+
Message is updated instead of created if the input was already replied
93+
to, which means this time the input was edited.
94+
"""
6595
modify_message = None
6696
if message.id in replies_cache:
6797
modify_message = replies_cache[message.id]
6898

6999
embed = discord.Embed()
70100
if len(modules) > 1:
71101
content = " " # force removal when modifying
72-
await self._embed_modules(embed, modules)
73-
elif len(modules) > 0:
102+
await self.embed_modules(embed, modules)
103+
elif len(modules) == 1:
74104
code, title = modules[0]
75105
content = await self.format_course(code, title)
76106
else:
@@ -88,11 +118,13 @@ async def post_modules(self, message, modules):
88118
content, embed=embed if len(modules) > 1 else None
89119
)
90120

91-
async def on_message(self, message):
92-
await self.do_mentions(message)
121+
async def on_message(self, message: discord.Message) -> None:
122+
await self.process_mentions(message)
93123

94-
async def on_message_edit(self, before, after):
95-
await self.do_mentions(after)
124+
async def on_message_edit(
125+
self, before: discord.Message, after: discord.Message
126+
) -> None:
127+
await self.process_mentions(after)
96128

97129

98130
def main():

0 commit comments

Comments
 (0)