Skip to content

Commit 3e80b26

Browse files
authored
Support pickling dynamic classes subclassing typing.Generic instances on 3.7+ (#351)
1 parent 215d3dd commit 3e80b26

File tree

4 files changed

+107
-5
lines changed

4 files changed

+107
-5
lines changed

Diff for: CHANGES.md

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
and expand the support for pickling `TypeVar` instances (dynamic or non-dynamic)
1212
to Python 3.5-3.6 ([PR #350](https://github.com/cloudpipe/cloudpickle/pull/350))
1313

14+
- Add support for pickling dynamic classes subclassing `typing.Generic`
15+
instances on Python 3.7+
16+
([PR #351](https://github.com/cloudpipe/cloudpickle/pull/351))
17+
1418
1.3.0
1519
=====
1620

Diff for: cloudpickle/cloudpickle.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def dump(self, obj):
446446
raise
447447

448448
def save_typevar(self, obj):
449-
self.save_reduce(*_typevar_reduce(obj))
449+
self.save_reduce(*_typevar_reduce(obj), obj=obj)
450450

451451
dispatch[typing.TypeVar] = save_typevar
452452

@@ -645,7 +645,7 @@ def save_dynamic_class(self, obj):
645645
# "Regular" class definition:
646646
tp = type(obj)
647647
self.save_reduce(_make_skeleton_class,
648-
(tp, obj.__name__, obj.__bases__, type_kwargs,
648+
(tp, obj.__name__, _get_bases(obj), type_kwargs,
649649
_ensure_tracking(obj), None),
650650
obj=obj)
651651

@@ -1163,7 +1163,10 @@ class id will also reuse this class definition.
11631163
The "extra" variable is meant to be a dict (or None) that can be used for
11641164
forward compatibility shall the need arise.
11651165
"""
1166-
skeleton_class = type_constructor(name, bases, type_kwargs)
1166+
skeleton_class = types.new_class(
1167+
name, bases, {'metaclass': type_constructor},
1168+
lambda ns: ns.update(type_kwargs)
1169+
)
11671170
return _lookup_class_or_track(class_tracker_id, skeleton_class)
11681171

11691172

@@ -1268,3 +1271,13 @@ def _typevar_reduce(obj):
12681271
if module_and_name is None:
12691272
return (_make_typevar, _decompose_typevar(obj))
12701273
return (getattr, module_and_name)
1274+
1275+
1276+
def _get_bases(typ):
1277+
if hasattr(typ, '__orig_bases__'):
1278+
# For generic types (see PEP 560)
1279+
bases_attr = '__orig_bases__'
1280+
else:
1281+
# For regular class objects
1282+
bases_attr = '__bases__'
1283+
return getattr(typ, bases_attr)

Diff for: cloudpickle/cloudpickle_fast.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
_is_dynamic, _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
2929
_find_imported_submodules, _get_cell_contents, _is_importable_by_name, _builtin_type,
3030
Enum, _ensure_tracking, _make_skeleton_class, _make_skeleton_enum,
31-
_extract_class_dict, dynamic_subimport, subimport, _typevar_reduce,
31+
_extract_class_dict, dynamic_subimport, subimport, _typevar_reduce, _get_bases,
3232
)
3333

3434
load, loads = _pickle.load, _pickle.loads
@@ -76,7 +76,7 @@ def _class_getnewargs(obj):
7676
if isinstance(__dict__, property):
7777
type_kwargs['__dict__'] = __dict__
7878

79-
return (type(obj), obj.__name__, obj.__bases__, type_kwargs,
79+
return (type(obj), obj.__name__, _get_bases(obj), type_kwargs,
8080
_ensure_tracking(obj), None)
8181

8282

Diff for: tests/cloudpickle_test.py

+85
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
from .testutils import subprocess_pickle_echo
5252
from .testutils import assert_run_python_script
53+
from .testutils import subprocess_worker
5354

5455

5556
_TEST_GLOBAL_VARIABLE = "default_value"
@@ -2121,6 +2122,12 @@ def test_pickle_dynamic_typevar(self):
21212122
for attr in attr_list:
21222123
assert getattr(T, attr) == getattr(depickled_T, attr)
21232124

2125+
def test_pickle_dynamic_typevar_memoization(self):
2126+
T = typing.TypeVar('T')
2127+
depickled_T1, depickled_T2 = pickle_depickle((T, T),
2128+
protocol=self.protocol)
2129+
assert depickled_T1 is depickled_T2
2130+
21242131
def test_pickle_importable_typevar(self):
21252132
from .mypkg import T
21262133
T1 = pickle_depickle(T, protocol=self.protocol)
@@ -2130,6 +2137,61 @@ def test_pickle_importable_typevar(self):
21302137
from typing import AnyStr
21312138
assert AnyStr is pickle_depickle(AnyStr, protocol=self.protocol)
21322139

2140+
@unittest.skipIf(sys.version_info < (3, 7),
2141+
"Pickling generics not supported below py37")
2142+
def test_generic_type(self):
2143+
T = typing.TypeVar('T')
2144+
2145+
class C(typing.Generic[T]):
2146+
pass
2147+
2148+
assert pickle_depickle(C, protocol=self.protocol) is C
2149+
assert pickle_depickle(C[int], protocol=self.protocol) is C[int]
2150+
2151+
with subprocess_worker(protocol=self.protocol) as worker:
2152+
2153+
def check_generic(generic, origin, type_value):
2154+
assert generic.__origin__ is origin
2155+
assert len(generic.__args__) == 1
2156+
assert generic.__args__[0] is type_value
2157+
2158+
assert len(origin.__orig_bases__) == 1
2159+
ob = origin.__orig_bases__[0]
2160+
assert ob.__origin__ is typing.Generic
2161+
assert len(ob.__parameters__) == 1
2162+
2163+
return "ok"
2164+
2165+
assert check_generic(C[int], C, int) == "ok"
2166+
assert worker.run(check_generic, C[int], C, int) == "ok"
2167+
2168+
@unittest.skipIf(sys.version_info < (3, 7),
2169+
"Pickling type hints not supported below py37")
2170+
def test_locally_defined_class_with_type_hints(self):
2171+
with subprocess_worker(protocol=self.protocol) as worker:
2172+
for type_ in _all_types_to_test():
2173+
# The type annotation syntax causes a SyntaxError on Python 3.5
2174+
code = textwrap.dedent("""\
2175+
class MyClass:
2176+
attribute: type_
2177+
2178+
def method(self, arg: type_) -> type_:
2179+
return arg
2180+
""")
2181+
ns = {"type_": type_}
2182+
exec(code, ns)
2183+
MyClass = ns["MyClass"]
2184+
2185+
def check_annotations(obj, expected_type):
2186+
assert obj.__annotations__["attribute"] is expected_type
2187+
assert obj.method.__annotations__["arg"] is expected_type
2188+
assert obj.method.__annotations__["return"] is expected_type
2189+
return "ok"
2190+
2191+
obj = MyClass()
2192+
assert check_annotations(obj, type_) == "ok"
2193+
assert worker.run(check_annotations, obj, type_) == "ok"
2194+
21332195

21342196
class Protocol2CloudPickleTest(CloudPickleTest):
21352197

@@ -2161,5 +2223,28 @@ def test_lookup_module_and_qualname_stdlib_typevar():
21612223
assert name == 'AnyStr'
21622224

21632225

2226+
def _all_types_to_test():
2227+
T = typing.TypeVar('T')
2228+
2229+
class C(typing.Generic[T]):
2230+
pass
2231+
2232+
return [
2233+
C, C[int],
2234+
T, typing.Any, typing.NoReturn, typing.Optional,
2235+
typing.Generic, typing.Union, typing.ClassVar,
2236+
typing.Optional[int],
2237+
typing.Generic[T],
2238+
typing.Callable[[int], typing.Any],
2239+
typing.Callable[..., typing.Any],
2240+
typing.Callable[[], typing.Any],
2241+
typing.Tuple[int, ...],
2242+
typing.Tuple[int, C[int]],
2243+
typing.ClassVar[C[int]],
2244+
typing.List[int],
2245+
typing.Dict[int, str],
2246+
]
2247+
2248+
21642249
if __name__ == '__main__':
21652250
unittest.main()

0 commit comments

Comments
 (0)