Skip to content

Commit 7b31ced

Browse files
pcmoritzogrisel
authored andcommitted
Fix pickling dataclasses (#245)
1 parent 54463b6 commit 7b31ced

File tree

3 files changed

+39
-11
lines changed

3 files changed

+39
-11
lines changed

Diff for: CHANGES.md

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
0.8.0
2+
=====
3+
4+
- Add support for pickling interactively defined dataclasses.
5+
([issue #245](https://github.com/cloudpipe/cloudpickle/pull/245))
6+
7+
18
0.7.0
29
=====
310

Diff for: cloudpickle/cloudpickle.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
6464

6565

66-
if sys.version < '3':
66+
if sys.version_info[0] < 3: # pragma: no branch
6767
from pickle import Pickler
6868
try:
6969
from cStringIO import StringIO
@@ -128,7 +128,7 @@ def inner(value):
128128
# NOTE: we are marking the cell variable as a free variable intentionally
129129
# so that we simulate an inner function instead of the outer function. This
130130
# is what gives us the ``nonlocal`` behavior in a Python 2 compatible way.
131-
if not PY3:
131+
if not PY3: # pragma: no branch
132132
return types.CodeType(
133133
co.co_argcount,
134134
co.co_nlocals,
@@ -229,14 +229,14 @@ def _factory():
229229
}
230230

231231

232-
if sys.version_info < (3, 4):
232+
if sys.version_info < (3, 4): # pragma: no branch
233233
def _walk_global_ops(code):
234234
"""
235235
Yield (opcode, argument number) tuples for all
236236
global-referencing instructions in *code*.
237237
"""
238238
code = getattr(code, 'co_code', b'')
239-
if not PY3:
239+
if not PY3: # pragma: no branch
240240
code = map(ord, code)
241241

242242
n = len(code)
@@ -293,7 +293,7 @@ def save_memoryview(self, obj):
293293

294294
dispatch[memoryview] = save_memoryview
295295

296-
if not PY3:
296+
if not PY3: # pragma: no branch
297297
def save_buffer(self, obj):
298298
self.save(str(obj))
299299

@@ -315,7 +315,7 @@ def save_codeobject(self, obj):
315315
"""
316316
Save a code object
317317
"""
318-
if PY3:
318+
if PY3: # pragma: no branch
319319
args = (
320320
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
321321
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames,
@@ -393,7 +393,7 @@ def save_function(self, obj, name=None):
393393
# So we pickle them here using save_reduce; have to do it differently
394394
# for different python versions.
395395
if not hasattr(obj, '__code__'):
396-
if PY3:
396+
if PY3: # pragma: no branch
397397
rv = obj.__reduce_ex__(self.proto)
398398
else:
399399
if hasattr(obj, '__self__'):
@@ -730,7 +730,7 @@ def save_instancemethod(self, obj):
730730
if obj.__self__ is None:
731731
self.save_reduce(getattr, (obj.im_class, obj.__name__))
732732
else:
733-
if PY3:
733+
if PY3: # pragma: no branch
734734
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj)
735735
else:
736736
self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__),
@@ -783,7 +783,7 @@ def save_inst(self, obj):
783783
save(stuff)
784784
write(pickle.BUILD)
785785

786-
if not PY3:
786+
if not PY3: # pragma: no branch
787787
dispatch[types.InstanceType] = save_inst
788788

789789
def save_property(self, obj):
@@ -883,7 +883,7 @@ def save_not_implemented(self, obj):
883883

884884
try: # Python 2
885885
dispatch[file] = save_file
886-
except NameError: # Python 3
886+
except NameError: # Python 3 # pragma: no branch
887887
dispatch[io.TextIOWrapper] = save_file
888888

889889
dispatch[type(Ellipsis)] = save_ellipsis
@@ -904,6 +904,12 @@ def save_root_logger(self, obj):
904904

905905
dispatch[logging.RootLogger] = save_root_logger
906906

907+
if hasattr(types, "MappingProxyType"): # pragma: no branch
908+
def save_mappingproxy(self, obj):
909+
self.save_reduce(types.MappingProxyType, (dict(obj),), obj=obj)
910+
911+
dispatch[types.MappingProxyType] = save_mappingproxy
912+
907913
"""Special functions for Add-on libraries"""
908914
def inject_addons(self):
909915
"""Plug in system. Register additional pickling functions if modules already loaded"""
@@ -1213,7 +1219,7 @@ def _getobject(modname, attribute):
12131219

12141220
""" Use copy_reg to extend global pickle definitions """
12151221

1216-
if sys.version_info < (3, 4):
1222+
if sys.version_info < (3, 4): # pragma: no branch
12171223
method_descriptor = type(str.upper)
12181224

12191225
def _reduce_method_descriptor(obj):

Diff for: tests/cloudpickle_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,21 @@ def __init__(self):
13441344
with pytest.raises(AttributeError):
13451345
obj.non_registered_attribute = 1
13461346

1347+
@unittest.skipIf(not hasattr(types, "MappingProxyType"),
1348+
"Old versions of Python do not have this type.")
1349+
def test_mappingproxy(self):
1350+
mp = types.MappingProxyType({"some_key": "some value"})
1351+
assert mp == pickle_depickle(mp, protocol=self.protocol)
1352+
1353+
def test_dataclass(self):
1354+
dataclasses = pytest.importorskip("dataclasses")
1355+
1356+
DataClass = dataclasses.make_dataclass('DataClass', [('x', int)])
1357+
data = DataClass(x=42)
1358+
1359+
pickle_depickle(DataClass, protocol=self.protocol)
1360+
assert data.x == pickle_depickle(data, protocol=self.protocol).x == 42
1361+
13471362

13481363
class Protocol2CloudPickleTest(CloudPickleTest):
13491364

0 commit comments

Comments
 (0)