Skip to content

Commit c393b08

Browse files
committed
enh: extend the nonlinear transforms API
This PR lays the ground for future work on #56, and #89, by defining the matrix multiplication operator on field-based transforms.
1 parent b610a9a commit c393b08

File tree

2 files changed

+84
-10
lines changed

2 files changed

+84
-10
lines changed

nitransforms/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ def apply(
303303

304304
return resampled
305305

306+
def __matmul__(self, b):
307+
"""Compose with a transform on the right."""
308+
return b
309+
306310
def map(self, x, inverse=False):
307311
r"""
308312
Apply :math:`y = f(x)`.

nitransforms/nonlinear.py

+80-10
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
ImageGrid,
2121
SpatialReference,
2222
_as_homogeneous,
23+
EQUALITY_TOL,
2324
)
2425

2526

26-
class DisplacementsFieldTransform(TransformBase):
27-
"""Represents a dense field of displacements (one vector per voxel)."""
27+
class DeformationFieldTransform(TransformBase):
28+
"""Represents a dense field of deformed locations (corresponding to each voxel)."""
2829

2930
__slots__ = ["_field"]
3031

@@ -34,8 +35,8 @@ def __init__(self, field, reference=None):
3435
3536
Example
3637
-------
37-
>>> DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
38-
<DisplacementFieldTransform[3D] (57, 67, 56)>
38+
>>> DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
39+
<DeformationFieldTransform[3D] (57, 67, 56)>
3940
4041
"""
4142
super().__init__()
@@ -59,13 +60,13 @@ def __init__(self, field, reference=None):
5960
ndim = self._field.ndim - 1
6061
if self._field.shape[-1] != ndim:
6162
raise TransformError(
62-
"The number of components of the displacements (%d) does not "
63+
"The number of components of the displacements (%d) does not match "
6364
"the number of dimensions (%d)" % (self._field.shape[-1], ndim)
6465
)
6566

6667
def __repr__(self):
6768
"""Beautify the python representation."""
68-
return f"<DisplacementFieldTransform[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
69+
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
6970

7071
def map(self, x, inverse=False):
7172
r"""
@@ -92,12 +93,12 @@ def map(self, x, inverse=False):
9293
9394
Examples
9495
--------
95-
>>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
96+
>>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
9697
>>> xfm.map([-6.5, -36., -19.5]).tolist()
97-
[[-6.5, -36.475167989730835, -19.5]]
98+
[[0.0, -0.47516798973083496, 0.0]]
9899
99100
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
100-
[[-6.5, -36.475167989730835, -19.5], [-1.0, -42.038356602191925, -11.25]]
101+
[[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]]
101102
102103
"""
103104

@@ -108,7 +109,76 @@ def map(self, x, inverse=False):
108109
if np.any(np.abs(ijk - indexes) > 0.05):
109110
warnings.warn("Some coordinates are off-grid of the displacements field.")
110111
indexes = tuple(tuple(i) for i in indexes.T)
111-
return x + self._field[indexes]
112+
return self._field[indexes]
113+
114+
def __matmul__(self, b):
115+
"""
116+
Compose with a transform on the right.
117+
118+
Examples
119+
--------
120+
>>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
121+
>>> xfm2 = xfm @ TransformBase()
122+
>>> xfm == xfm2
123+
True
124+
125+
"""
126+
retval = b.map(
127+
self._field.reshape((-1, self._field.shape[-1]))
128+
).reshape(self._field.shape)
129+
return DeformationFieldTransform(retval, reference=self.reference)
130+
131+
def __eq__(self, other):
132+
"""
133+
Overload equals operator.
134+
135+
Examples
136+
--------
137+
>>> xfm1 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
138+
>>> xfm2 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
139+
>>> xfm1 == xfm2
140+
True
141+
142+
"""
143+
_eq = np.allclose(self._field, other._field, rtol=EQUALITY_TOL)
144+
if _eq and self._reference != other._reference:
145+
warnings.warn("Fields are equal, but references do not match.")
146+
return _eq
147+
148+
149+
class DisplacementsFieldTransform(DeformationFieldTransform):
150+
"""
151+
Represents a dense field of displacements (one vector per voxel).
152+
153+
Converting to a field of deformations is straightforward by just adding the corresponding
154+
displacement to the :math:`(x, y, z)` coordinates of each voxel.
155+
Numerically, deformation fields are less susceptible to rounding errors
156+
than displacements fields.
157+
SPM generally prefers deformations for that reason.
158+
159+
"""
160+
161+
__slots__ = ["_displacements"]
162+
163+
def __init__(self, field, reference=None):
164+
"""
165+
Create a transform supported by a field of voxel-wise displacements.
166+
167+
Example
168+
-------
169+
>>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
170+
>>> xfm
171+
<DisplacementsFieldTransform[3D] (57, 67, 56)>
172+
173+
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
174+
[[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
175+
176+
"""
177+
super().__init__(field, reference=reference)
178+
self._displacements = self._field
179+
# Convert from displacements to deformations fields
180+
# (just add the origin to the displacements vector)
181+
self._field += self.reference.ndcoords.T.reshape(self._field.shape)
112182

113183
@classmethod
114184
def from_filename(cls, filename, fmt="X5"):

0 commit comments

Comments
 (0)