Skip to content

Commit ef5a28f

Browse files
authored
Merge pull request #166 from nipy/enh/extend-nonlinear-api
ENH: Extend the nonlinear transforms API
2 parents b610a9a + b97e55c commit ef5a28f

File tree

4 files changed

+128
-37
lines changed

4 files changed

+128
-37
lines changed

nitransforms/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919
from . import linear, manip, nonlinear
2020
from .linear import Affine, LinearTransformsMapping
21-
from .nonlinear import DisplacementsFieldTransform
21+
from .nonlinear import DenseFieldTransform
2222
from .manip import TransformChain
2323

2424
try:
@@ -42,7 +42,7 @@
4242
"nonlinear",
4343
"Affine",
4444
"LinearTransformsMapping",
45-
"DisplacementsFieldTransform",
45+
"DenseFieldTransform",
4646
"TransformChain",
4747
"__copyright__",
4848
"__packagename__",

nitransforms/manip.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
TransformError,
1616
)
1717
from .linear import Affine
18-
from .nonlinear import DisplacementsFieldTransform
18+
from .nonlinear import DenseFieldTransform
1919

2020

2121
class TransformChain(TransformBase):
@@ -197,7 +197,7 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
197197
if isinstance(xfmobj, itk.ITKLinearTransform):
198198
retval.insert(0, Affine(xfmobj.to_ras(), reference=reference))
199199
else:
200-
retval.insert(0, DisplacementsFieldTransform(xfmobj))
200+
retval.insert(0, DenseFieldTransform(xfmobj))
201201

202202
return TransformChain(retval)
203203

nitransforms/nonlinear.py

+110-27
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,51 @@
2323
)
2424

2525

26-
class DisplacementsFieldTransform(TransformBase):
27-
"""Represents a dense field of displacements (one vector per voxel)."""
26+
class DenseFieldTransform(TransformBase):
27+
"""Represents dense field (voxel-wise) transforms."""
2828

29-
__slots__ = ["_field"]
29+
__slots__ = ("_field", "_deltas")
3030

31-
def __init__(self, field, reference=None):
31+
def __init__(self, field=None, is_deltas=True, reference=None):
3232
"""
33-
Create a dense deformation field transform.
33+
Create a dense field transform.
34+
35+
Converting to a field of deformations is straightforward by just adding the corresponding
36+
displacement to the :math:`(x, y, z)` coordinates of each voxel.
37+
Numerically, deformation fields are less susceptible to rounding errors
38+
than displacements fields.
39+
SPM generally prefers deformations for that reason.
40+
41+
Parameters
42+
----------
43+
field : :obj:`numpy.array_like` or :obj:`nibabel.SpatialImage`
44+
The field of deformations or displacements (*deltas*). If given as a data array,
45+
then the reference **must** be given.
46+
is_deltas : :obj:`bool`
47+
Whether this is a displacements (deltas) field (default), or deformations.
48+
reference : :obj:`ImageGrid`
49+
Defines the domain of the transform. If not provided, the domain is defined from
50+
the ``field`` input.
3451
3552
Example
3653
-------
37-
>>> DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
38-
<DisplacementFieldTransform[3D] (57, 67, 56)>
54+
>>> DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
55+
<DenseFieldTransform[3D] (57, 67, 56)>
3956
4057
"""
58+
if field is None and reference is None:
59+
raise TransformError("DenseFieldTransforms require a spatial reference")
60+
4161
super().__init__()
4262

43-
field = _ensure_image(field)
44-
self._field = np.squeeze(
45-
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
46-
)
63+
if field is not None:
64+
field = _ensure_image(field)
65+
self._field = np.squeeze(
66+
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
67+
)
68+
else:
69+
self._field = np.zeros((*reference.shape, reference.ndim), dtype="float32")
70+
is_deltas = True
4771

4872
try:
4973
self.reference = ImageGrid(
@@ -59,45 +83,61 @@ def __init__(self, field, reference=None):
5983
ndim = self._field.ndim - 1
6084
if self._field.shape[-1] != ndim:
6185
raise TransformError(
62-
"The number of components of the displacements (%d) does not "
86+
"The number of components of the field (%d) does not match "
6387
"the number of dimensions (%d)" % (self._field.shape[-1], ndim)
6488
)
6589

90+
if is_deltas:
91+
self._deltas = self._field
92+
# Convert from displacements (deltas) to deformations fields
93+
# (just add its origin to each delta vector)
94+
self._field += self.reference.ndcoords.T.reshape(self._field.shape)
95+
6696
def __repr__(self):
6797
"""Beautify the python representation."""
68-
return f"<DisplacementFieldTransform[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
98+
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
6999

70100
def map(self, x, inverse=False):
71101
r"""
72102
Apply the transformation to a list of physical coordinate points.
73103
74104
.. math::
75-
\mathbf{y} = \mathbf{x} + D(\mathbf{x}),
105+
\mathbf{y} = \mathbf{x} + \Delta(\mathbf{x}),
76106
\label{eq:2}\tag{2}
77107
78-
where :math:`D(\mathbf{x})` is the value of the discrete field of displacements
79-
:math:`D` interpolated at the location :math:`\mathbf{x}`.
108+
where :math:`\Delta(\mathbf{x})` is the value of the discrete field of displacements
109+
:math:`\Delta` interpolated at the location :math:`\mathbf{x}`.
80110
81111
Parameters
82112
----------
83-
x : N x D numpy.ndarray
113+
x : N x D :obj:`numpy.array_like`
84114
Input RAS+ coordinates (i.e., physical coordinates).
85-
inverse : bool
115+
inverse : :obj:`bool`
86116
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
87117
88118
Returns
89119
-------
90-
y : N x D numpy.ndarray
120+
y : N x D :obj:`numpy.array_like`
91121
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
92122
93123
Examples
94124
--------
95-
>>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
125+
>>> xfm = DenseFieldTransform(
126+
... test_dir / "someones_displacement_field.nii.gz",
127+
... is_deltas=False,
128+
... )
96129
>>> xfm.map([-6.5, -36., -19.5]).tolist()
97-
[[-6.5, -36.475167989730835, -19.5]]
130+
[[0.0, -0.47516798973083496, 0.0]]
98131
99132
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
100-
[[-6.5, -36.475167989730835, -19.5], [-1.0, -42.038356602191925, -11.25]]
133+
[[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]]
134+
135+
>>> xfm = DenseFieldTransform(
136+
... test_dir / "someones_displacement_field.nii.gz",
137+
... is_deltas=True,
138+
... )
139+
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
140+
[[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
101141
102142
"""
103143

@@ -106,9 +146,51 @@ def map(self, x, inverse=False):
106146
ijk = self.reference.index(x)
107147
indexes = np.round(ijk).astype("int")
108148
if np.any(np.abs(ijk - indexes) > 0.05):
109-
warnings.warn("Some coordinates are off-grid of the displacements field.")
149+
warnings.warn("Some coordinates are off-grid of the field.")
110150
indexes = tuple(tuple(i) for i in indexes.T)
111-
return x + self._field[indexes]
151+
return self._field[indexes]
152+
153+
def __matmul__(self, b):
154+
"""
155+
Compose with a transform on the right.
156+
157+
Examples
158+
--------
159+
>>> deff = DenseFieldTransform(
160+
... test_dir / "someones_displacement_field.nii.gz",
161+
... is_deltas=False,
162+
... )
163+
>>> deff2 = deff @ TransformBase()
164+
>>> deff == deff2
165+
True
166+
167+
>>> disp = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
168+
>>> disp2 = disp @ TransformBase()
169+
>>> disp == disp2
170+
True
171+
172+
"""
173+
retval = b.map(
174+
self._field.reshape((-1, self._field.shape[-1]))
175+
).reshape(self._field.shape)
176+
return DenseFieldTransform(retval, is_deltas=False, reference=self.reference)
177+
178+
def __eq__(self, other):
179+
"""
180+
Overload equals operator.
181+
182+
Examples
183+
--------
184+
>>> xfm1 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
185+
>>> xfm2 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
186+
>>> xfm1 == xfm2
187+
True
188+
189+
"""
190+
_eq = np.array_equal(self._field, other._field)
191+
if _eq and self._reference != other._reference:
192+
warnings.warn("Fields are equal, but references do not match.")
193+
return _eq
112194

113195
@classmethod
114196
def from_filename(cls, filename, fmt="X5"):
@@ -123,7 +205,7 @@ def from_filename(cls, filename, fmt="X5"):
123205
return cls(_factory[fmt].from_filename(filename))
124206

125207

126-
load = DisplacementsFieldTransform.from_filename
208+
load = DenseFieldTransform.from_filename
127209

128210

129211
class BSplineFieldTransform(TransformBase):
@@ -169,8 +251,9 @@ def to_field(self, reference=None, dtype="float32"):
169251
# 1 x Nvox : (1 x K) @ (K x Nvox)
170252
field[:, d] = self._coeffs[..., d].reshape(-1) @ self._weights
171253

172-
return DisplacementsFieldTransform(
173-
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref)
254+
return DenseFieldTransform(
255+
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
256+
)
174257

175258
def apply(
176259
self,

nitransforms/tests/test_nonlinear.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from nitransforms.io.base import TransformFileError
1313
from nitransforms.nonlinear import (
1414
BSplineFieldTransform,
15-
DisplacementsFieldTransform,
15+
DenseFieldTransform,
1616
load as nlload,
1717
)
1818
from ..io.itk import ITKDisplacementsField
@@ -45,7 +45,7 @@ def test_itk_disp_load(size):
4545
def test_displacements_bad_sizes(size):
4646
"""Checks field sizes."""
4747
with pytest.raises(TransformError):
48-
DisplacementsFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
48+
DenseFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
4949

5050

5151
def test_itk_disp_load_intent():
@@ -59,15 +59,23 @@ def test_itk_disp_load_intent():
5959

6060

6161
def test_displacements_init():
62-
DisplacementsFieldTransform(
62+
identity1 = DenseFieldTransform(
6363
np.zeros((10, 10, 10, 3)),
6464
reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None),
6565
)
66+
identity2 = DenseFieldTransform(
67+
reference=nb.Nifti1Image(np.zeros((10, 10, 10)), np.eye(4), None),
68+
)
69+
70+
assert np.array_equal(identity1._field, identity2._field)
71+
assert np.array_equal(identity1._deltas, identity2._deltas)
6672

6773
with pytest.raises(TransformError):
68-
DisplacementsFieldTransform(np.zeros((10, 10, 10, 3)))
74+
DenseFieldTransform()
75+
with pytest.raises(TransformError):
76+
DenseFieldTransform(np.zeros((10, 10, 10, 3)))
6977
with pytest.raises(TransformError):
70-
DisplacementsFieldTransform(
78+
DenseFieldTransform(
7179
np.zeros((10, 10, 10, 3)),
7280
reference=np.zeros((10, 10, 10, 3)),
7381
)
@@ -237,7 +245,7 @@ def test_bspline(tmp_path, testdata_path):
237245
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"
238246

239247
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name)
240-
dispxfm = DisplacementsFieldTransform(disp_name)
248+
dispxfm = DenseFieldTransform(disp_name)
241249

242250
out_disp = dispxfm.apply(img_name)
243251
out_bspl = bsplxfm.apply(img_name)

0 commit comments

Comments
 (0)