4
4
import time
5
5
import warnings
6
6
from itertools import chain
7
- from typing import Optional , Type
7
+ from typing import Any , Callable , Dict , List , Optional , Type , Union
8
8
9
+ from redis ._parsers .encoders import Encoder
9
10
from redis ._parsers .helpers import (
10
11
_RedisCallbacks ,
11
12
_RedisCallbacksRESP2 ,
49
50
class CaseInsensitiveDict (dict ):
50
51
"Case insensitive dict implementation. Assumes string keys only."
51
52
52
- def __init__ (self , data ) :
53
+ def __init__ (self , data : Dict [ str , str ]) -> None :
53
54
for k , v in data .items ():
54
55
self [k .upper ()] = v
55
56
@@ -93,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands):
93
94
"""
94
95
95
96
@classmethod
96
- def from_url (cls , url , ** kwargs ):
97
+ def from_url (cls , url : str , ** kwargs ) -> None :
97
98
"""
98
99
Return a Redis client object configured from the given URL
99
100
@@ -202,7 +203,7 @@ def __init__(
202
203
redis_connect_func = None ,
203
204
credential_provider : Optional [CredentialProvider ] = None ,
204
205
protocol : Optional [int ] = 2 ,
205
- ):
206
+ ) -> None :
206
207
"""
207
208
Initialize a new Redis client.
208
209
To specify a retry policy for specific errors, first set
@@ -309,14 +310,14 @@ def __init__(
309
310
else :
310
311
self .response_callbacks .update (_RedisCallbacksRESP2 )
311
312
312
- def __repr__ (self ):
313
+ def __repr__ (self ) -> str :
313
314
return f"{ type (self ).__name__ } <{ repr (self .connection_pool )} >"
314
315
315
- def get_encoder (self ):
316
+ def get_encoder (self ) -> "Encoder" :
316
317
"""Get the connection pool's encoder"""
317
318
return self .connection_pool .get_encoder ()
318
319
319
- def get_connection_kwargs (self ):
320
+ def get_connection_kwargs (self ) -> Dict :
320
321
"""Get the connection's key-word arguments"""
321
322
return self .connection_pool .connection_kwargs
322
323
@@ -327,11 +328,11 @@ def set_retry(self, retry: "Retry") -> None:
327
328
self .get_connection_kwargs ().update ({"retry" : retry })
328
329
self .connection_pool .set_retry (retry )
329
330
330
- def set_response_callback (self , command , callback ) :
331
+ def set_response_callback (self , command : str , callback : Callable ) -> None :
331
332
"""Set a custom Response Callback"""
332
333
self .response_callbacks [command ] = callback
333
334
334
- def load_external_module (self , funcname , func ):
335
+ def load_external_module (self , funcname , func ) -> None :
335
336
"""
336
337
This function can be used to add externally defined redis modules,
337
338
and their namespaces to the redis client.
@@ -354,7 +355,7 @@ def load_external_module(self, funcname, func):
354
355
"""
355
356
setattr (self , funcname , func )
356
357
357
- def pipeline (self , transaction = True , shard_hint = None ):
358
+ def pipeline (self , transaction = True , shard_hint = None ) -> "Pipeline" :
358
359
"""
359
360
Return a new pipeline object that can queue multiple commands for
360
361
later execution. ``transaction`` indicates whether all commands
@@ -366,7 +367,9 @@ def pipeline(self, transaction=True, shard_hint=None):
366
367
self .connection_pool , self .response_callbacks , transaction , shard_hint
367
368
)
368
369
369
- def transaction (self , func , * watches , ** kwargs ):
370
+ def transaction (
371
+ self , func : Callable [["Pipeline" ], None ], * watches , ** kwargs
372
+ ) -> None :
370
373
"""
371
374
Convenience method for executing the callable `func` as a transaction
372
375
while watching all keys specified in `watches`. The 'func' callable
@@ -390,13 +393,13 @@ def transaction(self, func, *watches, **kwargs):
390
393
391
394
def lock (
392
395
self ,
393
- name ,
394
- timeout = None ,
395
- sleep = 0.1 ,
396
- blocking = True ,
397
- blocking_timeout = None ,
398
- lock_class = None ,
399
- thread_local = True ,
396
+ name : str ,
397
+ timeout : Optional [ float ] = None ,
398
+ sleep : float = 0.1 ,
399
+ blocking : bool = True ,
400
+ blocking_timeout : Optional [ float ] = None ,
401
+ lock_class : Union [ None , Any ] = None ,
402
+ thread_local : bool = True ,
400
403
):
401
404
"""
402
405
Return a new Lock object using key ``name`` that mimics
@@ -648,9 +651,9 @@ def __init__(
648
651
self ,
649
652
connection_pool ,
650
653
shard_hint = None ,
651
- ignore_subscribe_messages = False ,
652
- encoder = None ,
653
- push_handler_func = None ,
654
+ ignore_subscribe_messages : bool = False ,
655
+ encoder : Optional [ "Encoder" ] = None ,
656
+ push_handler_func : Union [ None , Callable [[ str ], None ]] = None ,
654
657
):
655
658
self .connection_pool = connection_pool
656
659
self .shard_hint = shard_hint
@@ -672,13 +675,13 @@ def __init__(
672
675
_set_info_logger ()
673
676
self .reset ()
674
677
675
- def __enter__ (self ):
678
+ def __enter__ (self ) -> "PubSub" :
676
679
return self
677
680
678
- def __exit__ (self , exc_type , exc_value , traceback ):
681
+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
679
682
self .reset ()
680
683
681
- def __del__ (self ):
684
+ def __del__ (self ) -> None :
682
685
try :
683
686
# if this object went out of scope prior to shutting down
684
687
# subscriptions, close the connection manually before
@@ -687,7 +690,7 @@ def __del__(self):
687
690
except Exception :
688
691
pass
689
692
690
- def reset (self ):
693
+ def reset (self ) -> None :
691
694
if self .connection :
692
695
self .connection .disconnect ()
693
696
self .connection ._deregister_connect_callback (self .on_connect )
@@ -702,10 +705,10 @@ def reset(self):
702
705
self .pending_unsubscribe_patterns = set ()
703
706
self .subscribed_event .clear ()
704
707
705
- def close (self ):
708
+ def close (self ) -> None :
706
709
self .reset ()
707
710
708
- def on_connect (self , connection ):
711
+ def on_connect (self , connection ) -> None :
709
712
"Re-subscribe to any channels and patterns previously subscribed to"
710
713
# NOTE: for python3, we can't pass bytestrings as keyword arguments
711
714
# so we need to decode channel/pattern names back to unicode strings
@@ -731,7 +734,7 @@ def on_connect(self, connection):
731
734
self .ssubscribe (** shard_channels )
732
735
733
736
@property
734
- def subscribed (self ):
737
+ def subscribed (self ) -> bool :
735
738
"""Indicates if there are subscriptions to any channels or patterns"""
736
739
return self .subscribed_event .is_set ()
737
740
@@ -757,7 +760,7 @@ def execute_command(self, *args):
757
760
self .clean_health_check_responses ()
758
761
self ._execute (connection , connection .send_command , * args , ** kwargs )
759
762
760
- def clean_health_check_responses (self ):
763
+ def clean_health_check_responses (self ) -> None :
761
764
"""
762
765
If any health check responses are present, clean them
763
766
"""
@@ -775,7 +778,7 @@ def clean_health_check_responses(self):
775
778
)
776
779
ttl -= 1
777
780
778
- def _disconnect_raise_connect (self , conn , error ):
781
+ def _disconnect_raise_connect (self , conn , error ) -> None :
779
782
"""
780
783
Close the connection and raise an exception
781
784
if retry_on_timeout is not set or the error
@@ -826,7 +829,7 @@ def try_read():
826
829
return None
827
830
return response
828
831
829
- def is_health_check_response (self , response ):
832
+ def is_health_check_response (self , response ) -> bool :
830
833
"""
831
834
Check if the response is a health check response.
832
835
If there are no subscriptions redis responds to PING command with a
@@ -837,7 +840,7 @@ def is_health_check_response(self, response):
837
840
self .health_check_response_b , # If there wasn't
838
841
]
839
842
840
- def check_health (self ):
843
+ def check_health (self ) -> None :
841
844
conn = self .connection
842
845
if conn is None :
843
846
raise RuntimeError (
@@ -849,7 +852,7 @@ def check_health(self):
849
852
conn .send_command ("PING" , self .HEALTH_CHECK_MESSAGE , check_health = False )
850
853
self .health_check_response_counter += 1
851
854
852
- def _normalize_keys (self , data ):
855
+ def _normalize_keys (self , data ) -> Dict :
853
856
"""
854
857
normalize channel/pattern names to be either bytes or strings
855
858
based on whether responses are automatically decoded. this saves us
@@ -983,7 +986,9 @@ def listen(self):
983
986
if response is not None :
984
987
yield response
985
988
986
- def get_message (self , ignore_subscribe_messages = False , timeout = 0.0 ):
989
+ def get_message (
990
+ self , ignore_subscribe_messages : bool = False , timeout : float = 0.0
991
+ ):
987
992
"""
988
993
Get the next message if one is available, otherwise None.
989
994
@@ -1012,7 +1017,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
1012
1017
1013
1018
get_sharded_message = get_message
1014
1019
1015
- def ping (self , message = None ):
1020
+ def ping (self , message : Union [ str , None ] = None ) -> bool :
1016
1021
"""
1017
1022
Ping the Redis server
1018
1023
"""
@@ -1093,7 +1098,12 @@ def handle_message(self, response, ignore_subscribe_messages=False):
1093
1098
1094
1099
return message
1095
1100
1096
- def run_in_thread (self , sleep_time = 0 , daemon = False , exception_handler = None ):
1101
+ def run_in_thread (
1102
+ self ,
1103
+ sleep_time : int = 0 ,
1104
+ daemon : bool = False ,
1105
+ exception_handler : Optional [Callable ] = None ,
1106
+ ) -> "PubSubWorkerThread" :
1097
1107
for channel , handler in self .channels .items ():
1098
1108
if handler is None :
1099
1109
raise PubSubError (f"Channel: '{ channel } ' has no handler registered" )
@@ -1114,15 +1124,23 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
1114
1124
1115
1125
1116
1126
class PubSubWorkerThread (threading .Thread ):
1117
- def __init__ (self , pubsub , sleep_time , daemon = False , exception_handler = None ):
1127
+ def __init__ (
1128
+ self ,
1129
+ pubsub ,
1130
+ sleep_time : float ,
1131
+ daemon : bool = False ,
1132
+ exception_handler : Union [
1133
+ Callable [[Exception , "PubSub" , "PubSubWorkerThread" ], None ], None
1134
+ ] = None ,
1135
+ ):
1118
1136
super ().__init__ ()
1119
1137
self .daemon = daemon
1120
1138
self .pubsub = pubsub
1121
1139
self .sleep_time = sleep_time
1122
1140
self .exception_handler = exception_handler
1123
1141
self ._running = threading .Event ()
1124
1142
1125
- def run (self ):
1143
+ def run (self ) -> None :
1126
1144
if self ._running .is_set ():
1127
1145
return
1128
1146
self ._running .set ()
@@ -1137,7 +1155,7 @@ def run(self):
1137
1155
self .exception_handler (e , pubsub , self )
1138
1156
pubsub .close ()
1139
1157
1140
- def stop (self ):
1158
+ def stop (self ) -> None :
1141
1159
# trip the flag so the run loop exits. the run loop will
1142
1160
# close the pubsub connection, which disconnects the socket
1143
1161
# and returns the connection to the pool.
@@ -1175,7 +1193,7 @@ def __init__(self, connection_pool, response_callbacks, transaction, shard_hint)
1175
1193
self .watching = False
1176
1194
self .reset ()
1177
1195
1178
- def __enter__ (self ):
1196
+ def __enter__ (self ) -> "Pipeline" :
1179
1197
return self
1180
1198
1181
1199
def __exit__ (self , exc_type , exc_value , traceback ):
@@ -1187,14 +1205,14 @@ def __del__(self):
1187
1205
except Exception :
1188
1206
pass
1189
1207
1190
- def __len__ (self ):
1208
+ def __len__ (self ) -> int :
1191
1209
return len (self .command_stack )
1192
1210
1193
- def __bool__ (self ):
1211
+ def __bool__ (self ) -> bool :
1194
1212
"""Pipeline instances should always evaluate to True"""
1195
1213
return True
1196
1214
1197
- def reset (self ):
1215
+ def reset (self ) -> None :
1198
1216
self .command_stack = []
1199
1217
self .scripts = set ()
1200
1218
# make sure to reset the connection state in the event that we were
@@ -1217,11 +1235,11 @@ def reset(self):
1217
1235
self .connection_pool .release (self .connection )
1218
1236
self .connection = None
1219
1237
1220
- def close (self ):
1238
+ def close (self ) -> None :
1221
1239
"""Close the pipeline"""
1222
1240
self .reset ()
1223
1241
1224
- def multi (self ):
1242
+ def multi (self ) -> None :
1225
1243
"""
1226
1244
Start a transactional block of the pipeline after WATCH commands
1227
1245
are issued. End the transactional block with `execute`.
@@ -1239,7 +1257,7 @@ def execute_command(self, *args, **kwargs):
1239
1257
return self .immediate_execute_command (* args , ** kwargs )
1240
1258
return self .pipeline_execute_command (* args , ** kwargs )
1241
1259
1242
- def _disconnect_reset_raise (self , conn , error ):
1260
+ def _disconnect_reset_raise (self , conn , error ) -> None :
1243
1261
"""
1244
1262
Close the connection, reset watching state and
1245
1263
raise an exception if we were watching,
@@ -1282,7 +1300,7 @@ def immediate_execute_command(self, *args, **options):
1282
1300
lambda error : self ._disconnect_reset_raise (conn , error ),
1283
1301
)
1284
1302
1285
- def pipeline_execute_command (self , * args , ** options ):
1303
+ def pipeline_execute_command (self , * args , ** options ) -> "Pipeline" :
1286
1304
"""
1287
1305
Stage a command to be executed when execute() is next called
1288
1306
@@ -1297,7 +1315,7 @@ def pipeline_execute_command(self, *args, **options):
1297
1315
self .command_stack .append ((args , options ))
1298
1316
return self
1299
1317
1300
- def _execute_transaction (self , connection , commands , raise_on_error ):
1318
+ def _execute_transaction (self , connection , commands , raise_on_error ) -> List :
1301
1319
cmds = chain ([(("MULTI" ,), {})], commands , [(("EXEC" ,), {})])
1302
1320
all_cmds = connection .pack_commands (
1303
1321
[args for args , options in cmds if EMPTY_RESPONSE not in options ]
@@ -1415,7 +1433,7 @@ def load_scripts(self):
1415
1433
if not exist :
1416
1434
s .sha = immediate ("SCRIPT LOAD" , s .script )
1417
1435
1418
- def _disconnect_raise_reset (self , conn , error ) :
1436
+ def _disconnect_raise_reset (self , conn : Redis , error : Exception ) -> None :
1419
1437
"""
1420
1438
Close the connection, raise an exception if we were watching,
1421
1439
and raise an exception if TimeoutError is not part of retry_on_error,
@@ -1477,6 +1495,6 @@ def watch(self, *names):
1477
1495
raise RedisError ("Cannot issue a WATCH after a MULTI" )
1478
1496
return self .execute_command ("WATCH" , * names )
1479
1497
1480
- def unwatch (self ):
1498
+ def unwatch (self ) -> bool :
1481
1499
"""Unwatches all previously specified keys"""
1482
1500
return self .watching and self .execute_command ("UNWATCH" ) or True
0 commit comments