1
1
# These tests test the behaviours expected of wsproto in when the
2
2
# connection is a client.
3
+ from typing import List , Optional , Tuple
3
4
4
5
import h11
5
6
import pytest
6
7
7
8
from wsproto import WSConnection
8
9
from wsproto .connection import CLIENT
9
- from wsproto .events import AcceptConnection , RejectConnection , RejectData , Request
10
+ from wsproto .events import (
11
+ AcceptConnection ,
12
+ Event ,
13
+ RejectConnection ,
14
+ RejectData ,
15
+ Request ,
16
+ )
17
+ from wsproto .extensions import Extension
10
18
from wsproto .frame_protocol import CloseReason
19
+ from wsproto .typing import Headers
11
20
from wsproto .utilities import (
12
21
generate_accept_token ,
13
22
normed_header_dict ,
16
25
from .helpers import FakeExtension
17
26
18
27
19
- def _make_connection_request (request ):
20
- # type: (Request) -> h11.Request
28
+ def _make_connection_request (request : Request ) -> h11 .Request :
21
29
client = WSConnection (CLIENT )
22
30
server = h11 .Connection (h11 .SERVER )
23
31
server .receive_data (client .send (request ))
24
32
return server .next_event ()
25
33
26
34
27
- def test_connection_request ():
35
+ def test_connection_request () -> None :
28
36
request = _make_connection_request (Request (host = "localhost" , target = "/" ))
29
37
30
38
assert request .http_version == b"1.1"
@@ -38,7 +46,7 @@ def test_connection_request():
38
46
assert b"sec-websocket-key" in headers
39
47
40
48
41
- def test_connection_request_additional_headers ():
49
+ def test_connection_request_additional_headers () -> None :
42
50
request = _make_connection_request (
43
51
Request (
44
52
host = "localhost" ,
@@ -52,40 +60,40 @@ def test_connection_request_additional_headers():
52
60
assert headers [b"x-bar" ] == b"Foo"
53
61
54
62
55
- def test_connection_request_simple_extension ():
63
+ def test_connection_request_simple_extension () -> None :
56
64
extension = FakeExtension (offer_response = True )
57
65
request = _make_connection_request (
58
- Request (host = "localhost" , target = "/" , extensions = [extension ])
66
+ Request (host = "localhost" , target = "/" , extensions = [extension ]) # type: ignore
59
67
)
60
68
61
69
headers = normed_header_dict (request .headers )
62
70
assert headers [b"sec-websocket-extensions" ] == extension .name .encode ("ascii" )
63
71
64
72
65
- def test_connection_request_simple_extension_no_offer ():
73
+ def test_connection_request_simple_extension_no_offer () -> None :
66
74
extension = FakeExtension (offer_response = False )
67
75
request = _make_connection_request (
68
- Request (host = "localhost" , target = "/" , extensions = [extension ])
76
+ Request (host = "localhost" , target = "/" , extensions = [extension ]) # type: ignore
69
77
)
70
78
71
79
headers = normed_header_dict (request .headers )
72
80
assert b"sec-websocket-extensions" not in headers
73
81
74
82
75
- def test_connection_request_parametrised_extension ():
83
+ def test_connection_request_parametrised_extension () -> None :
76
84
extension = FakeExtension (offer_response = "parameter1=value1; parameter2=value2" )
77
85
request = _make_connection_request (
78
- Request (host = "localhost" , target = "/" , extensions = [extension ])
86
+ Request (host = "localhost" , target = "/" , extensions = [extension ]) # type: ignore
79
87
)
80
88
81
89
headers = normed_header_dict (request .headers )
82
90
assert headers [b"sec-websocket-extensions" ] == b"%s; %s" % (
83
91
extension .name .encode ("ascii" ),
84
- extension .offer_response .encode ("ascii" ),
92
+ extension .offer_response .encode ("ascii" ), # type: ignore
85
93
)
86
94
87
95
88
- def test_connection_request_subprotocols ():
96
+ def test_connection_request_subprotocols () -> None :
89
97
request = _make_connection_request (
90
98
Request (host = "localhost" , target = "/" , subprotocols = ["one" , "two" ])
91
99
)
@@ -95,12 +103,12 @@ def test_connection_request_subprotocols():
95
103
96
104
97
105
def _make_handshake (
98
- response_status ,
99
- response_headers ,
100
- subprotocols = None ,
101
- extensions = None ,
102
- auto_accept_key = True ,
103
- ):
106
+ response_status : int ,
107
+ response_headers : Headers ,
108
+ subprotocols : Optional [ List [ str ]] = None ,
109
+ extensions : Optional [ List [ Extension ]] = None ,
110
+ auto_accept_key : bool = True ,
111
+ ) -> List [ Event ] :
104
112
client = WSConnection (CLIENT )
105
113
server = h11 .Connection (h11 .SERVER )
106
114
server .receive_data (
@@ -130,22 +138,22 @@ def _make_handshake(
130
138
return list (client .events ())
131
139
132
140
133
- def test_handshake ():
141
+ def test_handshake () -> None :
134
142
events = _make_handshake (
135
143
101 , [(b"connection" , b"Upgrade" ), (b"upgrade" , b"WebSocket" )]
136
144
)
137
145
assert events == [AcceptConnection ()]
138
146
139
147
140
- def test_broken_handshake ():
148
+ def test_broken_handshake () -> None :
141
149
events = _make_handshake (
142
150
102 , [(b"connection" , b"Upgrade" ), (b"upgrade" , b"WebSocket" )]
143
151
)
144
152
assert isinstance (events [0 ], RejectConnection )
145
153
assert events [0 ].status_code == 102
146
154
147
155
148
- def test_handshake_extra_accept_headers ():
156
+ def test_handshake_extra_accept_headers () -> None :
149
157
events = _make_handshake (
150
158
101 ,
151
159
[(b"connection" , b"Upgrade" ), (b"upgrade" , b"WebSocket" ), (b"X-Foo" , b"bar" )],
@@ -154,20 +162,20 @@ def test_handshake_extra_accept_headers():
154
162
155
163
156
164
@pytest .mark .parametrize ("extra_headers" , [[], [(b"connection" , b"Keep-Alive" )]])
157
- def test_handshake_response_broken_connection_header (extra_headers ) :
165
+ def test_handshake_response_broken_connection_header (extra_headers : Headers ) -> None :
158
166
with pytest .raises (RemoteProtocolError ) as excinfo :
159
167
events = _make_handshake (101 , [(b"upgrade" , b"WebSocket" )] + extra_headers )
160
168
assert str (excinfo .value ) == "Missing header, 'Connection: Upgrade'"
161
169
162
170
163
171
@pytest .mark .parametrize ("extra_headers" , [[], [(b"upgrade" , b"h2" )]])
164
- def test_handshake_response_broken_upgrade_header (extra_headers ) :
172
+ def test_handshake_response_broken_upgrade_header (extra_headers : Headers ) -> None :
165
173
with pytest .raises (RemoteProtocolError ) as excinfo :
166
174
events = _make_handshake (101 , [(b"connection" , b"Upgrade" )] + extra_headers )
167
175
assert str (excinfo .value ) == "Missing header, 'Upgrade: WebSocket'"
168
176
169
177
170
- def test_handshake_response_missing_websocket_key_header ():
178
+ def test_handshake_response_missing_websocket_key_header () -> None :
171
179
with pytest .raises (RemoteProtocolError ) as excinfo :
172
180
events = _make_handshake (
173
181
101 ,
@@ -177,7 +185,7 @@ def test_handshake_response_missing_websocket_key_header():
177
185
assert str (excinfo .value ) == "Bad accept token"
178
186
179
187
180
- def test_handshake_with_subprotocol ():
188
+ def test_handshake_with_subprotocol () -> None :
181
189
events = _make_handshake (
182
190
101 ,
183
191
[
@@ -190,7 +198,7 @@ def test_handshake_with_subprotocol():
190
198
assert events == [AcceptConnection (subprotocol = "one" )]
191
199
192
200
193
- def test_handshake_bad_subprotocol ():
201
+ def test_handshake_bad_subprotocol () -> None :
194
202
with pytest .raises (RemoteProtocolError ) as excinfo :
195
203
events = _make_handshake (
196
204
101 ,
@@ -203,7 +211,7 @@ def test_handshake_bad_subprotocol():
203
211
assert str (excinfo .value ) == "unrecognized subprotocol new"
204
212
205
213
206
- def test_handshake_with_extension ():
214
+ def test_handshake_with_extension () -> None :
207
215
extension = FakeExtension (offer_response = True )
208
216
events = _make_handshake (
209
217
101 ,
@@ -217,7 +225,7 @@ def test_handshake_with_extension():
217
225
assert events == [AcceptConnection (extensions = [extension ])]
218
226
219
227
220
- def test_handshake_bad_extension ():
228
+ def test_handshake_bad_extension () -> None :
221
229
with pytest .raises (RemoteProtocolError ) as excinfo :
222
230
events = _make_handshake (
223
231
101 ,
@@ -230,15 +238,17 @@ def test_handshake_bad_extension():
230
238
assert str (excinfo .value ) == "unrecognized extension bad"
231
239
232
240
233
- def test_protocol_error ():
241
+ def test_protocol_error () -> None :
234
242
client = WSConnection (CLIENT )
235
243
client .send (Request (host = "localhost" , target = "/" ))
236
244
with pytest .raises (RemoteProtocolError ) as excinfo :
237
245
client .receive_data (b"broken nonsense\r \n \r \n " )
238
246
assert str (excinfo .value ) == "Bad HTTP message"
239
247
240
248
241
- def _make_handshake_rejection (status_code , body = None ):
249
+ def _make_handshake_rejection (
250
+ status_code : int , body : Optional [bytes ] = None
251
+ ) -> List [Event ]:
242
252
client = WSConnection (CLIENT )
243
253
server = h11 .Connection (h11 .SERVER )
244
254
server .receive_data (client .send (Request (host = "localhost" , target = "/" )))
@@ -255,7 +265,7 @@ def _make_handshake_rejection(status_code, body=None):
255
265
return list (client .events ())
256
266
257
267
258
- def test_handshake_rejection ():
268
+ def test_handshake_rejection () -> None :
259
269
events = _make_handshake_rejection (400 )
260
270
assert events == [
261
271
RejectConnection (
@@ -265,7 +275,7 @@ def test_handshake_rejection():
265
275
]
266
276
267
277
268
- def test_handshake_rejection_with_body ():
278
+ def test_handshake_rejection_with_body () -> None :
269
279
events = _make_handshake_rejection (400 , b"Hello" )
270
280
assert events == [
271
281
RejectConnection (
0 commit comments