From c1ce72c798686696fd11304c3b64eb1e740b5d7b Mon Sep 17 00:00:00 2001 From: Jan Vollmer Date: Sun, 19 Jan 2025 18:26:39 +0100 Subject: [PATCH] allow passing a custom `ArqRedis` class to `create_pool` --- arq/connections.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/arq/connections.py b/arq/connections.py index c1058890..52336747 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from operator import attrgetter -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type, TypeVar, Union, cast from urllib.parse import parse_qs, urlparse from uuid import uuid4 @@ -217,6 +217,9 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef] return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs]) +TArqRedis = TypeVar('TArqRedis', bound=ArqRedis) + + async def create_pool( settings_: Optional[RedisSettings] = None, *, @@ -225,7 +228,8 @@ async def create_pool( job_deserializer: Optional[Deserializer] = None, default_queue_name: str = default_queue_name, expires_extra_ms: int = expires_extra_ms, -) -> ArqRedis: + arq_redis_cls: Type[TArqRedis] = ArqRedis, # type: ignore[assignment] +) -> TArqRedis: """ Create a new redis pool, retrying up to ``conn_retries`` times if the connection fails. @@ -238,19 +242,19 @@ async def create_pool( if settings.sentinel: - def pool_factory(*args: Any, **kwargs: Any) -> ArqRedis: + def pool_factory(*args: Any, **kwargs: Any) -> TArqRedis: client = Sentinel( # type: ignore[misc] *args, sentinels=settings.host, ssl=settings.ssl, **kwargs, ) - redis = client.master_for(settings.sentinel_master, redis_class=ArqRedis) - return cast(ArqRedis, redis) + redis = client.master_for(settings.sentinel_master, redis_class=arq_redis_cls) + return cast(TArqRedis, redis) else: pool_factory = functools.partial( - ArqRedis, + arq_redis_cls, host=settings.host, port=settings.port, unix_socket_path=settings.unix_socket_path, @@ -312,8 +316,5 @@ async def log_redis_info(redis: 'Redis[bytes]', log_func: Callable[[str], Any]) clients_connected = info_clients.get('connected_clients', '?') log_func( - f'redis_version={redis_version} ' - f'mem_usage={mem_usage} ' - f'clients_connected={clients_connected} ' - f'db_keys={key_count}' + f'redis_version={redis_version} mem_usage={mem_usage} clients_connected={clients_connected} db_keys={key_count}' )