-
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🩹 add overload for ndarray.__matmul__
#286
base: main
Are you sure you want to change the base?
Conversation
2d009f2
to
8cc63f9
Compare
8cc63f9
to
cf0f0db
Compare
src/numpy-stubs/__init__.pyi
Outdated
@@ -2566,12 +2569,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): | |||
@overload | |||
def __matmul__(self: NDArray[bool_ | number], rhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ... | |||
@overload | |||
def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ... | |||
def __matmul__(self: NDArray[object_], rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_ArrayLikeObject_co
is alias for _ArrayLike[np.object_]
, so it accepts only things that can be expressed using the numpy np.object_
, and rejects e.g. lists of decimal.Decimal
, even though that'd be valid in this case:
>>> A = np.array([[Decimal(0), Decimal(1)], [Decimal(-1), Decimal(0)]])
>>> A
array([[Decimal('0'), Decimal('1')],
[Decimal('-1'), Decimal('0')]], dtype=object)
>>> A @ [Decimal(2), Decimal(3)]
array([Decimal('3'), Decimal('-2')], dtype=object)
So this would be falsely rejected.
(I kinda expected that there would be a test for this, but apparently not)
Anyway, I guess I'm trying to say that object dtypes are very difficult to properly type, especially because of the lack of tests for it. So it might be for the best to leave it as rhs: object
for now, and look at it again once we have better testing in place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using the decimal.Decimal
type, an overlap issue arises between the following overloads:
@overload
def __matmul__(self: _Array1D[_MatmulScalarT], rhs: _Array1D[_MatmulScalarT], /) -> _MatmulScalarT: ...
- Here,
_MatmulScalarT
includesdecimal.Decimal
, so this overload is a valid match.
@overload
def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ...
- Since
decimal.Decimal
is also an instance ofobject
, this overload is also a valid match.
The Conflict:
These two overloads return different types:
- Overload 1 returns
_MatmulScalarT
(i.e.,Decimal
). - Overload 18 returns
NDArray[object_]
.
Example:
from decimal import Decimal
import numpy as np
A = np.array([Decimal('1'), Decimal('2')])
B = np.array([Decimal('3'), Decimal('4')])
# A @ B matches both overloads
# But they return different types
Because both overloads match A @ B
, the type checker cannot determine which one to use, leading to a type error. I'm not sure about how should this overlap issue be resolved to correctly handle object without ambiguity?
src/numpy-stubs/__init__.pyi
Outdated
@overload | ||
def __matmul__(self, rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ... | ||
|
||
# keep in sync with __matmul__ | ||
@overload | ||
def __rmatmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no need for this overload: The the parameter types are identical, so lhs @ rhs
will always use __matmul__
, and never __rmatmul__
src/numpy-stubs/__init__.pyi
Outdated
# TODO(jorenham): Support the "1d @ 1d -> scalar" case | ||
# https://github.com/numpy/numtype/issues/197 | ||
@overload | ||
def __matmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ScalarT
is bound to np.generic
, so this would also accept e.g. datetime64
and str_
, which would raise an error in np.matmul
. Specifically, it accepts these types:
>>> import numpy as np
>>> "".join(t[0] for t in np.matmul.types)
'?bBhHiIlLqQefdgFDGO'
which that translates to bool_ | number | object_
. Maybe there's already a TypeVar
with that bound, that could be reused here.
If you feel like it; you could try to broaden the rhs
type a bit more by exploiting the fact that np.bool
and builtins.bool
will always "promote". So for example,
rhs: _Array1D[_ScalarT | bool_]
is also valid here here, and also things like Sequence[_ScalarT | py_bool | bool_]
.
And just to be clear; it's already valid and type-safe. It's just that there are some easy ways to improve it, if you want.
src/numpy-stubs/__init__.pyi
Outdated
@@ -611,6 +613,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None) | |||
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True) | |||
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit) | |||
|
|||
_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.
is not needed here
e455f80
to
bec757d
Compare
# https://github.com/numpy/numtype/issues/197 | ||
# | ||
@overload | ||
def __matmul__(self: _Array1D[_IntegralT], rhs: _Array1D[_IntegralT], /) -> _IntegralT: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, the overlapping overloads are actually an issue in this case. In several cases mypy incorrectly reports overload-overlap
, but when pyright reports it, it's usually for a good reason. And if both report it, then only in very rare circumstances can it be safe to ignore (and this isn't one of those).
For example, let's assume that self
is 1d and float64
, and rhs
is the union of a 1d and 2d bool
arrays. Then the first overload does not match, because that requires rhs
to only be 1d, but rhs
could be both 1d or 2d (as it's a union type). But the second overload does match, because there, the shape of rhs
doesn't matter. The return type is therefore inferred as NDArray[float64]
. But that's incompatible with the return type of the first overload, because it doesn't account for the possibility that rhs
is 1d, which would result in a scalar.
It's very tricky to solve this unfortunately, which is why I made a separate issue for it. I've been building the _numtype
internal type-check-only package so that we can deal with situation like these, but it's changing quite fast, so I'm a bit hesitant to start using it in something as important as ndarray
at this point.
So until _numtype
has a somewhat stable API, and is actually tested (which is isn't right now), it's probably for the best to put this PR in the freezer for the time being.
This shows one of the (many) reasons why shape-typing is so difficult, and why it has been taking so long to make progress on. We'll get there eventually, but I'm careful not to rush into it.
I realize that __matmul__
is currently also incorrect, so it might seem weird that I don't want to fix it for this specific situation, as it seems like a net win. But the problem with overlapping overloads is, especially in case of overloads, a very pernicious one, and it could (and almost certainly will) lead to unexpected issues in which will be very difficult to debug. It's one of the most complicated parts of the Python typing system, and I probably don't fully understand it myself, so I'll just leave it at that for now.
In numpy/numpy#27032 (comment) explain it in a bit more detail, but even so, that only scratches the surface of this can of worms I'm afraid 😅.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation! Totally get the challenges with overlapping overloads and shape typing in ndarray. It makes sense to wait until _numtype is more stable before moving forward. Looking forward to the progress! 🚀
close #197