Skip to content

Commit dcfb914

Browse files
pgjonesKriechi
authored andcommittedMay 15, 2019
Add type hints and include mypy in build
This fully type hints the wsproto codebase and uses mypy to ensure the type hints are added and correct. This has identified some potential bugs, see 582d052 The additional linting disable is because the mypy TYPE_CHECKING is not understood by pylint.
1 parent 588cff1 commit dcfb914

20 files changed

+643
-512
lines changed
 

‎.prospector.yml

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pylint:
22
disable:
3+
- cyclic-import
34
- unused-argument
45
- useless-object-inheritance
56

‎setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ combine_as_imports=True
2525
force_grid_wrap=0
2626
include_trailing_comma=True
2727
known_first_party=wsproto, test
28-
known_third_party=h11, pytest
28+
known_third_party=h11, pytest, _pytest
2929
line_length=88
3030
multi_line_output=3
3131
no_lines_before=LOCALFOLDER

‎test/helpers.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
1+
from typing import Optional, Union
2+
13
from wsproto.extensions import Extension
24

35

46
class FakeExtension(Extension):
57
name = "fake"
68

7-
def __init__(self, offer_response=None, accept_response=None):
9+
def __init__(
10+
self,
11+
offer_response: Optional[Union[bool, str]] = None,
12+
accept_response: Optional[Union[bool, str]] = None,
13+
) -> None:
814
self.offer_response = offer_response
9-
self.accepted_offer = None
10-
self.offered = None
15+
self.accepted_offer: Optional[str] = None
16+
self.offered: Optional[str] = None
1117
self.accept_response = accept_response
1218

13-
def offer(self):
19+
def offer(self) -> Union[bool, str]:
1420
return self.offer_response
1521

16-
def finalize(self, offer):
22+
def finalize(self, offer: str) -> None:
1723
self.accepted_offer = offer
1824

19-
def accept(self, offer):
25+
def accept(self, offer: str) -> Union[bool, str]:
2026
self.offered = offer
2127
return self.accept_response

‎test/test_client.py

+43-33
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
# These tests test the behaviours expected of wsproto in when the
22
# connection is a client.
3+
from typing import List, Optional, Tuple
34

45
import h11
56
import pytest
67

78
from wsproto import WSConnection
89
from wsproto.connection import CLIENT
9-
from wsproto.events import AcceptConnection, RejectConnection, RejectData, Request
10+
from wsproto.events import (
11+
AcceptConnection,
12+
Event,
13+
RejectConnection,
14+
RejectData,
15+
Request,
16+
)
17+
from wsproto.extensions import Extension
1018
from wsproto.frame_protocol import CloseReason
19+
from wsproto.typing import Headers
1120
from wsproto.utilities import (
1221
generate_accept_token,
1322
normed_header_dict,
@@ -16,15 +25,14 @@
1625
from .helpers import FakeExtension
1726

1827

19-
def _make_connection_request(request):
20-
# type: (Request) -> h11.Request
28+
def _make_connection_request(request: Request) -> h11.Request:
2129
client = WSConnection(CLIENT)
2230
server = h11.Connection(h11.SERVER)
2331
server.receive_data(client.send(request))
2432
return server.next_event()
2533

2634

27-
def test_connection_request():
35+
def test_connection_request() -> None:
2836
request = _make_connection_request(Request(host="localhost", target="/"))
2937

3038
assert request.http_version == b"1.1"
@@ -38,7 +46,7 @@ def test_connection_request():
3846
assert b"sec-websocket-key" in headers
3947

4048

41-
def test_connection_request_additional_headers():
49+
def test_connection_request_additional_headers() -> None:
4250
request = _make_connection_request(
4351
Request(
4452
host="localhost",
@@ -52,40 +60,40 @@ def test_connection_request_additional_headers():
5260
assert headers[b"x-bar"] == b"Foo"
5361

5462

55-
def test_connection_request_simple_extension():
63+
def test_connection_request_simple_extension() -> None:
5664
extension = FakeExtension(offer_response=True)
5765
request = _make_connection_request(
58-
Request(host="localhost", target="/", extensions=[extension])
66+
Request(host="localhost", target="/", extensions=[extension]) # type: ignore
5967
)
6068

6169
headers = normed_header_dict(request.headers)
6270
assert headers[b"sec-websocket-extensions"] == extension.name.encode("ascii")
6371

6472

65-
def test_connection_request_simple_extension_no_offer():
73+
def test_connection_request_simple_extension_no_offer() -> None:
6674
extension = FakeExtension(offer_response=False)
6775
request = _make_connection_request(
68-
Request(host="localhost", target="/", extensions=[extension])
76+
Request(host="localhost", target="/", extensions=[extension]) # type: ignore
6977
)
7078

7179
headers = normed_header_dict(request.headers)
7280
assert b"sec-websocket-extensions" not in headers
7381

7482

75-
def test_connection_request_parametrised_extension():
83+
def test_connection_request_parametrised_extension() -> None:
7684
extension = FakeExtension(offer_response="parameter1=value1; parameter2=value2")
7785
request = _make_connection_request(
78-
Request(host="localhost", target="/", extensions=[extension])
86+
Request(host="localhost", target="/", extensions=[extension]) # type: ignore
7987
)
8088

8189
headers = normed_header_dict(request.headers)
8290
assert headers[b"sec-websocket-extensions"] == b"%s; %s" % (
8391
extension.name.encode("ascii"),
84-
extension.offer_response.encode("ascii"),
92+
extension.offer_response.encode("ascii"), # type: ignore
8593
)
8694

8795

88-
def test_connection_request_subprotocols():
96+
def test_connection_request_subprotocols() -> None:
8997
request = _make_connection_request(
9098
Request(host="localhost", target="/", subprotocols=["one", "two"])
9199
)
@@ -95,12 +103,12 @@ def test_connection_request_subprotocols():
95103

96104

97105
def _make_handshake(
98-
response_status,
99-
response_headers,
100-
subprotocols=None,
101-
extensions=None,
102-
auto_accept_key=True,
103-
):
106+
response_status: int,
107+
response_headers: Headers,
108+
subprotocols: Optional[List[str]] = None,
109+
extensions: Optional[List[Extension]] = None,
110+
auto_accept_key: bool = True,
111+
) -> List[Event]:
104112
client = WSConnection(CLIENT)
105113
server = h11.Connection(h11.SERVER)
106114
server.receive_data(
@@ -130,22 +138,22 @@ def _make_handshake(
130138
return list(client.events())
131139

132140

133-
def test_handshake():
141+
def test_handshake() -> None:
134142
events = _make_handshake(
135143
101, [(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket")]
136144
)
137145
assert events == [AcceptConnection()]
138146

139147

140-
def test_broken_handshake():
148+
def test_broken_handshake() -> None:
141149
events = _make_handshake(
142150
102, [(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket")]
143151
)
144152
assert isinstance(events[0], RejectConnection)
145153
assert events[0].status_code == 102
146154

147155

148-
def test_handshake_extra_accept_headers():
156+
def test_handshake_extra_accept_headers() -> None:
149157
events = _make_handshake(
150158
101,
151159
[(b"connection", b"Upgrade"), (b"upgrade", b"WebSocket"), (b"X-Foo", b"bar")],
@@ -154,20 +162,20 @@ def test_handshake_extra_accept_headers():
154162

155163

156164
@pytest.mark.parametrize("extra_headers", [[], [(b"connection", b"Keep-Alive")]])
157-
def test_handshake_response_broken_connection_header(extra_headers):
165+
def test_handshake_response_broken_connection_header(extra_headers: Headers) -> None:
158166
with pytest.raises(RemoteProtocolError) as excinfo:
159167
events = _make_handshake(101, [(b"upgrade", b"WebSocket")] + extra_headers)
160168
assert str(excinfo.value) == "Missing header, 'Connection: Upgrade'"
161169

162170

163171
@pytest.mark.parametrize("extra_headers", [[], [(b"upgrade", b"h2")]])
164-
def test_handshake_response_broken_upgrade_header(extra_headers):
172+
def test_handshake_response_broken_upgrade_header(extra_headers: Headers) -> None:
165173
with pytest.raises(RemoteProtocolError) as excinfo:
166174
events = _make_handshake(101, [(b"connection", b"Upgrade")] + extra_headers)
167175
assert str(excinfo.value) == "Missing header, 'Upgrade: WebSocket'"
168176

169177

170-
def test_handshake_response_missing_websocket_key_header():
178+
def test_handshake_response_missing_websocket_key_header() -> None:
171179
with pytest.raises(RemoteProtocolError) as excinfo:
172180
events = _make_handshake(
173181
101,
@@ -177,7 +185,7 @@ def test_handshake_response_missing_websocket_key_header():
177185
assert str(excinfo.value) == "Bad accept token"
178186

179187

180-
def test_handshake_with_subprotocol():
188+
def test_handshake_with_subprotocol() -> None:
181189
events = _make_handshake(
182190
101,
183191
[
@@ -190,7 +198,7 @@ def test_handshake_with_subprotocol():
190198
assert events == [AcceptConnection(subprotocol="one")]
191199

192200

193-
def test_handshake_bad_subprotocol():
201+
def test_handshake_bad_subprotocol() -> None:
194202
with pytest.raises(RemoteProtocolError) as excinfo:
195203
events = _make_handshake(
196204
101,
@@ -203,7 +211,7 @@ def test_handshake_bad_subprotocol():
203211
assert str(excinfo.value) == "unrecognized subprotocol new"
204212

205213

206-
def test_handshake_with_extension():
214+
def test_handshake_with_extension() -> None:
207215
extension = FakeExtension(offer_response=True)
208216
events = _make_handshake(
209217
101,
@@ -217,7 +225,7 @@ def test_handshake_with_extension():
217225
assert events == [AcceptConnection(extensions=[extension])]
218226

219227

220-
def test_handshake_bad_extension():
228+
def test_handshake_bad_extension() -> None:
221229
with pytest.raises(RemoteProtocolError) as excinfo:
222230
events = _make_handshake(
223231
101,
@@ -230,15 +238,17 @@ def test_handshake_bad_extension():
230238
assert str(excinfo.value) == "unrecognized extension bad"
231239

232240

233-
def test_protocol_error():
241+
def test_protocol_error() -> None:
234242
client = WSConnection(CLIENT)
235243
client.send(Request(host="localhost", target="/"))
236244
with pytest.raises(RemoteProtocolError) as excinfo:
237245
client.receive_data(b"broken nonsense\r\n\r\n")
238246
assert str(excinfo.value) == "Bad HTTP message"
239247

240248

241-
def _make_handshake_rejection(status_code, body=None):
249+
def _make_handshake_rejection(
250+
status_code: int, body: Optional[bytes] = None
251+
) -> List[Event]:
242252
client = WSConnection(CLIENT)
243253
server = h11.Connection(h11.SERVER)
244254
server.receive_data(client.send(Request(host="localhost", target="/")))
@@ -255,7 +265,7 @@ def _make_handshake_rejection(status_code, body=None):
255265
return list(client.events())
256266

257267

258-
def test_handshake_rejection():
268+
def test_handshake_rejection() -> None:
259269
events = _make_handshake_rejection(400)
260270
assert events == [
261271
RejectConnection(
@@ -265,7 +275,7 @@ def test_handshake_rejection():
265275
]
266276

267277

268-
def test_handshake_rejection_with_body():
278+
def test_handshake_rejection_with_body() -> None:
269279
events = _make_handshake_rejection(400, b"Hello")
270280
assert events == [
271281
RejectConnection(

‎test/test_connection.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@pytest.mark.parametrize("client_sends", [True, False])
2323
@pytest.mark.parametrize("final", [True, False])
24-
def test_send_message(client_sends, final):
24+
def test_send_message(client_sends: bool, final: bool) -> None:
2525
client = Connection(CLIENT)
2626
server = Connection(SERVER)
2727

@@ -33,7 +33,9 @@ def test_send_message(client_sends, final):
3333
remote = client
3434

3535
data = b"x" * 23
36-
remote.receive_data(local.send(BytesMessage(data=data, message_finished=final)))
36+
remote.receive_data(
37+
local.send(BytesMessage(data=data, message_finished=final)) # type: ignore
38+
)
3739
event = next(remote.events())
3840
assert isinstance(event, BytesMessage)
3941
assert event.data == data
@@ -45,7 +47,7 @@ def test_send_message(client_sends, final):
4547
"code, reason",
4648
[(CloseReason.NORMAL_CLOSURE, "bye"), (CloseReason.GOING_AWAY, "👋👋")],
4749
)
48-
def test_closure(client_sends, code, reason):
50+
def test_closure(client_sends: bool, code: CloseReason, reason: str) -> None:
4951
client = Connection(CLIENT)
5052
server = Connection(SERVER)
5153

@@ -75,7 +77,7 @@ def test_closure(client_sends, code, reason):
7577
assert local.state is ConnectionState.CLOSED
7678

7779

78-
def test_abnormal_closure():
80+
def test_abnormal_closure() -> None:
7981
client = Connection(CLIENT)
8082
client.receive_data(None)
8183
event = next(client.events())
@@ -84,15 +86,15 @@ def test_abnormal_closure():
8486
assert client.state is ConnectionState.CLOSED
8587

8688

87-
def test_close_whilst_closing():
89+
def test_close_whilst_closing() -> None:
8890
client = Connection(CLIENT)
8991
client.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE))
9092
with pytest.raises(LocalProtocolError):
9193
client.send(CloseConnection(code=CloseReason.NORMAL_CLOSURE))
9294

9395

9496
@pytest.mark.parametrize("client_sends", [True, False])
95-
def test_ping_pong(client_sends):
97+
def test_ping_pong(client_sends: bool) -> None:
9698
client = Connection(CLIENT)
9799
server = Connection(SERVER)
98100

@@ -115,7 +117,7 @@ def test_ping_pong(client_sends):
115117
assert event.payload == payload
116118

117119

118-
def test_unsolicited_pong():
120+
def test_unsolicited_pong() -> None:
119121
client = Connection(CLIENT)
120122
server = Connection(SERVER)
121123

@@ -127,20 +129,22 @@ def test_unsolicited_pong():
127129

128130

129131
@pytest.mark.parametrize("split_message", [True, False])
130-
def test_data(split_message):
132+
def test_data(split_message: bool) -> None:
131133
client = Connection(CLIENT)
132134
server = Connection(SERVER)
133135

134136
data = "ƒñö®∂😎"
135137
server.receive_data(
136-
client.send(TextMessage(data=data, message_finished=not split_message))
138+
client.send(
139+
TextMessage(data=data, message_finished=not split_message) # type: ignore
140+
)
137141
)
138142
event = next(server.events())
139143
assert isinstance(event, TextMessage)
140144
assert event.message_finished is not split_message
141145

142146

143-
def test_frame_protocol_gets_fed_garbage():
147+
def test_frame_protocol_gets_fed_garbage() -> None:
144148
client = Connection(CLIENT)
145149

146150
payload = b"x" * 23
@@ -152,13 +156,13 @@ def test_frame_protocol_gets_fed_garbage():
152156
assert event.code == CloseReason.PROTOCOL_ERROR
153157

154158

155-
def test_send_invalid_event():
159+
def test_send_invalid_event() -> None:
156160
client = Connection(CLIENT)
157161
with pytest.raises(LocalProtocolError):
158162
client.send(Request(target="/", host="wsproto"))
159163

160164

161-
def test_receive_data_when_closed():
165+
def test_receive_data_when_closed() -> None:
162166
client = Connection(CLIENT)
163167
client._state = ConnectionState.CLOSED
164168
with pytest.raises(LocalProtocolError):

0 commit comments

Comments
 (0)
Please sign in to comment.