Skip to content

Commit ecf56ea

Browse files
committed
Adds support for basic union math with overloads
This commit adds support for very basic and simple union math when calling overloaded functions, resolving python#4576. One thing led to another, and this ended up accidentally fixing or touching on several different overload-related issues. In particular, I believe this pull request: 1. Fixes the bug (?) where calling overloaded functions can sometimes silently infer a return type of 'Any' 2. Changes the semantics of how mypy handles overlapping functions, which I believe is currently under discussion in python/typing#253 Although this change is functional and mergable, I was planning on polishing it more -- adding more tests, fleshing out the union math behavior, etc. However, I think these are sort of big changes and wanted to check in and make sure this pull request is actually welcome/is a good idea. If not, let me know, and I'd be happy to abandon it. --- Details on specific changes made: 1. The new algorithm works by modifying checkexpr.overload_call_targets to return all possible matches, rather then just one. We start by trying the first matching signature. If there was some error, we (conservatively) attempt to union all of the matching signatures together and repeat the typechecking process. If it doesn't seem like it's possible to combine the matching signatures in a sound way, we end and just output the errors we obtained from typechecking the first match. The "signature-unioning" code is currently deliberately very conservative. I figured it was better to start small and attempt to handle only basic cases like python#1943 and relax the restrictions later as needed. For more details on this algorithm, see the comments in checkexpr.union_overload_matches. 2. This change incidentally resolves any bugs related to how calling an overloaded function can sometimes silently infer a return type of Any. Previously, if a function call caused an overload to be less precise then a previous one, we gave up and returned a silent Any. This change removes this case altogether and only infers Any if either (a) the caller arguments explicitly contains Any or (b) if there was some error. For example, see python#3295 and python#1322 -- I believe this pull request touches on and maybe resolves (??) those two issues. 3. As a result, I needed to fix a few parts of mypy that were relying on this "silently infer Any" behavior -- see the changes in checker.py and semanal.py. Both files were using expressions of the form `zip(*iterable)`, which ended up having a type of `Any` under the old algorithm. The new algorithm will instead infer `Iterable[Tuple[Any, ...]]` which actually matches the stubs in typeshed. 4. These changes cause the attr stubs in `test-data/unit/lib-stub` to no longer work. It seems that the stubs both here and in typeshed were both also falling prey to the 'silently infer Any' bug: code like `a = attr.ib()` typechecked not because they matched the signature of any of the overloads, but because that particular call caused one or more overloads to overlap, which made mypy give up and infer Any. I couldn't find a clean way of fixing the stubs to infer the correct thing under this new behavior, so just gave up and removed the overloads altogether. I think this is fine though -- it seems like the attrs plugin infers the correct type for us anyways, regardless of what the stubs say. If this pull request is accepted, I plan on submitting a similar pull request to the stubs in typeshed. 4. This pull request also probably touches on python/typing#253. We still require the overloads to be written from the most narrow to general and disallow overlapping signatures. However, if a *call* now causes overlaps, we try the "union" algorithm described above and default to selecting the first matching overload instead of giving up.
1 parent 21cd8e2 commit ecf56ea

9 files changed

+278
-85
lines changed

mypy/checker.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1798,8 +1798,8 @@ def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: E
17981798
expr = expr.expr
17991799
types, declared_types = zip(*items)
18001800
self.binder.assign_type(expr,
1801-
UnionType.make_simplified_union(types),
1802-
UnionType.make_simplified_union(declared_types),
1801+
UnionType.make_simplified_union(list(types)),
1802+
UnionType.make_simplified_union(list(declared_types)),
18031803
False)
18041804
for union, lv in zip(union_types, self.flatten_lvalues(lvalues)):
18051805
# Properly store the inferred types.

mypy/checkexpr.py

+151-58
Original file line numberDiff line numberDiff line change
@@ -611,10 +611,63 @@ def check_call(self, callee: Type, args: List[Expression],
611611
arg_types = self.infer_arg_types_in_context(None, args)
612612
self.msg.enable_errors()
613613

614-
target = self.overload_call_target(arg_types, arg_kinds, arg_names,
615-
callee, context,
616-
messages=arg_messages)
617-
return self.check_call(target, args, arg_kinds, context, arg_names,
614+
overload_messages = arg_messages.copy()
615+
targets = self.overload_call_targets(arg_types, arg_kinds, arg_names,
616+
callee, context,
617+
messages=overload_messages)
618+
619+
# If there are multiple targets, that means that there were
620+
# either multiple possible matches or the types were overlapping in some
621+
# way. In either case, we default to picking the first match and
622+
# see what happens if we try using it.
623+
#
624+
# Note: if we pass in an argument that inherits from two overloaded
625+
# types, we default to picking the first match. For example:
626+
#
627+
# class A: pass
628+
# class B: pass
629+
# class C(A, B): pass
630+
#
631+
# @overload
632+
# def f(x: A) -> int: ...
633+
# @overload
634+
# def f(x: B) -> str: ...
635+
# def f(x): ...
636+
#
637+
# reveal_type(f(C())) # Will be 'int', not 'Union[int, str]'
638+
#
639+
# It's unclear if this is really the best thing to do, but multiple
640+
# inheritance is rare. See the docstring of mypy.meet.is_overlapping_types
641+
# for more about this.
642+
643+
original_output = self.check_call(targets[0], args, arg_kinds, context, arg_names,
644+
arg_messages=overload_messages,
645+
callable_name=callable_name,
646+
object_type=object_type)
647+
648+
if not overload_messages.is_errors() or len(targets) == 1:
649+
# If there were no errors or if there was only one match, we can end now.
650+
#
651+
# Note that if we have only one target, there's nothing else we
652+
# can try doing. In that case, we just give up and return early
653+
# and skip the below steps.
654+
arg_messages.add_errors(overload_messages)
655+
return original_output
656+
657+
# Otherwise, we attempt to synthesize together a new callable by combining
658+
# together the different matches by union-ing together their arguments
659+
# and return type.
660+
661+
targets = cast(List[CallableType], targets)
662+
unioned_callable = self.union_overload_matches(targets)
663+
if unioned_callable is None:
664+
# If it was not possible to actually combine together the
665+
# callables in a sound way, we give up and return the original
666+
# error message.
667+
arg_messages.add_errors(overload_messages)
668+
return original_output
669+
670+
return self.check_call(unioned_callable, args, arg_kinds, context, arg_names,
618671
arg_messages=arg_messages,
619672
callable_name=callable_name,
620673
object_type=object_type)
@@ -1089,83 +1142,123 @@ def check_arg(self, caller_type: Type, original_caller_type: Type,
10891142
(callee_type.item.type.is_abstract or callee_type.item.type.is_protocol) and
10901143
# ...except for classmethod first argument
10911144
not caller_type.is_classmethod_class):
1092-
self.msg.concrete_only_call(callee_type, context)
1145+
messages.concrete_only_call(callee_type, context)
10931146
elif not is_subtype(caller_type, callee_type):
10941147
if self.chk.should_suppress_optional_error([caller_type, callee_type]):
10951148
return
10961149
messages.incompatible_argument(n, m, callee, original_caller_type,
10971150
caller_kind, context)
10981151
if (isinstance(original_caller_type, (Instance, TupleType, TypedDictType)) and
10991152
isinstance(callee_type, Instance) and callee_type.type.is_protocol):
1100-
self.msg.report_protocol_problems(original_caller_type, callee_type, context)
1153+
messages.report_protocol_problems(original_caller_type, callee_type, context)
11011154
if (isinstance(callee_type, CallableType) and
11021155
isinstance(original_caller_type, Instance)):
11031156
call = find_member('__call__', original_caller_type, original_caller_type)
11041157
if call:
1105-
self.msg.note_call(original_caller_type, call, context)
1106-
1107-
def overload_call_target(self, arg_types: List[Type], arg_kinds: List[int],
1108-
arg_names: Optional[Sequence[Optional[str]]],
1109-
overload: Overloaded, context: Context,
1110-
messages: Optional[MessageBuilder] = None) -> Type:
1111-
"""Infer the correct overload item to call with given argument types.
1112-
1113-
The return value may be CallableType or AnyType (if an unique item
1114-
could not be determined).
1158+
messages.note_call(original_caller_type, call, context)
1159+
1160+
def overload_call_targets(self, arg_types: List[Type], arg_kinds: List[int],
1161+
arg_names: Optional[Sequence[Optional[str]]],
1162+
overload: Overloaded, context: Context,
1163+
messages: Optional[MessageBuilder] = None) -> Sequence[Type]:
1164+
"""Infer all possible overload targets to call with given argument types.
1165+
The list is guaranteed be one of the following:
1166+
1167+
1. A List[CallableType] of length 1 if we were able to find an
1168+
unambiguous best match.
1169+
2. A List[AnyType] of length 1 if we were unable to find any match
1170+
or discovered the match was ambiguous due to conflicting Any types.
1171+
3. A List[CallableType] of length 2 or more if there were multiple
1172+
plausible matches. The matches are returned in the order they
1173+
were defined.
11151174
"""
11161175
messages = messages or self.msg
1117-
# TODO: For overlapping signatures we should try to get a more precise
1118-
# result than 'Any'.
11191176
match = [] # type: List[CallableType]
11201177
best_match = 0
11211178
for typ in overload.items():
11221179
similarity = self.erased_signature_similarity(arg_types, arg_kinds, arg_names,
11231180
typ, context=context)
11241181
if similarity > 0 and similarity >= best_match:
1125-
if (match and not is_same_type(match[-1].ret_type,
1126-
typ.ret_type) and
1127-
(not mypy.checker.is_more_precise_signature(match[-1], typ)
1128-
or (any(isinstance(arg, AnyType) for arg in arg_types)
1129-
and any_arg_causes_overload_ambiguity(
1130-
match + [typ], arg_types, arg_kinds, arg_names)))):
1131-
# Ambiguous return type. Either the function overload is
1132-
# overlapping (which we don't handle very well here) or the
1133-
# caller has provided some Any argument types; in either
1134-
# case we'll fall back to Any. It's okay to use Any types
1135-
# in calls.
1136-
#
1137-
# Overlapping overload items are generally fine if the
1138-
# overlapping is only possible when there is multiple
1139-
# inheritance, as this is rare. See docstring of
1140-
# mypy.meet.is_overlapping_types for more about this.
1141-
#
1142-
# Note that there is no ambiguity if the items are
1143-
# covariant in both argument types and return types with
1144-
# respect to type precision. We'll pick the best/closest
1145-
# match.
1146-
#
1147-
# TODO: Consider returning a union type instead if the
1148-
# overlapping is NOT due to Any types?
1149-
return AnyType(TypeOfAny.special_form)
1150-
else:
1151-
match.append(typ)
1182+
if (match and not is_same_type(match[-1].ret_type, typ.ret_type)
1183+
and any(isinstance(arg, AnyType) for arg in arg_types)
1184+
and any_arg_causes_overload_ambiguity(
1185+
match + [typ], arg_types, arg_kinds, arg_names)):
1186+
# Ambiguous return type. The caller has provided some
1187+
# Any argument types (which are okay to use in calls),
1188+
# so we fall back to returning 'Any'.
1189+
return [AnyType(TypeOfAny.special_form)]
1190+
match.append(typ)
11521191
best_match = max(best_match, similarity)
1153-
if not match:
1192+
1193+
if len(match) == 0:
11541194
if not self.chk.should_suppress_optional_error(arg_types):
11551195
messages.no_variant_matches_arguments(overload, arg_types, context)
1156-
return AnyType(TypeOfAny.from_error)
1196+
return [AnyType(TypeOfAny.from_error)]
1197+
elif len(match) == 1:
1198+
return match
11571199
else:
1158-
if len(match) == 1:
1159-
return match[0]
1160-
else:
1161-
# More than one signature matches. Pick the first *non-erased*
1162-
# matching signature, or default to the first one if none
1163-
# match.
1164-
for m in match:
1165-
if self.match_signature_types(arg_types, arg_kinds, arg_names, m,
1166-
context=context):
1167-
return m
1168-
return match[0]
1200+
# More than one signature matches or the signatures are
1201+
# overlapping. In either case, we return all of the matching
1202+
# signatures and let the caller decide what to do with them.
1203+
out = [m for m in match if self.match_signature_types(
1204+
arg_types, arg_kinds, arg_names, m, context=context)]
1205+
return out if len(out) >= 1 else match
1206+
1207+
def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]:
1208+
"""Accepts a list of overload signatures and attempts to combine them together into a
1209+
new CallableType consisting of the union of all of the given arguments and return types.
1210+
1211+
Returns None if it is not possible to combine the different callables together in a
1212+
sound manner."""
1213+
1214+
new_args: List[List[Type]] = [[] for _ in range(len(callables[0].arg_types))]
1215+
1216+
expected_names = callables[0].arg_names
1217+
expected_kinds = callables[0].arg_kinds
1218+
1219+
for target in callables:
1220+
if target.arg_names != expected_names or target.arg_kinds != expected_kinds:
1221+
# We conservatively end if the overloads do not have the exact same signature.
1222+
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
1223+
return None
1224+
1225+
for i, arg in enumerate(target.arg_types):
1226+
new_args[i].append(arg)
1227+
1228+
union_count = 0
1229+
final_args = []
1230+
for args in new_args:
1231+
new_type = UnionType.make_simplified_union(args)
1232+
union_count += 1 if isinstance(new_type, UnionType) else 0
1233+
final_args.append(new_type)
1234+
1235+
# TODO: Modify this check to be less conservative.
1236+
#
1237+
# Currently, we permit only one union union in the arguments because if we allow
1238+
# multiple, we can't always guarantee the synthesized callable will be correct.
1239+
#
1240+
# For example, suppose we had the following two overloads:
1241+
#
1242+
# @overload
1243+
# def f(x: A, y: B) -> None: ...
1244+
# @overload
1245+
# def f(x: B, y: A) -> None: ...
1246+
#
1247+
# If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...",
1248+
# then we'd incorrectly accept calls like "f(A(), A())" when they really ought to
1249+
# be rejected.
1250+
#
1251+
# However, that means we'll also give up if the original overloads contained
1252+
# any unions. This is likely unnecessary -- we only really need to give up if
1253+
# there are more then one *synthesized* union arguments.
1254+
if union_count >= 2:
1255+
return None
1256+
1257+
return callables[0].copy_modified(
1258+
arg_types=final_args,
1259+
ret_type=UnionType.make_simplified_union([t.ret_type for t in callables]),
1260+
implicit=True,
1261+
from_overloads=True)
11691262

11701263
def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int],
11711264
arg_names: Optional[Sequence[Optional[str]]],

mypy/messages.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,19 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type:
629629
expected_type = callee.arg_types[m - 1]
630630
except IndexError: # Varargs callees
631631
expected_type = callee.arg_types[-1]
632+
632633
arg_type_str, expected_type_str = self.format_distinctly(
633634
arg_type, expected_type, bare=True)
635+
expected_type_str = self.quote_type_string(expected_type_str)
636+
637+
if callee.from_overloads and isinstance(expected_type, UnionType):
638+
expected_formatted = []
639+
for e in expected_type.items:
640+
type_str = self.format_distinctly(arg_type, e, bare=True)[1]
641+
expected_formatted.append(self.quote_type_string(type_str))
642+
expected_type_str = 'one of {} based on available overloads'.format(
643+
', '.join(expected_formatted))
644+
634645
if arg_kind == ARG_STAR:
635646
arg_type_str = '*' + arg_type_str
636647
elif arg_kind == ARG_STAR2:
@@ -645,8 +656,7 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type:
645656
arg_label = '"{}"'.format(arg_name)
646657

647658
msg = 'Argument {} {}has incompatible type {}; expected {}'.format(
648-
arg_label, target, self.quote_type_string(arg_type_str),
649-
self.quote_type_string(expected_type_str))
659+
arg_label, target, self.quote_type_string(arg_type_str), expected_type_str)
650660
if isinstance(arg_type, Instance) and isinstance(expected_type, Instance):
651661
notes = append_invariance_notes(notes, arg_type, expected_type)
652662
self.fail(msg, context)

mypy/semanal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2870,7 +2870,7 @@ def process_module_assignment(self, lvals: List[Lvalue], rval: Expression,
28702870
# about the length mismatch in type-checking.
28712871
elementwise_assignments = zip(rval.items, *[v.items for v in seq_lvals])
28722872
for rv, *lvs in elementwise_assignments:
2873-
self.process_module_assignment(lvs, rv, ctx)
2873+
self.process_module_assignment(list(lvs), rv, ctx)
28742874
elif isinstance(rval, RefExpr):
28752875
rnode = self.lookup_type_node(rval)
28762876
if rnode and rnode.kind == MODULE_REF:

mypy/types.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,8 @@ class CallableType(FunctionLike):
660660
special_sig = None # type: Optional[str]
661661
# Was this callable generated by analyzing Type[...] instantiation?
662662
from_type_type = False # type: bool
663+
# Was this callable generated by synthesizing multiple overloads?
664+
from_overloads = False # type: bool
663665

664666
bound_args = None # type: List[Optional[Type]]
665667

@@ -679,6 +681,7 @@ def __init__(self,
679681
is_classmethod_class: bool = False,
680682
special_sig: Optional[str] = None,
681683
from_type_type: bool = False,
684+
from_overloads: bool = False,
682685
bound_args: Optional[List[Optional[Type]]] = None,
683686
) -> None:
684687
assert len(arg_types) == len(arg_kinds) == len(arg_names)
@@ -703,6 +706,7 @@ def __init__(self,
703706
self.is_classmethod_class = is_classmethod_class
704707
self.special_sig = special_sig
705708
self.from_type_type = from_type_type
709+
self.from_overloads = from_overloads
706710
self.bound_args = bound_args or []
707711
super().__init__(line, column)
708712

@@ -718,8 +722,10 @@ def copy_modified(self,
718722
line: int = _dummy,
719723
column: int = _dummy,
720724
is_ellipsis_args: bool = _dummy,
725+
implicit: bool = _dummy,
721726
special_sig: Optional[str] = _dummy,
722727
from_type_type: bool = _dummy,
728+
from_overloads: bool = _dummy,
723729
bound_args: List[Optional[Type]] = _dummy) -> 'CallableType':
724730
return CallableType(
725731
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
@@ -734,10 +740,11 @@ def copy_modified(self,
734740
column=column if column is not _dummy else self.column,
735741
is_ellipsis_args=(
736742
is_ellipsis_args if is_ellipsis_args is not _dummy else self.is_ellipsis_args),
737-
implicit=self.implicit,
743+
implicit=implicit if implicit is not _dummy else self.implicit,
738744
is_classmethod_class=self.is_classmethod_class,
739745
special_sig=special_sig if special_sig is not _dummy else self.special_sig,
740746
from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type,
747+
from_overloads=from_overloads if from_overloads is not _dummy else self.from_overloads,
741748
bound_args=bound_args if bound_args is not _dummy else self.bound_args,
742749
)
743750

@@ -889,6 +896,7 @@ def serialize(self) -> JsonDict:
889896
'is_ellipsis_args': self.is_ellipsis_args,
890897
'implicit': self.implicit,
891898
'is_classmethod_class': self.is_classmethod_class,
899+
'from_overloads': self.from_overloads,
892900
'bound_args': [(None if t is None else t.serialize())
893901
for t in self.bound_args],
894902
}
@@ -907,6 +915,7 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
907915
is_ellipsis_args=data['is_ellipsis_args'],
908916
implicit=data['implicit'],
909917
is_classmethod_class=data['is_classmethod_class'],
918+
from_overloads=data['from_overloads'],
910919
bound_args=[(None if t is None else deserialize_type(t))
911920
for t in data['bound_args']],
912921
)

0 commit comments

Comments
 (0)