Skip to content

Commit 670f486

Browse files
authored
stubtest: Fix crash with numpy array default values (#18353)
See #18343 (comment)
1 parent 60da03a commit 670f486

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

mypy/stubtest.py

+25-17
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ def _verify_arg_default_value(
670670
stub_arg: nodes.Argument, runtime_arg: inspect.Parameter
671671
) -> Iterator[str]:
672672
"""Checks whether argument default values are compatible."""
673-
if runtime_arg.default != inspect.Parameter.empty:
673+
if runtime_arg.default is not inspect.Parameter.empty:
674674
if stub_arg.kind.is_required():
675675
yield (
676676
f'runtime argument "{runtime_arg.name}" '
@@ -705,18 +705,26 @@ def _verify_arg_default_value(
705705
stub_default is not UNKNOWN
706706
and stub_default is not ...
707707
and runtime_arg.default is not UNREPRESENTABLE
708-
and (
709-
stub_default != runtime_arg.default
710-
# We want the types to match exactly, e.g. in case the stub has
711-
# True and the runtime has 1 (or vice versa).
712-
or type(stub_default) is not type(runtime_arg.default)
713-
)
714708
):
715-
yield (
716-
f'runtime argument "{runtime_arg.name}" '
717-
f"has a default value of {runtime_arg.default!r}, "
718-
f"which is different from stub argument default {stub_default!r}"
719-
)
709+
defaults_match = True
710+
# We want the types to match exactly, e.g. in case the stub has
711+
# True and the runtime has 1 (or vice versa).
712+
if type(stub_default) is not type(runtime_arg.default):
713+
defaults_match = False
714+
else:
715+
try:
716+
defaults_match = bool(stub_default == runtime_arg.default)
717+
except Exception:
718+
# Exception can be raised in bool dunder method (e.g. numpy arrays)
719+
# At this point, consider the default to be different, it is probably
720+
# too complex to put in a stub anyway.
721+
defaults_match = False
722+
if not defaults_match:
723+
yield (
724+
f'runtime argument "{runtime_arg.name}" '
725+
f"has a default value of {runtime_arg.default!r}, "
726+
f"which is different from stub argument default {stub_default!r}"
727+
)
720728
else:
721729
if stub_arg.kind.is_optional():
722730
yield (
@@ -758,7 +766,7 @@ def get_type(arg: Any) -> str | None:
758766

759767
def has_default(arg: Any) -> bool:
760768
if isinstance(arg, inspect.Parameter):
761-
return bool(arg.default != inspect.Parameter.empty)
769+
return arg.default is not inspect.Parameter.empty
762770
if isinstance(arg, nodes.Argument):
763771
return arg.kind.is_optional()
764772
raise AssertionError
@@ -1628,13 +1636,13 @@ def anytype() -> mypy.types.AnyType:
16281636
arg_names.append(
16291637
None if arg.kind == inspect.Parameter.POSITIONAL_ONLY else arg.name
16301638
)
1631-
has_default = arg.default == inspect.Parameter.empty
1639+
no_default = arg.default is inspect.Parameter.empty
16321640
if arg.kind == inspect.Parameter.POSITIONAL_ONLY:
1633-
arg_kinds.append(nodes.ARG_POS if has_default else nodes.ARG_OPT)
1641+
arg_kinds.append(nodes.ARG_POS if no_default else nodes.ARG_OPT)
16341642
elif arg.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
1635-
arg_kinds.append(nodes.ARG_POS if has_default else nodes.ARG_OPT)
1643+
arg_kinds.append(nodes.ARG_POS if no_default else nodes.ARG_OPT)
16361644
elif arg.kind == inspect.Parameter.KEYWORD_ONLY:
1637-
arg_kinds.append(nodes.ARG_NAMED if has_default else nodes.ARG_NAMED_OPT)
1645+
arg_kinds.append(nodes.ARG_NAMED if no_default else nodes.ARG_NAMED_OPT)
16381646
elif arg.kind == inspect.Parameter.VAR_POSITIONAL:
16391647
arg_kinds.append(nodes.ARG_STAR)
16401648
elif arg.kind == inspect.Parameter.VAR_KEYWORD:

mypy/test/teststubtest.py

+12
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,18 @@ def f11(text=None) -> None: pass
529529
error="f11",
530530
)
531531

532+
# Simulate numpy ndarray.__bool__ that raises an error
533+
yield Case(
534+
stub="def f12(x=1): ...",
535+
runtime="""
536+
class _ndarray:
537+
def __eq__(self, obj): return self
538+
def __bool__(self): raise ValueError
539+
def f12(x=_ndarray()) -> None: pass
540+
""",
541+
error="f12",
542+
)
543+
532544
@collect_cases
533545
def test_static_class_method(self) -> Iterator[Case]:
534546
yield Case(

0 commit comments

Comments
 (0)