Skip to content

Commit 2ee7c3c

Browse files
chayimdvora-h
andauthored
Type hint improvements (#2952)
* Some type hints * fixed callable[T] * con * more connectios * restoring dev reqs * Update redis/commands/search/suggestion.py Co-authored-by: dvora-h <[email protected]> * Update redis/commands/core.py Co-authored-by: dvora-h <[email protected]> * Update redis/commands/search/suggestion.py Co-authored-by: dvora-h <[email protected]> * Update redis/commands/search/commands.py Co-authored-by: dvora-h <[email protected]> * Update redis/client.py Co-authored-by: dvora-h <[email protected]> * Update redis/commands/search/suggestion.py Co-authored-by: dvora-h <[email protected]> * Update redis/connection.py Co-authored-by: dvora-h <[email protected]> * Update redis/connection.py Co-authored-by: dvora-h <[email protected]> * Update redis/connection.py Co-authored-by: dvora-h <[email protected]> * Update redis/connection.py Co-authored-by: dvora-h <[email protected]> * Update redis/client.py Co-authored-by: dvora-h <[email protected]> * Update redis/client.py Co-authored-by: dvora-h <[email protected]> * Apply suggestions from code review Co-authored-by: dvora-h <[email protected]> * linters --------- Co-authored-by: dvora-h <[email protected]>
1 parent 56b254e commit 2ee7c3c

32 files changed

+289
-276
lines changed

redis/_parsers/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646

4747

4848
class BaseParser(ABC):
49-
5049
EXCEPTION_CLASSES = {
5150
"ERR": {
5251
"max number of clients reached": ConnectionError,

redis/_parsers/resp3.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,8 @@ async def _read_response(
243243
]
244244
res = self.push_handler_func(response)
245245
if not push_request:
246-
return await (
247-
self._read_response(
248-
disable_decoding=disable_decoding, push_request=push_request
249-
)
246+
return await self._read_response(
247+
disable_decoding=disable_decoding, push_request=push_request
250248
)
251249
else:
252250
return res

redis/asyncio/connection.py

-1
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,6 @@ def __init__(
11551155
queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated
11561156
**connection_kwargs,
11571157
):
1158-
11591158
super().__init__(
11601159
connection_class=connection_class,
11611160
max_connections=max_connections,

redis/client.py

+68-50
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import time
55
import warnings
66
from itertools import chain
7-
from typing import Optional, Type
7+
from typing import Any, Callable, Dict, List, Optional, Type, Union
88

9+
from redis._parsers.encoders import Encoder
910
from redis._parsers.helpers import (
1011
_RedisCallbacks,
1112
_RedisCallbacksRESP2,
@@ -49,7 +50,7 @@
4950
class CaseInsensitiveDict(dict):
5051
"Case insensitive dict implementation. Assumes string keys only."
5152

52-
def __init__(self, data):
53+
def __init__(self, data: Dict[str, str]) -> None:
5354
for k, v in data.items():
5455
self[k.upper()] = v
5556

@@ -93,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
9394
"""
9495

9596
@classmethod
96-
def from_url(cls, url, **kwargs):
97+
def from_url(cls, url: str, **kwargs) -> None:
9798
"""
9899
Return a Redis client object configured from the given URL
99100
@@ -202,7 +203,7 @@ def __init__(
202203
redis_connect_func=None,
203204
credential_provider: Optional[CredentialProvider] = None,
204205
protocol: Optional[int] = 2,
205-
):
206+
) -> None:
206207
"""
207208
Initialize a new Redis client.
208209
To specify a retry policy for specific errors, first set
@@ -309,14 +310,14 @@ def __init__(
309310
else:
310311
self.response_callbacks.update(_RedisCallbacksRESP2)
311312

312-
def __repr__(self):
313+
def __repr__(self) -> str:
313314
return f"{type(self).__name__}<{repr(self.connection_pool)}>"
314315

315-
def get_encoder(self):
316+
def get_encoder(self) -> "Encoder":
316317
"""Get the connection pool's encoder"""
317318
return self.connection_pool.get_encoder()
318319

319-
def get_connection_kwargs(self):
320+
def get_connection_kwargs(self) -> Dict:
320321
"""Get the connection's key-word arguments"""
321322
return self.connection_pool.connection_kwargs
322323

@@ -327,11 +328,11 @@ def set_retry(self, retry: "Retry") -> None:
327328
self.get_connection_kwargs().update({"retry": retry})
328329
self.connection_pool.set_retry(retry)
329330

330-
def set_response_callback(self, command, callback):
331+
def set_response_callback(self, command: str, callback: Callable) -> None:
331332
"""Set a custom Response Callback"""
332333
self.response_callbacks[command] = callback
333334

334-
def load_external_module(self, funcname, func):
335+
def load_external_module(self, funcname, func) -> None:
335336
"""
336337
This function can be used to add externally defined redis modules,
337338
and their namespaces to the redis client.
@@ -354,7 +355,7 @@ def load_external_module(self, funcname, func):
354355
"""
355356
setattr(self, funcname, func)
356357

357-
def pipeline(self, transaction=True, shard_hint=None):
358+
def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline":
358359
"""
359360
Return a new pipeline object that can queue multiple commands for
360361
later execution. ``transaction`` indicates whether all commands
@@ -366,7 +367,9 @@ def pipeline(self, transaction=True, shard_hint=None):
366367
self.connection_pool, self.response_callbacks, transaction, shard_hint
367368
)
368369

369-
def transaction(self, func, *watches, **kwargs):
370+
def transaction(
371+
self, func: Callable[["Pipeline"], None], *watches, **kwargs
372+
) -> None:
370373
"""
371374
Convenience method for executing the callable `func` as a transaction
372375
while watching all keys specified in `watches`. The 'func' callable
@@ -390,13 +393,13 @@ def transaction(self, func, *watches, **kwargs):
390393

391394
def lock(
392395
self,
393-
name,
394-
timeout=None,
395-
sleep=0.1,
396-
blocking=True,
397-
blocking_timeout=None,
398-
lock_class=None,
399-
thread_local=True,
396+
name: str,
397+
timeout: Optional[float] = None,
398+
sleep: float = 0.1,
399+
blocking: bool = True,
400+
blocking_timeout: Optional[float] = None,
401+
lock_class: Union[None, Any] = None,
402+
thread_local: bool = True,
400403
):
401404
"""
402405
Return a new Lock object using key ``name`` that mimics
@@ -648,9 +651,9 @@ def __init__(
648651
self,
649652
connection_pool,
650653
shard_hint=None,
651-
ignore_subscribe_messages=False,
652-
encoder=None,
653-
push_handler_func=None,
654+
ignore_subscribe_messages: bool = False,
655+
encoder: Optional["Encoder"] = None,
656+
push_handler_func: Union[None, Callable[[str], None]] = None,
654657
):
655658
self.connection_pool = connection_pool
656659
self.shard_hint = shard_hint
@@ -672,13 +675,13 @@ def __init__(
672675
_set_info_logger()
673676
self.reset()
674677

675-
def __enter__(self):
678+
def __enter__(self) -> "PubSub":
676679
return self
677680

678-
def __exit__(self, exc_type, exc_value, traceback):
681+
def __exit__(self, exc_type, exc_value, traceback) -> None:
679682
self.reset()
680683

681-
def __del__(self):
684+
def __del__(self) -> None:
682685
try:
683686
# if this object went out of scope prior to shutting down
684687
# subscriptions, close the connection manually before
@@ -687,7 +690,7 @@ def __del__(self):
687690
except Exception:
688691
pass
689692

690-
def reset(self):
693+
def reset(self) -> None:
691694
if self.connection:
692695
self.connection.disconnect()
693696
self.connection._deregister_connect_callback(self.on_connect)
@@ -702,10 +705,10 @@ def reset(self):
702705
self.pending_unsubscribe_patterns = set()
703706
self.subscribed_event.clear()
704707

705-
def close(self):
708+
def close(self) -> None:
706709
self.reset()
707710

708-
def on_connect(self, connection):
711+
def on_connect(self, connection) -> None:
709712
"Re-subscribe to any channels and patterns previously subscribed to"
710713
# NOTE: for python3, we can't pass bytestrings as keyword arguments
711714
# so we need to decode channel/pattern names back to unicode strings
@@ -731,7 +734,7 @@ def on_connect(self, connection):
731734
self.ssubscribe(**shard_channels)
732735

733736
@property
734-
def subscribed(self):
737+
def subscribed(self) -> bool:
735738
"""Indicates if there are subscriptions to any channels or patterns"""
736739
return self.subscribed_event.is_set()
737740

@@ -757,7 +760,7 @@ def execute_command(self, *args):
757760
self.clean_health_check_responses()
758761
self._execute(connection, connection.send_command, *args, **kwargs)
759762

760-
def clean_health_check_responses(self):
763+
def clean_health_check_responses(self) -> None:
761764
"""
762765
If any health check responses are present, clean them
763766
"""
@@ -775,7 +778,7 @@ def clean_health_check_responses(self):
775778
)
776779
ttl -= 1
777780

778-
def _disconnect_raise_connect(self, conn, error):
781+
def _disconnect_raise_connect(self, conn, error) -> None:
779782
"""
780783
Close the connection and raise an exception
781784
if retry_on_timeout is not set or the error
@@ -826,7 +829,7 @@ def try_read():
826829
return None
827830
return response
828831

829-
def is_health_check_response(self, response):
832+
def is_health_check_response(self, response) -> bool:
830833
"""
831834
Check if the response is a health check response.
832835
If there are no subscriptions redis responds to PING command with a
@@ -837,7 +840,7 @@ def is_health_check_response(self, response):
837840
self.health_check_response_b, # If there wasn't
838841
]
839842

840-
def check_health(self):
843+
def check_health(self) -> None:
841844
conn = self.connection
842845
if conn is None:
843846
raise RuntimeError(
@@ -849,7 +852,7 @@ def check_health(self):
849852
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
850853
self.health_check_response_counter += 1
851854

852-
def _normalize_keys(self, data):
855+
def _normalize_keys(self, data) -> Dict:
853856
"""
854857
normalize channel/pattern names to be either bytes or strings
855858
based on whether responses are automatically decoded. this saves us
@@ -983,7 +986,9 @@ def listen(self):
983986
if response is not None:
984987
yield response
985988

986-
def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
989+
def get_message(
990+
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
991+
):
987992
"""
988993
Get the next message if one is available, otherwise None.
989994
@@ -1012,7 +1017,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
10121017

10131018
get_sharded_message = get_message
10141019

1015-
def ping(self, message=None):
1020+
def ping(self, message: Union[str, None] = None) -> bool:
10161021
"""
10171022
Ping the Redis server
10181023
"""
@@ -1093,7 +1098,12 @@ def handle_message(self, response, ignore_subscribe_messages=False):
10931098

10941099
return message
10951100

1096-
def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
1101+
def run_in_thread(
1102+
self,
1103+
sleep_time: int = 0,
1104+
daemon: bool = False,
1105+
exception_handler: Optional[Callable] = None,
1106+
) -> "PubSubWorkerThread":
10971107
for channel, handler in self.channels.items():
10981108
if handler is None:
10991109
raise PubSubError(f"Channel: '{channel}' has no handler registered")
@@ -1114,15 +1124,23 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
11141124

11151125

11161126
class PubSubWorkerThread(threading.Thread):
1117-
def __init__(self, pubsub, sleep_time, daemon=False, exception_handler=None):
1127+
def __init__(
1128+
self,
1129+
pubsub,
1130+
sleep_time: float,
1131+
daemon: bool = False,
1132+
exception_handler: Union[
1133+
Callable[[Exception, "PubSub", "PubSubWorkerThread"], None], None
1134+
] = None,
1135+
):
11181136
super().__init__()
11191137
self.daemon = daemon
11201138
self.pubsub = pubsub
11211139
self.sleep_time = sleep_time
11221140
self.exception_handler = exception_handler
11231141
self._running = threading.Event()
11241142

1125-
def run(self):
1143+
def run(self) -> None:
11261144
if self._running.is_set():
11271145
return
11281146
self._running.set()
@@ -1137,7 +1155,7 @@ def run(self):
11371155
self.exception_handler(e, pubsub, self)
11381156
pubsub.close()
11391157

1140-
def stop(self):
1158+
def stop(self) -> None:
11411159
# trip the flag so the run loop exits. the run loop will
11421160
# close the pubsub connection, which disconnects the socket
11431161
# and returns the connection to the pool.
@@ -1175,7 +1193,7 @@ def __init__(self, connection_pool, response_callbacks, transaction, shard_hint)
11751193
self.watching = False
11761194
self.reset()
11771195

1178-
def __enter__(self):
1196+
def __enter__(self) -> "Pipeline":
11791197
return self
11801198

11811199
def __exit__(self, exc_type, exc_value, traceback):
@@ -1187,14 +1205,14 @@ def __del__(self):
11871205
except Exception:
11881206
pass
11891207

1190-
def __len__(self):
1208+
def __len__(self) -> int:
11911209
return len(self.command_stack)
11921210

1193-
def __bool__(self):
1211+
def __bool__(self) -> bool:
11941212
"""Pipeline instances should always evaluate to True"""
11951213
return True
11961214

1197-
def reset(self):
1215+
def reset(self) -> None:
11981216
self.command_stack = []
11991217
self.scripts = set()
12001218
# make sure to reset the connection state in the event that we were
@@ -1217,11 +1235,11 @@ def reset(self):
12171235
self.connection_pool.release(self.connection)
12181236
self.connection = None
12191237

1220-
def close(self):
1238+
def close(self) -> None:
12211239
"""Close the pipeline"""
12221240
self.reset()
12231241

1224-
def multi(self):
1242+
def multi(self) -> None:
12251243
"""
12261244
Start a transactional block of the pipeline after WATCH commands
12271245
are issued. End the transactional block with `execute`.
@@ -1239,7 +1257,7 @@ def execute_command(self, *args, **kwargs):
12391257
return self.immediate_execute_command(*args, **kwargs)
12401258
return self.pipeline_execute_command(*args, **kwargs)
12411259

1242-
def _disconnect_reset_raise(self, conn, error):
1260+
def _disconnect_reset_raise(self, conn, error) -> None:
12431261
"""
12441262
Close the connection, reset watching state and
12451263
raise an exception if we were watching,
@@ -1282,7 +1300,7 @@ def immediate_execute_command(self, *args, **options):
12821300
lambda error: self._disconnect_reset_raise(conn, error),
12831301
)
12841302

1285-
def pipeline_execute_command(self, *args, **options):
1303+
def pipeline_execute_command(self, *args, **options) -> "Pipeline":
12861304
"""
12871305
Stage a command to be executed when execute() is next called
12881306
@@ -1297,7 +1315,7 @@ def pipeline_execute_command(self, *args, **options):
12971315
self.command_stack.append((args, options))
12981316
return self
12991317

1300-
def _execute_transaction(self, connection, commands, raise_on_error):
1318+
def _execute_transaction(self, connection, commands, raise_on_error) -> List:
13011319
cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})])
13021320
all_cmds = connection.pack_commands(
13031321
[args for args, options in cmds if EMPTY_RESPONSE not in options]
@@ -1415,7 +1433,7 @@ def load_scripts(self):
14151433
if not exist:
14161434
s.sha = immediate("SCRIPT LOAD", s.script)
14171435

1418-
def _disconnect_raise_reset(self, conn, error):
1436+
def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None:
14191437
"""
14201438
Close the connection, raise an exception if we were watching,
14211439
and raise an exception if TimeoutError is not part of retry_on_error,
@@ -1477,6 +1495,6 @@ def watch(self, *names):
14771495
raise RedisError("Cannot issue a WATCH after a MULTI")
14781496
return self.execute_command("WATCH", *names)
14791497

1480-
def unwatch(self):
1498+
def unwatch(self) -> bool:
14811499
"""Unwatches all previously specified keys"""
14821500
return self.watching and self.execute_command("UNWATCH") or True

0 commit comments

Comments
 (0)