@@ -27,8 +27,6 @@ from typing import (
27
27
)
28
28
from typing_extensions import Buffer , CapsuleType , LiteralString , Never , Protocol , Self , TypeVar , Unpack , deprecated , override
29
29
30
- import numpy as np
31
-
32
30
from . import (
33
31
__config__ as __config__ ,
34
32
_array_api_info as _array_api_info ,
@@ -590,6 +588,7 @@ _IntegerT = TypeVar("_IntegerT", bound=integer)
590
588
_SignedIntegerT = TypeVar ("_SignedIntegerT" , bound = signedinteger )
591
589
_UnsignedIntegerT = TypeVar ("_UnsignedIntegerT" , bound = unsignedinteger )
592
590
_CharT = TypeVar ("_CharT" , bound = character )
591
+ _MatmulScalarT = TypeVar ("_MatmulScalarT" , bound = bool_ | number | object_ )
593
592
594
593
_NBitT = TypeVar ("_NBitT" , bound = NBitBase , default = Any )
595
594
_NBitT1 = TypeVar ("_NBitT1" , bound = NBitBase , default = Any )
@@ -613,7 +612,7 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None)
613
612
_DT64ItemT_co = TypeVar ("_DT64ItemT_co" , bound = dt .date | int | None , default = dt .date | int | None , covariant = True )
614
613
_TD64UnitT = TypeVar ("_TD64UnitT" , bound = _TD64Unit , default = _TD64Unit )
615
614
616
- _Array1D : TypeAlias = np . ndarray [tuple [int ], np . dtype [_ScalarT ]]
615
+ _Array1D : TypeAlias = ndarray [tuple [int ], dtype [_ScalarT ]]
617
616
618
617
###
619
618
# Type Aliases (for internal use only)
@@ -2534,8 +2533,10 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
2534
2533
def __imul__ (self : NDArray [complexfloating ], rhs : _ArrayLikeComplex_co , / ) -> ndarray [_ShapeT_co , _DTypeT_co ]: ...
2535
2534
@overload
2536
2535
def __imul__ (self : NDArray [object_ ], rhs : object , / ) -> ndarray [_ShapeT_co , _DTypeT_co ]: ...
2536
+
2537
+ #
2537
2538
@overload
2538
- def __matmul__ (self : _Array1D [_ScalarT ], rhs : _Array1D [_ScalarT ], / ) -> _ScalarT : ...
2539
+ def __matmul__ (self : _Array1D [_MatmulScalarT ], rhs : _Array1D [_MatmulScalarT ], / ) -> _MatmulScalarT : ...
2539
2540
@overload
2540
2541
def __matmul__ (self : NDArray [_NumberT ], rhs : _ArrayLikeBool_co , / ) -> NDArray [_NumberT ]: ...
2541
2542
@overload
@@ -2569,14 +2570,12 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
2569
2570
@overload
2570
2571
def __matmul__ (self : NDArray [bool_ | number ], rhs : _ArrayLikeNumber_co , / ) -> NDArray [Incomplete ]: ...
2571
2572
@overload
2572
- def __matmul__ (self : NDArray [object_ ], rhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
2573
+ def __matmul__ (self : NDArray [object_ ], rhs : object , / ) -> NDArray [object_ ]: ...
2573
2574
@overload
2574
2575
def __matmul__ (self , rhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
2575
2576
2576
2577
# keep in sync with __matmul__
2577
2578
@overload
2578
- def __rmatmul__ (self : _Array1D [_ScalarT ], rhs : _Array1D [_ScalarT ], / ) -> _ScalarT : ...
2579
- @overload
2580
2579
def __rmatmul__ (self : NDArray [_NumberT ], lhs : _ArrayLikeBool_co , / ) -> NDArray [_NumberT ]: ...
2581
2580
@overload
2582
2581
def __rmatmul__ (self : NDArray [bool_ ], lhs : _ArrayLike [_NumberT ], / ) -> NDArray [_NumberT ]: ...
@@ -2609,7 +2608,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
2609
2608
@overload
2610
2609
def __rmatmul__ (self : NDArray [bool_ | number ], lhs : _ArrayLikeNumber_co , / ) -> NDArray [Incomplete ]: ...
2611
2610
@overload
2612
- def __rmatmul__ (self : NDArray [object_ ], lhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
2611
+ def __rmatmul__ (self : NDArray [object_ ], lhs : object , / ) -> NDArray [object_ ]: ...
2613
2612
@overload
2614
2613
def __rmatmul__ (self , lhs : _ArrayLikeObject_co , / ) -> NDArray [object_ ]: ...
2615
2614
0 commit comments