Skip to content

Commit 9a52ba3

Browse files
tomMoralogrisel
authored andcommitted
FIX globals for functions defined in __main__ (#188)
Fix #187 The cause of this issue is that global variables defined in `__main__` are not shared by unpickled functions and their value are overwritten at each unpickling.
1 parent a249c44 commit 9a52ba3

File tree

6 files changed

+80
-3
lines changed

6 files changed

+80
-3
lines changed

Diff for: .coveragerc

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[run]
22
branch = True
3+
parallel = True
34
source = cloudpickle
45

56
[report]

Diff for: .travis.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ before_script:
2626
- flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics
2727
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
2828
- flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
29+
- python ci/install_coverage_subprocess_pth.py
2930
script:
3031
- if [[ $TRAVIS_PYTHON_VERSION != 'pypy'* ]]; then source activate testenv; fi
31-
- PYTHONPATH='.:tests' py.test -r s --cov-config .coveragerc --cov=cloudpickle
32+
- COVERAGE_PROCESS_START="$TRAVIS_BUILD_DIR/.coveragerc" PYTHONPATH='.:tests' py.test -r s
3233
after_success:
3334
- if [[ $TRAVIS_PYTHON_VERSION != 'pypy'* ]]; then source activate testenv; fi
35+
- coverage combine --append
3436
- codecov

Diff for: ci/install_coverage_subprocess_pth.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Make it possible to enable test coverage reporting for Python
2+
# code run in children processes.
3+
# http://coverage.readthedocs.io/en/latest/subprocess.html
4+
5+
import os.path as op
6+
from distutils.sysconfig import get_python_lib
7+
8+
FILE_CONTENT = u"""\
9+
import coverage; coverage.process_startup()
10+
"""
11+
12+
filename = op.join(get_python_lib(), 'coverage_subprocess.pth')
13+
with open(filename, 'wb') as f:
14+
f.write(FILE_CONTENT.encode('ascii'))
15+
16+
print('Installed subprocess coverage support: %s' % filename)

Diff for: cloudpickle/cloudpickle.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ def _save_subimports(self, code, top_level_dependencies):
431431
Ensure de-pickler imports any package child-modules that
432432
are needed by the function
433433
"""
434+
434435
# check if any known dependency is an imported package
435436
for x in top_level_dependencies:
436437
if isinstance(x, types.ModuleType) and hasattr(x, '__package__') and x.__package__:
@@ -632,7 +633,15 @@ def extract_func_data(self, func):
632633
# save the dict
633634
dct = func.__dict__
634635

635-
base_globals = self.globals_ref.get(id(func.__globals__), {})
636+
base_globals = self.globals_ref.get(id(func.__globals__), None)
637+
if base_globals is None:
638+
# For functions defined in __main__, use vars(__main__) for
639+
# base_global. This is necessary to share the global variables
640+
# across multiple functions in this module.
641+
if func.__module__ == "__main__":
642+
base_globals = "__main__"
643+
else:
644+
base_globals = {}
636645
self.globals_ref[id(func.__globals__)] = base_globals
637646

638647
return (code, f_globals, defaults, closure, dct, base_globals)
@@ -1037,7 +1046,11 @@ def _fill_function(*args):
10371046
else:
10381047
raise ValueError('Unexpected _fill_value arguments: %r' % (args,))
10391048

1040-
func.__globals__.update(state['globals'])
1049+
# Only set global variables that do not exist.
1050+
for k, v in state['globals'].items():
1051+
if k not in func.__globals__:
1052+
func.__globals__[k] = v
1053+
10411054
func.__defaults__ = state['defaults']
10421055
func.__dict__ = state['dict']
10431056
if 'annotations' in state:
@@ -1076,6 +1089,8 @@ def _make_skel_func(code, cell_count, base_globals=None):
10761089
"""
10771090
if base_globals is None:
10781091
base_globals = {}
1092+
elif isinstance(base_globals, str):
1093+
base_globals = vars(sys.modules[base_globals])
10791094
base_globals['__builtins__'] = __builtins__
10801095

10811096
closure = (

Diff for: tests/cloudpickle_test.py

+39
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,45 @@ def f4(x):
848848
""".format(protocol=self.protocol)
849849
assert_run_python_script(textwrap.dedent(code))
850850

851+
def test_interactively_defined_global_variable(self):
852+
# Check that callables defined in the __main__ module of a Python
853+
# script (or jupyter kernel) correctly retrieve global variables.
854+
code_template = """\
855+
from testutils import subprocess_pickle_echo
856+
from cloudpickle import dumps, loads
857+
858+
def local_clone(obj, protocol=None):
859+
return loads(dumps(obj, protocol=protocol))
860+
861+
VARIABLE = "default_value"
862+
863+
def f0():
864+
global VARIABLE
865+
VARIABLE = "changed_by_f0"
866+
867+
def f1():
868+
return VARIABLE
869+
870+
cloned_f0 = {clone_func}(f0, protocol={protocol})
871+
cloned_f1 = {clone_func}(f1, protocol={protocol})
872+
pickled_f1 = dumps(f1, protocol={protocol})
873+
874+
# Change the value of the global variable
875+
cloned_f0()
876+
877+
# Ensure that the global variable is the same for another function
878+
result_f1 = cloned_f1()
879+
assert result_f1 == "changed_by_f0", result_f1
880+
881+
# Ensure that unpickling the global variable does not change its value
882+
result_pickled_f1 = loads(pickled_f1)()
883+
assert result_pickled_f1 == "changed_by_f0", result_pickled_f1
884+
"""
885+
for clone_func in ['local_clone', 'subprocess_pickle_echo']:
886+
code = code_template.format(protocol=self.protocol,
887+
clone_func=clone_func)
888+
assert_run_python_script(textwrap.dedent(code))
889+
851890
@pytest.mark.skipif(sys.version_info >= (3, 0),
852891
reason="hardcoded pickle bytes for 2.7")
853892
def test_function_pickle_compat_0_4_0(self):

Diff for: tests/testutils.py

+4
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def assert_run_python_script(source_code, timeout=5):
9393
'stderr': STDOUT,
9494
'env': {'PYTHONPATH': pythonpath},
9595
}
96+
# If coverage is running, pass the config file to the subprocess
97+
coverage_rc = os.environ.get("COVERAGE_PROCESS_START")
98+
if coverage_rc:
99+
kwargs['env']['COVERAGE_PROCESS_START'] = coverage_rc
96100
if timeout_supported:
97101
kwargs['timeout'] = timeout
98102
try:

0 commit comments

Comments
 (0)