Skip to content

Commit 93813b1

Browse files
authored
Merge pull request #90 from oesteban/enh/88-collapse-linear-xfms
ENH: Add an ``.asaffine()`` member to ``TransformChain``
2 parents 1e4aa00 + 6afb6f0 commit 93813b1

File tree

5 files changed

+135
-38
lines changed

5 files changed

+135
-38
lines changed

nitransforms/linear.py

+51-5
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
class Affine(TransformBase):
2525
"""Represents linear transforms on image data."""
2626

27-
__slots__ = ("_matrix", )
27+
__slots__ = ("_matrix", "_inverse")
2828

2929
def __init__(self, matrix=None, reference=None):
3030
"""
@@ -57,6 +57,7 @@ def __init__(self, matrix=None, reference=None):
5757
"""
5858
super().__init__(reference=reference)
5959
self._matrix = np.eye(4)
60+
self._inverse = np.eye(4)
6061

6162
if matrix is not None:
6263
matrix = np.array(matrix)
@@ -72,6 +73,7 @@ def __init__(self, matrix=None, reference=None):
7273

7374
# Normalize last row
7475
self._matrix[3, :] = (0, 0, 0, 1)
76+
self._inverse = np.linalg.inv(self._matrix)
7577

7678
def __eq__(self, other):
7779
"""
@@ -90,6 +92,44 @@ def __eq__(self, other):
9092
warnings.warn("Affines are equal, but references do not match.")
9193
return _eq
9294

95+
def __invert__(self):
96+
"""
97+
Get the inverse of this transform.
98+
99+
Example
100+
-------
101+
>>> matrix = [[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
102+
>>> Affine(np.linalg.inv(matrix)) == ~Affine(matrix)
103+
True
104+
105+
"""
106+
return self.__class__(self._inverse)
107+
108+
def __matmul__(self, b):
109+
"""
110+
Compose two Affines.
111+
112+
Example
113+
-------
114+
>>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
115+
>>> xfm1 @ ~xfm1 == Affine()
116+
True
117+
118+
>>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
119+
>>> xfm1 @ np.eye(4) == xfm1
120+
True
121+
122+
"""
123+
if not isinstance(b, self.__class__):
124+
_b = self.__class__(b)
125+
else:
126+
_b = b
127+
128+
retval = self.__class__(self.matrix.dot(_b.matrix))
129+
if _b.reference:
130+
retval.reference = _b.reference
131+
return retval
132+
93133
@property
94134
def matrix(self):
95135
"""Access the internal representation of this affine."""
@@ -124,14 +164,14 @@ def map(self, x, inverse=False):
124164
affine = self._matrix
125165
coords = _as_homogeneous(x, dim=affine.shape[0] - 1).T
126166
if inverse is True:
127-
affine = np.linalg.inv(self._matrix)
167+
affine = self._inverse
128168
return affine.dot(coords).T[..., :-1]
129169

130170
def _to_hdf5(self, x5_root):
131171
"""Serialize this object into the x5 file format."""
132172
xform = x5_root.create_dataset("Transform", data=[self._matrix])
133173
xform.attrs["Type"] = "affine"
134-
x5_root.create_dataset("Inverse", data=[np.linalg.inv(self._matrix)])
174+
x5_root.create_dataset("Inverse", data=[(~self).matrix])
135175

136176
if self._reference:
137177
self.reference._to_hdf5(x5_root.create_group("Reference"))
@@ -175,7 +215,7 @@ def to_filename(self, filename, fmt="X5", moving=None):
175215
lt["dst"] = io.VolumeGeometry.from_image(moving)
176216
# However, the affine needs to be inverted
177217
# (i.e., it is not a pure "points" convention).
178-
lt["m_L"] = np.linalg.inv(self.matrix)
218+
lt["m_L"] = (~self).matrix
179219
# to make LTA file format
180220
lta = io.LinearTransformArray()
181221
lta["type"] = 1 # RAS2RAS
@@ -234,6 +274,11 @@ def __init__(self, transforms, reference=None):
234274
[0., 1., 0., 2.],
235275
[0., 0., 1., 3.],
236276
[0., 0., 0., 1.]])
277+
>>> (~xfm)[0].matrix # doctest: +NORMALIZE_WHITESPACE
278+
array([[ 1., 0., 0., -1.],
279+
[ 0., 1., 0., -2.],
280+
[ 0., 0., 1., -3.],
281+
[ 0., 0., 0., 1.]])
237282
238283
"""
239284
super().__init__(reference=reference)
@@ -245,6 +290,7 @@ def __init__(self, transforms, reference=None):
245290
).matrix
246291
for xfm in transforms
247292
], axis=0)
293+
self._inverse = np.linalg.inv(self._matrix)
248294

249295
def __getitem__(self, i):
250296
"""Enable indexed access to the series of matrices."""
@@ -304,7 +350,7 @@ def map(self, x, inverse=False):
304350
affine = self.matrix
305351
coords = _as_homogeneous(x, dim=affine.shape[-1] - 1).T
306352
if inverse is True:
307-
affine = np.linalg.inv(affine)
353+
affine = self._inverse
308354
return np.swapaxes(affine.dot(coords), 1, 2)
309355

310356
def to_filename(self, filename, fmt="X5", moving=None):

nitransforms/manip.py

+7
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ def map(self, x, inverse=False):
139139

140140
return x
141141

142+
def asaffine(self):
143+
"""Combine a succession of linear transforms into one."""
144+
retval = self.transforms[-1]
145+
for xfm in self.transforms[:-1][::-1]:
146+
retval @= xfm
147+
return retval
148+
142149
@classmethod
143150
def from_filename(cls, filename, fmt="X5",
144151
reference=None, moving=None):

nitransforms/tests/test_base.py

+9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import h5py
66

77
from ..base import SpatialReference, SampledSpatialData, ImageGrid, TransformBase
8+
from .. import linear as nitl
89

910

1011
def test_SpatialReference(testdata_path):
@@ -134,3 +135,11 @@ def test_SampledSpatialData(testdata_path):
134135
with pytest.raises(TypeError):
135136
gii = nb.gifti.GiftiImage()
136137
SampledSpatialData(gii)
138+
139+
140+
def test_concatenation(testdata_path):
141+
"""Check concatenation of affines."""
142+
aff = nitl.Affine(reference=testdata_path / 'someones_anatomy.nii.gz')
143+
x = [(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)]
144+
assert np.all((aff + nitl.Affine())(x) == x)
145+
assert np.all((aff + nitl.Affine())(x, inverse=True) == x)

nitransforms/tests/test_linear.py

+45-32
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import nibabel as nb
1212
from nibabel.eulerangles import euler2mat
1313
from nibabel.affines import from_matvec
14-
from .. import linear as ntl
14+
from .. import linear as nitl
1515
from .utils import assert_affines_by_filename
1616

1717
TESTS_BORDER_TOLERANCE = 0.05
@@ -42,35 +42,35 @@
4242
def test_linear_typeerrors1(matrix):
4343
"""Exercise errors in Affine creation."""
4444
with pytest.raises(TypeError):
45-
ntl.Affine(matrix)
45+
nitl.Affine(matrix)
4646

4747

4848
def test_linear_typeerrors2(data_path):
4949
"""Exercise errors in Affine creation."""
5050
with pytest.raises(TypeError):
51-
ntl.Affine.from_filename(data_path / 'itktflist.tfm', fmt='itk')
51+
nitl.Affine.from_filename(data_path / 'itktflist.tfm', fmt='itk')
5252

5353

5454
def test_linear_valueerror():
5555
"""Exercise errors in Affine creation."""
5656
with pytest.raises(ValueError):
57-
ntl.Affine(np.ones((4, 4)))
57+
nitl.Affine(np.ones((4, 4)))
5858

5959

6060
def test_loadsave_itk(tmp_path, data_path, testdata_path):
6161
"""Test idempotency."""
6262
ref_file = testdata_path / 'someones_anatomy.nii.gz'
63-
xfm = ntl.load(data_path / 'itktflist2.tfm', fmt='itk')
64-
assert isinstance(xfm, ntl.LinearTransformsMapping)
63+
xfm = nitl.load(data_path / 'itktflist2.tfm', fmt='itk')
64+
assert isinstance(xfm, nitl.LinearTransformsMapping)
6565
xfm.reference = ref_file
6666
xfm.to_filename(tmp_path / 'transform-mapping.tfm', fmt='itk')
6767

6868
assert (data_path / 'itktflist2.tfm').read_text() \
6969
== (tmp_path / 'transform-mapping.tfm').read_text()
7070

71-
single_xfm = ntl.load(data_path / 'affine-LAS.itk.tfm', fmt='itk')
72-
assert isinstance(single_xfm, ntl.Affine)
73-
assert single_xfm == ntl.Affine.from_filename(
71+
single_xfm = nitl.load(data_path / 'affine-LAS.itk.tfm', fmt='itk')
72+
assert isinstance(single_xfm, nitl.Affine)
73+
assert single_xfm == nitl.Affine.from_filename(
7474
data_path / 'affine-LAS.itk.tfm', fmt='itk')
7575

7676

@@ -79,23 +79,23 @@ def test_loadsave_itk(tmp_path, data_path, testdata_path):
7979
def test_loadsave(tmp_path, data_path, testdata_path, fmt):
8080
"""Test idempotency."""
8181
ref_file = testdata_path / 'someones_anatomy.nii.gz'
82-
xfm = ntl.load(data_path / 'itktflist2.tfm', fmt='itk')
82+
xfm = nitl.load(data_path / 'itktflist2.tfm', fmt='itk')
8383
xfm.reference = ref_file
8484

8585
fname = tmp_path / '.'.join(('transform-mapping', fmt))
8686
xfm.to_filename(fname, fmt=fmt)
87-
xfm == ntl.load(fname, fmt=fmt, reference=ref_file)
87+
xfm == nitl.load(fname, fmt=fmt, reference=ref_file)
8888
xfm.to_filename(fname, fmt=fmt, moving=ref_file)
89-
xfm == ntl.load(fname, fmt=fmt, reference=ref_file)
89+
xfm == nitl.load(fname, fmt=fmt, reference=ref_file)
9090

9191
ref_file = testdata_path / 'someones_anatomy.nii.gz'
92-
xfm = ntl.load(data_path / 'affine-LAS.itk.tfm', fmt='itk')
92+
xfm = nitl.load(data_path / 'affine-LAS.itk.tfm', fmt='itk')
9393
xfm.reference = ref_file
9494
fname = tmp_path / '.'.join(('single-transform', fmt))
9595
xfm.to_filename(fname, fmt=fmt)
96-
xfm == ntl.load(fname, fmt=fmt, reference=ref_file)
96+
xfm == nitl.load(fname, fmt=fmt, reference=ref_file)
9797
xfm.to_filename(fname, fmt=fmt, moving=ref_file)
98-
xfm == ntl.load(fname, fmt=fmt, reference=ref_file)
98+
xfm == nitl.load(fname, fmt=fmt, reference=ref_file)
9999

100100

101101
@pytest.mark.xfail(reason="Not fully implemented")
@@ -107,7 +107,7 @@ def test_linear_save(tmpdir, data_path, get_testdata, image_orientation, sw_tool
107107
img = get_testdata[image_orientation]
108108
# Generate test transform
109109
T = from_matvec(euler2mat(x=0.9, y=0.001, z=0.001), [4.0, 2.0, -1.0])
110-
xfm = ntl.Affine(T)
110+
xfm = nitl.Affine(T)
111111
xfm.reference = img
112112

113113
ext = ''
@@ -140,7 +140,7 @@ def test_apply_linear_transform(
140140
img = get_testdata[image_orientation]
141141
# Generate test transform
142142
T = from_matvec(euler2mat(x=0.9, y=0.001, z=0.001), [4.0, 2.0, -1.0])
143-
xfm = ntl.Affine(T)
143+
xfm = nitl.Affine(T)
144144
xfm.reference = img
145145

146146
ext = ''
@@ -172,11 +172,16 @@ def test_apply_linear_transform(
172172
# A certain tolerance is necessary because of resampling at borders
173173
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
174174

175+
nt_moved = xfm.apply('img.nii.gz', order=0)
176+
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
177+
# A certain tolerance is necessary because of resampling at borders
178+
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
179+
175180

176181
def test_Affine_to_x5(tmpdir, testdata_path):
177182
"""Test affine's operations."""
178183
tmpdir.chdir()
179-
aff = ntl.Affine()
184+
aff = nitl.Affine()
180185
with h5py.File('xfm.x5', 'w') as f:
181186
aff._to_hdf5(f.create_group('Affine'))
182187

@@ -185,34 +190,42 @@ def test_Affine_to_x5(tmpdir, testdata_path):
185190
aff._to_hdf5(f.create_group('Affine'))
186191

187192

188-
def test_concatenation(testdata_path):
189-
"""Check concatenation of affines."""
190-
aff = ntl.Affine(reference=testdata_path / 'someones_anatomy.nii.gz')
191-
x = [(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)]
192-
assert np.all((aff + ntl.Affine())(x) == x)
193-
assert np.all((aff + ntl.Affine())(x, inverse=True) == x)
194-
195-
196193
def test_LinearTransformsMapping_apply(tmp_path, data_path, testdata_path):
197194
"""Apply transform mappings."""
198-
hmc = ntl.load(data_path / 'hmc-itk.tfm', fmt='itk',
199-
reference=testdata_path / 'sbref.nii.gz')
200-
assert isinstance(hmc, ntl.LinearTransformsMapping)
195+
hmc = nitl.load(data_path / 'hmc-itk.tfm', fmt='itk',
196+
reference=testdata_path / 'sbref.nii.gz')
197+
assert isinstance(hmc, nitl.LinearTransformsMapping)
201198

202-
# Test-case: realing functional data on to sbref
199+
# Test-case: realign functional data on to sbref
203200
nii = hmc.apply(testdata_path / 'func.nii.gz', order=1,
204201
reference=testdata_path / 'sbref.nii.gz')
205202
assert nii.dataobj.shape[-1] == len(hmc)
206203

207204
# Test-case: write out a fieldmap moved with head
208-
hmcinv = ntl.LinearTransformsMapping(
205+
hmcinv = nitl.LinearTransformsMapping(
209206
np.linalg.inv(hmc.matrix),
210207
reference=testdata_path / 'func.nii.gz')
211208
nii = hmcinv.apply(testdata_path / 'fmap.nii.gz', order=1)
212209
assert nii.dataobj.shape[-1] == len(hmc)
213210

214211
# Ensure a ValueError is issued when trying to do weird stuff
215-
hmc = ntl.LinearTransformsMapping(hmc.matrix[:1, ...])
212+
hmc = nitl.LinearTransformsMapping(hmc.matrix[:1, ...])
216213
with pytest.raises(ValueError):
217214
hmc.apply(testdata_path / 'func.nii.gz', order=1,
218215
reference=testdata_path / 'sbref.nii.gz')
216+
217+
218+
def test_mulmat_operator(testdata_path):
219+
"""Check the @ operator."""
220+
ref = testdata_path / 'someones_anatomy.nii.gz'
221+
mat1 = np.diag([2., 2., 2., 1.])
222+
mat2 = from_matvec(np.eye(3), (4, 2, -1))
223+
aff = nitl.Affine(mat1, reference=ref)
224+
225+
composed = aff @ mat2
226+
assert composed.reference is None
227+
assert composed == nitl.Affine(mat1.dot(mat2))
228+
229+
composed = nitl.Affine(mat2) @ aff
230+
assert composed.reference == aff.reference
231+
assert composed == nitl.Affine(mat2.dot(mat1), reference=ref)

nitransforms/tests/test_manip.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88

99
import numpy as np
1010
import nibabel as nb
11-
from ..manip import load as _load
11+
from ..manip import load as _load, TransformChain
12+
from ..linear import Affine
1213
from .test_nonlinear import (
1314
TESTS_BORDER_TOLERANCE,
1415
APPLY_NONLINEAR_CMD,
1516
)
1617

18+
FMT = {"lta": "fs", "tfm": "itk"}
19+
1720

1821
def test_itk_h5(tmp_path, testdata_path):
1922
"""Check a translation-only field on one or more axes, different image orientations."""
@@ -51,3 +54,22 @@ def test_itk_h5(tmp_path, testdata_path):
5154
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
5255
# A certain tolerance is necessary because of resampling at borders
5356
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
57+
58+
59+
@pytest.mark.parametrize("ext0", ["lta", "tfm"])
60+
@pytest.mark.parametrize("ext1", ["lta", "tfm"])
61+
@pytest.mark.parametrize("ext2", ["lta", "tfm"])
62+
def test_collapse_affines(tmp_path, data_path, ext0, ext1, ext2):
63+
"""Check whether affines are correctly collapsed."""
64+
chain = TransformChain([
65+
Affine.from_filename(data_path / "regressions"
66+
/ f"from-fsnative_to-scanner_mode-image.{ext0}", fmt=f"{FMT[ext0]}"),
67+
Affine.from_filename(data_path / "regressions"
68+
/ f"from-scanner_to-bold_mode-image.{ext1}", fmt=f"{FMT[ext1]}"),
69+
])
70+
assert np.allclose(
71+
chain.asaffine().matrix,
72+
Affine.from_filename(
73+
data_path / "regressions" / f"from-fsnative_to-bold_mode-image.{ext2}",
74+
fmt=f"{FMT[ext2]}").matrix,
75+
)

0 commit comments

Comments
 (0)