8
8
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9
9
"""Common interface for transforms."""
10
10
from pathlib import Path
11
+ from collections .abc import Iterable
11
12
import numpy as np
12
13
import h5py
13
14
import warnings
14
15
from nibabel .loadsave import load
15
16
from nibabel .nifti1 import intent_codes as INTENT_CODES
16
17
from nibabel .cifti2 import Cifti2Image
17
-
18
18
from scipy import ndimage as ndi
19
19
20
20
EQUALITY_TOL = 1e-5
21
21
22
22
23
+ class TransformError (ValueError ):
24
+ """A custom exception for transforms."""
25
+
26
+
23
27
class SpatialReference :
24
28
"""Factory to create spatial references."""
25
29
@@ -172,6 +176,23 @@ def __call__(self, x, inverse=False, index=0):
172
176
"""Apply y = f(x)."""
173
177
return self .map (x , inverse = inverse , index = index )
174
178
179
+ def __add__ (self , b ):
180
+ """
181
+ Compose this and other transforms.
182
+
183
+ Example
184
+ -------
185
+ >>> T1 = TransformBase()
186
+ >>> added = T1 + TransformBase()
187
+ >>> isinstance(added, TransformChain)
188
+ True
189
+
190
+ >>> len(added.transforms)
191
+ 2
192
+
193
+ """
194
+ return TransformChain (transforms = [self , b ])
195
+
175
196
@property
176
197
def reference (self ):
177
198
"""Access a reference space where data will be resampled onto."""
@@ -262,6 +283,8 @@ def map(self, x, inverse=False, index=0):
262
283
r"""
263
284
Apply :math:`y = f(x)`.
264
285
286
+ TransformBase implements the identity transform.
287
+
265
288
Parameters
266
289
----------
267
290
x : N x D numpy.ndarray
@@ -277,7 +300,7 @@ def map(self, x, inverse=False, index=0):
277
300
Transformed (mapped) RAS+ coordinates (i.e., physical coordinates).
278
301
279
302
"""
280
- raise NotImplementedError
303
+ return x
281
304
282
305
def to_filename (self , filename , fmt = 'X5' ):
283
306
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
@@ -294,6 +317,127 @@ def _to_hdf5(self, x5_root):
294
317
raise NotImplementedError
295
318
296
319
320
+ class TransformChain (TransformBase ):
321
+ """Implements the concatenation of transforms."""
322
+
323
+ __slots__ = ['_transforms' ]
324
+
325
+ def __init__ (self , transforms = None ):
326
+ """Initialize a chain of transforms."""
327
+ self ._transforms = None
328
+ if transforms is not None :
329
+ self .transforms = transforms
330
+
331
+ def __add__ (self , b ):
332
+ """
333
+ Compose this and other transforms.
334
+
335
+ Example
336
+ -------
337
+ >>> T1 = TransformBase()
338
+ >>> added = T1 + TransformBase() + TransformBase()
339
+ >>> isinstance(added, TransformChain)
340
+ True
341
+
342
+ >>> len(added.transforms)
343
+ 3
344
+
345
+ """
346
+ self .append (b )
347
+ return self
348
+
349
+ def __getitem__ (self , i ):
350
+ """
351
+ Enable indexed access of transform chains.
352
+
353
+ Example
354
+ -------
355
+ >>> T1 = TransformBase()
356
+ >>> chain = T1 + TransformBase()
357
+ >>> chain[0] == T1
358
+ True
359
+
360
+ """
361
+ return self .transforms [i ]
362
+
363
+ def __len__ (self ):
364
+ """Enable using len()."""
365
+ return len (self .transforms )
366
+
367
+ @property
368
+ def transforms (self ):
369
+ """Get the internal list of transforms."""
370
+ return self ._transforms
371
+
372
+ @transforms .setter
373
+ def transforms (self , value ):
374
+ self ._transforms = _as_chain (value )
375
+ if self .transforms [0 ].reference :
376
+ self .reference = self .transforms [0 ].reference
377
+
378
+ def append (self , x ):
379
+ """
380
+ Concatenate one element to the chain.
381
+
382
+ Example
383
+ -------
384
+ >>> chain = TransformChain(transforms=TransformBase())
385
+ >>> chain.append((TransformBase(), TransformBase()))
386
+ >>> len(chain)
387
+ 3
388
+
389
+ """
390
+ self .transforms += _as_chain (x )
391
+
392
+ def insert (self , i , x ):
393
+ """
394
+ Insert an item at a given position.
395
+
396
+ Example
397
+ -------
398
+ >>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
399
+ >>> chain.insert(1, TransformBase())
400
+ >>> len(chain)
401
+ 3
402
+
403
+ >>> chain.insert(1, TransformChain(chain))
404
+ >>> len(chain)
405
+ 6
406
+
407
+ """
408
+ self .transforms = self .transforms [:i ] + _as_chain (x ) + self .transforms [i :]
409
+
410
+ def map (self , x , inverse = False , index = 0 ):
411
+ """
412
+ Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
413
+
414
+ Example
415
+ -------
416
+ >>> chain = TransformChain(transforms=[TransformBase(), TransformBase()])
417
+ >>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)])
418
+ [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
419
+
420
+ >>> chain([(0., 0., 0.), (1., 1., 1.), (-1., -1., -1.)], inverse=True)
421
+ [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0), (-1.0, -1.0, -1.0)]
422
+
423
+ >>> TransformChain()((0., 0., 0.)) # doctest: +IGNORE_EXCEPTION_DETAIL
424
+ Traceback (most recent call last):
425
+ TransformError:
426
+
427
+ """
428
+ if not self .transforms :
429
+ raise TransformError ('Cannot apply an empty transforms chain.' )
430
+
431
+ transforms = self .transforms
432
+ if inverse :
433
+ transforms = reversed (self .transforms )
434
+
435
+ for xfm in transforms :
436
+ x = xfm (x , inverse = inverse )
437
+
438
+ return x
439
+
440
+
297
441
def _as_homogeneous (xyz , dtype = 'float32' , dim = 3 ):
298
442
"""
299
443
Convert 2D and 3D coordinates into homogeneous coordinates.
@@ -324,3 +468,12 @@ def _as_homogeneous(xyz, dtype='float32', dim=3):
324
468
def _apply_affine (x , affine , dim ):
325
469
"""Get the image array's indexes corresponding to coordinates."""
326
470
return affine .dot (_as_homogeneous (x , dim = dim ).T )[:dim , ...].T
471
+
472
+
473
+ def _as_chain (x ):
474
+ """Convert a value into a transform chain."""
475
+ if isinstance (x , TransformChain ):
476
+ return x .transforms
477
+ if isinstance (x , Iterable ):
478
+ return list (x )
479
+ return [x ]
0 commit comments