29
29
JSONMode = typing .Literal ["text" , "binary" ]
30
30
TaskFunction = typing .TypeVar ("TaskFunction" )
31
31
TaskResult = typing .TypeVar ("TaskResult" )
32
+ SyncSession = typing .TypeVar ("SyncSession" , bound = "WebSocketSession" )
33
+ AsyncSession = typing .TypeVar ("AsyncSession" , bound = "AsyncWebSocketSession" )
32
34
33
35
DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536
34
36
DEFAULT_QUEUE_SIZE = 512
@@ -1074,23 +1076,25 @@ def _connect_ws(
1074
1076
float
1075
1077
] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS ,
1076
1078
subprotocols : typing .Optional [list [str ]] = None ,
1079
+ session_class : type [SyncSession ] = WebSocketSession ,
1077
1080
** kwargs : typing .Any ,
1078
- ) -> typing .Generator [WebSocketSession , None , None ]:
1081
+ ) -> typing .Generator [SyncSession , None , None ]:
1079
1082
headers = kwargs .pop ("headers" , {})
1080
1083
headers .update (_get_headers (subprotocols ))
1081
1084
1082
1085
with client .stream ("GET" , url , headers = headers , ** kwargs ) as response :
1083
1086
if response .status_code != 101 :
1084
1087
raise WebSocketUpgradeError (response )
1085
1088
1086
- with WebSocketSession (
1089
+ session = session_class (
1087
1090
response .extensions ["network_stream" ],
1088
1091
max_message_size_bytes = max_message_size_bytes ,
1089
1092
queue_size = queue_size ,
1090
1093
keepalive_ping_interval_seconds = keepalive_ping_interval_seconds ,
1091
1094
keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds ,
1092
1095
response = response ,
1093
- ) as session :
1096
+ )
1097
+ with session :
1094
1098
yield session
1095
1099
1096
1100
@@ -1108,8 +1112,9 @@ def connect_ws(
1108
1112
float
1109
1113
] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS ,
1110
1114
subprotocols : typing .Optional [list [str ]] = None ,
1115
+ session_class : type [SyncSession ] = WebSocketSession ,
1111
1116
** kwargs : typing .Any ,
1112
- ) -> typing .Generator [WebSocketSession , None , None ]:
1117
+ ) -> typing .Generator [SyncSession , None , None ]:
1113
1118
"""
1114
1119
Start a sync WebSocket session.
1115
1120
@@ -1176,6 +1181,7 @@ def connect_ws(
1176
1181
keepalive_ping_interval_seconds = keepalive_ping_interval_seconds ,
1177
1182
keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds ,
1178
1183
subprotocols = subprotocols ,
1184
+ session_class = session_class ,
1179
1185
** kwargs ,
1180
1186
) as websocket :
1181
1187
yield websocket
@@ -1188,6 +1194,7 @@ def connect_ws(
1188
1194
keepalive_ping_interval_seconds = keepalive_ping_interval_seconds ,
1189
1195
keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds ,
1190
1196
subprotocols = subprotocols ,
1197
+ session_class = session_class ,
1191
1198
** kwargs ,
1192
1199
) as websocket :
1193
1200
yield websocket
@@ -1207,23 +1214,25 @@ async def _aconnect_ws(
1207
1214
float
1208
1215
] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS ,
1209
1216
subprotocols : typing .Optional [list [str ]] = None ,
1217
+ session_class : type [AsyncSession ] = AsyncWebSocketSession ,
1210
1218
** kwargs : typing .Any ,
1211
- ) -> typing .AsyncGenerator [AsyncWebSocketSession , None ]:
1219
+ ) -> typing .AsyncGenerator [AsyncSession , None ]:
1212
1220
headers = kwargs .pop ("headers" , {})
1213
1221
headers .update (_get_headers (subprotocols ))
1214
1222
1215
1223
async with client .stream ("GET" , url , headers = headers , ** kwargs ) as response :
1216
1224
if response .status_code != 101 :
1217
1225
raise WebSocketUpgradeError (response )
1218
1226
1219
- async with AsyncWebSocketSession (
1227
+ session = session_class (
1220
1228
response .extensions ["network_stream" ],
1221
1229
max_message_size_bytes = max_message_size_bytes ,
1222
1230
queue_size = queue_size ,
1223
1231
keepalive_ping_interval_seconds = keepalive_ping_interval_seconds ,
1224
1232
keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds ,
1225
1233
response = response ,
1226
- ) as session :
1234
+ )
1235
+ async with session :
1227
1236
yield session
1228
1237
1229
1238
@@ -1241,8 +1250,9 @@ async def aconnect_ws(
1241
1250
float
1242
1251
] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS ,
1243
1252
subprotocols : typing .Optional [list [str ]] = None ,
1253
+ session_class : type [AsyncSession ] = AsyncWebSocketSession ,
1244
1254
** kwargs : typing .Any ,
1245
- ) -> typing .AsyncGenerator [AsyncWebSocketSession , None ]:
1255
+ ) -> typing .AsyncGenerator [AsyncSession , None ]:
1246
1256
"""
1247
1257
Start an async WebSocket session.
1248
1258
@@ -1309,6 +1319,7 @@ async def aconnect_ws(
1309
1319
keepalive_ping_interval_seconds = keepalive_ping_interval_seconds ,
1310
1320
keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds ,
1311
1321
subprotocols = subprotocols ,
1322
+ session_class = session_class ,
1312
1323
** kwargs ,
1313
1324
) as websocket :
1314
1325
yield websocket
@@ -1321,6 +1332,7 @@ async def aconnect_ws(
1321
1332
keepalive_ping_interval_seconds = keepalive_ping_interval_seconds ,
1322
1333
keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds ,
1323
1334
subprotocols = subprotocols ,
1335
+ session_class = session_class ,
1324
1336
** kwargs ,
1325
1337
) as websocket :
1326
1338
yield websocket
0 commit comments