Skip to content

MNT integrate pickle.save_global logic in cloudpickle natively #273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 100 additions & 84 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import opcode
import operator
import pickle
import platform
import struct
import sys
import traceback
Expand Down Expand Up @@ -92,6 +93,7 @@
string_types = (str,)
PY3 = True
PY2 = False
from importlib._bootstrap import _find_spec


def _ensure_tracking(class_def):
Expand Down Expand Up @@ -123,6 +125,69 @@ def _getattribute(obj, name):
return getattr(obj, name, None), None


def whichmodule(obj, name):
"""Find the module an object belongs to.

This function differs from ``pickle.whichmodule`` in two ways:
- it does not mangle the cases where obj's module is __main__ and obj was
not found in any module.
- Errors arising during module introspection are ignored, as those errors
are considered unwanted side effects.
"""
module_name = getattr(obj, '__module__', None)
if module_name is not None:
return module_name
# Protect the iteration by using a list copy of sys.modules against dynamic
# modules that trigger imports of other modules upon calls to getattr.
for module_name, module in list(sys.modules.items()):
if module_name == '__main__' or module is None:
continue
try:
if _getattribute(module, name)[0] is obj:
return module_name
except Exception:
pass
return None


def _is_global(obj, name=None):
"""Determine if obj can be pickled as attribute of a file-backed module"""
if name is None:
name = getattr(obj, '__qualname__', None)
if name is None:
name = getattr(obj, '__name__', None)

module_name = whichmodule(obj, name)

if module_name is None:
# In this case, obj.__module__ is None AND obj was not found in any
# imported module. obj is thus treated as dynamic.
return False

if module_name == "__main__":
return False

module = sys.modules.get(module_name, None)
if module is None:
# The main reason why obj's module would not be imported is that this
# module has been dynamically created, using for example
# types.ModuleType. The other possibility is that module was removed
# from sys.modules after obj was created/imported. But this case is not
# supported, as the standard pickle does not support it either.
return False

# module has been added to sys.modules, but it can still be dynamic.
if _is_dynamic(module):
return False

try:
obj2, parent = _getattribute(module, name)
except AttributeError:
# obj was not found inside the module it points to
return False
return obj2 is obj


def _make_cell_set_template_code():
"""Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF

Expand Down Expand Up @@ -236,10 +301,6 @@ def cell_set(cell, value):
EXTENDED_ARG = dis.EXTENDED_ARG


def islambda(func):
return getattr(func, '__name__') == '<lambda>'


_BUILTIN_TYPE_NAMES = {}
for k, v in types.__dict__.items():
if type(v) is type:
Expand Down Expand Up @@ -392,61 +453,9 @@ def save_function(self, obj, name=None):
Determines what kind of function obj is (e.g. lambda, defined at
interactive prompt, etc) and handles the pickling appropriately.
"""
write = self.write

if name is None:
name = getattr(obj, '__qualname__', None)
if name is None:
name = getattr(obj, '__name__', None)
try:
# whichmodule() could fail, see
# https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling
modname = pickle.whichmodule(obj, name)
except Exception:
modname = None
# print('which gives %s %s %s' % (modname, obj, name))
try:
themodule = sys.modules[modname]
except KeyError:
# eval'd items such as namedtuple give invalid items for their function __module__
modname = '__main__'

if modname == '__main__':
themodule = None

try:
lookedup_by_name, _ = _getattribute(themodule, name)
except Exception:
lookedup_by_name = None

if themodule:
if lookedup_by_name is obj:
return self.save_global(obj, name)

# if func is lambda, def'ed at prompt, is in main, or is nested, then
# we'll pickle the actual function object rather than simply saving a
# reference (as is done in default pickler), via save_function_tuple.
if (islambda(obj)
or getattr(obj.__code__, 'co_filename', None) == '<stdin>'
or themodule is None):
self.save_function_tuple(obj)
return
else:
# func is nested
if lookedup_by_name is None or lookedup_by_name is not obj:
self.save_function_tuple(obj)
return

if obj.__dict__:
# essentially save_reduce, but workaround needed to avoid recursion
self.save(_restore_attr)
write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n')
self.memoize(obj)
self.save(obj.__dict__)
write(pickle.TUPLE + pickle.REDUCE)
else:
write(pickle.GLOBAL + modname + '\n' + name + '\n')
self.memoize(obj)
if not _is_global(obj, name=name):
return self.save_function_tuple(obj)
return Pickler.save_global(self, obj, name=name)

dispatch[types.FunctionType] = save_function

Expand Down Expand Up @@ -801,23 +810,15 @@ def save_global(self, obj, name=None, pack=struct.pack):
return self.save_reduce(type, (Ellipsis,), obj=obj)
elif obj is type(NotImplemented):
return self.save_reduce(type, (NotImplemented,), obj=obj)

if obj.__module__ == "__main__":
return self.save_dynamic_class(obj)

try:
return Pickler.save_global(self, obj, name=name)
except Exception:
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
if obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)

typ = type(obj)
if typ is not obj and isinstance(obj, (type, types.ClassType)):
return self.save_dynamic_class(obj)

raise
elif obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
elif name is not None:
Pickler.save_global(self, obj, name=name)
elif not _is_global(obj, name=name):
self.save_dynamic_class(obj)
else:
Pickler.save_global(self, obj, name=name)

dispatch[type] = save_global
dispatch[types.ClassType] = save_global
Expand Down Expand Up @@ -1085,13 +1086,6 @@ def dynamic_subimport(name, vars):
return mod


# restores function attributes
def _restore_attr(obj, attr):
for key, val in attr.items():
setattr(obj, key, val)
return obj


def _gen_ellipsis():
return Ellipsis

Expand Down Expand Up @@ -1298,7 +1292,29 @@ def _is_dynamic(module):
return False

if hasattr(module, '__spec__'):
return module.__spec__ is None
if module.__spec__ is not None:
return False

# In PyPy, Some built-in modules such as _codecs can have their
# __spec__ attribute set to None despite being imported. For such
# modules, the ``_find_spec`` utility of the standard library is used.
parent_name = module.__name__.rpartition('.')[0]
if parent_name: # pragma: no cover
# This code handles the case where an imported package (and not
# module) remains with __spec__ set to None. It is however untested
# as no package in the PyPy stdlib has __spec__ set to None after
# it is imported.
try:
parent = sys.modules[parent_name]
except KeyError:
msg = "parent {!r} not in sys.modules"
raise ImportError(msg.format(parent_name))
else:
pkgpath = parent.__path__
else:
pkgpath = None
return _find_spec(module.__name__, pkgpath, module) is None

else:
# Backward compat for Python 2
import imp
Expand Down
24 changes: 23 additions & 1 deletion tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,15 @@ def method(self, x):
mod1, mod2 = pickle_depickle([mod, mod])
self.assertEqual(id(mod1), id(mod2))

# Ensure proper pickling of mod's functions when module "looks" like a
# file-backed module even though it is not:
try:
sys.modules['mod'] = mod
depickled_f = pickle_depickle(mod.f, protocol=self.protocol)
self.assertEqual(mod.f(5), depickled_f(5))
finally:
sys.modules.pop('mod', None)

def test_module_locals_behavior(self):
# Makes sure that a local function defined in another module is
# correctly serialized. This notably checks that the globals are
Expand Down Expand Up @@ -621,6 +630,10 @@ def test_is_dynamic_module(self):
dynamic_module = types.ModuleType('dynamic_module')
assert _is_dynamic(dynamic_module)

if platform.python_implementation() == 'PyPy':
import _codecs
assert not _is_dynamic(_codecs)

def test_Ellipsis(self):
self.assertEqual(Ellipsis,
pickle_depickle(Ellipsis, protocol=self.protocol))
Expand Down Expand Up @@ -1023,7 +1036,7 @@ def __init__(self, x):
self.assertEqual(set(weakset), {depickled1, depickled2})

def test_faulty_module(self):
for module_name in ['_faulty_module', '_missing_module', None]:
for module_name in ['_missing_module', None]:
class FaultyModule(object):
def __getattr__(self, name):
# This throws an exception while looking up within
Expand Down Expand Up @@ -1794,6 +1807,15 @@ def f(a, /, b=1):
""".format(protocol=self.protocol)
assert_run_python_script(textwrap.dedent(code))

def test___reduce___returns_string(self):
# Non regression test for objects with a __reduce__ method returning a
# string, meaning "save by attribute using save_global"
from .mypkg import some_singleton
assert some_singleton.__reduce__() == "some_singleton"
depickled_singleton = pickle_depickle(
some_singleton, protocol=self.protocol)
assert depickled_singleton is some_singleton

class Protocol2CloudPickleTest(CloudPickleTest):

protocol = 2
Expand Down
9 changes: 9 additions & 0 deletions tests/mypkg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,12 @@
def package_function():
"""Function living inside a package, not a simple module"""
return "hello from a package!"


class _SingletonClass(object):
def __reduce__(self):
# This reducer is only valid for the top level "some_singleton" object.
return "some_singleton"


some_singleton = _SingletonClass()