Skip to content

Commit ce25fe9

Browse files
committed
WIP try to generalize func base globals handling
1 parent 9a52ba3 commit ce25fe9

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

Diff for: cloudpickle/cloudpickle.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -635,11 +635,12 @@ def extract_func_data(self, func):
635635

636636
base_globals = self.globals_ref.get(id(func.__globals__), None)
637637
if base_globals is None:
638-
# For functions defined in __main__, use vars(__main__) for
639-
# base_global. This is necessary to share the global variables
640-
# across multiple functions in this module.
641-
if func.__module__ == "__main__":
642-
base_globals = "__main__"
638+
# For functions defined in a well behaved module use
639+
# vars(func.__module__) for base_globals. This is necessary to
640+
# share the global variables across multiple pickled functions from
641+
# this module.
642+
if hasattr(func, '__module__') and func.__module__ is not None:
643+
base_globals = func.__module__
643644
else:
644645
base_globals = {}
645646
self.globals_ref[id(func.__globals__)] = base_globals

Diff for: tests/cloudpickle_test.py

+36
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
from .testutils import assert_run_python_script
4848

4949

50+
_TEST_GLOBAL_VARIABLE = "default_value"
51+
52+
5053
class RaiserOnPickle(object):
5154

5255
def __init__(self, exc):
@@ -887,6 +890,39 @@ def f1():
887890
clone_func=clone_func)
888891
assert_run_python_script(textwrap.dedent(code))
889892

893+
def test_closure_interacting_with_a_global_variable(self):
894+
global _TEST_GLOBAL_VARIABLE
895+
orig_value = _TEST_GLOBAL_VARIABLE
896+
try:
897+
def f0():
898+
global _TEST_GLOBAL_VARIABLE
899+
_TEST_GLOBAL_VARIABLE = "changed_by_f0"
900+
901+
def f1():
902+
return _TEST_GLOBAL_VARIABLE
903+
904+
cloned_f0 = cloudpickle.loads(cloudpickle.dumps(
905+
f0, protocol=self.protocol))
906+
cloned_f1 = cloudpickle.loads(cloudpickle.dumps(
907+
f1, protocol=self.protocol))
908+
pickled_f1 = cloudpickle.dumps(f1, protocol=self.protocol)
909+
910+
# Change the value of the global variable
911+
cloned_f0()
912+
913+
# Ensure that the global variable is the same for another function
914+
result_f1 = cloned_f1()
915+
assert result_f1 == "changed_by_f0", result_f1
916+
assert f1() == result_f1
917+
918+
# Ensure that unpickling the global variable does not change its
919+
# value
920+
result_pickled_f1 = cloudpickle.loads(pickled_f1)()
921+
assert result_pickled_f1 == "changed_by_f0", result_pickled_f1
922+
finally:
923+
_TEST_GLOBAL_VARIABLE = orig_value
924+
925+
890926
@pytest.mark.skipif(sys.version_info >= (3, 0),
891927
reason="hardcoded pickle bytes for 2.7")
892928
def test_function_pickle_compat_0_4_0(self):

0 commit comments

Comments
 (0)