Skip to content

Commit b610a9a

Browse files
authored
Merge pull request #165 from oesteban/fix/81-ordering-of-xforms
ENH: API change in ``TransformChain`` - new composition convention
2 parents e5a6b41 + e80bd3c commit b610a9a

File tree

2 files changed

+69
-12
lines changed

2 files changed

+69
-12
lines changed

nitransforms/linear.py

+19
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from scipy import ndimage as ndi
1414

1515
from nibabel.loadsave import load as _nbload
16+
from nibabel.affines import from_matvec
1617

1718
from nitransforms.base import (
1819
ImageGrid,
@@ -218,6 +219,24 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
218219
f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})."
219220
)
220221

222+
@classmethod
223+
def from_matvec(cls, mat=None, vec=None, reference=None):
224+
"""
225+
Create an affine from a matrix and translation pair.
226+
227+
Example
228+
-------
229+
>>> Affine.from_matvec(vec=(4, 0, 0)) # doctest: +NORMALIZE_WHITESPACE
230+
array([[1., 0., 0., 4.],
231+
[0., 1., 0., 0.],
232+
[0., 0., 1., 0.],
233+
[0., 0., 0., 1.]])
234+
235+
"""
236+
mat = mat if mat is not None else np.eye(3)
237+
vec = vec if vec is not None else np.zeros((3,))
238+
return cls(from_matvec(mat, vector=vec), reference=reference)
239+
221240
def __repr__(self):
222241
"""
223242
Change representation to the internal matrix.

nitransforms/manip.py

+50-12
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Common interface for transforms."""
1010
from collections.abc import Iterable
11+
import numpy as np
1112

1213
from .base import (
1314
TransformBase,
@@ -74,8 +75,8 @@ def transforms(self):
7475
@transforms.setter
7576
def transforms(self, value):
7677
self._transforms = _as_chain(value)
77-
if self.transforms[-1].reference:
78-
self.reference = self.transforms[-1].reference
78+
if self.transforms[0].reference:
79+
self.reference = self.transforms[0].reference
7980

8081
def append(self, x):
8182
"""
@@ -131,19 +132,56 @@ def map(self, x, inverse=False):
131132
raise TransformError("Cannot apply an empty transforms chain.")
132133

133134
transforms = self.transforms
134-
if not inverse:
135-
transforms = self.transforms[::-1]
135+
if inverse:
136+
transforms = list(reversed(self.transforms))
136137

137138
for xfm in transforms:
138-
x = xfm(x, inverse=inverse)
139+
x = xfm.map(x, inverse=inverse)
139140

140141
return x
141142

142-
def asaffine(self):
143-
"""Combine a succession of linear transforms into one."""
144-
retval = self.transforms[-1]
145-
for xfm in self.transforms[:-1][::-1]:
146-
retval @= xfm
143+
def asaffine(self, indices=None):
144+
"""
145+
Combine a succession of linear transforms into one.
146+
147+
Example
148+
------
149+
>>> chain = TransformChain(transforms=[
150+
... Affine.from_matvec(vec=(2, -10, 3)),
151+
... Affine.from_matvec(vec=(-2, 10, -3)),
152+
... ])
153+
>>> chain.asaffine()
154+
array([[1., 0., 0., 0.],
155+
[0., 1., 0., 0.],
156+
[0., 0., 1., 0.],
157+
[0., 0., 0., 1.]])
158+
159+
>>> chain = TransformChain(transforms=[
160+
... Affine.from_matvec(vec=(1, 2, 3)),
161+
... Affine.from_matvec(mat=[[0, 1, 0], [0, 0, 1], [1, 0, 0]]),
162+
... ])
163+
>>> chain.asaffine()
164+
array([[0., 1., 0., 2.],
165+
[0., 0., 1., 3.],
166+
[1., 0., 0., 1.],
167+
[0., 0., 0., 1.]])
168+
169+
>>> np.allclose(
170+
... chain.map((4, -2, 1)),
171+
... chain.asaffine().map((4, -2, 1)),
172+
... )
173+
True
174+
175+
Parameters
176+
----------
177+
indices : :obj:`numpy.array_like`
178+
The indices of the values to extract.
179+
180+
"""
181+
affines = self.transforms if indices is None else np.take(self.transforms, indices)
182+
retval = affines[0]
183+
for xfm in affines[1:]:
184+
retval = xfm @ retval
147185
return retval
148186

149187
@classmethod
@@ -157,9 +195,9 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
157195
xforms = itk.ITKCompositeH5.from_filename(filename)
158196
for xfmobj in xforms:
159197
if isinstance(xfmobj, itk.ITKLinearTransform):
160-
retval.append(Affine(xfmobj.to_ras(), reference=reference))
198+
retval.insert(0, Affine(xfmobj.to_ras(), reference=reference))
161199
else:
162-
retval.append(DisplacementsFieldTransform(xfmobj))
200+
retval.insert(0, DisplacementsFieldTransform(xfmobj))
163201

164202
return TransformChain(retval)
165203

0 commit comments

Comments
 (0)