Skip to content

Commit cb7d41a

Browse files
committed
enh(TransformChain): base implementation of transforms chains (composition)
Closes #20
1 parent 77a3a3b commit cb7d41a

File tree

1 file changed

+155
-2
lines changed

1 file changed

+155
-2
lines changed

nitransforms/base.py

+155-2
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Common interface for transforms."""
1010
from pathlib import Path
11+
from collections.abc import Iterable
1112
import numpy as np
1213
import h5py
1314
import warnings
1415
from nibabel.loadsave import load
1516
from nibabel.nifti1 import intent_codes as INTENT_CODES
1617
from nibabel.cifti2 import Cifti2Image
17-
1818
from scipy import ndimage as ndi
1919

2020
EQUALITY_TOL = 1e-5
2121

2222

23+
class TransformError(ValueError):
24+
"""A custom exception for transforms."""
25+
26+
2327
class SpatialReference:
2428
"""Factory to create spatial references."""
2529

@@ -172,6 +176,23 @@ def __call__(self, x, inverse=False, index=0):
172176
"""Apply y = f(x)."""
173177
return self.map(x, inverse=inverse, index=index)
174178

179+
def __add__(self, b):
180+
"""
181+
Compose this and other transforms.
182+
183+
Example
184+
-------
185+
>>> T1 = TransformBase()
186+
>>> added = T1 + TransformBase()
187+
>>> isinstance(added, TransformChain)
188+
True
189+
190+
>>> len(added.transforms)
191+
2
192+
193+
"""
194+
return TransformChain(transforms=[self, b])
195+
175196
@property
176197
def reference(self):
177198
"""Access a reference space where data will be resampled onto."""
@@ -262,6 +283,8 @@ def map(self, x, inverse=False, index=0):
262283
r"""
263284
Apply :math:`y = f(x)`.
264285
286+
TransformBase implements the identity transform.
287+
265288
Parameters
266289
----------
267290
x : N x D numpy.ndarray
@@ -277,7 +300,7 @@ def map(self, x, inverse=False, index=0):
277300
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
278301
279302
"""
280-
raise NotImplementedError
303+
return x
281304

282305
def to_filename(self, filename, fmt='X5'):
283306
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
@@ -294,6 +317,127 @@ def _to_hdf5(self, x5_root):
294317
raise NotImplementedError
295318

296319

320+
class TransformChain(TransformBase):
321+
"""Implements the concatenation of transforms."""
322+
323+
__slots__ = ['_transforms']
324+
325+
def __init__(self, transforms=None):
326+
"""Initialize a chain of transforms."""
327+
self._transforms = None
328+
if transforms is not None:
329+
self.transforms = transforms
330+
331+
def __add__(self, b):
332+
"""
333+
Compose this and other transforms.
334+
335+
Example
336+
-------
337+
>>> T1 = TransformBase()
338+
>>> added = T1 + TransformBase() + TransformBase()
339+
>>> isinstance(added, TransformChain)
340+
True
341+
342+
>>> len(added.transforms)
343+
3
344+
345+
"""
346+
self.append(b)
347+
return self
348+
349+
def __getitem__(self, i):
350+
"""
351+
Enable indexed access of transform chains.
352+
353+
Example
354+
-------
355+
>>> T1 = TransformBase()
356+
>>> chain = T1 + TransformBase()
357+
>>> chain[0] == T1
358+
True
359+
360+
"""
361+
return self.transforms[i]
362+
363+
def __len__(self):
364+
"""Enable using len()."""
365+
return len(self.transforms)
366+
367+
@property
368+
def transforms(self):
369+
"""Get the internal list of transforms."""
370+
return self._transforms
371+
372+
@transforms.setter
373+
def transforms(self, value):
374+
self._transforms = _as_chain(value)
375+
if self.transforms[0].reference:
376+
self.reference = self.transforms[0].reference
377+
378+
def append(self, x):
379+
"""
380+
Concatenate one element to the chain.
381+
382+
Example
383+
-------
384+
>>> chain = TransformChain(transforms=TransformBase())
385+
>>> chain.append((TransformBase(), TransformBase()))
386+
>>> len(chain)
387+
3
388+
389+
"""
390+
self.transforms += _as_chain(x)
391+
392+
def insert(self, i, x):
393+
"""
394+
Insert an item at a given position.
395+
396+
Example
397+
-------
398+
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
399+
>>> chain.insert(1, TransformBase())
400+
>>> len(chain)
401+
3
402+
403+
>>> chain.insert(1, TransformChain(chain))
404+
>>> len(chain)
405+
6
406+
407+
"""
408+
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]
409+
410+
def map(self, x, inverse=False, index=0):
411+
"""
412+
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
413+
414+
Example
415+
-------
416+
>>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
417+
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)])
418+
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
419+
420+
>>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True)
421+
[(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
422+
423+
>>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL
424+
Traceback (most recent call last):
425+
TransformError:
426+
427+
"""
428+
if not self.transforms:
429+
raise TransformError('Cannot apply an empty transforms chain.')
430+
431+
transforms = self.transforms
432+
if inverse:
433+
transforms = reversed(self.transforms)
434+
435+
for xfm in transforms:
436+
x = xfm(x, inverse=inverse)
437+
438+
return x
439+
440+
297441
def _as_homogeneous(xyz, dtype='float32', dim=3):
298442
"""
299443
Convert 2D and 3D coordinates into homogeneous coordinates.
@@ -324,3 +468,12 @@ def _as_homogeneous(xyz, dtype='float32', dim=3):
324468
def _apply_affine(x, affine, dim):
325469
"""Get the image array's indexes corresponding to coordinates."""
326470
return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T
471+
472+
473+
def _as_chain(x):
474+
"""Convert a value into a transform chain."""
475+
if isinstance(x, TransformChain):
476+
return x.transforms
477+
if isinstance(x, Iterable):
478+
return list(x)
479+
return [x]

0 commit comments

Comments
 (0)