14
14
number_to_string , datetime_normalize , KEY_TO_VAL_STR , short_repr ,
15
15
get_truncate_datetime , dict_ , add_root_to_paths )
16
16
from deepdiff .base import Base
17
+
18
+ try :
19
+ import pandas
20
+ except ImportError :
21
+ pandas = False
22
+
23
+ try :
24
+ import polars
25
+ except ImportError :
26
+ polars = False
27
+
17
28
logger = logging .getLogger (__name__ )
18
29
19
30
UNPROCESSED_KEY = object ()
@@ -139,6 +150,7 @@ def __init__(self,
139
150
ignore_numeric_type_changes = False ,
140
151
ignore_type_subclasses = False ,
141
152
ignore_string_case = False ,
153
+ use_enum_value = False ,
142
154
exclude_obj_callback = None ,
143
155
number_to_string_func = None ,
144
156
ignore_private_variables = True ,
@@ -154,7 +166,7 @@ def __init__(self,
154
166
"exclude_paths, include_paths, exclude_regex_paths, hasher, ignore_repetition, "
155
167
"number_format_notation, apply_hash, ignore_type_in_groups, ignore_string_type_changes, "
156
168
"ignore_numeric_type_changes, ignore_type_subclasses, ignore_string_case "
157
- "number_to_string_func, ignore_private_variables, parent "
169
+ "number_to_string_func, ignore_private_variables, parent, use_enum_value "
158
170
"encodings, ignore_encoding_errors" ) % ', ' .join (kwargs .keys ()))
159
171
if isinstance (hashes , MutableMapping ):
160
172
self .hashes = hashes
@@ -170,6 +182,7 @@ def __init__(self,
170
182
self .exclude_regex_paths = convert_item_or_items_into_compiled_regexes_else_none (exclude_regex_paths )
171
183
self .hasher = default_hasher if hasher is None else hasher
172
184
self .hashes [UNPROCESSED_KEY ] = []
185
+ self .use_enum_value = use_enum_value
173
186
174
187
self .significant_digits = self .get_significant_digits (significant_digits , ignore_numeric_type_changes )
175
188
self .truncate_datetime = get_truncate_datetime (truncate_datetime )
@@ -206,10 +219,10 @@ def __init__(self,
206
219
sha1hex = sha1hex
207
220
208
221
def __getitem__ (self , obj , extract_index = 0 ):
209
- return self ._getitem (self .hashes , obj , extract_index = extract_index )
222
+ return self ._getitem (self .hashes , obj , extract_index = extract_index , use_enum_value = self . use_enum_value )
210
223
211
224
@staticmethod
212
- def _getitem (hashes , obj , extract_index = 0 ):
225
+ def _getitem (hashes , obj , extract_index = 0 , use_enum_value = False ):
213
226
"""
214
227
extract_index is zero for hash and 1 for count and None to get them both.
215
228
To keep it backward compatible, we only get the hash by default so it is set to zero by default.
@@ -220,6 +233,8 @@ def _getitem(hashes, obj, extract_index=0):
220
233
key = BoolObj .TRUE
221
234
elif obj is False :
222
235
key = BoolObj .FALSE
236
+ elif use_enum_value and isinstance (obj , Enum ):
237
+ key = obj .value
223
238
224
239
result_n_count = (None , 0 )
225
240
@@ -256,14 +271,14 @@ def get(self, key, default=None, extract_index=0):
256
271
return self .get_key (self .hashes , key , default = default , extract_index = extract_index )
257
272
258
273
@staticmethod
259
- def get_key (hashes , key , default = None , extract_index = 0 ):
274
+ def get_key (hashes , key , default = None , extract_index = 0 , use_enum_value = False ):
260
275
"""
261
276
get_key method for the hashes dictionary.
262
277
It can extract the hash for a given key that is already calculated when extract_index=0
263
278
or the count of items that went to building the object whenextract_index=1.
264
279
"""
265
280
try :
266
- result = DeepHash ._getitem (hashes , key , extract_index = extract_index )
281
+ result = DeepHash ._getitem (hashes , key , extract_index = extract_index , use_enum_value = use_enum_value )
267
282
except KeyError :
268
283
result = default
269
284
return result
@@ -444,7 +459,6 @@ def _prep_path(self, obj):
444
459
type_ = obj .__class__ .__name__
445
460
return KEY_TO_VAL_STR .format (type_ , obj )
446
461
447
-
448
462
def _prep_number (self , obj ):
449
463
type_ = "number" if self .ignore_numeric_type_changes else obj .__class__ .__name__
450
464
if self .significant_digits is not None :
@@ -475,12 +489,14 @@ def _prep_tuple(self, obj, parent, parents_ids):
475
489
return result , counts
476
490
477
491
def _hash (self , obj , parent , parents_ids = EMPTY_FROZENSET ):
478
- """The main diff method"""
492
+ """The main hash method"""
479
493
counts = 1
480
494
481
495
if isinstance (obj , bool ):
482
496
obj = self ._prep_bool (obj )
483
497
result = None
498
+ elif self .use_enum_value and isinstance (obj , Enum ):
499
+ obj = obj .value
484
500
else :
485
501
result = not_hashed
486
502
try :
@@ -523,6 +539,19 @@ def _hash(self, obj, parent, parents_ids=EMPTY_FROZENSET):
523
539
elif isinstance (obj , tuple ):
524
540
result , counts = self ._prep_tuple (obj = obj , parent = parent , parents_ids = parents_ids )
525
541
542
+ elif (pandas and isinstance (obj , pandas .DataFrame )):
543
+ def gen ():
544
+ yield ('dtype' , obj .dtypes )
545
+ yield ('index' , obj .index )
546
+ yield from obj .items () # which contains (column name, series tuples)
547
+ result , counts = self ._prep_iterable (obj = gen (), parent = parent , parents_ids = parents_ids )
548
+ elif (polars and isinstance (obj , polars .DataFrame )):
549
+ def gen ():
550
+ yield from obj .columns
551
+ yield from list (obj .schema .items ())
552
+ yield from obj .rows ()
553
+ result , counts = self ._prep_iterable (obj = gen (), parent = parent , parents_ids = parents_ids )
554
+
526
555
elif isinstance (obj , Iterable ):
527
556
result , counts = self ._prep_iterable (obj = obj , parent = parent , parents_ids = parents_ids )
528
557
0 commit comments