Skip to content

Commit e479b6d

Browse files
authored
Support selecting TypedDicts from unions (#7184)
It is a relatively common pattern to narrow down typed dicts from unions with non-dict types using `isinstance(x, dict)`. Currently mypy infers `Dict[Any, Any]` after such checks which is suboptimal. I propose to special-case this in `narrow_declared_type()` and `restrict_subtype_away()`. Using this opportunity I factored out special cases from the latter in a separate helper function. Using this opportunity I also fix an old type erasure bug in `isinstance()` checks (type should be erased after mapping to supertype, not before).
1 parent fc4baa6 commit e479b6d

File tree

3 files changed

+98
-22
lines changed

3 files changed

+98
-22
lines changed

mypy/meet.py

+10
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
5454
return TypeType.make_normalized(narrow_declared_type(declared.item, narrowed.item))
5555
elif isinstance(declared, (Instance, TupleType, TypeType, LiteralType)):
5656
return meet_types(declared, narrowed)
57+
elif isinstance(declared, TypedDictType) and isinstance(narrowed, Instance):
58+
# Special case useful for selecting TypedDicts from unions using isinstance(x, dict).
59+
if (narrowed.type.fullname() == 'builtins.dict' and
60+
all(isinstance(t, AnyType) for t in narrowed.args)):
61+
return declared
62+
return meet_types(declared, narrowed)
5763
return narrowed
5864

5965

@@ -478,6 +484,8 @@ def visit_instance(self, t: Instance) -> Type:
478484
return meet_types(t, self.s)
479485
elif isinstance(self.s, LiteralType):
480486
return meet_types(t, self.s)
487+
elif isinstance(self.s, TypedDictType):
488+
return meet_types(t, self.s)
481489
return self.default(self.s)
482490

483491
def visit_callable_type(self, t: CallableType) -> Type:
@@ -555,6 +563,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
555563
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
556564
required_keys = t.required_keys | self.s.required_keys
557565
return TypedDictType(items, required_keys, fallback)
566+
elif isinstance(self.s, Instance) and is_subtype(t, self.s):
567+
return t
558568
else:
559569
return self.default(self.s)
560570

mypy/subtypes.py

+47-20
Original file line numberDiff line numberDiff line change
@@ -1007,58 +1007,81 @@ def unify_generic_callable(type: CallableType, target: CallableType,
10071007

10081008

10091009
def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type:
1010-
"""Return t minus s.
1010+
"""Return t minus s for runtime type assertions.
10111011
10121012
If we can't determine a precise result, return a supertype of the
10131013
ideal result (just t is a valid result).
10141014
10151015
This is used for type inference of runtime type checks such as
1016-
isinstance.
1017-
1018-
Currently this just removes elements of a union type.
1016+
isinstance(). Currently this just removes elements of a union type.
10191017
"""
10201018
if isinstance(t, UnionType):
1021-
# Since runtime type checks will ignore type arguments, erase the types.
1022-
erased_s = erase_type(s)
1023-
# TODO: Implement more robust support for runtime isinstance() checks,
1024-
# see issue #3827
10251019
new_items = [item for item in t.relevant_items()
1026-
if (not (is_proper_subtype(erase_type(item), erased_s,
1027-
ignore_promotions=ignore_promotions) or
1028-
is_proper_subtype(item, erased_s,
1029-
ignore_promotions=ignore_promotions))
1030-
or isinstance(item, AnyType))]
1020+
if (isinstance(item, AnyType) or
1021+
not covers_at_runtime(item, s, ignore_promotions))]
10311022
return UnionType.make_union(new_items)
10321023
else:
10331024
return t
10341025

10351026

1036-
def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False) -> bool:
1027+
def covers_at_runtime(item: Type, supertype: Type, ignore_promotions: bool) -> bool:
1028+
"""Will isinstance(item, supertype) always return True at runtime?"""
1029+
# Since runtime type checks will ignore type arguments, erase the types.
1030+
supertype = erase_type(supertype)
1031+
if is_proper_subtype(erase_type(item), supertype, ignore_promotions=ignore_promotions,
1032+
erase_instances=True):
1033+
return True
1034+
if isinstance(supertype, Instance) and supertype.type.is_protocol:
1035+
# TODO: Implement more robust support for runtime isinstance() checks, see issue #3827.
1036+
if is_proper_subtype(item, supertype, ignore_promotions=ignore_promotions):
1037+
return True
1038+
if isinstance(item, TypedDictType) and isinstance(supertype, Instance):
1039+
# Special case useful for selecting TypedDicts from unions using isinstance(x, dict).
1040+
if supertype.type.fullname() == 'builtins.dict':
1041+
return True
1042+
# TODO: Add more special cases.
1043+
return False
1044+
1045+
1046+
def is_proper_subtype(left: Type, right: Type, *, ignore_promotions: bool = False,
1047+
erase_instances: bool = False) -> bool:
10371048
"""Is left a proper subtype of right?
10381049
10391050
For proper subtypes, there's no need to rely on compatibility due to
10401051
Any types. Every usable type is a proper subtype of itself.
1052+
1053+
If erase_instances is True, erase left instance *after* mapping it to supertype
1054+
(this is useful for runtime isinstance() checks).
10411055
"""
10421056
if isinstance(right, UnionType) and not isinstance(left, UnionType):
1043-
return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions)
1057+
return any([is_proper_subtype(left, item, ignore_promotions=ignore_promotions,
1058+
erase_instances=erase_instances)
10441059
for item in right.items])
1045-
return left.accept(ProperSubtypeVisitor(right, ignore_promotions=ignore_promotions))
1060+
return left.accept(ProperSubtypeVisitor(right, ignore_promotions=ignore_promotions,
1061+
erase_instances=erase_instances))
10461062

10471063

10481064
class ProperSubtypeVisitor(TypeVisitor[bool]):
1049-
def __init__(self, right: Type, *, ignore_promotions: bool = False) -> None:
1065+
def __init__(self, right: Type, *,
1066+
ignore_promotions: bool = False,
1067+
erase_instances: bool = False) -> None:
10501068
self.right = right
10511069
self.ignore_promotions = ignore_promotions
1070+
self.erase_instances = erase_instances
10521071
self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind(
10531072
ignore_promotions=ignore_promotions,
1073+
erase_instances=erase_instances,
10541074
)
10551075

10561076
@staticmethod
1057-
def build_subtype_kind(*, ignore_promotions: bool = False) -> SubtypeKind:
1058-
return (True, ignore_promotions)
1077+
def build_subtype_kind(*, ignore_promotions: bool = False,
1078+
erase_instances: bool = False) -> SubtypeKind:
1079+
return True, ignore_promotions, erase_instances
10591080

10601081
def _is_proper_subtype(self, left: Type, right: Type) -> bool:
1061-
return is_proper_subtype(left, right, ignore_promotions=self.ignore_promotions)
1082+
return is_proper_subtype(left, right,
1083+
ignore_promotions=self.ignore_promotions,
1084+
erase_instances=self.erase_instances)
10621085

10631086
def visit_unbound_type(self, left: UnboundType) -> bool:
10641087
# This can be called if there is a bad type annotation. The result probably
@@ -1107,6 +1130,10 @@ def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool:
11071130
return mypy.sametypes.is_same_type(leftarg, rightarg)
11081131
# Map left type to corresponding right instances.
11091132
left = map_instance_to_supertype(left, right.type)
1133+
if self.erase_instances:
1134+
erased = erase_type(left)
1135+
assert isinstance(erased, Instance)
1136+
left = erased
11101137

11111138
nominal = all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in
11121139
zip(left.args, right.args, right.type.defn.type_vars))

test-data/unit/check-typeddict.test

+41-2
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,6 @@ def g(x: X, y: M) -> None: pass
580580
reveal_type(f(g)) # N: Revealed type is '<nothing>'
581581
[builtins fixtures/dict.pyi]
582582

583-
# TODO: It would be more accurate for the meet to be TypedDict instead.
584583
[case testMeetOfTypedDictWithCompatibleMappingSuperclassIsUninhabitedForNow]
585584
# flags: --strict-optional
586585
from mypy_extensions import TypedDict
@@ -590,7 +589,7 @@ I = Iterable[str]
590589
T = TypeVar('T')
591590
def f(x: Callable[[T, T], None]) -> T: pass
592591
def g(x: X, y: I) -> None: pass
593-
reveal_type(f(g)) # N: Revealed type is '<nothing>'
592+
reveal_type(f(g)) # N: Revealed type is 'TypedDict('__main__.X', {'x': builtins.int})'
594593
[builtins fixtures/dict.pyi]
595594

596595
[case testMeetOfTypedDictsWithNonTotal]
@@ -1838,3 +1837,43 @@ def func(x):
18381837
pass
18391838
[builtins fixtures/dict.pyi]
18401839
[typing fixtures/typing-full.pyi]
1840+
1841+
[case testTypedDictIsInstance]
1842+
from typing import TypedDict, Union
1843+
1844+
class User(TypedDict):
1845+
id: int
1846+
name: str
1847+
1848+
u: Union[str, User]
1849+
u2: User
1850+
1851+
if isinstance(u, dict):
1852+
reveal_type(u) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})'
1853+
else:
1854+
reveal_type(u) # N: Revealed type is 'builtins.str'
1855+
1856+
assert isinstance(u2, dict)
1857+
reveal_type(u2) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})'
1858+
[builtins fixtures/dict.pyi]
1859+
[typing fixtures/typing-full.pyi]
1860+
1861+
[case testTypedDictIsInstanceABCs]
1862+
from typing import TypedDict, Union, Mapping, Iterable
1863+
1864+
class User(TypedDict):
1865+
id: int
1866+
name: str
1867+
1868+
u: Union[int, User]
1869+
u2: User
1870+
1871+
if isinstance(u, Iterable):
1872+
reveal_type(u) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})'
1873+
else:
1874+
reveal_type(u) # N: Revealed type is 'builtins.int'
1875+
1876+
assert isinstance(u2, Mapping)
1877+
reveal_type(u2) # N: Revealed type is 'TypedDict('__main__.User', {'id': builtins.int, 'name': builtins.str})'
1878+
[builtins fixtures/dict.pyi]
1879+
[typing fixtures/typing-full.pyi]

0 commit comments

Comments
 (0)