Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🩹 add overload for ndarray.__matmul__ #286

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

guan404ming
Copy link
Contributor

close #197

@@ -2566,12 +2569,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
@overload
def __matmul__(self: NDArray[bool_ | number], rhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ...
@overload
def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ...
def __matmul__(self: NDArray[object_], rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ArrayLikeObject_co is alias for _ArrayLike[np.object_], so it accepts only things that can be expressed using the numpy np.object_, and rejects e.g. lists of decimal.Decimal, even though that'd be valid in this case:

>>> A = np.array([[Decimal(0), Decimal(1)], [Decimal(-1), Decimal(0)]])
>>> A
array([[Decimal('0'), Decimal('1')],
       [Decimal('-1'), Decimal('0')]], dtype=object)
>>> A @ [Decimal(2), Decimal(3)]
array([Decimal('3'), Decimal('-2')], dtype=object)

So this would be falsely rejected.

(I kinda expected that there would be a test for this, but apparently not)

Anyway, I guess I'm trying to say that object dtypes are very difficult to properly type, especially because of the lack of tests for it. So it might be for the best to leave it as rhs: object for now, and look at it again once we have better testing in place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using the decimal.Decimal type, an overlap issue arises between the following overloads:

@overload
def __matmul__(self: _Array1D[_MatmulScalarT], rhs: _Array1D[_MatmulScalarT], /) -> _MatmulScalarT: ...
  • Here, _MatmulScalarT includes decimal.Decimal, so this overload is a valid match.
@overload
def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ...
  • Since decimal.Decimal is also an instance of object, this overload is also a valid match.

The Conflict:

These two overloads return different types:

  • Overload 1 returns _MatmulScalarT (i.e., Decimal).
  • Overload 18 returns NDArray[object_].

Example:

from decimal import Decimal
import numpy as np

A = np.array([Decimal('1'), Decimal('2')])
B = np.array([Decimal('3'), Decimal('4')])

# A @ B matches both overloads
# But they return different types

Because both overloads match A @ B, the type checker cannot determine which one to use, leading to a type error. I'm not sure about how should this overlap issue be resolved to correctly handle object without ambiguity?

@overload
def __matmul__(self, rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...

# keep in sync with __matmul__
@overload
def __rmatmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no need for this overload: The the parameter types are identical, so lhs @ rhs will always use __matmul__, and never __rmatmul__

# TODO(jorenham): Support the "1d @ 1d -> scalar" case
# https://github.com/numpy/numtype/issues/197
@overload
def __matmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ScalarT is bound to np.generic, so this would also accept e.g. datetime64 and str_, which would raise an error in np.matmul. Specifically, it accepts these types:

>>> import numpy as np
>>> "".join(t[0] for t in np.matmul.types)
'?bBhHiIlLqQefdgFDGO'

which that translates to bool_ | number | object_. Maybe there's already a TypeVar with that bound, that could be reused here.

If you feel like it; you could try to broaden the rhs type a bit more by exploiting the fact that np.bool and builtins.bool will always "promote". So for example,
rhs: _Array1D[_ScalarT | bool_] is also valid here here, and also things like Sequence[_ScalarT | py_bool | bool_].

And just to be clear; it's already valid and type-safe. It's just that there are some easy ways to improve it, if you want.

@@ -611,6 +613,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None)
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True)
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)

_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np. is not needed here

# https://github.com/numpy/numtype/issues/197
#
@overload
def __matmul__(self: _Array1D[_IntegralT], rhs: _Array1D[_IntegralT], /) -> _IntegralT: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
Copy link
Member

@jorenham jorenham Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the overlapping overloads are actually an issue in this case. In several cases mypy incorrectly reports overload-overlap, but when pyright reports it, it's usually for a good reason. And if both report it, then only in very rare circumstances can it be safe to ignore (and this isn't one of those).

For example, let's assume that self is 1d and float64, and rhs is the union of a 1d and 2d bool arrays. Then the first overload does not match, because that requires rhs to only be 1d, but rhs could be both 1d or 2d (as it's a union type). But the second overload does match, because there, the shape of rhs doesn't matter. The return type is therefore inferred as NDArray[float64]. But that's incompatible with the return type of the first overload, because it doesn't account for the possibility that rhs is 1d, which would result in a scalar.

It's very tricky to solve this unfortunately, which is why I made a separate issue for it. I've been building the _numtype internal type-check-only package so that we can deal with situation like these, but it's changing quite fast, so I'm a bit hesitant to start using it in something as important as ndarray at this point.
So until _numtype has a somewhat stable API, and is actually tested (which is isn't right now), it's probably for the best to put this PR in the freezer for the time being.

This shows one of the (many) reasons why shape-typing is so difficult, and why it has been taking so long to make progress on. We'll get there eventually, but I'm careful not to rush into it.

I realize that __matmul__ is currently also incorrect, so it might seem weird that I don't want to fix it for this specific situation, as it seems like a net win. But the problem with overlapping overloads is, especially in case of overloads, a very pernicious one, and it could (and almost certainly will) lead to unexpected issues in which will be very difficult to debug. It's one of the most complicated parts of the Python typing system, and I probably don't fully understand it myself, so I'll just leave it at that for now.

In numpy/numpy#27032 (comment) explain it in a bit more detail, but even so, that only scratches the surface of this can of worms I'm afraid 😅.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Totally get the challenges with overlapping overloads and shape typing in ndarray. It makes sense to wait until _numtype is more stable before moving forward. Looking forward to the progress! 🚀

@jorenham jorenham marked this pull request as draft March 10, 2025 05:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ndarray.__matmul__ should return a scalar for 1d input
2 participants