Skip to content

Commit dc716cc

Browse files
GreyElainafrankie567
authored andcommitted
Allow use custom session class
1 parent 5971a48 commit dc716cc

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

httpx_ws/_api.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
JSONMode = typing.Literal["text", "binary"]
3030
TaskFunction = typing.TypeVar("TaskFunction")
3131
TaskResult = typing.TypeVar("TaskResult")
32+
SyncSession = typing.TypeVar("SyncSession", bound="WebSocketSession")
33+
AsyncSession = typing.TypeVar("AsyncSession", bound="AsyncWebSocketSession")
3234

3335
DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536
3436
DEFAULT_QUEUE_SIZE = 512
@@ -1074,23 +1076,25 @@ def _connect_ws(
10741076
float
10751077
] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS,
10761078
subprotocols: typing.Optional[list[str]] = None,
1079+
session_class: type[SyncSession] = WebSocketSession,
10771080
**kwargs: typing.Any,
1078-
) -> typing.Generator[WebSocketSession, None, None]:
1081+
) -> typing.Generator[SyncSession, None, None]:
10791082
headers = kwargs.pop("headers", {})
10801083
headers.update(_get_headers(subprotocols))
10811084

10821085
with client.stream("GET", url, headers=headers, **kwargs) as response:
10831086
if response.status_code != 101:
10841087
raise WebSocketUpgradeError(response)
10851088

1086-
with WebSocketSession(
1089+
session = session_class(
10871090
response.extensions["network_stream"],
10881091
max_message_size_bytes=max_message_size_bytes,
10891092
queue_size=queue_size,
10901093
keepalive_ping_interval_seconds=keepalive_ping_interval_seconds,
10911094
keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds,
10921095
response=response,
1093-
) as session:
1096+
)
1097+
with session:
10941098
yield session
10951099

10961100

@@ -1108,8 +1112,9 @@ def connect_ws(
11081112
float
11091113
] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS,
11101114
subprotocols: typing.Optional[list[str]] = None,
1115+
session_class: type[SyncSession] = WebSocketSession,
11111116
**kwargs: typing.Any,
1112-
) -> typing.Generator[WebSocketSession, None, None]:
1117+
) -> typing.Generator[SyncSession, None, None]:
11131118
"""
11141119
Start a sync WebSocket session.
11151120
@@ -1176,6 +1181,7 @@ def connect_ws(
11761181
keepalive_ping_interval_seconds=keepalive_ping_interval_seconds,
11771182
keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds,
11781183
subprotocols=subprotocols,
1184+
session_class=session_class,
11791185
**kwargs,
11801186
) as websocket:
11811187
yield websocket
@@ -1188,6 +1194,7 @@ def connect_ws(
11881194
keepalive_ping_interval_seconds=keepalive_ping_interval_seconds,
11891195
keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds,
11901196
subprotocols=subprotocols,
1197+
session_class=session_class,
11911198
**kwargs,
11921199
) as websocket:
11931200
yield websocket
@@ -1207,23 +1214,25 @@ async def _aconnect_ws(
12071214
float
12081215
] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS,
12091216
subprotocols: typing.Optional[list[str]] = None,
1217+
session_class: type[AsyncSession] = AsyncWebSocketSession,
12101218
**kwargs: typing.Any,
1211-
) -> typing.AsyncGenerator[AsyncWebSocketSession, None]:
1219+
) -> typing.AsyncGenerator[AsyncSession, None]:
12121220
headers = kwargs.pop("headers", {})
12131221
headers.update(_get_headers(subprotocols))
12141222

12151223
async with client.stream("GET", url, headers=headers, **kwargs) as response:
12161224
if response.status_code != 101:
12171225
raise WebSocketUpgradeError(response)
12181226

1219-
async with AsyncWebSocketSession(
1227+
session = session_class(
12201228
response.extensions["network_stream"],
12211229
max_message_size_bytes=max_message_size_bytes,
12221230
queue_size=queue_size,
12231231
keepalive_ping_interval_seconds=keepalive_ping_interval_seconds,
12241232
keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds,
12251233
response=response,
1226-
) as session:
1234+
)
1235+
async with session:
12271236
yield session
12281237

12291238

@@ -1241,8 +1250,9 @@ async def aconnect_ws(
12411250
float
12421251
] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS,
12431252
subprotocols: typing.Optional[list[str]] = None,
1253+
session_class: type[AsyncSession] = AsyncWebSocketSession,
12441254
**kwargs: typing.Any,
1245-
) -> typing.AsyncGenerator[AsyncWebSocketSession, None]:
1255+
) -> typing.AsyncGenerator[AsyncSession, None]:
12461256
"""
12471257
Start an async WebSocket session.
12481258
@@ -1309,6 +1319,7 @@ async def aconnect_ws(
13091319
keepalive_ping_interval_seconds=keepalive_ping_interval_seconds,
13101320
keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds,
13111321
subprotocols=subprotocols,
1322+
session_class=session_class,
13121323
**kwargs,
13131324
) as websocket:
13141325
yield websocket
@@ -1321,6 +1332,7 @@ async def aconnect_ws(
13211332
keepalive_ping_interval_seconds=keepalive_ping_interval_seconds,
13221333
keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds,
13231334
subprotocols=subprotocols,
1335+
session_class=session_class,
13241336
**kwargs,
13251337
) as websocket:
13261338
yield websocket

0 commit comments

Comments
 (0)