Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type: Track type of SpatialImage.affine, test type inference #1411

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -239,22 +239,18 @@ jobs:
continue-on-error: true
strategy:
matrix:
check: ['style', 'doctest', 'typecheck', 'spellcheck']
check: ['style', 'doctest', 'typecheck', 'spellcheck', 'type-inference']

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: 3
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v5
- name: Install tox
run: uv tool install tox --with=tox-uv
- name: Show tox config
run: pipx run tox c
- name: Show tox config (this call)
run: pipx run tox c -e ${{ matrix.check }}
run: tox c -e ${{ matrix.check }}
- name: Run check
run: pipx run tox -e ${{ matrix.check }}
run: tox -e ${{ matrix.check }}

publish:
runs-on: ubuntu-latest
Expand Down
25 changes: 25 additions & 0 deletions nibabel/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Helpers for typing compatibility across Python versions"""

import sys

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
from typing import ParamSpec

if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self

if sys.version_info < (3, 13):
from typing_extensions import TypeVar
else:
from typing import TypeVar


__all__ = [
'ParamSpec',
'Self',
'TypeVar',
]
24 changes: 21 additions & 3 deletions nibabel/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@

from __future__ import annotations

import typing as ty

import numpy as np

from .arrayproxy import ArrayProxy
from .arraywriters import ArrayWriter, WriterError, get_slope_inter, make_array_writer
from .batteryrunners import Report
from .fileholders import copy_file_map
from .spatialimages import HeaderDataError, HeaderTypeError, SpatialHeader, SpatialImage
from .spatialimages import AffT, HeaderDataError, HeaderTypeError, SpatialHeader, SpatialImage
from .volumeutils import (
apply_read_scaling,
array_from_file,
Expand All @@ -102,6 +104,13 @@
)
from .wrapstruct import LabeledWrapStruct

if ty.TYPE_CHECKING:
from collections.abc import Mapping

from .arrayproxy import ArrayLike
from .filebasedimages import FileBasedHeader
from .fileholders import FileMap

# Sub-parts of standard analyze header from
# Mayo dbh.h file
header_key_dtd = [
Expand Down Expand Up @@ -893,11 +902,12 @@ def may_contain_header(klass, binaryblock):
return 348 in (hdr_struct['sizeof_hdr'], bs_hdr_struct['sizeof_hdr'])


class AnalyzeImage(SpatialImage):
class AnalyzeImage(SpatialImage[AffT]):
"""Class for basic Analyze format image"""

header_class: type[AnalyzeHeader] = AnalyzeHeader
header: AnalyzeHeader
_header: AnalyzeHeader
_meta_sniff_len = header_class.sizeof_hdr
files_types: tuple[tuple[str, str], ...] = (('image', '.img'), ('header', '.hdr'))
valid_exts: tuple[str, ...] = ('.img', '.hdr')
Expand All @@ -908,7 +918,15 @@ class AnalyzeImage(SpatialImage):

ImageArrayProxy = ArrayProxy

def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, dtype=None):
def __init__(
self,
dataobj: ArrayLike,
affine: AffT,
header: FileBasedHeader | Mapping | None = None,
extra: Mapping | None = None,
file_map: FileMap | None = None,
dtype=None,
) -> None:
super().__init__(dataobj, affine, header, extra, file_map)
# Reset consumable values
self._header.set_data_offset(0)
Expand Down
5 changes: 3 additions & 2 deletions nibabel/arrayproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@

if ty.TYPE_CHECKING:
import numpy.typing as npt
from typing_extensions import Self # PY310

from ._typing import Self, TypeVar

# Taken from numpy/__init__.pyi
_DType = ty.TypeVar('_DType', bound=np.dtype[ty.Any])
_DType = TypeVar('_DType', bound=np.dtype[ty.Any])


class ArrayLike(ty.Protocol):
Expand Down
4 changes: 2 additions & 2 deletions nibabel/brikhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from .arrayproxy import ArrayProxy
from .fileslice import strided_scalar
from .spatialimages import HeaderDataError, ImageDataError, SpatialHeader, SpatialImage
from .spatialimages import Affine, HeaderDataError, ImageDataError, SpatialHeader, SpatialImage
from .volumeutils import Recoder

# used for doc-tests
Expand Down Expand Up @@ -453,7 +453,7 @@ def get_volume_labels(self):
return labels


class AFNIImage(SpatialImage):
class AFNIImage(SpatialImage[Affine]):
"""
AFNI Image file

Expand Down
37 changes: 33 additions & 4 deletions nibabel/dataobj_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

from ._typing import TypeVar
from .deprecated import deprecate_with_version
from .filebasedimages import FileBasedHeader, FileBasedImage

Expand All @@ -24,7 +25,13 @@
from .fileholders import FileMap
from .filename_parser import FileSpec

ArrayImgT = ty.TypeVar('ArrayImgT', bound='DataobjImage')
FT = TypeVar('FT', bound=np.floating)
F16 = ty.Literal['float16', 'f2', '|f2', '=f2', '<f2', '>f2']
F32 = ty.Literal['float32', 'f4', '|f4', '=f4', '<f4', '>f4']
F64 = ty.Literal['float64', 'f8', '|f8', '=f8', '<f8', '>f8']
Caching = ty.Literal['fill', 'unchanged']

ArrayImgT = TypeVar('ArrayImgT', bound='DataobjImage')


class DataobjImage(FileBasedImage):
Expand All @@ -39,7 +46,7 @@ def __init__(
header: FileBasedHeader | ty.Mapping | None = None,
extra: ty.Mapping | None = None,
file_map: FileMap | None = None,
):
) -> None:
"""Initialize dataobj image

The datobj image is a combination of (dataobj, header), with optional
Expand Down Expand Up @@ -224,11 +231,33 @@ def get_data(self, caching='fill'):
self._data_cache = data
return data

# Types and dtypes, e.g., np.float64 or np.dtype('f8')
@ty.overload
def get_fdata(
self, *, caching: Caching = 'fill', dtype: type[FT] | np.dtype[FT]
) -> npt.NDArray[FT]: ...
@ty.overload
def get_fdata(self, caching: Caching, dtype: type[FT] | np.dtype[FT]) -> npt.NDArray[FT]: ...
# Support string literals
@ty.overload
def get_fdata(self, caching: Caching, dtype: F16) -> npt.NDArray[np.float16]: ...
@ty.overload
def get_fdata(self, caching: Caching, dtype: F32) -> npt.NDArray[np.float32]: ...
@ty.overload
def get_fdata(self, *, caching: Caching = 'fill', dtype: F16) -> npt.NDArray[np.float16]: ...
@ty.overload
def get_fdata(self, *, caching: Caching = 'fill', dtype: F32) -> npt.NDArray[np.float32]: ...
# Double-up on float64 literals and the default (no arguments) case
@ty.overload
def get_fdata(
self, caching: Caching = 'fill', dtype: F64 = 'f8'
) -> npt.NDArray[np.float64]: ...

def get_fdata(
self,
caching: ty.Literal['fill', 'unchanged'] = 'fill',
caching: Caching = 'fill',
dtype: npt.DTypeLike = np.float64,
) -> np.ndarray[ty.Any, np.dtype[np.floating]]:
) -> npt.NDArray[np.floating]:
"""Return floating point image data with necessary scaling applied

The image ``dataobj`` property can be an array proxy or an array. An
Expand Down
8 changes: 2 additions & 6 deletions nibabel/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
import typing as ty
import warnings

from ._typing import ParamSpec
from .deprecator import Deprecator
from .pkg_info import cmp_pkg_version

if ty.TYPE_CHECKING:
# PY39: ParamSpec is available in Python 3.10+
P = ty.ParamSpec('P')
else:
# Just to keep the runtime happy
P = ty.TypeVar('P')
P = ParamSpec('P')


class ModuleProxy:
Expand Down
6 changes: 4 additions & 2 deletions nibabel/deprecator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from textwrap import dedent

if ty.TYPE_CHECKING:
T = ty.TypeVar('T')
P = ty.ParamSpec('P')
from ._typing import ParamSpec, TypeVar

T = TypeVar('T')
P = ParamSpec('P')

_LEADING_WHITE = re.compile(r'^(\s*)')

Expand Down
74 changes: 47 additions & 27 deletions nibabel/ecat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,30 @@
below). It's not clear what the licenses are for these files.
"""

from __future__ import annotations

import warnings
from numbers import Integral
from typing import TYPE_CHECKING

import numpy as np

from .arraywriters import make_array_writer
from .fileslice import canonical_slicers, predict_shape, slice2outax
from .spatialimages import SpatialHeader, SpatialImage
from .spatialimages import Affine, AffT, SpatialHeader, SpatialImage
from .volumeutils import array_from_file, make_dt_codes, native_code, swapped_code
from .wrapstruct import WrapStruct

if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Literal as L

import numpy.typing as npt

from .arrayproxy import ArrayLike
from .filebasedimages import FileBasedHeader
from .fileholders import FileMap

BLOCK_SIZE = 512

main_header_dtd = [
Expand Down Expand Up @@ -743,7 +756,7 @@ def __getitem__(self, sliceobj):
return out_data


class EcatImage(SpatialImage):
class EcatImage(SpatialImage[AffT]):
"""Class returns a list of Ecat images, with one image(hdr/data) per frame"""

header_class = EcatHeader
Expand All @@ -756,7 +769,16 @@ class EcatImage(SpatialImage):

ImageArrayProxy = EcatImageArrayProxy

def __init__(self, dataobj, affine, header, subheader, mlist, extra=None, file_map=None):
def __init__(
self,
dataobj: ArrayLike,
affine: AffT,
header: FileBasedHeader | Mapping | None,
subheader: EcatSubHeader,
mlist: npt.NDArray[np.integer],
extra: Mapping | None = None,
file_map: FileMap | None = None,
) -> None:
"""Initialize Image

The image is a combination of
Expand Down Expand Up @@ -798,40 +820,38 @@ def __init__(self, dataobj, affine, header, subheader, mlist, extra=None, file_m
>>> data4d.shape == (10, 10, 3, 1)
True
"""
super().__init__(
dataobj=dataobj,
affine=affine,
header=header,
extra=extra,
file_map=file_map,
)
self._subheader = subheader
self._mlist = mlist
self._dataobj = dataobj
if affine is not None:
# Check that affine is array-like 4,4. Maybe this is too strict at
# this abstract level, but so far I think all image formats we know
# do need 4,4.
affine = np.array(affine, dtype=np.float64, copy=True)
if not affine.shape == (4, 4):
raise ValueError('Affine should be shape 4,4')
self._affine = affine
if extra is None:
extra = {}
self.extra = extra
self._header = header
if file_map is None:
file_map = self.__class__.make_file_map()
self.file_map = file_map
self._data_cache = None
self._fdata_cache = None

# Override SpatialImage default, which attempts to set the
# affine in the header.
def update_header(self) -> None:
"""Does nothing"""

@property
def affine(self):
def affine(self) -> AffT:
if not self._subheader._check_affines():
warnings.warn(
'Affines different across frames, loading affine from FIRST frame', UserWarning
)
return self._affine

def get_frame_affine(self, frame):
def get_frame_affine(self, frame: int) -> Affine:
"""returns 4X4 affine"""
return self._subheader.get_frame_affine(frame=frame)

def get_frame(self, frame, orientation=None):
def get_frame(
self,
frame: int,
orientation: L['neurological', 'radiological'] | None = None,
) -> np.ndarray:
"""
Get full volume for a time frame

Expand All @@ -847,16 +867,16 @@ def get_data_dtype(self, frame):
return dt

@property
def shape(self):
def shape(self) -> tuple[int, int, int, int]:
x, y, z = self._subheader.get_shape()
nframes = self._subheader.get_nframes()
return (x, y, z, nframes)

def get_mlist(self):
def get_mlist(self) -> npt.NDArray[np.integer]:
"""get access to the mlist"""
return self._mlist

def get_subheaders(self):
def get_subheaders(self) -> EcatSubHeader:
"""get access to subheaders"""
return self._subheader

Expand Down
Loading
Loading