Skip to content

Commit 6b555fc

Browse files
committed
WIP try to generalize func base globals handling
1 parent d2e879b commit 6b555fc

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

Diff for: cloudpickle/cloudpickle.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -626,11 +626,12 @@ def extract_func_data(self, func):
626626

627627
base_globals = self.globals_ref.get(id(func.__globals__), None)
628628
if base_globals is None:
629-
# For functions defined in __main__, use vars(__main__) for
630-
# base_global. This is necessary to share the global variables
631-
# across multiple functions in this module.
632-
if func.__module__ == "__main__":
633-
base_globals = "__main__"
629+
# For functions defined in a well behaved module use
630+
# vars(func.__module__) for base_globals. This is necessary to
631+
# share the global variables across multiple pickled functions from
632+
# this module.
633+
if hasattr(func, '__module__') and func.__module__ is not None:
634+
base_globals = func.__module__
634635
else:
635636
base_globals = {}
636637
self.globals_ref[id(func.__globals__)] = base_globals

Diff for: tests/cloudpickle_test.py

+36
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
from .testutils import assert_run_python_script
4848

4949

50+
_TEST_GLOBAL_VARIABLE = "default_value"
51+
52+
5053
class RaiserOnPickle(object):
5154

5255
def __init__(self, exc):
@@ -880,6 +883,39 @@ def f1():
880883
clone_func=clone_func)
881884
assert_run_python_script(textwrap.dedent(code))
882885

886+
def test_closure_interacting_with_a_global_variable(self):
887+
global _TEST_GLOBAL_VARIABLE
888+
orig_value = _TEST_GLOBAL_VARIABLE
889+
try:
890+
def f0():
891+
global _TEST_GLOBAL_VARIABLE
892+
_TEST_GLOBAL_VARIABLE = "changed_by_f0"
893+
894+
def f1():
895+
return _TEST_GLOBAL_VARIABLE
896+
897+
cloned_f0 = cloudpickle.loads(cloudpickle.dumps(
898+
f0, protocol=self.protocol))
899+
cloned_f1 = cloudpickle.loads(cloudpickle.dumps(
900+
f1, protocol=self.protocol))
901+
pickled_f1 = cloudpickle.dumps(f1, protocol=self.protocol)
902+
903+
# Change the value of the global variable
904+
cloned_f0()
905+
906+
# Ensure that the global variable is the same for another function
907+
result_f1 = cloned_f1()
908+
assert result_f1 == "changed_by_f0", result_f1
909+
assert f1() == result_f1
910+
911+
# Ensure that unpickling the global variable does not change its
912+
# value
913+
result_pickled_f1 = cloudpickle.loads(pickled_f1)()
914+
assert result_pickled_f1 == "changed_by_f0", result_pickled_f1
915+
finally:
916+
_TEST_GLOBAL_VARIABLE = orig_value
917+
918+
883919
@pytest.mark.skipif(sys.version_info >= (3, 0),
884920
reason="hardcoded pickle bytes for 2.7")
885921
def test_function_pickle_compat_0_4_0(self):

0 commit comments

Comments
 (0)