Skip to content

Commit e0ad635

Browse files
suquarkogrisel
andauthored
Add the missing buffer_callback argument (#308)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent f4ce61f commit e0ad635

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

Diff for: CHANGES.md

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
over the module list
2020
([PR #322](https://github.com/cloudpipe/cloudpickle/pull/322)).
2121

22+
- Add support for out-of-band pickling (Python 3.8 and later).
23+
https://docs.python.org/3/library/pickle.html#example
24+
([issue #308](https://github.com/cloudpipe/cloudpickle/pull/308))
25+
2226
1.2.2
2327
=====
2428

Diff for: cloudpickle/cloudpickle_fast.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
# Shorthands similar to pickle.dump/pickle.dumps
37-
def dump(obj, file, protocol=None):
37+
def dump(obj, file, protocol=None, buffer_callback=None):
3838
"""Serialize obj as bytes streamed into file
3939
4040
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
@@ -44,10 +44,10 @@ def dump(obj, file, protocol=None):
4444
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
4545
compatibility with older versions of Python.
4646
"""
47-
CloudPickler(file, protocol=protocol).dump(obj)
47+
CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback).dump(obj)
4848

4949

50-
def dumps(obj, protocol=None):
50+
def dumps(obj, protocol=None, buffer_callback=None):
5151
"""Serialize obj as a string of bytes allocated in memory
5252
5353
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
@@ -58,7 +58,7 @@ def dumps(obj, protocol=None):
5858
compatibility with older versions of Python.
5959
"""
6060
with io.BytesIO() as file:
61-
cp = CloudPickler(file, protocol=protocol)
61+
cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback)
6262
cp.dump(obj)
6363
return file.getvalue()
6464

@@ -421,10 +421,10 @@ class CloudPickler(Pickler):
421421
dispatch[types.MappingProxyType] = _mappingproxy_reduce
422422
dispatch[weakref.WeakSet] = _weakset_reduce
423423

424-
def __init__(self, file, protocol=None):
424+
def __init__(self, file, protocol=None, buffer_callback=None):
425425
if protocol is None:
426426
protocol = DEFAULT_PROTOCOL
427-
Pickler.__init__(self, file, protocol=protocol)
427+
Pickler.__init__(self, file, protocol=protocol, buffer_callback=buffer_callback)
428428
# map functions __globals__ attribute ids, to ensure that functions
429429
# sharing the same global namespace at pickling time also share their
430430
# global namespace at unpickling time.

Diff for: tests/cloudpickle_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,22 @@ def __getattr__(self, name):
20522052
with pytest.raises(pickle.PicklingError, match='recursion'):
20532053
cloudpickle.dumps(a)
20542054

2055+
def test_out_of_band_buffers(self):
2056+
if self.protocol < 5:
2057+
pytest.skip("Need Pickle Protocol 5 or later")
2058+
np = pytest.importorskip("numpy")
2059+
2060+
class LocallyDefinedClass:
2061+
data = np.zeros(10)
2062+
2063+
data_instance = LocallyDefinedClass()
2064+
buffers = []
2065+
pickle_bytes = cloudpickle.dumps(data_instance, protocol=self.protocol,
2066+
buffer_callback=buffers.append)
2067+
assert len(buffers) == 1
2068+
reconstructed = pickle.loads(pickle_bytes, buffers=buffers)
2069+
np.testing.assert_allclose(reconstructed.data, data_instance.data)
2070+
20552071

20562072
class Protocol2CloudPickleTest(CloudPickleTest):
20572073

0 commit comments

Comments
 (0)