@@ -577,7 +577,8 @@ def is_callable_compatible(left: CallableType, right: CallableType,
577
577
is_compat_return : Optional [Callable [[Type , Type ], bool ]] = None ,
578
578
ignore_return : bool = False ,
579
579
ignore_pos_arg_names : bool = False ,
580
- check_args_covariantly : bool = False ) -> bool :
580
+ check_args_covariantly : bool = False ,
581
+ allow_partial_overlap : bool = False ) -> bool :
581
582
"""Is the left compatible with the right, using the provided compatibility check?
582
583
583
584
is_compat:
@@ -616,6 +617,55 @@ def g(x: int) -> int: ...
616
617
617
618
In this case, the first call will succeed and the second will fail: f is a
618
619
valid stand-in for g but not vice-versa.
620
+
621
+ allow_partial_overlap:
622
+ By default this function returns True if and only if *all* calls to left are
623
+ also calls to right (with respect to the provided 'is_compat' function).
624
+
625
+ If this parameter is set to 'True', we return True if *there exists at least one*
626
+ call to left that's also a call to right.
627
+
628
+ In other words, we perform an existential check instead of a universal one;
629
+ we require left to only overlap with right instead of being a subset.
630
+
631
+ For example, suppose we set 'is_compat' to some subtype check and compare following:
632
+
633
+ f(x: float, y: str = "...", *args: bool) -> str
634
+ g(*args: int) -> str
635
+
636
+ This function would normally return 'False': f is not a subtype of g.
637
+ However, we would return True if this parameter is set to 'True': the two
638
+ calls are compatible if the user runs "f_or_g(3)". In the context of that
639
+ specific call, the two functions effectively have signatures of:
640
+
641
+ f2(float) -> str
642
+ g2(int) -> str
643
+
644
+ Here, f2 is a valid subtype of g2 so we return True.
645
+
646
+ Specifically, if this parameter is set this function will:
647
+
648
+ - Ignore optional arguments on either the left or right that have no
649
+ corresponding match.
650
+ - No longer mandate optional arguments on either side are also optional
651
+ on the other.
652
+ - No longer mandate that if right has a *arg or **kwarg that left must also
653
+ have the same.
654
+
655
+ Note: when this argument is set to True, this function becomes "symmetric" --
656
+ the following calls are equivalent:
657
+
658
+ is_callable_compatible(f, g,
659
+ is_compat=some_check,
660
+ check_args_covariantly=False,
661
+ allow_partial_overlap=True)
662
+ is_callable_compatible(g, f,
663
+ is_compat=some_check,
664
+ check_args_covariantly=True,
665
+ allow_partial_overlap=True)
666
+
667
+ If the 'some_check' function is also symmetric, the two calls would be equivalent
668
+ whether or not we check the args covariantly.
619
669
"""
620
670
if is_compat_return is None :
621
671
is_compat_return = is_compat
@@ -638,7 +688,6 @@ def g(x: int) -> int: ...
638
688
# type variables of L, because generating and solving
639
689
# constraints for the variables of L to make L a subtype of R
640
690
# (below) treats type variables on the two sides as independent.
641
-
642
691
if left .variables :
643
692
# Apply generic type variables away in left via type inference.
644
693
unified = unify_generic_callable (left , right , ignore_return = ignore_return )
@@ -647,6 +696,17 @@ def g(x: int) -> int: ...
647
696
else :
648
697
left = unified
649
698
699
+ # If we allow partial overlaps, we don't need to leave R generic:
700
+ # if we can find even just a single typevar assignment which
701
+ # would make these callables compatible, we should return True.
702
+
703
+ # So, we repeat the above checks in the opposite direction. This also
704
+ # lets us preserve the 'symmetry' property of allow_partial_overlap.
705
+ if allow_partial_overlap and right .variables :
706
+ unified = unify_generic_callable (right , left , ignore_return = ignore_return )
707
+ if unified is not None :
708
+ right = unified
709
+
650
710
# Check return types.
651
711
if not ignore_return and not is_compat_return (left .ret_type , right .ret_type ):
652
712
return False
@@ -657,16 +717,17 @@ def g(x: int) -> int: ...
657
717
if right .is_ellipsis_args :
658
718
return True
659
719
660
- right_star_type = None # type: Optional[Type]
661
- right_star2_type = None # type: Optional[Type]
720
+ left_star = left .var_arg
721
+ left_star2 = left .kw_arg
722
+ right_star = right .var_arg
723
+ right_star2 = right .kw_arg
662
724
663
725
# Match up corresponding arguments and check them for compatibility. In
664
726
# every pair (argL, argR) of corresponding arguments from L and R, argL must
665
727
# be "more general" than argR if L is to be a subtype of R.
666
728
667
729
# Arguments are corresponding if they either share a name, share a position,
668
- # or both. If L's corresponding argument is ambiguous, L is not a subtype of
669
- # R.
730
+ # or both. If L's corresponding argument is ambiguous, L is not a subtype of R.
670
731
671
732
# If left has one corresponding argument by name and another by position,
672
733
# consider them to be one "merged" argument (and not ambiguous) if they're
@@ -677,94 +738,92 @@ def g(x: int) -> int: ...
677
738
678
739
# Every argument in R must have a corresponding argument in L, and every
679
740
# required argument in L must have a corresponding argument in R.
680
- done_with_positional = False
681
- for i in range (len (right .arg_types )):
682
- right_kind = right .arg_kinds [i ]
683
- if right_kind in (ARG_STAR , ARG_STAR2 , ARG_NAMED , ARG_NAMED_OPT ):
684
- done_with_positional = True
685
- right_required = right_kind in (ARG_POS , ARG_NAMED )
686
- right_pos = None if done_with_positional else i
687
-
688
- right_arg = FormalArgument (
689
- right .arg_names [i ],
690
- right_pos ,
691
- right .arg_types [i ],
692
- right_required )
693
-
694
- if right_kind == ARG_STAR :
695
- right_star_type = right_arg .typ
696
- # Right has an infinite series of optional positional arguments
697
- # here. Get all further positional arguments of left, and make sure
698
- # they're more general than their corresponding member in this
699
- # series. Also make sure left has its own infinite series of
700
- # optional positional arguments.
701
- if not left .is_var_arg :
702
- return False
703
- j = i
704
- while j < len (left .arg_kinds ) and left .arg_kinds [j ] in (ARG_POS , ARG_OPT ):
705
- left_by_position = left .argument_by_position (j )
706
- assert left_by_position is not None
707
- # This fetches the synthetic argument that's from the *args
708
- right_by_position = right .argument_by_position (j )
709
- assert right_by_position is not None
710
- if not are_args_compatible (left_by_position , right_by_position ,
711
- ignore_pos_arg_names , is_compat ):
712
- return False
713
- j += 1
714
- continue
715
-
716
- if right_kind == ARG_STAR2 :
717
- right_star2_type = right_arg .typ
718
- # Right has an infinite set of optional named arguments here. Get
719
- # all further named arguments of left and make sure they're more
720
- # general than their corresponding member in this set. Also make
721
- # sure left has its own infinite set of optional named arguments.
722
- if not left .is_kw_arg :
723
- return False
724
- left_names = {name for name in left .arg_names if name is not None }
725
- right_names = {name for name in right .arg_names if name is not None }
726
- left_only_names = left_names - right_names
727
- for name in left_only_names :
728
- left_by_name = left .argument_by_name (name )
729
- assert left_by_name is not None
730
- # This fetches the synthetic argument that's from the **kwargs
731
- right_by_name = right .argument_by_name (name )
732
- assert right_by_name is not None
733
- if not are_args_compatible (left_by_name , right_by_name ,
734
- ignore_pos_arg_names , is_compat ):
735
- return False
736
- continue
737
741
738
- # Left must have some kind of corresponding argument.
742
+ # Phase 1: Confirm every argument in R has a corresponding argument in L.
743
+
744
+ # Phase 1a: If left and right can both accept an infinite number of args,
745
+ # their types must be compatible.
746
+ #
747
+ # Furthermore, if we're checking for compatibility in all cases,
748
+ # we confirm that if R accepts an infinite number of arguments,
749
+ # L must accept the same.
750
+ def _incompatible (left_arg : Optional [FormalArgument ],
751
+ right_arg : Optional [FormalArgument ]) -> bool :
752
+ if right_arg is None :
753
+ return False
754
+ if left_arg is None :
755
+ return not allow_partial_overlap
756
+ return not is_compat (right_arg .typ , left_arg .typ )
757
+
758
+ if _incompatible (left_star , right_star ) or _incompatible (left_star2 , right_star2 ):
759
+ return False
760
+
761
+ # Phase 1b: Check non-star args: for every arg right can accept, left must
762
+ # also accept. The only exception is if we are allowing partial
763
+ # partial overlaps: in that case, we ignore optional args on the right.
764
+ for right_arg in right .formal_arguments ():
739
765
left_arg = left .corresponding_argument (right_arg )
740
766
if left_arg is None :
767
+ if allow_partial_overlap and not right_arg .required :
768
+ continue
741
769
return False
742
-
743
- if not are_args_compatible (left_arg , right_arg ,
744
- ignore_pos_arg_names , is_compat ):
770
+ if not are_args_compatible (left_arg , right_arg , ignore_pos_arg_names ,
771
+ allow_partial_overlap , is_compat ):
745
772
return False
746
773
747
- done_with_positional = False
748
- for i in range (len (left .arg_types )):
749
- left_kind = left .arg_kinds [i ]
750
- if left_kind in (ARG_STAR , ARG_STAR2 , ARG_NAMED , ARG_NAMED_OPT ):
751
- done_with_positional = True
752
- left_arg = FormalArgument (
753
- left .arg_names [i ],
754
- None if done_with_positional else i ,
755
- left .arg_types [i ],
756
- left_kind in (ARG_POS , ARG_NAMED ))
757
-
758
- # Check that *args and **kwargs types match in this loop
759
- if left_kind == ARG_STAR :
760
- if right_star_type is not None and not is_compat (right_star_type , left_arg .typ ):
774
+ # Phase 1c: Check var args. Right has an infinite series of optional positional
775
+ # arguments. Get all further positional args of left, and make sure
776
+ # they're more general then the corresponding member in right.
777
+ if right_star is not None :
778
+ # Synthesize an anonymous formal argument for the right
779
+ right_by_position = right .try_synthesizing_arg_from_vararg (None )
780
+ assert right_by_position is not None
781
+
782
+ i = right_star .pos
783
+ assert i is not None
784
+ while i < len (left .arg_kinds ) and left .arg_kinds [i ] in (ARG_POS , ARG_OPT ):
785
+ if allow_partial_overlap and left .arg_kinds [i ] == ARG_OPT :
786
+ break
787
+
788
+ left_by_position = left .argument_by_position (i )
789
+ assert left_by_position is not None
790
+
791
+ if not are_args_compatible (left_by_position , right_by_position ,
792
+ ignore_pos_arg_names , allow_partial_overlap ,
793
+ is_compat ):
761
794
return False
762
- continue
763
- elif left_kind == ARG_STAR2 :
764
- if right_star2_type is not None and not is_compat (right_star2_type , left_arg .typ ):
795
+ i += 1
796
+
797
+ # Phase 1d: Check kw args. Right has an infinite series of optional named
798
+ # arguments. Get all further named args of left, and make sure
799
+ # they're more general then the corresponding member in right.
800
+ if right_star2 is not None :
801
+ right_names = {name for name in right .arg_names if name is not None }
802
+ left_only_names = set ()
803
+ for name , kind in zip (left .arg_names , left .arg_kinds ):
804
+ if name is None or kind in (ARG_STAR , ARG_STAR2 ) or name in right_names :
805
+ continue
806
+ left_only_names .add (name )
807
+
808
+ # Synthesize an anonymous formal argument for the right
809
+ right_by_name = right .try_synthesizing_arg_from_kwarg (None )
810
+ assert right_by_name is not None
811
+
812
+ for name in left_only_names :
813
+ left_by_name = left .argument_by_name (name )
814
+ assert left_by_name is not None
815
+
816
+ if allow_partial_overlap and not left_by_name .required :
817
+ continue
818
+
819
+ if not are_args_compatible (left_by_name , right_by_name , ignore_pos_arg_names ,
820
+ allow_partial_overlap , is_compat ):
765
821
return False
766
- continue
767
822
823
+ # Phase 2: Left must not impose additional restrictions.
824
+ # (Every required argument in L must have a corresponding argument in R)
825
+ # Note: we already checked the *arg and **kwarg arguments in phase 1a.
826
+ for left_arg in left .formal_arguments ():
768
827
right_by_name = (right .argument_by_name (left_arg .name )
769
828
if left_arg .name is not None
770
829
else None )
@@ -782,7 +841,7 @@ def g(x: int) -> int: ...
782
841
return False
783
842
784
843
# All *required* left-hand arguments must have a corresponding
785
- # right-hand argument. Optional args it does not matter.
844
+ # right-hand argument. Optional args do not matter.
786
845
if left_arg .required and right_by_pos is None and right_by_name is None :
787
846
return False
788
847
@@ -793,23 +852,46 @@ def are_args_compatible(
793
852
left : FormalArgument ,
794
853
right : FormalArgument ,
795
854
ignore_pos_arg_names : bool ,
855
+ allow_partial_overlap : bool ,
796
856
is_compat : Callable [[Type , Type ], bool ]) -> bool :
857
+ def is_different (left_item : Optional [object ], right_item : Optional [object ]) -> bool :
858
+ """Checks if the left and right items are different.
859
+
860
+ If the right item is unspecified (e.g. if the right callable doesn't care
861
+ about what name or position its arg has), we default to returning False.
862
+
863
+ If we're allowing partial overlap, we also default to returning False
864
+ if the left callable also doesn't care."""
865
+ if right_item is None :
866
+ return False
867
+ if allow_partial_overlap and left_item is None :
868
+ return False
869
+ return left_item != right_item
870
+
797
871
# If right has a specific name it wants this argument to be, left must
798
872
# have the same.
799
- if right . name is not None and left .name != right .name :
873
+ if is_different ( left .name , right .name ) :
800
874
# But pay attention to whether we're ignoring positional arg names
801
875
if not ignore_pos_arg_names or right .pos is None :
802
876
return False
877
+
803
878
# If right is at a specific position, left must have the same:
804
- if right . pos is not None and left .pos != right .pos :
879
+ if is_different ( left .pos , right .pos ) :
805
880
return False
806
- # Left must have a more general type
807
- if not is_compat ( right . typ , left . typ ):
808
- return False
809
- # If right's argument is optional, left's must also be .
810
- if not right .required and left .required :
881
+
882
+ # If right's argument is optional , left's must also be
883
+ # (unless we're relaxing the checks to allow potential
884
+ # rather then definite compatibility) .
885
+ if not allow_partial_overlap and not right .required and left .required :
811
886
return False
812
- return True
887
+
888
+ # If we're allowing partial overlaps and neither arg is required,
889
+ # the types don't actually need to be the same
890
+ if allow_partial_overlap and not left .required and not right .required :
891
+ return True
892
+
893
+ # Left must have a more general type
894
+ return is_compat (right .typ , left .typ )
813
895
814
896
815
897
def flip_compat_check (is_compat : Callable [[Type , Type ], bool ]) -> Callable [[Type , Type ], bool ]:
0 commit comments