Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Base implementation of transforms chains (composition) #43

Merged
merged 3 commits into from
Oct 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 155 additions & 2 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Common interface for transforms."""
from pathlib import Path
from collections.abc import Iterable
import numpy as np
import h5py
import warnings
from nibabel.loadsave import load
from nibabel.nifti1 import intent_codes as INTENT_CODES
from nibabel.cifti2 import Cifti2Image

from scipy import ndimage as ndi

EQUALITY_TOL = 1e-5


class TransformError(ValueError):
"""A custom exception for transforms."""


class SpatialReference:
"""Factory to create spatial references."""

Expand Down Expand Up @@ -172,6 +176,23 @@ def __call__(self, x, inverse=False, index=0):
"""Apply y = f(x)."""
return self.map(x, inverse=inverse, index=index)

def __add__(self, b):
"""
Compose this and other transforms.

Example
-------
>>> T1 = TransformBase()
>>> added = T1 + TransformBase()
>>> isinstance(added, TransformChain)
True

>>> len(added.transforms)
2

"""
return TransformChain(transforms=[self, b])

@property
def reference(self):
"""Access a reference space where data will be resampled onto."""
Expand Down Expand Up @@ -262,6 +283,8 @@ def map(self, x, inverse=False, index=0):
r"""
Apply :math:`y = f(x)`.

TransformBase implements the identity transform.

Parameters
----------
x : N x D numpy.ndarray
Expand All @@ -277,7 +300,7 @@ def map(self, x, inverse=False, index=0):
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).

"""
raise NotImplementedError
return x

def to_filename(self, filename, fmt='X5'):
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
Expand All @@ -294,6 +317,127 @@ def _to_hdf5(self, x5_root):
raise NotImplementedError


class TransformChain(TransformBase):
"""Implements the concatenation of transforms."""

__slots__ = ['_transforms']

def __init__(self, transforms=None):
"""Initialize a chain of transforms."""
self._transforms = None
if transforms is not None:
self.transforms = transforms

def __add__(self, b):
"""
Compose this and other transforms.

Example
-------
>>> T1 = TransformBase()
>>> added = T1 + TransformBase() + TransformBase()
>>> isinstance(added, TransformChain)
True

>>> len(added.transforms)
3

"""
self.append(b)
return self

def __getitem__(self, i):
"""
Enable indexed access of transform chains.

Example
-------
>>> T1 = TransformBase()
>>> chain = T1 + TransformBase()
>>> chain[0] is T1
True

"""
return self.transforms[i]

def __len__(self):
"""Enable using len()."""
return len(self.transforms)

@property
def transforms(self):
"""Get the internal list of transforms."""
return self._transforms

@transforms.setter
def transforms(self, value):
self._transforms = _as_chain(value)
if self.transforms[0].reference:
self.reference = self.transforms[0].reference

def append(self, x):
"""
Concatenate one element to the chain.

Example
-------
>>> chain = TransformChain(transforms=TransformBase())
>>> chain.append((TransformBase(), TransformBase()))
>>> len(chain)
3

"""
self.transforms += _as_chain(x)

def insert(self, i, x):
"""
Insert an item at a given position.

Example
-------
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
>>> chain.insert(1, TransformBase())
>>> len(chain)
3

>>> chain.insert(1, TransformChain(chain))
>>> len(chain)
6

"""
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]

def map(self, x, inverse=False, index=0):
"""
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.

Example
-------
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)])
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]

>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True)
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]

>>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
TransformError:

"""
if not self.transforms:
raise TransformError('Cannot apply an empty transforms chain.')

transforms = self.transforms
if inverse:
transforms = reversed(self.transforms)

for xfm in transforms:
x = xfm(x, inverse=inverse)

return x


def _as_homogeneous(xyz, dtype='float32', dim=3):
"""
Convert 2D and 3D coordinates into homogeneous coordinates.
Expand Down Expand Up @@ -324,3 +468,12 @@ def _as_homogeneous(xyz, dtype='float32', dim=3):
def _apply_affine(x, affine, dim):
"""Get the image array's indexes corresponding to coordinates."""
return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T


def _as_chain(x):
"""Convert a value into a transform chain."""
if isinstance(x, TransformChain):
return x.transforms
if isinstance(x, Iterable):
return list(x)
return [x]
8 changes: 8 additions & 0 deletions nitransforms/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,11 @@ def test_Affine_to_x5(tmpdir, data_path):
aff.reference = data_path / 'someones_anatomy.nii.gz'
with h5py.File('withref-xfm.x5', 'w') as f:
aff._to_hdf5(f.create_group('Affine'))


def test_concatenation(data_path):
"""Check concatenation of affines."""
aff = ntl.Affine(reference=data_path / 'someones_anatomy.nii.gz')
x = [(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)]
assert np.all((aff + ntl.Affine())(x) == x)
assert np.all((aff + ntl.Affine())(x, inverse=True) == x)