Skip to content

Commit bfa44fe

Browse files
bluetechpgjones
authored andcommitted
Enable mypy strict-optional
This ensures that values are properly checked for None.
1 parent db92ff8 commit bfa44fe

10 files changed

+54
-31
lines changed

Diff for: setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ disallow_untyped_defs = True
4646
; implicit_reexport = False
4747
no_implicit_optional = True
4848
strict_equality = True
49-
strict_optional = False
49+
strict_optional = True
5050
warn_redundant_casts = True
5151
# warn_return_any = True
5252
warn_unused_configs = True

Diff for: test/helpers.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ def __init__(
1717
self.accept_response = accept_response
1818

1919
def offer(self) -> Union[bool, str]:
20+
assert self.offer_response is not None
2021
return self.offer_response
2122

2223
def finalize(self, offer: str) -> None:
2324
self.accepted_offer = offer
2425

2526
def accept(self, offer: str) -> Union[bool, str]:
27+
assert self.accept_response is not None
2628
self.offered = offer
2729
return self.accept_response

Diff for: test/test_extensions.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,33 @@ def test_offer(self) -> None:
1313

1414
def test_accept(self) -> None:
1515
ext = wpext.Extension()
16-
assert ext.accept(None) is None
16+
offer = "myext"
17+
assert ext.accept(offer) is None
1718

1819
def test_finalize(self) -> None:
1920
ext = wpext.Extension()
20-
ext.finalize(None)
21+
offer = "myext"
22+
ext.finalize(offer)
2123

2224
def test_frame_inbound_header(self) -> None:
2325
ext = wpext.Extension()
24-
result = ext.frame_inbound_header(None, None, None, None)
26+
result = ext.frame_inbound_header(None, None, None, None) # type: ignore
2527
assert result == fp.RsvBits(False, False, False)
2628

2729
def test_frame_inbound_payload_data(self) -> None:
2830
ext = wpext.Extension()
2931
data = b""
30-
assert ext.frame_inbound_payload_data(None, data) == data
32+
assert ext.frame_inbound_payload_data(None, data) == data # type: ignore
3133

3234
def test_frame_inbound_complete(self) -> None:
3335
ext = wpext.Extension()
34-
assert ext.frame_inbound_complete(None, None) is None
36+
assert ext.frame_inbound_complete(None, None) is None # type: ignore
3537

3638
def test_frame_outbound(self) -> None:
3739
ext = wpext.Extension()
3840
rsv = fp.RsvBits(True, True, True)
3941
data = b""
40-
assert ext.frame_outbound(None, None, rsv, data, None) == (rsv, data)
42+
assert ext.frame_outbound(None, None, rsv, data, None) == ( # type: ignore
43+
rsv,
44+
data,
45+
)

Diff for: test/test_frame_protocol.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,10 @@ def test_long_text_message(self) -> None:
917917
assert frame.payload == payload
918918

919919
def _close_test(
920-
self, code: Optional[int], reason: str = None, reason_bytes: bytes = None
920+
self,
921+
code: Optional[int],
922+
reason: Optional[str] = None,
923+
reason_bytes: Optional[bytes] = None,
921924
) -> None:
922925
payload = b""
923926
if code:

Diff for: test/test_permessage_deflate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def test_client_inbound_compressed_multiple_data_frames(self, client: bool) -> N
296296
data += result4
297297

298298
result5 = ext.frame_inbound_complete(proto, True)
299-
assert not isinstance(result5, fp.CloseReason)
299+
assert isinstance(result5, bytes)
300300
data += result5
301301

302302
assert data == payload

Diff for: wsproto/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def send(self, event: Event) -> bytes:
6464
data += self.connection.send(event)
6565
return data
6666

67-
def receive_data(self, data: bytes) -> None:
67+
def receive_data(self, data: Optional[bytes]) -> None:
6868
"""
6969
Feed network data into the connection instance.
7070

Diff for: wsproto/connection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def send(self, event: Event) -> bytes:
109109
raise LocalProtocolError("Event {} cannot be sent.".format(event))
110110
return data
111111

112-
def receive_data(self, data: bytes) -> None:
112+
def receive_data(self, data: Optional[bytes]) -> None:
113113
"""
114114
Pass some received data to the connection for handling.
115115

Diff for: wsproto/extensions.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313

1414

1515
class Extension:
16-
name: Optional[str] = None
16+
name: str
1717

1818
def enabled(self) -> bool:
1919
return False
2020

2121
def offer(self) -> Union[bool, str]:
2222
pass
2323

24-
def accept(self, offer: str) -> Union[bool, str]:
24+
def accept(self, offer: str) -> Optional[Union[bool, str]]:
2525
pass
2626

2727
def finalize(self, offer: str) -> None:
@@ -123,7 +123,7 @@ def finalize(self, offer: str) -> None:
123123

124124
self._enabled = True
125125

126-
def _parse_params(self, params: str) -> Tuple[int, int]:
126+
def _parse_params(self, params: str) -> Tuple[Optional[int], Optional[int]]:
127127
client_max_window_bits = None
128128
server_max_window_bits = None
129129

@@ -198,6 +198,7 @@ def frame_inbound_payload_data(
198198
) -> Union[bytes, CloseReason]:
199199
if not self._inbound_compressed or not self._inbound_is_compressible:
200200
return data
201+
assert self._decompressor is not None
201202

202203
try:
203204
return self._decompressor.decompress(bytes(data))
@@ -215,6 +216,7 @@ def frame_inbound_complete(
215216
if not self._inbound_compressed:
216217
self._inbound_compressed = None
217218
return None
219+
assert self._decompressor is not None
218220

219221
try:
220222
data = self._decompressor.decompress(b"\x00\x00\xff\xff")

Diff for: wsproto/frame_protocol.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ class Header(NamedTuple):
211211
rsv: RsvBits
212212
opcode: Opcode
213213
payload_len: int
214-
masking_key: bytes
214+
masking_key: Optional[bytes]
215215

216216

217217
class Frame(NamedTuple):
@@ -238,7 +238,7 @@ def _truncate_utf8(data: bytes, nbytes: int) -> bytes:
238238

239239

240240
class Buffer:
241-
def __init__(self, initial_bytes: bytes = None) -> None:
241+
def __init__(self, initial_bytes: Optional[bytes] = None) -> None:
242242
self.buffer = bytearray()
243243
self.bytes_used = 0
244244
if initial_bytes:
@@ -255,7 +255,7 @@ def consume_at_most(self, nbytes: int) -> bytes:
255255
self.bytes_used += len(data)
256256
return data
257257

258-
def consume_exactly(self, nbytes: int) -> bytes:
258+
def consume_exactly(self, nbytes: int) -> Optional[bytes]:
259259
if len(self.buffer) - self.bytes_used < nbytes:
260260
return None
261261

@@ -297,7 +297,10 @@ def process_frame(self, frame: Frame) -> Frame:
297297
data = frame.payload
298298
else:
299299
assert isinstance(frame.payload, (bytes, bytearray))
300-
data = self.decode_payload(frame.payload, finished)
300+
try:
301+
data = self.decoder.decode(frame.payload, finished)
302+
except UnicodeDecodeError as exc:
303+
raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA)
301304

302305
frame = Frame(self.opcode, data, frame.frame_finished, finished)
303306

@@ -307,15 +310,11 @@ def process_frame(self, frame: Frame) -> Frame:
307310

308311
return frame
309312

310-
def decode_payload(self, data: bytes, finished: bool) -> str:
311-
try:
312-
return self.decoder.decode(data, finished)
313-
except UnicodeDecodeError as exc:
314-
raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA)
315-
316313

317314
class FrameDecoder:
318-
def __init__(self, client: bool, extensions: List["Extension"] = None) -> None:
315+
def __init__(
316+
self, client: bool, extensions: Optional[List["Extension"]] = None
317+
) -> None:
319318
self.client = client
320319
self.extensions = extensions or []
321320

@@ -330,10 +329,14 @@ def __init__(self, client: bool, extensions: List["Extension"] = None) -> None:
330329
def receive_bytes(self, data: bytes) -> None:
331330
self.buffer.feed(data)
332331

333-
def process_buffer(self) -> Frame:
332+
def process_buffer(self) -> Optional[Frame]:
334333
if not self.header:
335334
if not self.parse_header():
336335
return None
336+
# parse_header() sets these.
337+
assert self.header is not None
338+
assert self.masker is not None
339+
assert self.effective_opcode is not None
337340

338341
if len(self.buffer) < self.payload_required:
339342
return None
@@ -398,8 +401,8 @@ def parse_header(self) -> bool:
398401
raise ParseFailed("Invalid attempt to fragment control frame")
399402

400403
has_mask = bool(data[1] & MASK_MASK)
401-
payload_len = data[1] & PAYLOAD_LEN_MASK
402-
payload_len = self.parse_extended_payload_length(opcode, payload_len)
404+
payload_len_short = data[1] & PAYLOAD_LEN_MASK
405+
payload_len = self.parse_extended_payload_length(opcode, payload_len_short)
403406
if payload_len is None:
404407
self.buffer.rollback()
405408
return False
@@ -429,7 +432,9 @@ def parse_header(self) -> bool:
429432
self.payload_consumed = 0
430433
return True
431434

432-
def parse_extended_payload_length(self, opcode: Opcode, payload_len: int) -> int:
435+
def parse_extended_payload_length(
436+
self, opcode: Opcode, payload_len: int
437+
) -> Optional[int]:
433438
if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL:
434439
raise ParseFailed("Control frame with payload len > 125")
435440
if payload_len == PAYLOAD_LENGTH_TWO_BYTE:
@@ -518,7 +523,7 @@ def _process_close(self, frame: Frame) -> Frame:
518523

519524
return Frame(frame.opcode, data, frame.frame_finished, frame.message_finished)
520525

521-
def _parse_more_gen(self) -> Generator[Frame, None, None]:
526+
def _parse_more_gen(self) -> Generator[Optional[Frame], None, None]:
522527
# Consume as much as we can from self._buffer, yielding events, and
523528
# then yield None when we need more data. Or raise ParseFailed.
524529

Diff for: wsproto/handshake.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def send(self, event: Event) -> bytes:
102102
)
103103
return data
104104

105-
def receive_data(self, data: bytes) -> None:
105+
def receive_data(self, data: Optional[bytes]) -> None:
106106
"""Receive data from the remote.
107107
108108
A list of events that the remote peer triggered by sending
@@ -243,6 +243,8 @@ def _process_connection_request(self, event: h11.Request) -> Request: # noqa: M
243243
return self._initiating_request
244244

245245
def _accept(self, event: AcceptConnection) -> bytes:
246+
# _accept is always called after _process_connection_request.
247+
assert self._initiating_request is not None
246248
request_headers = normed_header_dict(self._initiating_request.extra_headers)
247249

248250
nonce = request_headers[b"sec-websocket-key"]
@@ -354,6 +356,10 @@ def _initiate_connection(self, request: Request) -> bytes:
354356
def _establish_client_connection(
355357
self, event: h11.InformationalResponse
356358
) -> AcceptConnection: # noqa: MC0001
359+
# _establish_client_connection is always called after _initiate_connection.
360+
assert self._initiating_request is not None
361+
assert self._nonce is not None
362+
357363
accept = None
358364
connection_tokens = None
359365
accepts: List[str] = []

0 commit comments

Comments
 (0)