Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🏷️ ufunc annotations for logical_{not,and,or,xor} #324

Merged
merged 8 commits into from
Mar 18, 2025
21 changes: 16 additions & 5 deletions src/numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,18 @@ from ._typing._char_codes import (
_UnsignedIntegerCodes,
_VoidCodes,
)
from ._typing._ufunc import _Call11Bool, _Call11Isnat, _Call21Bool, _gufunc_2_1, _ufunc_1_1, _ufunc_1_2, _ufunc_2_1, _ufunc_2_2
from ._typing._ufunc import (
_Call11Bool,
_Call11Isnat,
_Call11Logical,
_Call21Bool,
_Call21Logical,
_gufunc_2_1,
_ufunc_1_1,
_ufunc_1_2,
_ufunc_2_1,
_ufunc_2_2,
)
from .lib import scimath as emath
from .lib._arraypad_impl import pad
from .lib._arraysetops_impl import (
Expand Down Expand Up @@ -7102,7 +7113,7 @@ log: Final[_ufunc_1_1] = ...
log2: Final[_ufunc_1_1] = ...
log10: Final[_ufunc_1_1] = ...
log1p: Final[_ufunc_1_1] = ...
logical_not: Final[_ufunc_1_1] = ...
logical_not: Final[_ufunc_1_1[_Call11Logical]] = ...
negative: Final[_ufunc_1_1] = ...
positive: Final[_ufunc_1_1] = ...
rad2deg: Final[_ufunc_1_1] = ...
Expand Down Expand Up @@ -7159,9 +7170,9 @@ ldexp: Final[_ufunc_2_1] = ...
left_shift: Final[_ufunc_2_1] = ...
logaddexp: Final[_ufunc_2_1] = ...
logaddexp2: Final[_ufunc_2_1] = ...
logical_and: Final[_ufunc_2_1] = ...
logical_or: Final[_ufunc_2_1] = ...
logical_xor: Final[_ufunc_2_1] = ...
logical_and: Final[_ufunc_2_1[_Call21Logical]] = ...
logical_or: Final[_ufunc_2_1[_Call21Logical]] = ...
logical_xor: Final[_ufunc_2_1[_Call21Logical]] = ...
maximum: Final[_ufunc_2_1] = ...
minimum: Final[_ufunc_2_1] = ...
mod: Final[_ufunc_2_1] = ...
Expand Down
200 changes: 197 additions & 3 deletions src/numpy-stubs/_typing/_ufunc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@ import numpy as np
from numpy import _CastingKind, _OrderKACF # noqa: ICN003
from numpy._typing import _DTypeLikeBool, _NestedSequence

from ._array_like import ArrayLike, NDArray, _ArrayLike, _ArrayLikeBool_co, _ArrayLikeInt_co
from ._array_like import (
ArrayLike,
NDArray,
_ArrayLike,
_ArrayLikeBool_co,
_ArrayLikeInt_co,
_ArrayLikeNumber_co,
_ArrayLikeObject_co,
)
from ._dtype_like import DTypeLike, _DTypeLike
from ._scalars import _ScalarLike_co
from ._scalars import _NumberLike_co, _ScalarLike_co
from ._shape import _ShapeLike

###
Expand Down Expand Up @@ -201,7 +209,7 @@ class _Call11Bool(Protocol):
dtype: _DTypeLikeBool | None = None,
**kwds: Unpack[_Kwargs2],
) -> _ArrayT: ...
@overload # (array) -> Array[bool] | bool
@overload # (array) -> Array[bool]
def __call__(
self,
x: _AnyArray,
Expand Down Expand Up @@ -255,6 +263,79 @@ class _Call11Isnat(Protocol):
**kwds: Unpack[_Kwargs2],
) -> NDArray[np.bool]: ...

@type_check_only
class _Call11Logical(Protocol):
@overload
def __call__( # (scalar, dtype: np.object_) -> bool
self,
x: _ScalarLike_co,
/,
out: None = None,
*,
dtype: _DTypeLike[np.object_],
**kwargs: Unpack[_Kwargs2],
) -> bool: ...
@overload
def __call__( # (scalar) -> np.bool
self,
x: _NumberLike_co,
/,
out: None = None,
*,
dtype: _DTypeLikeBool | None = None,
**kwargs: Unpack[_Kwargs2],
) -> np.bool: ...
@overload
def __call__( # (array-like, dtype: np.object_) -> np.object_
self,
x: _ArrayLikeNumber_co | _ArrayLikeObject_co,
/,
out: None = None,
*,
dtype: _DTypeLike[np.object_],
**kwargs: Unpack[_Kwargs2],
) -> NDArray[np.object_] | bool: ...
@overload
def __call__( # (array-like, out: T) -> T
self,
x: _ArrayLikeNumber_co | _ArrayLikeObject_co,
/,
out: _Out1[_ArrayT],
*,
dtype: DTypeLike | None = None,
**kwargs: Unpack[_Kwargs2],
) -> _ArrayT: ...
@overload # (array) -> Array[bool]
def __call__(
self,
x: NDArray[np.bool | np.number] | _NestedSequence[np.bool | np.number],
/,
out: _Out1[NDArray[np.bool]] | None = None,
*,
dtype: _DTypeLikeBool | None = None,
**kwds: Unpack[_Kwargs2],
) -> NDArray[np.bool]: ...
@overload
def __call__( # (array-like) -> Array[bool] | bool
self,
x: _ArrayLikeNumber_co,
/,
out: None = None,
*,
dtype: _DTypeLikeBool | None = None,
**kwargs: Unpack[_Kwargs2],
) -> NDArray[np.bool] | np.bool: ...
@overload
def __call__( # (?) -> ?
self,
x: _CanArrayUFunc,
/,
out: _Out1[_AnyArray] | None = None,
*,
dtype: DTypeLike | None = None,
**kwargs: Unpack[_Kwargs2],
) -> Any: ...

@type_check_only
class _Call12(Protocol):
@overload
Expand Down Expand Up @@ -418,6 +499,119 @@ class _Call21Bool(Protocol):
**kwds: Unpack[_Kwargs3],
) -> np.bool | NDArray[np.bool]: ...

@type_check_only
class _Call21Logical(Protocol):
@overload # (scalar, scalar, dtype: np.object_) -> np.object_
def __call__(
self,
x1: _ScalarLike_co,
x2: _ScalarLike_co,
/,
out: None = None,
*,
dtype: _DTypeLike[np.object_],
**kwds: Unpack[_Kwargs3],
) -> bool: ...
@overload # (scalar, scalar) -> bool
def __call__(
self,
x1: _NumberLike_co,
x2: _NumberLike_co,
/,
out: None = None,
*,
dtype: _DTypeLikeBool | None = None,
**kwds: Unpack[_Kwargs3],
) -> np.bool: ...
@overload # (array-like, array, dtype: object_) -> Array[object_]
def __call__(
self,
x1: _ArrayLikeNumber_co | _ArrayLikeObject_co,
x2: _AnyArray,
/,
out: None = None,
*,
dtype: _DTypeLike[np.object_],
**kwds: Unpack[_Kwargs3],
) -> NDArray[np.object_]: ...
@overload # (array, array-like, dtype: object_) -> Array[object_]
def __call__(
self,
x1: _AnyArray,
x2: _ArrayLikeNumber_co | _ArrayLikeObject_co,
/,
out: None = None,
*,
dtype: _DTypeLike[np.object_],
**kwds: Unpack[_Kwargs3],
) -> NDArray[np.object_]: ...
@overload # (array-like, array, dtype: dtype[T]) -> Array[T]
def __call__(
self,
x1: _ArrayLikeNumber_co,
x2: NDArray[np.bool | np.number] | _NestedSequence[np.bool | np.number],
/,
out: None = None,
*,
dtype: _DTypeLikeBool | None = None,
**kwds: Unpack[_Kwargs3],
) -> NDArray[np.bool]: ...
@overload # (array, array-like, dtype: dtype[T]) -> Array[T]
def __call__(
self,
x1: NDArray[np.bool | np.number] | _NestedSequence[np.bool | np.number],
x2: _ArrayLikeNumber_co,
/,
out: None = None,
*,
dtype: DTypeLike | None = None,
**kwds: Unpack[_Kwargs3],
) -> NDArray[np.bool]: ...
@overload # (array-like, array-like, out: T) -> T
def __call__(
self,
x1: _ArrayLikeNumber_co | _ArrayLikeObject_co,
x2: _ArrayLikeNumber_co | _ArrayLikeObject_co,
/,
out: _Out1[_ArrayT],
*,
dtype: None = None,
**kwds: Unpack[_Kwargs3],
) -> _ArrayT: ...
@overload # (array-like, array) -> Array[?]
def __call__(
self,
x1: _ArrayLikeNumber_co,
x2: NDArray[np.bool | np.number] | _NestedSequence[np.bool | np.number | complex],
/,
out: _Out1[NDArray[np.bool]] | None = None,
*,
dtype: _DTypeLikeBool | None = None,
**kwds: Unpack[_Kwargs3],
) -> NDArray[np.bool]: ...
@overload # (array, array-like) -> Array[?]
def __call__(
self,
x1: NDArray[np.bool | np.number] | _NestedSequence[np.bool | np.number | complex],
x2: _ArrayLikeNumber_co,
/,
out: _Out1[NDArray[np.bool]] | None = None,
*,
dtype: _DTypeLikeBool | None = None,
**kwds: Unpack[_Kwargs3],
) -> NDArray[np.bool]: ...
@overload # (array-like, array-like) -> Array[?] | ?
def __call__(
self,
x1: _ArrayLikeNumber_co | _ArrayLikeObject_co,
x2: _ArrayLikeNumber_co | _ArrayLikeObject_co,
/,
out: _Out1[_AnyArray] | None = None,
*,
dtype: DTypeLike | None = None,
**kwds: Unpack[_Kwargs3],
) -> Any: ...

@type_check_only
class _Call21(Protocol):
@overload # (scalar, scalar, dtype: type[T]) -> T
Expand Down
26 changes: 26 additions & 0 deletions test/static/accept/ufuncs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,29 @@ assert_type(np.isinf(AR_f8, out=AR_bool), npt.NDArray[np.bool_])
assert_type(np.isfinite(f8), np.bool_)
assert_type(np.isfinite(AR_f8), npt.NDArray[np.bool_])
assert_type(np.isfinite(AR_f8, out=AR_bool), npt.NDArray[np.bool_])

assert_type(np.logical_not(True), np.bool_)
assert_type(np.logical_not(AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_not(AR_bool, out=AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_not(AR_bool, dtype=np.object_), npt.NDArray[np.object_] | bool)

assert_type(np.logical_and(True, True), np.bool_)
assert_type(np.logical_and(AR_bool, AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_and(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_and(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_and(AR_bool, AR_i8), npt.NDArray[np.bool_])
assert_type(np.logical_and(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])

assert_type(np.logical_or(True, True), np.bool_)
assert_type(np.logical_or(AR_bool, AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_or(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_or(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_or(AR_bool, AR_i8), npt.NDArray[np.bool_])
assert_type(np.logical_or(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])

assert_type(np.logical_xor(True, True), np.bool_)
assert_type(np.logical_xor(AR_bool, AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_xor(AR_bool, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_xor(AR_i8, AR_bool, out=AR_bool), npt.NDArray[np.bool_])
assert_type(np.logical_xor(AR_bool, AR_i8), npt.NDArray[np.bool_])
assert_type(np.logical_xor(AR_bool, AR_bool, dtype=np.object_), npt.NDArray[np.object_])
13 changes: 13 additions & 0 deletions test/static/reject/ufuncs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import numpy.typing as npt

i8: np.int64
AR_f8: npt.NDArray[np.float64]
dt64: np.datetime64

np.sin.nin + "foo" # type: ignore[operator] # pyright: ignore[reportOperatorIssue]

Expand Down Expand Up @@ -47,3 +48,15 @@ np.isnat(i8, dtype=np.int64) # type: ignore[call-overload] # pyright: ignore[r
np.isnat(i8) # type: ignore[call-overload] # pyright: ignore[reportCallIssue,reportArgumentType]
np.isinf(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
np.isfinite(i8, dtype=np.int64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

np.logical_not(i8, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
np.logical_not(dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

np.logical_and(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
np.logical_and(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

np.logical_or(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
np.logical_or(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]

np.logical_xor(dt64, dt64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]
np.logical_xor(dt64, dt64, dtype=np.datetime64) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue]