Skip to content

Commit 26f9e80

Browse files
committed
feat: split transform chains out from base and add a load function
This was necessary to integrate one test equivalent to resampling with ``antsApplyTransforms``, but via nitransforms.
1 parent 3869064 commit 26f9e80

File tree

5 files changed

+234
-141
lines changed

5 files changed

+234
-141
lines changed

nitransforms/base.py

+3-136
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Common interface for transforms."""
1010
from pathlib import Path
11-
from collections.abc import Iterable
1211
import numpy as np
1312
import h5py
1413
import warnings
@@ -168,10 +167,10 @@ def __ne__(self, other):
168167
return not self == other
169168

170169

171-
class TransformBase(object):
170+
class TransformBase:
172171
"""Abstract image class to represent transforms."""
173172

174-
__slots__ = ['_reference']
173+
__slots__ = ('_reference', )
175174

176175
def __init__(self, reference=None):
177176
"""Instantiate a transform."""
@@ -191,13 +190,11 @@ def __add__(self, b):
191190
-------
192191
>>> T1 = TransformBase()
193192
>>> added = T1 + TransformBase()
194-
>>> isinstance(added, TransformChain)
195-
True
196-
197193
>>> len(added.transforms)
198194
2
199195
200196
"""
197+
from .manip import TransformChain
201198
return TransformChain(transforms=[self, b])
202199

203200
@property
@@ -322,127 +319,6 @@ def _to_hdf5(self, x5_root):
322319
raise NotImplementedError
323320

324321

325-
class TransformChain(TransformBase):
326-
"""Implements the concatenation of transforms."""
327-
328-
__slots__ = ['_transforms']
329-
330-
def __init__(self, transforms=None):
331-
"""Initialize a chain of transforms."""
332-
self._transforms = None
333-
if transforms is not None:
334-
self.transforms = transforms
335-
336-
def __add__(self, b):
337-
"""
338-
Compose this and other transforms.
339-
340-
Example
341-
-------
342-
>>> T1 = TransformBase()
343-
>>> added = T1 + TransformBase() + TransformBase()
344-
>>> isinstance(added, TransformChain)
345-
True
346-
347-
>>> len(added.transforms)
348-
3
349-
350-
"""
351-
self.append(b)
352-
return self
353-
354-
def __getitem__(self, i):
355-
"""
356-
Enable indexed access of transform chains.
357-
358-
Example
359-
-------
360-
>>> T1 = TransformBase()
361-
>>> chain = T1 + TransformBase()
362-
>>> chain[0] is T1
363-
True
364-
365-
"""
366-
return self.transforms[i]
367-
368-
def __len__(self):
369-
"""Enable using len()."""
370-
return len(self.transforms)
371-
372-
@property
373-
def transforms(self):
374-
"""Get the internal list of transforms."""
375-
return self._transforms
376-
377-
@transforms.setter
378-
def transforms(self, value):
379-
self._transforms = _as_chain(value)
380-
if self.transforms[0].reference:
381-
self.reference = self.transforms[0].reference
382-
383-
def append(self, x):
384-
"""
385-
Concatenate one element to the chain.
386-
387-
Example
388-
-------
389-
>>> chain = TransformChain(transforms=TransformBase())
390-
>>> chain.append((TransformBase(), TransformBase()))
391-
>>> len(chain)
392-
3
393-
394-
"""
395-
self.transforms += _as_chain(x)
396-
397-
def insert(self, i, x):
398-
"""
399-
Insert an item at a given position.
400-
401-
Example
402-
-------
403-
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
404-
>>> chain.insert(1, TransformBase())
405-
>>> len(chain)
406-
3
407-
408-
>>> chain.insert(1, TransformChain(chain))
409-
>>> len(chain)
410-
6
411-
412-
"""
413-
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]
414-
415-
def map(self, x, inverse=False):
416-
"""
417-
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
418-
419-
Example
420-
-------
421-
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
422-
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)])
423-
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
424-
425-
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True)
426-
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
427-
428-
>>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL
429-
Traceback (most recent call last):
430-
TransformError:
431-
432-
"""
433-
if not self.transforms:
434-
raise TransformError('Cannot apply an empty transforms chain.')
435-
436-
transforms = self.transforms
437-
if inverse:
438-
transforms = reversed(self.transforms)
439-
440-
for xfm in transforms:
441-
x = xfm(x, inverse=inverse)
442-
443-
return x
444-
445-
446322
def _as_homogeneous(xyz, dtype='float32', dim=3):
447323
"""
448324
Convert 2D and 3D coordinates into homogeneous coordinates.
@@ -473,12 +349,3 @@ def _as_homogeneous(xyz, dtype='float32', dim=3):
473349
def _apply_affine(x, affine, dim):
474350
"""Get the image array's indexes corresponding to coordinates."""
475351
return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T
476-
477-
478-
def _as_chain(x):
479-
"""Convert a value into a transform chain."""
480-
if isinstance(x, TransformChain):
481-
return x.transforms
482-
if isinstance(x, Iterable):
483-
return list(x)
484-
return [x]

nitransforms/io/itk.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def from_h5obj(cls, fileobj, check=True):
309309
except KeyError:
310310
typo_fallback = "Tranform"
311311

312-
for xfm in reversed(list(h5group.values())[1:]):
312+
for xfm in list(h5group.values())[1:]:
313313
if xfm["TransformType"][0].startswith(b"AffineTransform"):
314314
_params = np.asanyarray(xfm[f"{typo_fallback}Parameters"])
315315
xfm_list.append(

nitransforms/manip.py

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
4+
#
5+
# See COPYING file distributed along with the NiBabel package for the
6+
# copyright and license terms.
7+
#
8+
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9+
"""Common interface for transforms."""
10+
from collections.abc import Iterable
11+
12+
from .base import (
13+
TransformBase,
14+
TransformError,
15+
)
16+
from .linear import Affine
17+
from .nonlinear import DisplacementsFieldTransform
18+
19+
20+
class TransformChain(TransformBase):
21+
"""Implements the concatenation of transforms."""
22+
23+
__slots__ = ('_transforms', )
24+
25+
def __init__(self, transforms=None):
26+
"""Initialize a chain of transforms."""
27+
super().__init__()
28+
self._transforms = None
29+
30+
if transforms is not None:
31+
self.transforms = transforms
32+
33+
def __add__(self, b):
34+
"""
35+
Compose this and other transforms.
36+
37+
Example
38+
-------
39+
>>> T1 = TransformBase()
40+
>>> added = T1 + TransformBase() + TransformBase()
41+
>>> isinstance(added, TransformChain)
42+
True
43+
44+
>>> len(added.transforms)
45+
3
46+
47+
"""
48+
self.append(b)
49+
return self
50+
51+
def __getitem__(self, i):
52+
"""
53+
Enable indexed access of transform chains.
54+
55+
Example
56+
-------
57+
>>> T1 = TransformBase()
58+
>>> chain = T1 + TransformBase()
59+
>>> chain[0] is T1
60+
True
61+
62+
"""
63+
return self.transforms[i]
64+
65+
def __len__(self):
66+
"""Enable using len()."""
67+
return len(self.transforms)
68+
69+
@property
70+
def transforms(self):
71+
"""Get the internal list of transforms."""
72+
return self._transforms
73+
74+
@transforms.setter
75+
def transforms(self, value):
76+
self._transforms = _as_chain(value)
77+
if self.transforms[-1].reference:
78+
self.reference = self.transforms[-1].reference
79+
80+
def append(self, x):
81+
"""
82+
Concatenate one element to the chain.
83+
84+
Example
85+
-------
86+
>>> chain = TransformChain(transforms=TransformBase())
87+
>>> chain.append((TransformBase(), TransformBase()))
88+
>>> len(chain)
89+
3
90+
91+
"""
92+
self.transforms += _as_chain(x)
93+
94+
def insert(self, i, x):
95+
"""
96+
Insert an item at a given position.
97+
98+
Example
99+
-------
100+
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
101+
>>> chain.insert(1, TransformBase())
102+
>>> len(chain)
103+
3
104+
105+
>>> chain.insert(1, TransformChain(chain))
106+
>>> len(chain)
107+
6
108+
109+
"""
110+
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]
111+
112+
def map(self, x, inverse=False):
113+
"""
114+
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
115+
116+
Example
117+
-------
118+
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
119+
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)])
120+
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
121+
122+
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True)
123+
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
124+
125+
>>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL
126+
Traceback (most recent call last):
127+
TransformError:
128+
129+
"""
130+
if not self.transforms:
131+
raise TransformError('Cannot apply an empty transforms chain.')
132+
133+
transforms = self.transforms
134+
if not inverse:
135+
transforms = self.transforms[::-1]
136+
137+
for xfm in transforms:
138+
x = xfm(x, inverse=inverse)
139+
140+
return x
141+
142+
@classmethod
143+
def from_filename(cls, filename, fmt="X5",
144+
reference=None, moving=None):
145+
"""Load a transform file."""
146+
from .io import itk
147+
148+
retval = []
149+
if str(filename).endswith(".h5"):
150+
reference = None
151+
xforms = itk.ITKCompositeH5.from_filename(filename)
152+
for xfmobj in xforms:
153+
if isinstance(xfmobj, itk.ITKLinearTransform):
154+
retval.append(Affine(xfmobj.to_ras(), reference=reference))
155+
else:
156+
retval.append(DisplacementsFieldTransform(xfmobj))
157+
158+
return TransformChain(retval)
159+
160+
raise NotImplementedError
161+
162+
163+
def _as_chain(x):
164+
"""Convert a value into a transform chain."""
165+
if isinstance(x, TransformChain):
166+
return x.transforms
167+
if isinstance(x, Iterable):
168+
return list(x)
169+
return [x]
170+
171+
172+
load = TransformChain.from_filename

nitransforms/tests/test_io.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,14 @@ def test_afni_Displacements():
341341

342342
def test_itk_h5(data_path):
343343
"""Test displacements fields."""
344-
itk.ITKCompositeH5.from_filename(
344+
assert len(list(itk.ITKCompositeH5.from_filename(
345345
data_path / 'ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5'
346-
)
346+
))) == 2
347347

348348
with pytest.raises(RuntimeError):
349-
itk.ITKCompositeH5.from_filename(
349+
list(itk.ITKCompositeH5.from_filename(
350350
data_path / 'ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.x5'
351-
)
351+
))
352352

353353

354354
@pytest.mark.parametrize('file_type, test_file', [

0 commit comments

Comments
 (0)