Skip to content

Commit f2e9c51

Browse files
committed
MNT use the is_dynamic for both classes and functions
1 parent 3f4d9da commit f2e9c51

File tree

2 files changed

+83
-74
lines changed

2 files changed

+83
-74
lines changed

Diff for: cloudpickle/cloudpickle.py

+69-71
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,67 @@ def _getattribute(obj, name):
123123
return getattr(obj, name, None), None
124124

125125

126+
def whichmodule(obj, name):
127+
"""Find the module an object belongs to.
128+
129+
This function differs from ``pickle.whichmodule`` in two ways:
130+
- it does not mangle the cases where obj's module is __main__ and obj was
131+
not found in any module.
132+
- Errors arising during module introspection are ignored, as those errors
133+
are considered unwanted side effects.
134+
"""
135+
module_name = getattr(obj, '__module__', None)
136+
if module_name is not None:
137+
return module_name
138+
# Protect the iteration by using a list copy of sys.modules against dynamic
139+
# modules that trigger imports of other modules upon calls to getattr.
140+
for module_name, module in list(sys.modules.items()):
141+
try:
142+
if _getattribute(module, name)[0] is obj:
143+
return module_name
144+
except Exception:
145+
pass
146+
return None
147+
148+
149+
def _is_global(obj, name=None):
150+
"""Determine if obj can be pickled as attribute of a file-backed module."""
151+
if name is None:
152+
name = getattr(obj, '__qualname__', None)
153+
if name is None:
154+
name = getattr(obj, '__name__', None)
155+
156+
module_name = whichmodule(obj, name)
157+
158+
if module_name is None:
159+
# In this case, obj.__module__ is None AND obj was not found in any
160+
# imported module. obj is thus treated as dynamic.
161+
return False
162+
163+
if module_name == "__main__":
164+
return False
165+
166+
module = sys.modules.get(module_name, None)
167+
if module is None:
168+
# The main reason why obj's module would not be imported is that this
169+
# module has been dynamically created, using for example
170+
# types.ModuleType. The other possibility is that module was removed
171+
# from sys.modules after obj was created/imported. But this case is not
172+
# supported, as the standard pickle does not support it either.
173+
return False
174+
175+
if _is_dynamic(module):
176+
# module has been added to sys.modules, but it can still be dynamic.
177+
return False
178+
179+
try:
180+
obj2, parent = _getattribute(module, name)
181+
except AttributeError:
182+
# obj was not found inside the module it points to
183+
return False
184+
return obj2 is obj
185+
186+
126187
def _make_cell_set_template_code():
127188
"""Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF
128189
@@ -392,61 +453,9 @@ def save_function(self, obj, name=None):
392453
Determines what kind of function obj is (e.g. lambda, defined at
393454
interactive prompt, etc) and handles the pickling appropriately.
394455
"""
395-
write = self.write
396-
397-
if name is None:
398-
name = getattr(obj, '__qualname__', None)
399-
if name is None:
400-
name = getattr(obj, '__name__', None)
401-
try:
402-
# whichmodule() could fail, see
403-
# https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling
404-
modname = pickle.whichmodule(obj, name)
405-
except Exception:
406-
modname = None
407-
# print('which gives %s %s %s' % (modname, obj, name))
408-
try:
409-
themodule = sys.modules[modname]
410-
except KeyError:
411-
# eval'd items such as namedtuple give invalid items for their function __module__
412-
modname = '__main__'
413-
414-
if modname == '__main__':
415-
themodule = None
416-
417-
try:
418-
lookedup_by_name, _ = _getattribute(themodule, name)
419-
except Exception:
420-
lookedup_by_name = None
421-
422-
if themodule:
423-
if lookedup_by_name is obj:
424-
return self.save_global(obj, name)
425-
426-
# if func is lambda, def'ed at prompt, is in main, or is nested, then
427-
# we'll pickle the actual function object rather than simply saving a
428-
# reference (as is done in default pickler), via save_function_tuple.
429-
if (islambda(obj)
430-
or getattr(obj.__code__, 'co_filename', None) == '<stdin>'
431-
or themodule is None):
432-
self.save_function_tuple(obj)
433-
return
434-
else:
435-
# func is nested
436-
if lookedup_by_name is None or lookedup_by_name is not obj:
437-
self.save_function_tuple(obj)
438-
return
439-
440-
if obj.__dict__:
441-
# essentially save_reduce, but workaround needed to avoid recursion
442-
self.save(_restore_attr)
443-
write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n')
444-
self.memoize(obj)
445-
self.save(obj.__dict__)
446-
write(pickle.TUPLE + pickle.REDUCE)
447-
else:
448-
write(pickle.GLOBAL + modname + '\n' + name + '\n')
449-
self.memoize(obj)
456+
if not _is_global(obj, name=name):
457+
return self.save_function_tuple(obj)
458+
return Pickler.save_global(self, obj, name=name)
450459

451460
dispatch[types.FunctionType] = save_function
452461

@@ -801,23 +810,12 @@ def save_global(self, obj, name=None, pack=struct.pack):
801810
return self.save_reduce(type, (Ellipsis,), obj=obj)
802811
elif obj is type(NotImplemented):
803812
return self.save_reduce(type, (NotImplemented,), obj=obj)
804-
805-
if obj.__module__ == "__main__":
813+
elif obj in _BUILTIN_TYPE_NAMES:
814+
return self.save_reduce(
815+
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
816+
elif not _is_global(obj, name=name):
806817
return self.save_dynamic_class(obj)
807-
808-
try:
809-
return Pickler.save_global(self, obj, name=name)
810-
except Exception:
811-
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
812-
if obj in _BUILTIN_TYPE_NAMES:
813-
return self.save_reduce(
814-
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
815-
816-
typ = type(obj)
817-
if typ is not obj and isinstance(obj, (type, types.ClassType)):
818-
return self.save_dynamic_class(obj)
819-
820-
raise
818+
return Pickler.save_global(self, obj, name=name)
821819

822820
dispatch[type] = save_global
823821
dispatch[types.ClassType] = save_global

Diff for: tests/cloudpickle_test.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1023,12 +1023,12 @@ def __init__(self, x):
10231023
self.assertEqual(set(weakset), {depickled1, depickled2})
10241024

10251025
def test_faulty_module(self):
1026-
for module_name in ['_faulty_module', '_missing_module', None]:
1026+
for module_name in ['_missing_module', None]:
10271027
class FaultyModule(object):
10281028
def __getattr__(self, name):
10291029
# This throws an exception while looking up within
10301030
# pickle.whichmodule or getattr(module, name, None)
1031-
raise Exception()
1031+
raise Exception("FaultyModule error")
10321032

10331033
class Foo(object):
10341034
__module__ = module_name
@@ -1050,7 +1050,18 @@ def foo():
10501050
cloned = pickle_depickle(foo, protocol=self.protocol)
10511051
self.assertEqual(cloned(), "it works!")
10521052
finally:
1053-
sys.modules.pop("_faulty_module", None)
1053+
pass
1054+
1055+
# If a class/function points to a faulty module, the exception raised
1056+
# by the faulty module will not be caught by cloudpickle.
1057+
Foo.__module__ = foo.__module__ = "_faulty_module"
1058+
try:
1059+
for obj in [Foo, foo]:
1060+
with pytest.raises(Exception) as exc_info:
1061+
cloned = pickle_depickle(obj, protocol=self.protocol)
1062+
assert "FaultyModule error" == str(exc_info.value)
1063+
finally:
1064+
sys.modules.pop("_faulty_module", None)
10541065

10551066
def test_dynamic_pytest_module(self):
10561067
# Test case for pull request https://github.com/cloudpipe/cloudpickle/pull/116

0 commit comments

Comments
 (0)