Skip to content

Commit 4fe2220

Browse files
authored
Add support for detecting overloads with overlapping arities (#5163)
This commit addresses TODO 2 from #5119 by adding support for detecting overloads with partially overlapping arities. It also refactors the `is_callable_compatible` method. Specifically, this pull request... 1. Pulls out a lot of the logic for iterating over formal arguments into a helper method in CallableType. 2. Pulls out logic for handling varargs and kwargs outside of loops. 3. Rearranges some of the logic so we can return earlier slightly more frequently.
1 parent 595545f commit 4fe2220

File tree

4 files changed

+485
-144
lines changed

4 files changed

+485
-144
lines changed

mypy/checker.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -3621,7 +3621,7 @@ def is_unsafe_overlapping_overload_signatures(signature: CallableType,
36213621
Assumes that 'signature' appears earlier in the list of overload
36223622
alternatives then 'other' and that their argument counts are overlapping.
36233623
"""
3624-
# TODO: Handle partially overlapping parameter types and argument counts
3624+
# TODO: Handle partially overlapping parameter types
36253625
#
36263626
# For example, the signatures "f(x: Union[A, B]) -> int" and "f(x: Union[B, C]) -> str"
36273627
# is unsafe: the parameter types are partially overlapping.
@@ -3632,27 +3632,15 @@ def is_unsafe_overlapping_overload_signatures(signature: CallableType,
36323632
#
36333633
# (We already have a rudimentary implementation of 'is_partially_overlapping', but it only
36343634
# attempts to handle the obvious cases -- see its docstring for more info.)
3635-
#
3636-
# Similarly, the signatures "f(x: A, y: A) -> str" and "f(*x: A) -> int" are also unsafe:
3637-
# the parameter *counts* or arity are partially overlapping.
3638-
#
3639-
# To fix this, we need to modify is_callable_compatible so it can optionally detect
3640-
# functions that are *potentially* compatible rather then *definitely* compatible.
36413635

36423636
def is_more_precise_or_partially_overlapping(t: Type, s: Type) -> bool:
36433637
return is_more_precise(t, s) or is_partially_overlapping_types(t, s)
36443638

3645-
# The reason we repeat this check twice is so we can do a slightly better job of
3646-
# checking for potentially overlapping param counts. Both calls will actually check
3647-
# the param and return types in the same "direction" -- the only thing that differs
3648-
# is how is_callable_compatible checks non-positional arguments.
3649-
return (is_callable_compatible(signature, other,
3650-
is_compat=is_more_precise_or_partially_overlapping,
3651-
is_compat_return=lambda l, r: not is_subtype(l, r),
3652-
check_args_covariantly=True) or
3653-
is_callable_compatible(other, signature,
3654-
is_compat=is_more_precise_or_partially_overlapping,
3655-
is_compat_return=lambda l, r: not is_subtype(r, l)))
3639+
return is_callable_compatible(signature, other,
3640+
is_compat=is_more_precise_or_partially_overlapping,
3641+
is_compat_return=lambda l, r: not is_subtype(l, r),
3642+
check_args_covariantly=True,
3643+
allow_partial_overlap=True)
36563644

36573645

36583646
def overload_can_never_match(signature: CallableType, other: CallableType) -> bool:

mypy/subtypes.py

+176-94
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,8 @@ def is_callable_compatible(left: CallableType, right: CallableType,
577577
is_compat_return: Optional[Callable[[Type, Type], bool]] = None,
578578
ignore_return: bool = False,
579579
ignore_pos_arg_names: bool = False,
580-
check_args_covariantly: bool = False) -> bool:
580+
check_args_covariantly: bool = False,
581+
allow_partial_overlap: bool = False) -> bool:
581582
"""Is the left compatible with the right, using the provided compatibility check?
582583
583584
is_compat:
@@ -616,6 +617,55 @@ def g(x: int) -> int: ...
616617
617618
In this case, the first call will succeed and the second will fail: f is a
618619
valid stand-in for g but not vice-versa.
620+
621+
allow_partial_overlap:
622+
By default this function returns True if and only if *all* calls to left are
623+
also calls to right (with respect to the provided 'is_compat' function).
624+
625+
If this parameter is set to 'True', we return True if *there exists at least one*
626+
call to left that's also a call to right.
627+
628+
In other words, we perform an existential check instead of a universal one;
629+
we require left to only overlap with right instead of being a subset.
630+
631+
For example, suppose we set 'is_compat' to some subtype check and compare following:
632+
633+
f(x: float, y: str = "...", *args: bool) -> str
634+
g(*args: int) -> str
635+
636+
This function would normally return 'False': f is not a subtype of g.
637+
However, we would return True if this parameter is set to 'True': the two
638+
calls are compatible if the user runs "f_or_g(3)". In the context of that
639+
specific call, the two functions effectively have signatures of:
640+
641+
f2(float) -> str
642+
g2(int) -> str
643+
644+
Here, f2 is a valid subtype of g2 so we return True.
645+
646+
Specifically, if this parameter is set this function will:
647+
648+
- Ignore optional arguments on either the left or right that have no
649+
corresponding match.
650+
- No longer mandate optional arguments on either side are also optional
651+
on the other.
652+
- No longer mandate that if right has a *arg or **kwarg that left must also
653+
have the same.
654+
655+
Note: when this argument is set to True, this function becomes "symmetric" --
656+
the following calls are equivalent:
657+
658+
is_callable_compatible(f, g,
659+
is_compat=some_check,
660+
check_args_covariantly=False,
661+
allow_partial_overlap=True)
662+
is_callable_compatible(g, f,
663+
is_compat=some_check,
664+
check_args_covariantly=True,
665+
allow_partial_overlap=True)
666+
667+
If the 'some_check' function is also symmetric, the two calls would be equivalent
668+
whether or not we check the args covariantly.
619669
"""
620670
if is_compat_return is None:
621671
is_compat_return = is_compat
@@ -638,7 +688,6 @@ def g(x: int) -> int: ...
638688
# type variables of L, because generating and solving
639689
# constraints for the variables of L to make L a subtype of R
640690
# (below) treats type variables on the two sides as independent.
641-
642691
if left.variables:
643692
# Apply generic type variables away in left via type inference.
644693
unified = unify_generic_callable(left, right, ignore_return=ignore_return)
@@ -647,6 +696,17 @@ def g(x: int) -> int: ...
647696
else:
648697
left = unified
649698

699+
# If we allow partial overlaps, we don't need to leave R generic:
700+
# if we can find even just a single typevar assignment which
701+
# would make these callables compatible, we should return True.
702+
703+
# So, we repeat the above checks in the opposite direction. This also
704+
# lets us preserve the 'symmetry' property of allow_partial_overlap.
705+
if allow_partial_overlap and right.variables:
706+
unified = unify_generic_callable(right, left, ignore_return=ignore_return)
707+
if unified is not None:
708+
right = unified
709+
650710
# Check return types.
651711
if not ignore_return and not is_compat_return(left.ret_type, right.ret_type):
652712
return False
@@ -657,16 +717,17 @@ def g(x: int) -> int: ...
657717
if right.is_ellipsis_args:
658718
return True
659719

660-
right_star_type = None # type: Optional[Type]
661-
right_star2_type = None # type: Optional[Type]
720+
left_star = left.var_arg
721+
left_star2 = left.kw_arg
722+
right_star = right.var_arg
723+
right_star2 = right.kw_arg
662724

663725
# Match up corresponding arguments and check them for compatibility. In
664726
# every pair (argL, argR) of corresponding arguments from L and R, argL must
665727
# be "more general" than argR if L is to be a subtype of R.
666728

667729
# Arguments are corresponding if they either share a name, share a position,
668-
# or both. If L's corresponding argument is ambiguous, L is not a subtype of
669-
# R.
730+
# or both. If L's corresponding argument is ambiguous, L is not a subtype of R.
670731

671732
# If left has one corresponding argument by name and another by position,
672733
# consider them to be one "merged" argument (and not ambiguous) if they're
@@ -677,94 +738,92 @@ def g(x: int) -> int: ...
677738

678739
# Every argument in R must have a corresponding argument in L, and every
679740
# required argument in L must have a corresponding argument in R.
680-
done_with_positional = False
681-
for i in range(len(right.arg_types)):
682-
right_kind = right.arg_kinds[i]
683-
if right_kind in (ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT):
684-
done_with_positional = True
685-
right_required = right_kind in (ARG_POS, ARG_NAMED)
686-
right_pos = None if done_with_positional else i
687-
688-
right_arg = FormalArgument(
689-
right.arg_names[i],
690-
right_pos,
691-
right.arg_types[i],
692-
right_required)
693-
694-
if right_kind == ARG_STAR:
695-
right_star_type = right_arg.typ
696-
# Right has an infinite series of optional positional arguments
697-
# here. Get all further positional arguments of left, and make sure
698-
# they're more general than their corresponding member in this
699-
# series. Also make sure left has its own infinite series of
700-
# optional positional arguments.
701-
if not left.is_var_arg:
702-
return False
703-
j = i
704-
while j < len(left.arg_kinds) and left.arg_kinds[j] in (ARG_POS, ARG_OPT):
705-
left_by_position = left.argument_by_position(j)
706-
assert left_by_position is not None
707-
# This fetches the synthetic argument that's from the *args
708-
right_by_position = right.argument_by_position(j)
709-
assert right_by_position is not None
710-
if not are_args_compatible(left_by_position, right_by_position,
711-
ignore_pos_arg_names, is_compat):
712-
return False
713-
j += 1
714-
continue
715-
716-
if right_kind == ARG_STAR2:
717-
right_star2_type = right_arg.typ
718-
# Right has an infinite set of optional named arguments here. Get
719-
# all further named arguments of left and make sure they're more
720-
# general than their corresponding member in this set. Also make
721-
# sure left has its own infinite set of optional named arguments.
722-
if not left.is_kw_arg:
723-
return False
724-
left_names = {name for name in left.arg_names if name is not None}
725-
right_names = {name for name in right.arg_names if name is not None}
726-
left_only_names = left_names - right_names
727-
for name in left_only_names:
728-
left_by_name = left.argument_by_name(name)
729-
assert left_by_name is not None
730-
# This fetches the synthetic argument that's from the **kwargs
731-
right_by_name = right.argument_by_name(name)
732-
assert right_by_name is not None
733-
if not are_args_compatible(left_by_name, right_by_name,
734-
ignore_pos_arg_names, is_compat):
735-
return False
736-
continue
737741

738-
# Left must have some kind of corresponding argument.
742+
# Phase 1: Confirm every argument in R has a corresponding argument in L.
743+
744+
# Phase 1a: If left and right can both accept an infinite number of args,
745+
# their types must be compatible.
746+
#
747+
# Furthermore, if we're checking for compatibility in all cases,
748+
# we confirm that if R accepts an infinite number of arguments,
749+
# L must accept the same.
750+
def _incompatible(left_arg: Optional[FormalArgument],
751+
right_arg: Optional[FormalArgument]) -> bool:
752+
if right_arg is None:
753+
return False
754+
if left_arg is None:
755+
return not allow_partial_overlap
756+
return not is_compat(right_arg.typ, left_arg.typ)
757+
758+
if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
759+
return False
760+
761+
# Phase 1b: Check non-star args: for every arg right can accept, left must
762+
# also accept. The only exception is if we are allowing partial
763+
# partial overlaps: in that case, we ignore optional args on the right.
764+
for right_arg in right.formal_arguments():
739765
left_arg = left.corresponding_argument(right_arg)
740766
if left_arg is None:
767+
if allow_partial_overlap and not right_arg.required:
768+
continue
741769
return False
742-
743-
if not are_args_compatible(left_arg, right_arg,
744-
ignore_pos_arg_names, is_compat):
770+
if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names,
771+
allow_partial_overlap, is_compat):
745772
return False
746773

747-
done_with_positional = False
748-
for i in range(len(left.arg_types)):
749-
left_kind = left.arg_kinds[i]
750-
if left_kind in (ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT):
751-
done_with_positional = True
752-
left_arg = FormalArgument(
753-
left.arg_names[i],
754-
None if done_with_positional else i,
755-
left.arg_types[i],
756-
left_kind in (ARG_POS, ARG_NAMED))
757-
758-
# Check that *args and **kwargs types match in this loop
759-
if left_kind == ARG_STAR:
760-
if right_star_type is not None and not is_compat(right_star_type, left_arg.typ):
774+
# Phase 1c: Check var args. Right has an infinite series of optional positional
775+
# arguments. Get all further positional args of left, and make sure
776+
# they're more general then the corresponding member in right.
777+
if right_star is not None:
778+
# Synthesize an anonymous formal argument for the right
779+
right_by_position = right.try_synthesizing_arg_from_vararg(None)
780+
assert right_by_position is not None
781+
782+
i = right_star.pos
783+
assert i is not None
784+
while i < len(left.arg_kinds) and left.arg_kinds[i] in (ARG_POS, ARG_OPT):
785+
if allow_partial_overlap and left.arg_kinds[i] == ARG_OPT:
786+
break
787+
788+
left_by_position = left.argument_by_position(i)
789+
assert left_by_position is not None
790+
791+
if not are_args_compatible(left_by_position, right_by_position,
792+
ignore_pos_arg_names, allow_partial_overlap,
793+
is_compat):
761794
return False
762-
continue
763-
elif left_kind == ARG_STAR2:
764-
if right_star2_type is not None and not is_compat(right_star2_type, left_arg.typ):
795+
i += 1
796+
797+
# Phase 1d: Check kw args. Right has an infinite series of optional named
798+
# arguments. Get all further named args of left, and make sure
799+
# they're more general then the corresponding member in right.
800+
if right_star2 is not None:
801+
right_names = {name for name in right.arg_names if name is not None}
802+
left_only_names = set()
803+
for name, kind in zip(left.arg_names, left.arg_kinds):
804+
if name is None or kind in (ARG_STAR, ARG_STAR2) or name in right_names:
805+
continue
806+
left_only_names.add(name)
807+
808+
# Synthesize an anonymous formal argument for the right
809+
right_by_name = right.try_synthesizing_arg_from_kwarg(None)
810+
assert right_by_name is not None
811+
812+
for name in left_only_names:
813+
left_by_name = left.argument_by_name(name)
814+
assert left_by_name is not None
815+
816+
if allow_partial_overlap and not left_by_name.required:
817+
continue
818+
819+
if not are_args_compatible(left_by_name, right_by_name, ignore_pos_arg_names,
820+
allow_partial_overlap, is_compat):
765821
return False
766-
continue
767822

823+
# Phase 2: Left must not impose additional restrictions.
824+
# (Every required argument in L must have a corresponding argument in R)
825+
# Note: we already checked the *arg and **kwarg arguments in phase 1a.
826+
for left_arg in left.formal_arguments():
768827
right_by_name = (right.argument_by_name(left_arg.name)
769828
if left_arg.name is not None
770829
else None)
@@ -782,7 +841,7 @@ def g(x: int) -> int: ...
782841
return False
783842

784843
# All *required* left-hand arguments must have a corresponding
785-
# right-hand argument. Optional args it does not matter.
844+
# right-hand argument. Optional args do not matter.
786845
if left_arg.required and right_by_pos is None and right_by_name is None:
787846
return False
788847

@@ -793,23 +852,46 @@ def are_args_compatible(
793852
left: FormalArgument,
794853
right: FormalArgument,
795854
ignore_pos_arg_names: bool,
855+
allow_partial_overlap: bool,
796856
is_compat: Callable[[Type, Type], bool]) -> bool:
857+
def is_different(left_item: Optional[object], right_item: Optional[object]) -> bool:
858+
"""Checks if the left and right items are different.
859+
860+
If the right item is unspecified (e.g. if the right callable doesn't care
861+
about what name or position its arg has), we default to returning False.
862+
863+
If we're allowing partial overlap, we also default to returning False
864+
if the left callable also doesn't care."""
865+
if right_item is None:
866+
return False
867+
if allow_partial_overlap and left_item is None:
868+
return False
869+
return left_item != right_item
870+
797871
# If right has a specific name it wants this argument to be, left must
798872
# have the same.
799-
if right.name is not None and left.name != right.name:
873+
if is_different(left.name, right.name):
800874
# But pay attention to whether we're ignoring positional arg names
801875
if not ignore_pos_arg_names or right.pos is None:
802876
return False
877+
803878
# If right is at a specific position, left must have the same:
804-
if right.pos is not None and left.pos != right.pos:
879+
if is_different(left.pos, right.pos):
805880
return False
806-
# Left must have a more general type
807-
if not is_compat(right.typ, left.typ):
808-
return False
809-
# If right's argument is optional, left's must also be.
810-
if not right.required and left.required:
881+
882+
# If right's argument is optional, left's must also be
883+
# (unless we're relaxing the checks to allow potential
884+
# rather then definite compatibility).
885+
if not allow_partial_overlap and not right.required and left.required:
811886
return False
812-
return True
887+
888+
# If we're allowing partial overlaps and neither arg is required,
889+
# the types don't actually need to be the same
890+
if allow_partial_overlap and not left.required and not right.required:
891+
return True
892+
893+
# Left must have a more general type
894+
return is_compat(right.typ, left.typ)
813895

814896

815897
def flip_compat_check(is_compat: Callable[[Type, Type], bool]) -> Callable[[Type, Type], bool]:

0 commit comments

Comments
 (0)