4
4
from dataclasses import dataclass
5
5
from datetime import datetime , timedelta
6
6
from operator import attrgetter
7
- from typing import TYPE_CHECKING , Any , Callable , List , Optional , Tuple , Union , cast
7
+ from typing import TYPE_CHECKING , Any , Callable , List , Optional , Tuple , Type , TypeVar , Union , cast
8
8
from urllib .parse import parse_qs , urlparse
9
9
from uuid import uuid4
10
10
@@ -217,6 +217,9 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]
217
217
return await asyncio .gather (* [self ._get_job_def (job_id , int (score )) for job_id , score in jobs ])
218
218
219
219
220
+ TArqRedis = TypeVar ('TArqRedis' , bound = ArqRedis )
221
+
222
+
220
223
async def create_pool (
221
224
settings_ : Optional [RedisSettings ] = None ,
222
225
* ,
@@ -225,6 +228,7 @@ async def create_pool(
225
228
job_deserializer : Optional [Deserializer ] = None ,
226
229
default_queue_name : str = default_queue_name ,
227
230
expires_extra_ms : int = expires_extra_ms ,
231
+ arq_redis_cls : Type [TArqRedis ] = ArqRedis ,
228
232
) -> ArqRedis :
229
233
"""
230
234
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
@@ -238,19 +242,19 @@ async def create_pool(
238
242
239
243
if settings .sentinel :
240
244
241
- def pool_factory (* args : Any , ** kwargs : Any ) -> ArqRedis :
245
+ def pool_factory (* args : Any , ** kwargs : Any ) -> TArqRedis :
242
246
client = Sentinel ( # type: ignore[misc]
243
247
* args ,
244
248
sentinels = settings .host ,
245
249
ssl = settings .ssl ,
246
250
** kwargs ,
247
251
)
248
- redis = client .master_for (settings .sentinel_master , redis_class = ArqRedis )
249
- return cast (ArqRedis , redis )
252
+ redis = client .master_for (settings .sentinel_master , redis_class = arq_redis_cls )
253
+ return cast (TArqRedis , redis )
250
254
251
255
else :
252
256
pool_factory = functools .partial (
253
- ArqRedis ,
257
+ arq_redis_cls ,
254
258
host = settings .host ,
255
259
port = settings .port ,
256
260
unix_socket_path = settings .unix_socket_path ,
@@ -312,8 +316,5 @@ async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any])
312
316
clients_connected = info_clients .get ('connected_clients' , '?' )
313
317
314
318
log_func (
315
- f'redis_version={ redis_version } '
316
- f'mem_usage={ mem_usage } '
317
- f'clients_connected={ clients_connected } '
318
- f'db_keys={ key_count } '
319
+ f'redis_version={ redis_version } mem_usage={ mem_usage } clients_connected={ clients_connected } db_keys={ key_count } '
319
320
)
0 commit comments