Skip to content

Commit 2607f17

Browse files
committed
RF: Recast Pointset as a dataclass with associated affines
1 parent 422441f commit 2607f17

File tree

1 file changed

+173
-53
lines changed

1 file changed

+173
-53
lines changed

nibabel/pointset.py

+173-53
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,151 @@
1-
import operator as op
2-
from functools import reduce
1+
"""Point-set structures
2+
3+
Imaging data are sampled at points in space, and these points
4+
can be described by coordinates.
5+
These structures are designed to enable operations on sets of
6+
points, as opposed to the data sampled at those points.
7+
8+
Abstractly, a point set is any collection of points, but there are
9+
two types that warrant special consideration in the neuroimaging
10+
context: grids and meshes.
11+
12+
A *grid* is a collection of regularly-spaced points. The canonical
13+
examples of grids are the indices of voxels and their affine
14+
projection into a reference space.
15+
16+
A *mesh* is a collection of points and some structure that enables
17+
adjacent points to be identified. A *triangular mesh* in particular
18+
uses triplets of adjacent vertices to describe faces.
19+
"""
20+
from __future__ import annotations
21+
22+
import math
23+
import typing as ty
24+
from dataclasses import dataclass, replace
325

426
import numpy as np
527

6-
from nibabel.affines import apply_affine
28+
from nibabel.casting import able_int_type
29+
from nibabel.fileslice import strided_scalar
30+
from nibabel.spatialimages import SpatialImage
31+
32+
if ty.TYPE_CHECKING: # pragma: no cover
33+
from typing_extensions import Self
734

35+
_DType = ty.TypeVar('_DType', bound=np.dtype[ty.Any])
836

37+
38+
class CoordinateArray(ty.Protocol):
39+
ndim: int
40+
shape: tuple[int, int]
41+
42+
@ty.overload
43+
def __array__(self, dtype: None = ..., /) -> np.ndarray[ty.Any, np.dtype[ty.Any]]:
44+
... # pragma: no cover
45+
46+
@ty.overload
47+
def __array__(self, dtype: _DType, /) -> np.ndarray[ty.Any, _DType]:
48+
... # pragma: no cover
49+
50+
51+
@dataclass(slots=True)
952
class Pointset:
10-
def __init__(self, coords):
11-
self._coords = coords
53+
"""A collection of points described by coordinates.
54+
55+
Parameters
56+
----------
57+
coords : array-like
58+
2-dimensional array with coordinates as rows
59+
affine : :class:`numpy.ndarray`
60+
Affine transform to be applied to coordinates array
61+
homogeneous : :class:`bool`
62+
Indicate whether the provided coordinates are homogeneous,
63+
i.e., homogeneous 3D coordinates have the form ``(x, y, z, 1)``
64+
"""
65+
66+
coordinates: CoordinateArray
67+
affine: np.ndarray
68+
homogeneous: bool = False
69+
ndim = 2
70+
__array_priority__ = 99
71+
72+
def __init__(
73+
self,
74+
coordinates: CoordinateArray,
75+
affine: np.ndarray | None = None,
76+
homogeneous: bool = False,
77+
):
78+
self.coordinates = coordinates
79+
self.homogeneous = homogeneous
80+
81+
if affine is None:
82+
self.affine = np.eye(self.dim + 1)
83+
else:
84+
self.affine = np.asanyarray(affine)
85+
86+
if self.affine.shape != (self.dim + 1,) * 2:
87+
raise ValueError(f'Invalid affine for {self.dim}D coordinates:\n{self.affine}')
88+
if np.any(self.affine[-1, :-1] != 0) or self.affine[-1, -1] != 1:
89+
raise ValueError(f'Invalid affine matrix:\n{self.affine}')
90+
91+
@property
92+
def shape(self) -> tuple[int, int]:
93+
"""The shape of the coordinate array"""
94+
return self.coordinates.shape
1295

1396
@property
14-
def n_coords(self):
97+
def n_coords(self) -> int:
1598
"""Number of coordinates
1699
17100
Subclasses should override with more efficient implementations.
18101
"""
19-
return self.get_coords().shape[0]
102+
return self.coordinates.shape[0]
103+
104+
@property
105+
def dim(self) -> int:
106+
"""The dimensionality of the space the coordinates are in"""
107+
return self.coordinates.shape[1] - self.homogeneous
108+
109+
def __rmatmul__(self, affine: np.ndarray) -> Self:
110+
"""Apply an affine transformation to the pointset
111+
112+
This will return a new pointset with an updated affine matrix only.
113+
"""
114+
return replace(self, affine=np.asanyarray(affine) @ self.affine)
115+
116+
def _homogeneous_coords(self):
117+
if self.homogeneous:
118+
return np.asanyarray(self.coordinates)
119+
120+
ones = strided_scalar(
121+
shape=(self.coordinates.shape[0], 1),
122+
scalar=np.array(1, dtype=self.coordinates.dtype),
123+
)
124+
return np.hstack((self.coordinates, ones))
125+
126+
def get_coords(self, *, as_homogeneous: bool = False):
127+
"""Retrieve the coordinates
20128
21-
def get_coords(self, name=None):
22-
"""Nx3 array of coordinates.
23-
24129
Parameters
25130
----------
26-
name : :obj:`str`
131+
as_homogeneous : :class:`bool`
132+
Return homogeneous coordinates if ``True``, or Cartesian
133+
coordiantes if ``False``.
134+
135+
name : :class:`str`
27136
Select a particular coordinate system if more than one may exist.
28137
By default, `None` is equivalent to `"world"` and corresponds to
29138
an RAS+ coordinate system.
30139
"""
31-
return self._coords
140+
ident = np.allclose(self.affine, np.eye(self.affine.shape[0]))
141+
if self.homogeneous == as_homogeneous and ident:
142+
return np.asanyarray(self.coordinates)
143+
coords = self._homogeneous_coords()
144+
if not ident:
145+
coords = (self.affine @ coords.T).T
146+
if not as_homogeneous:
147+
coords = coords[:, :-1]
148+
return coords
32149

33150

34151
class TriangularMesh(Pointset):
@@ -65,14 +182,6 @@ def get_names(self):
65182
"""
66183
raise NotImplementedError
67184

68-
## This method is called for by the BIAP, but it now seems simpler to wait to
69-
## provide it until there are any proposed implementations
70-
# def decimate(self, *, n_coords=None, ratio=None):
71-
# """ Return a TriangularMesh with a smaller number of vertices that
72-
# preserves the geometry of the original """
73-
# # To be overridden when a format provides optimization opportunities
74-
# raise NotImplementedError
75-
76185

77186
class TriMeshFamily(TriangularMesh):
78187
def __init__(self, mapping, default=None):
@@ -97,40 +206,51 @@ def get_coords(self, name=None):
97206
return self._coords[name]
98207

99208

100-
class NdGrid(Pointset):
101-
"""
102-
Attributes
103-
----------
104-
shape : 3-tuple
105-
number of coordinates in each dimension of grid
209+
class Grid(Pointset):
210+
r"""A regularly-spaced collection of coordinates
211+
212+
This class provides factory methods for generating Pointsets from
213+
:class:`~nibabel.spatialimages.SpatialImage`\s and generating masks
214+
from coordinate sets.
106215
"""
107216

108-
def __init__(self, shape, affines):
109-
self.shape = tuple(shape)
110-
try:
111-
self._affines = dict(affines)
112-
except (TypeError, ValueError):
113-
self._affines = {'world': np.array(affines)}
114-
if 'voxels' not in self._affines:
115-
self._affines['voxels'] = np.eye(4, dtype=np.uint8)
116-
117-
def get_affine(self, name=None):
118-
"""4x4 array"""
119-
if name is None:
120-
name = next(iter(self._affines))
121-
return self._affines[name]
217+
@classmethod
218+
def from_image(cls, spatialimage: SpatialImage) -> Self:
219+
return cls(coordinates=GridIndices(spatialimage.shape[:3]), affine=spatialimage.affine)
122220

123-
def get_coords(self, name=None):
124-
if name is None:
125-
name = next(iter(self._affines))
126-
aff = self.get_affine(name)
127-
dt = np.result_type(*(np.min_scalar_type(dim) for dim in self.shape))
128-
# This is pretty wasteful; we almost certainly want instead an
129-
# object that will retrieve a coordinate when indexed, but where
130-
# np.array(obj) returns this
131-
ijk_coords = np.array(list(np.ndindex(self.shape)), dtype=dt)
132-
return apply_affine(aff, ijk_coords)
221+
@classmethod
222+
def from_mask(cls, mask: SpatialImage) -> Self:
223+
mask_arr = np.bool_(mask.dataobj)
224+
return cls(
225+
coordinates=np.c_[np.nonzero(mask_arr)].astype(able_int_type(mask.shape)),
226+
affine=mask.affine,
227+
)
133228

134-
@property
135-
def n_coords(self):
136-
return reduce(op.mul, self.shape)
229+
def to_mask(self, shape=None) -> SpatialImage:
230+
if shape is None:
231+
shape = tuple(np.max(self.coordinates, axis=1)[: self.dim])
232+
mask_arr = np.zeros(shape, dtype='bool')
233+
mask_arr[np.asanyarray(self.coordinates)[:, : self.dim]] = True
234+
return SpatialImage(mask_arr, self.affine)
235+
236+
237+
class GridIndices:
238+
"""Class for generating indices just-in-time"""
239+
240+
__slots__ = ('gridshape', 'dtype', 'shape')
241+
ndim = 2
242+
243+
def __init__(self, shape, dtype=None):
244+
self.gridshape = shape
245+
self.dtype = dtype or able_int_type(shape)
246+
self.shape = (math.prod(self.gridshape), len(self.gridshape))
247+
248+
def __repr__(self):
249+
return f'<{self.__class__.__name__}{self.gridshape}>'
250+
251+
def __array__(self, dtype=None):
252+
if dtype is None:
253+
dtype = self.dtype
254+
255+
axes = [np.arange(s, dtype=dtype) for s in self.gridshape]
256+
return np.reshape(np.meshgrid(*axes, copy=False, indexing='ij'), (len(axes), -1)).T

0 commit comments

Comments
 (0)