Skip to content

Commit e0cd33b

Browse files
pierreglaserogrisel
authored andcommitted
MNT integrate pickle.save_global logic in cloudpickle natively (#273)
1 parent 6176193 commit e0cd33b

File tree

3 files changed

+132
-85
lines changed

3 files changed

+132
-85
lines changed

Diff for: cloudpickle/cloudpickle.py

+100-84
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import opcode
5151
import operator
5252
import pickle
53+
import platform
5354
import struct
5455
import sys
5556
import traceback
@@ -92,6 +93,7 @@
9293
string_types = (str,)
9394
PY3 = True
9495
PY2 = False
96+
from importlib._bootstrap import _find_spec
9597

9698

9799
def _ensure_tracking(class_def):
@@ -123,6 +125,69 @@ def _getattribute(obj, name):
123125
return getattr(obj, name, None), None
124126

125127

128+
def _whichmodule(obj, name):
129+
"""Find the module an object belongs to.
130+
131+
This function differs from ``pickle.whichmodule`` in two ways:
132+
- it does not mangle the cases where obj's module is __main__ and obj was
133+
not found in any module.
134+
- Errors arising during module introspection are ignored, as those errors
135+
are considered unwanted side effects.
136+
"""
137+
module_name = getattr(obj, '__module__', None)
138+
if module_name is not None:
139+
return module_name
140+
# Protect the iteration by using a list copy of sys.modules against dynamic
141+
# modules that trigger imports of other modules upon calls to getattr.
142+
for module_name, module in list(sys.modules.items()):
143+
if module_name == '__main__' or module is None:
144+
continue
145+
try:
146+
if _getattribute(module, name)[0] is obj:
147+
return module_name
148+
except Exception:
149+
pass
150+
return None
151+
152+
153+
def _is_global(obj, name=None):
154+
"""Determine if obj can be pickled as attribute of a file-backed module"""
155+
if name is None:
156+
name = getattr(obj, '__qualname__', None)
157+
if name is None:
158+
name = getattr(obj, '__name__', None)
159+
160+
module_name = _whichmodule(obj, name)
161+
162+
if module_name is None:
163+
# In this case, obj.__module__ is None AND obj was not found in any
164+
# imported module. obj is thus treated as dynamic.
165+
return False
166+
167+
if module_name == "__main__":
168+
return False
169+
170+
module = sys.modules.get(module_name, None)
171+
if module is None:
172+
# The main reason why obj's module would not be imported is that this
173+
# module has been dynamically created, using for example
174+
# types.ModuleType. The other possibility is that module was removed
175+
# from sys.modules after obj was created/imported. But this case is not
176+
# supported, as the standard pickle does not support it either.
177+
return False
178+
179+
# module has been added to sys.modules, but it can still be dynamic.
180+
if _is_dynamic(module):
181+
return False
182+
183+
try:
184+
obj2, parent = _getattribute(module, name)
185+
except AttributeError:
186+
# obj was not found inside the module it points to
187+
return False
188+
return obj2 is obj
189+
190+
126191
def _make_cell_set_template_code():
127192
"""Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF
128193
@@ -236,10 +301,6 @@ def cell_set(cell, value):
236301
EXTENDED_ARG = dis.EXTENDED_ARG
237302

238303

239-
def islambda(func):
240-
return getattr(func, '__name__') == '<lambda>'
241-
242-
243304
_BUILTIN_TYPE_NAMES = {}
244305
for k, v in types.__dict__.items():
245306
if type(v) is type:
@@ -392,61 +453,9 @@ def save_function(self, obj, name=None):
392453
Determines what kind of function obj is (e.g. lambda, defined at
393454
interactive prompt, etc) and handles the pickling appropriately.
394455
"""
395-
write = self.write
396-
397-
if name is None:
398-
name = getattr(obj, '__qualname__', None)
399-
if name is None:
400-
name = getattr(obj, '__name__', None)
401-
try:
402-
# whichmodule() could fail, see
403-
# https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling
404-
modname = pickle.whichmodule(obj, name)
405-
except Exception:
406-
modname = None
407-
# print('which gives %s %s %s' % (modname, obj, name))
408-
try:
409-
themodule = sys.modules[modname]
410-
except KeyError:
411-
# eval'd items such as namedtuple give invalid items for their function __module__
412-
modname = '__main__'
413-
414-
if modname == '__main__':
415-
themodule = None
416-
417-
try:
418-
lookedup_by_name, _ = _getattribute(themodule, name)
419-
except Exception:
420-
lookedup_by_name = None
421-
422-
if themodule:
423-
if lookedup_by_name is obj:
424-
return self.save_global(obj, name)
425-
426-
# if func is lambda, def'ed at prompt, is in main, or is nested, then
427-
# we'll pickle the actual function object rather than simply saving a
428-
# reference (as is done in default pickler), via save_function_tuple.
429-
if (islambda(obj)
430-
or getattr(obj.__code__, 'co_filename', None) == '<stdin>'
431-
or themodule is None):
432-
self.save_function_tuple(obj)
433-
return
434-
else:
435-
# func is nested
436-
if lookedup_by_name is None or lookedup_by_name is not obj:
437-
self.save_function_tuple(obj)
438-
return
439-
440-
if obj.__dict__:
441-
# essentially save_reduce, but workaround needed to avoid recursion
442-
self.save(_restore_attr)
443-
write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n')
444-
self.memoize(obj)
445-
self.save(obj.__dict__)
446-
write(pickle.TUPLE + pickle.REDUCE)
447-
else:
448-
write(pickle.GLOBAL + modname + '\n' + name + '\n')
449-
self.memoize(obj)
456+
if not _is_global(obj, name=name):
457+
return self.save_function_tuple(obj)
458+
return Pickler.save_global(self, obj, name=name)
450459

451460
dispatch[types.FunctionType] = save_function
452461

@@ -801,23 +810,15 @@ def save_global(self, obj, name=None, pack=struct.pack):
801810
return self.save_reduce(type, (Ellipsis,), obj=obj)
802811
elif obj is type(NotImplemented):
803812
return self.save_reduce(type, (NotImplemented,), obj=obj)
804-
805-
if obj.__module__ == "__main__":
806-
return self.save_dynamic_class(obj)
807-
808-
try:
809-
return Pickler.save_global(self, obj, name=name)
810-
except Exception:
811-
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
812-
if obj in _BUILTIN_TYPE_NAMES:
813-
return self.save_reduce(
814-
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
815-
816-
typ = type(obj)
817-
if typ is not obj and isinstance(obj, (type, types.ClassType)):
818-
return self.save_dynamic_class(obj)
819-
820-
raise
813+
elif obj in _BUILTIN_TYPE_NAMES:
814+
return self.save_reduce(
815+
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
816+
elif name is not None:
817+
Pickler.save_global(self, obj, name=name)
818+
elif not _is_global(obj, name=name):
819+
self.save_dynamic_class(obj)
820+
else:
821+
Pickler.save_global(self, obj, name=name)
821822

822823
dispatch[type] = save_global
823824
dispatch[types.ClassType] = save_global
@@ -1085,13 +1086,6 @@ def dynamic_subimport(name, vars):
10851086
return mod
10861087

10871088

1088-
# restores function attributes
1089-
def _restore_attr(obj, attr):
1090-
for key, val in attr.items():
1091-
setattr(obj, key, val)
1092-
return obj
1093-
1094-
10951089
def _gen_ellipsis():
10961090
return Ellipsis
10971091

@@ -1298,7 +1292,29 @@ def _is_dynamic(module):
12981292
return False
12991293

13001294
if hasattr(module, '__spec__'):
1301-
return module.__spec__ is None
1295+
if module.__spec__ is not None:
1296+
return False
1297+
1298+
# In PyPy, Some built-in modules such as _codecs can have their
1299+
# __spec__ attribute set to None despite being imported. For such
1300+
# modules, the ``_find_spec`` utility of the standard library is used.
1301+
parent_name = module.__name__.rpartition('.')[0]
1302+
if parent_name: # pragma: no cover
1303+
# This code handles the case where an imported package (and not
1304+
# module) remains with __spec__ set to None. It is however untested
1305+
# as no package in the PyPy stdlib has __spec__ set to None after
1306+
# it is imported.
1307+
try:
1308+
parent = sys.modules[parent_name]
1309+
except KeyError:
1310+
msg = "parent {!r} not in sys.modules"
1311+
raise ImportError(msg.format(parent_name))
1312+
else:
1313+
pkgpath = parent.__path__
1314+
else:
1315+
pkgpath = None
1316+
return _find_spec(module.__name__, pkgpath, module) is None
1317+
13021318
else:
13031319
# Backward compat for Python 2
13041320
import imp

Diff for: tests/cloudpickle_test.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,15 @@ def method(self, x):
478478
mod1, mod2 = pickle_depickle([mod, mod])
479479
self.assertEqual(id(mod1), id(mod2))
480480

481+
# Ensure proper pickling of mod's functions when module "looks" like a
482+
# file-backed module even though it is not:
483+
try:
484+
sys.modules['mod'] = mod
485+
depickled_f = pickle_depickle(mod.f, protocol=self.protocol)
486+
self.assertEqual(mod.f(5), depickled_f(5))
487+
finally:
488+
sys.modules.pop('mod', None)
489+
481490
def test_module_locals_behavior(self):
482491
# Makes sure that a local function defined in another module is
483492
# correctly serialized. This notably checks that the globals are
@@ -621,6 +630,10 @@ def test_is_dynamic_module(self):
621630
dynamic_module = types.ModuleType('dynamic_module')
622631
assert _is_dynamic(dynamic_module)
623632

633+
if platform.python_implementation() == 'PyPy':
634+
import _codecs
635+
assert not _is_dynamic(_codecs)
636+
624637
def test_Ellipsis(self):
625638
self.assertEqual(Ellipsis,
626639
pickle_depickle(Ellipsis, protocol=self.protocol))
@@ -1023,7 +1036,7 @@ def __init__(self, x):
10231036
self.assertEqual(set(weakset), {depickled1, depickled2})
10241037

10251038
def test_faulty_module(self):
1026-
for module_name in ['_faulty_module', '_missing_module', None]:
1039+
for module_name in ['_missing_module', None]:
10271040
class FaultyModule(object):
10281041
def __getattr__(self, name):
10291042
# This throws an exception while looking up within
@@ -1794,6 +1807,15 @@ def f(a, /, b=1):
17941807
""".format(protocol=self.protocol)
17951808
assert_run_python_script(textwrap.dedent(code))
17961809

1810+
def test___reduce___returns_string(self):
1811+
# Non regression test for objects with a __reduce__ method returning a
1812+
# string, meaning "save by attribute using save_global"
1813+
from .mypkg import some_singleton
1814+
assert some_singleton.__reduce__() == "some_singleton"
1815+
depickled_singleton = pickle_depickle(
1816+
some_singleton, protocol=self.protocol)
1817+
assert depickled_singleton is some_singleton
1818+
17971819
class Protocol2CloudPickleTest(CloudPickleTest):
17981820

17991821
protocol = 2

Diff for: tests/mypkg/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,12 @@
44
def package_function():
55
"""Function living inside a package, not a simple module"""
66
return "hello from a package!"
7+
8+
9+
class _SingletonClass(object):
10+
def __reduce__(self):
11+
# This reducer is only valid for the top level "some_singleton" object.
12+
return "some_singleton"
13+
14+
15+
some_singleton = _SingletonClass()

0 commit comments

Comments
 (0)