diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 00acfafb..84c9126b 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -13,6 +13,7 @@ from scipy import ndimage as ndi from nibabel.loadsave import load as _nbload +from nibabel.affines import from_matvec from nitransforms.base import ( ImageGrid, @@ -218,6 +219,24 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None): f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})." ) + @classmethod + def from_matvec(cls, mat=None, vec=None, reference=None): + """ + Create an affine from a matrix and translation pair. + + Example + ------- + >>> Affine.from_matvec(vec=(4, 0, 0)) # doctest: +NORMALIZE_WHITESPACE + array([[1., 0., 0., 4.], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]]) + + """ + mat = mat if mat is not None else np.eye(3) + vec = vec if vec is not None else np.zeros((3,)) + return cls(from_matvec(mat, vector=vec), reference=reference) + def __repr__(self): """ Change representation to the internal matrix. diff --git a/nitransforms/manip.py b/nitransforms/manip.py index d4e7e651..1372ef31 100644 --- a/nitransforms/manip.py +++ b/nitransforms/manip.py @@ -8,6 +8,7 @@ ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Common interface for transforms.""" from collections.abc import Iterable +import numpy as np from .base import ( TransformBase, @@ -74,8 +75,8 @@ def transforms(self): @transforms.setter def transforms(self, value): self._transforms = _as_chain(value) - if self.transforms[-1].reference: - self.reference = self.transforms[-1].reference + if self.transforms[0].reference: + self.reference = self.transforms[0].reference def append(self, x): """ @@ -131,19 +132,56 @@ def map(self, x, inverse=False): raise TransformError("Cannot apply an empty transforms chain.") transforms = self.transforms - if not inverse: - transforms = self.transforms[::-1] + if inverse: + transforms = list(reversed(self.transforms)) for xfm in transforms: - x = xfm(x, inverse=inverse) + x = xfm.map(x, inverse=inverse) return x - def asaffine(self): - """Combine a succession of linear transforms into one.""" - retval = self.transforms[-1] - for xfm in self.transforms[:-1][::-1]: - retval @= xfm + def asaffine(self, indices=None): + """ + Combine a succession of linear transforms into one. + + Example + ------ + >>> chain = TransformChain(transforms=[ + ... Affine.from_matvec(vec=(2, -10, 3)), + ... Affine.from_matvec(vec=(-2, 10, -3)), + ... ]) + >>> chain.asaffine() + array([[1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]]) + + >>> chain = TransformChain(transforms=[ + ... Affine.from_matvec(vec=(1, 2, 3)), + ... Affine.from_matvec(mat=[[0, 1, 0], [0, 0, 1], [1, 0, 0]]), + ... ]) + >>> chain.asaffine() + array([[0., 1., 0., 2.], + [0., 0., 1., 3.], + [1., 0., 0., 1.], + [0., 0., 0., 1.]]) + + >>> np.allclose( + ... chain.map((4, -2, 1)), + ... chain.asaffine().map((4, -2, 1)), + ... ) + True + + Parameters + ---------- + indices : :obj:`numpy.array_like` + The indices of the values to extract. + + """ + affines = self.transforms if indices is None else np.take(self.transforms, indices) + retval = affines[0] + for xfm in affines[1:]: + retval = xfm @ retval return retval @classmethod @@ -157,9 +195,9 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None): xforms = itk.ITKCompositeH5.from_filename(filename) for xfmobj in xforms: if isinstance(xfmobj, itk.ITKLinearTransform): - retval.append(Affine(xfmobj.to_ras(), reference=reference)) + retval.insert(0, Affine(xfmobj.to_ras(), reference=reference)) else: - retval.append(DisplacementsFieldTransform(xfmobj)) + retval.insert(0, DisplacementsFieldTransform(xfmobj)) return TransformChain(retval)