Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🩹 add overload for ndarray.__matmul__ #286

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ from typing import (
)
from typing_extensions import Buffer, CapsuleType, LiteralString, Never, Protocol, Self, TypeVar, Unpack, deprecated, override

import numpy as np

from . import (
__config__ as __config__,
_array_api_info as _array_api_info,
Expand Down Expand Up @@ -611,6 +613,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None)
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True)
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)

_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np. is not needed here


###
# Type Aliases (for internal use only)

Expand Down Expand Up @@ -2530,9 +2534,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
def __imul__(self: NDArray[complexfloating], rhs: _ArrayLikeComplex_co, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ...
@overload
def __imul__(self: NDArray[object_], rhs: object, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ...

# TODO(jorenham): Support the "1d @ 1d -> scalar" case
# https://github.com/numpy/numtype/issues/197
@overload
def __matmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ScalarT is bound to np.generic, so this would also accept e.g. datetime64 and str_, which would raise an error in np.matmul. Specifically, it accepts these types:

>>> import numpy as np
>>> "".join(t[0] for t in np.matmul.types)
'?bBhHiIlLqQefdgFDGO'

which that translates to bool_ | number | object_. Maybe there's already a TypeVar with that bound, that could be reused here.

If you feel like it; you could try to broaden the rhs type a bit more by exploiting the fact that np.bool and builtins.bool will always "promote". So for example,
rhs: _Array1D[_ScalarT | bool_] is also valid here here, and also things like Sequence[_ScalarT | py_bool | bool_].

And just to be clear; it's already valid and type-safe. It's just that there are some easy ways to improve it, if you want.

@overload
def __matmul__(self: NDArray[_NumberT], rhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ...
@overload
Expand Down Expand Up @@ -2566,12 +2569,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
@overload
def __matmul__(self: NDArray[bool_ | number], rhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ...
@overload
def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ...
def __matmul__(self: NDArray[object_], rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ArrayLikeObject_co is alias for _ArrayLike[np.object_], so it accepts only things that can be expressed using the numpy np.object_, and rejects e.g. lists of decimal.Decimal, even though that'd be valid in this case:

>>> A = np.array([[Decimal(0), Decimal(1)], [Decimal(-1), Decimal(0)]])
>>> A
array([[Decimal('0'), Decimal('1')],
       [Decimal('-1'), Decimal('0')]], dtype=object)
>>> A @ [Decimal(2), Decimal(3)]
array([Decimal('3'), Decimal('-2')], dtype=object)

So this would be falsely rejected.

(I kinda expected that there would be a test for this, but apparently not)

Anyway, I guess I'm trying to say that object dtypes are very difficult to properly type, especially because of the lack of tests for it. So it might be for the best to leave it as rhs: object for now, and look at it again once we have better testing in place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using the decimal.Decimal type, an overlap issue arises between the following overloads:

@overload
def __matmul__(self: _Array1D[_MatmulScalarT], rhs: _Array1D[_MatmulScalarT], /) -> _MatmulScalarT: ...
  • Here, _MatmulScalarT includes decimal.Decimal, so this overload is a valid match.
@overload
def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ...
  • Since decimal.Decimal is also an instance of object, this overload is also a valid match.

The Conflict:

These two overloads return different types:

  • Overload 1 returns _MatmulScalarT (i.e., Decimal).
  • Overload 18 returns NDArray[object_].

Example:

from decimal import Decimal
import numpy as np

A = np.array([Decimal('1'), Decimal('2')])
B = np.array([Decimal('3'), Decimal('4')])

# A @ B matches both overloads
# But they return different types

Because both overloads match A @ B, the type checker cannot determine which one to use, leading to a type error. I'm not sure about how should this overlap issue be resolved to correctly handle object without ambiguity?

@overload
def __matmul__(self, rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...

# keep in sync with __matmul__
@overload
def __rmatmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no need for this overload: The the parameter types are identical, so lhs @ rhs will always use __matmul__, and never __rmatmul__

@overload
def __rmatmul__(self: NDArray[_NumberT], lhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ...
@overload
def __rmatmul__(self: NDArray[bool_], lhs: _ArrayLike[_NumberT], /) -> NDArray[_NumberT]: ...
Expand Down Expand Up @@ -2604,7 +2609,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
@overload
def __rmatmul__(self: NDArray[bool_ | number], lhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ...
@overload
def __rmatmul__(self: NDArray[object_], lhs: object, /) -> NDArray[object_]: ...
def __rmatmul__(self: NDArray[object_], lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
@overload
def __rmatmul__(self, lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...

Expand Down