From cb7d41ae663a31f19a691914d641d76321f1e2ab Mon Sep 17 00:00:00 2001 From: oesteban Date: Thu, 31 Oct 2019 11:43:58 -0700 Subject: [PATCH 1/3] enh(TransformChain): base implementation of transforms chains (composition) Closes #20 --- nitransforms/base.py | 157 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 155 insertions(+), 2 deletions(-) diff --git a/nitransforms/base.py b/nitransforms/base.py index 42a63641..5339fc35 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -8,18 +8,22 @@ ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Common interface for transforms.""" from pathlib import Path +from collections.abc import Iterable import numpy as np import h5py import warnings from nibabel.loadsave import load from nibabel.nifti1 import intent_codes as INTENT_CODES from nibabel.cifti2 import Cifti2Image - from scipy import ndimage as ndi EQUALITY_TOL = 1e-5 +class TransformError(ValueError): + """A custom exception for transforms.""" + + class SpatialReference: """Factory to create spatial references.""" @@ -172,6 +176,23 @@ def __call__(self, x, inverse=False, index=0): """Apply y = f(x).""" return self.map(x, inverse=inverse, index=index) + def __add__(self, b): + """ + Compose this and other transforms. + + Example + ------- + >>> T1 = TransformBase() + >>> added = T1 + TransformBase() + >>> isinstance(added, TransformChain) + True + + >>> len(added.transforms) + 2 + + """ + return TransformChain(transforms=[self, b]) + @property def reference(self): """Access a reference space where data will be resampled onto.""" @@ -262,6 +283,8 @@ def map(self, x, inverse=False, index=0): r""" Apply :math:`y = f(x)`. + TransformBase implements the identity transform. + Parameters ---------- x : N x D numpy.ndarray @@ -277,7 +300,7 @@ def map(self, x, inverse=False, index=0): Transformed (mapped) RAS+ coordinates (i.e., physical coordinates). """ - raise NotImplementedError + return x def to_filename(self, filename, fmt='X5'): """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" @@ -294,6 +317,127 @@ def _to_hdf5(self, x5_root): raise NotImplementedError +class TransformChain(TransformBase): + """Implements the concatenation of transforms.""" + + __slots__ = ['_transforms'] + + def __init__(self, transforms=None): + """Initialize a chain of transforms.""" + self._transforms = None + if transforms is not None: + self.transforms = transforms + + def __add__(self, b): + """ + Compose this and other transforms. + + Example + ------- + >>> T1 = TransformBase() + >>> added = T1 + TransformBase() + TransformBase() + >>> isinstance(added, TransformChain) + True + + >>> len(added.transforms) + 3 + + """ + self.append(b) + return self + + def __getitem__(self, i): + """ + Enable indexed access of transform chains. + + Example + ------- + >>> T1 = TransformBase() + >>> chain = T1 + TransformBase() + >>> chain[0] == T1 + True + + """ + return self.transforms[i] + + def __len__(self): + """Enable using len().""" + return len(self.transforms) + + @property + def transforms(self): + """Get the internal list of transforms.""" + return self._transforms + + @transforms.setter + def transforms(self, value): + self._transforms = _as_chain(value) + if self.transforms[0].reference: + self.reference = self.transforms[0].reference + + def append(self, x): + """ + Concatenate one element to the chain. + + Example + ------- + >>> chain = TransformChain(transforms=TransformBase()) + >>> chain.append((TransformBase(), TransformBase())) + >>> len(chain) + 3 + + """ + self.transforms += _as_chain(x) + + def insert(self, i, x): + """ + Insert an item at a given position. + + Example + ------- + >>> chain = TransformChain(transforms=[TransformBase(), TransformBase()]) + >>> chain.insert(1, TransformBase()) + >>> len(chain) + 3 + + >>> chain.insert(1, TransformChain(chain)) + >>> len(chain) + 6 + + """ + self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:] + + def map(self, x, inverse=False, index=0): + """ + Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`. + + Example + ------- + >>> chain = TransformChain(transforms=[TransformBase(), TransformBase()]) + >>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)]) + [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)] + + >>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True) + [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)] + + >>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + TransformError: + + """ + if not self.transforms: + raise TransformError('Cannot apply an empty transforms chain.') + + transforms = self.transforms + if inverse: + transforms = reversed(self.transforms) + + for xfm in transforms: + x = xfm(x, inverse=inverse) + + return x + + def _as_homogeneous(xyz, dtype='float32', dim=3): """ Convert 2D and 3D coordinates into homogeneous coordinates. @@ -324,3 +468,12 @@ def _as_homogeneous(xyz, dtype='float32', dim=3): def _apply_affine(x, affine, dim): """Get the image array's indexes corresponding to coordinates.""" return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T + + +def _as_chain(x): + """Convert a value into a transform chain.""" + if isinstance(x, TransformChain): + return x.transforms + if isinstance(x, Iterable): + return list(x) + return [x] From 048c735d71885bf90844b1aed0daf8c44e36531b Mon Sep 17 00:00:00 2001 From: oesteban Date: Thu, 31 Oct 2019 12:15:08 -0700 Subject: [PATCH 2/3] enh: add one test with concatenation of affines --- nitransforms/tests/test_linear.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 50f89812..658b46e3 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -114,3 +114,11 @@ def test_Affine_to_x5(tmpdir, data_path): aff.reference = data_path / 'someones_anatomy.nii.gz' with h5py.File('withref-xfm.x5', 'w') as f: aff._to_hdf5(f.create_group('Affine')) + + +def test_concatenation(data_path): + """Check concatenation of affines.""" + aff = ntl.Affine(reference=data_path / 'someones_anatomy.nii.gz') + x = [(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)] + assert np.all((aff + ntl.Affine())(x) == x) + assert np.all((aff + ntl.Affine())(x, inverse=True) == x) From 986e9ca714486298913256fac48a476bda98d2ca Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 31 Oct 2019 12:47:40 -0700 Subject: [PATCH 3/3] Update nitransforms/base.py Co-Authored-By: Mathias Goncalves --- nitransforms/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nitransforms/base.py b/nitransforms/base.py index 5339fc35..466019d4 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -354,7 +354,7 @@ def __getitem__(self, i): ------- >>> T1 = TransformBase() >>> chain = T1 + TransformBase() - >>> chain[0] == T1 + >>> chain[0] is T1 True """