Skip to content

Commit 4510be8

Browse files
authored
Add pickling of dict_keys, dict_values, dict_items (#384)
1 parent da2a604 commit 4510be8

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

Diff for: cloudpickle/cloudpickle.py

+12
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,15 @@ def _get_bases(typ):
828828
# For regular class objects
829829
bases_attr = '__bases__'
830830
return getattr(typ, bases_attr)
831+
832+
833+
def _make_dict_keys(obj):
834+
return dict.fromkeys(obj).keys()
835+
836+
837+
def _make_dict_values(obj):
838+
return {i: _ for i, _ in enumerate(obj)}.values()
839+
840+
841+
def _make_dict_items(obj):
842+
return obj.items()

Diff for: cloudpickle/cloudpickle_fast.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
guards present in cloudpickle.py that were written to handle PyPy specificities
1111
are not present in cloudpickle_fast.py
1212
"""
13+
import _collections_abc
1314
import abc
1415
import copyreg
1516
import io
@@ -33,8 +34,8 @@
3334
_typevar_reduce, _get_bases, _make_cell, _make_empty_cell, CellType,
3435
_is_parametrized_type_hint, PYPY, cell_set,
3536
parametrized_type_hint_getinitargs, _create_parametrized_type_hint,
36-
builtin_code_type
37-
37+
builtin_code_type,
38+
_make_dict_keys, _make_dict_values, _make_dict_items,
3839
)
3940

4041

@@ -400,6 +401,24 @@ def _class_reduce(obj):
400401
return NotImplemented
401402

402403

404+
def _dict_keys_reduce(obj):
405+
# Safer not to ship the full dict as sending the rest might
406+
# be unintended and could potentially cause leaking of
407+
# sensitive information
408+
return _make_dict_keys, (list(obj), )
409+
410+
411+
def _dict_values_reduce(obj):
412+
# Safer not to ship the full dict as sending the rest might
413+
# be unintended and could potentially cause leaking of
414+
# sensitive information
415+
return _make_dict_values, (list(obj), )
416+
417+
418+
def _dict_items_reduce(obj):
419+
return _make_dict_items, (dict(obj), )
420+
421+
403422
# COLLECTIONS OF OBJECTS STATE SETTERS
404423
# ------------------------------------
405424
# state setters are called at unpickling time, once the object is created and
@@ -473,6 +492,10 @@ class CloudPickler(Pickler):
473492
_dispatch_table[types.MappingProxyType] = _mappingproxy_reduce
474493
_dispatch_table[weakref.WeakSet] = _weakset_reduce
475494
_dispatch_table[typing.TypeVar] = _typevar_reduce
495+
_dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce
496+
_dispatch_table[_collections_abc.dict_values] = _dict_values_reduce
497+
_dispatch_table[_collections_abc.dict_items] = _dict_items_reduce
498+
476499

477500
dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table)
478501

Diff for: tests/cloudpickle_test.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import division
22

3+
import _collections_abc
34
import abc
45
import collections
56
import base64
@@ -31,7 +32,7 @@
3132
# tests should be skipped if these modules are not available
3233
import numpy as np
3334
import scipy.special as spp
34-
except ImportError:
35+
except (ImportError, RuntimeError):
3536
np = None
3637
spp = None
3738

@@ -207,6 +208,24 @@ def test_memoryview(self):
207208
self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),
208209
buffer_obj.tobytes())
209210

211+
def test_dict_keys(self):
212+
keys = {"a": 1, "b": 2}.keys()
213+
results = pickle_depickle(keys)
214+
self.assertEqual(results, keys)
215+
assert isinstance(results, _collections_abc.dict_keys)
216+
217+
def test_dict_values(self):
218+
values = {"a": 1, "b": 2}.values()
219+
results = pickle_depickle(values)
220+
self.assertEqual(sorted(results), sorted(values))
221+
assert isinstance(results, _collections_abc.dict_values)
222+
223+
def test_dict_items(self):
224+
items = {"a": 1, "b": 2}.items()
225+
results = pickle_depickle(items)
226+
self.assertEqual(results, items)
227+
assert isinstance(results, _collections_abc.dict_items)
228+
210229
def test_sliced_and_non_contiguous_memoryview(self):
211230
buffer_obj = memoryview(b"Hello!" * 3)[2:15:2]
212231
self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),

0 commit comments

Comments
 (0)