Skip to content

Commit f7be1f3

Browse files
committed
allow passing a custom ArqRedis class to create_pool
1 parent 7a911f3 commit f7be1f3

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

arq/connections.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass
55
from datetime import datetime, timedelta
66
from operator import attrgetter
7-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, cast
7+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, TypeVar, Union, cast
88
from urllib.parse import parse_qs, urlparse
99
from uuid import uuid4
1010

@@ -217,6 +217,9 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]
217217
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])
218218

219219

220+
TArqRedis = TypeVar('TArqRedis', bound=ArqRedis)
221+
222+
220223
async def create_pool(
221224
settings_: Optional[RedisSettings] = None,
222225
*,
@@ -225,6 +228,7 @@ async def create_pool(
225228
job_deserializer: Optional[Deserializer] = None,
226229
default_queue_name: str = default_queue_name,
227230
expires_extra_ms: int = expires_extra_ms,
231+
arq_redis_cls: Type[TArqRedis] = ArqRedis,
228232
) -> ArqRedis:
229233
"""
230234
Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails.
@@ -238,19 +242,19 @@ async def create_pool(
238242

239243
if settings.sentinel:
240244

241-
def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis:
245+
def pool_factory(*args: Any, **kwargs: Any) -> TArqRedis:
242246
client = Sentinel( # type: ignore[misc]
243247
*args,
244248
sentinels=settings.host,
245249
ssl=settings.ssl,
246250
**kwargs,
247251
)
248-
redis = client.master_for(settings.sentinel_master, redis_class=ArqRedis)
249-
return cast(ArqRedis, redis)
252+
redis = client.master_for(settings.sentinel_master, redis_class=arq_redis_cls)
253+
return cast(TArqRedis, redis)
250254

251255
else:
252256
pool_factory = functools.partial(
253-
ArqRedis,
257+
arq_redis_cls,
254258
host=settings.host,
255259
port=settings.port,
256260
unix_socket_path=settings.unix_socket_path,
@@ -312,8 +316,5 @@ async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any])
312316
clients_connected = info_clients.get('connected_clients', '?')
313317

314318
log_func(
315-
f'redis_version={redis_version} '
316-
f'mem_usage={mem_usage} '
317-
f'clients_connected={clients_connected} '
318-
f'db_keys={key_count}'
319+
f'redis_version={redis_version} mem_usage={mem_usage} clients_connected={clients_connected} db_keys={key_count}'
319320
)

0 commit comments

Comments
 (0)