Skip to content

Commit 955cd38

Browse files
authored
Merge pull request #59 from oesteban/enh/transform-map
ENH: Support for transforms mappings (e.g., head-motion correction)
2 parents 391f28d + da44997 commit 955cd38

12 files changed

+504
-61
lines changed

nitransforms/base.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@
1212
import numpy as np
1313
import h5py
1414
import warnings
15-
from nibabel.loadsave import load
15+
from nibabel.loadsave import load as _nbload
16+
from nibabel import funcs as _nbfuncs
1617
from nibabel.nifti1 import intent_codes as INTENT_CODES
1718
from nibabel.cifti2 import Cifti2Image
1819
from scipy import ndimage as ndi
1920

2021
EQUALITY_TOL = 1e-5
2122

2223

23-
class TransformError(ValueError):
24+
class TransformError(TypeError):
2425
"""A custom exception for transforms."""
2526

2627

@@ -51,7 +52,7 @@ def __init__(self, dataset):
5152
return
5253

5354
if isinstance(dataset, (str, Path)):
54-
dataset = load(str(dataset))
55+
dataset = _nbload(str(dataset))
5556

5657
if hasattr(dataset, 'numDA'): # Looks like a Gifti file
5758
_das = dataset.get_arrays_from_intent(INTENT_CODES['pointset'])
@@ -96,14 +97,18 @@ class ImageGrid(SampledSpatialData):
9697
def __init__(self, image):
9798
"""Create a gridded sampling reference."""
9899
if isinstance(image, (str, Path)):
99-
image = load(str(image))
100+
image = _nbfuncs.squeeze_image(_nbload(str(image)))
100101

101102
self._affine = image.affine
102103
self._shape = image.shape
104+
103105
self._ndim = getattr(image, 'ndim', len(image.shape))
106+
if self._ndim == 4:
107+
self._shape = image.shape[:3]
108+
self._ndim = 3
104109

105110
self._npoints = getattr(image, 'npoints',
106-
np.prod(image.shape))
111+
np.prod(self._shape))
107112
self._ndindex = None
108113
self._coords = None
109114
self._inverse = getattr(image, 'inverse',
@@ -168,13 +173,15 @@ class TransformBase(object):
168173

169174
__slots__ = ['_reference']
170175

171-
def __init__(self):
176+
def __init__(self, reference=None):
172177
"""Instantiate a transform."""
173178
self._reference = None
179+
if reference:
180+
self.reference = reference
174181

175-
def __call__(self, x, inverse=False, index=0):
182+
def __call__(self, x, inverse=False):
176183
"""Apply y = f(x)."""
177-
return self.map(x, inverse=inverse, index=index)
184+
return self.map(x, inverse=inverse)
178185

179186
def __add__(self, b):
180187
"""
@@ -246,13 +253,13 @@ def apply(self, spatialimage, reference=None,
246253
247254
"""
248255
if reference is not None and isinstance(reference, (str, Path)):
249-
reference = load(str(reference))
256+
reference = _nbload(str(reference))
250257

251258
_ref = self.reference if reference is None \
252259
else SpatialReference.factory(reference)
253260

254261
if isinstance(spatialimage, (str, Path)):
255-
spatialimage = load(str(spatialimage))
262+
spatialimage = _nbload(str(spatialimage))
256263

257264
data = np.asanyarray(spatialimage.dataobj)
258265
output_dtype = output_dtype or data.dtype
@@ -279,7 +286,7 @@ def apply(self, spatialimage, reference=None,
279286

280287
return resampled
281288

282-
def map(self, x, inverse=False, index=0):
289+
def map(self, x, inverse=False):
283290
r"""
284291
Apply :math:`y = f(x)`.
285292
@@ -291,8 +298,6 @@ def map(self, x, inverse=False, index=0):
291298
Input RAS+ coordinates (i.e., physical coordinates).
292299
inverse : bool
293300
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
294-
index : int, optional
295-
Transformation index
296301
297302
Returns
298303
-------
@@ -407,7 +412,7 @@ def insert(self, i, x):
407412
"""
408413
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]
409414

410-
def map(self, x, inverse=False, index=0):
415+
def map(self, x, inverse=False):
411416
"""
412417
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
413418

nitransforms/io/fsl.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Read/write FSL's transforms."""
22
import os
33
import numpy as np
4+
from pathlib import Path
45
from nibabel.affines import voxel_sizes
56

67
from .base import BaseLinearTransformList, LinearParameters, TransformFileError
@@ -63,13 +64,11 @@ class FSLLinearTransformArray(BaseLinearTransformList):
6364

6465
def to_filename(self, filename):
6566
"""Store this transform to a file with the appropriate format."""
66-
if len(self.xforms) == 1:
67-
self.xforms[0].to_filename(filename)
68-
return
69-
67+
output_dir = Path(filename).parent
68+
output_dir.mkdir(exist_ok=True, parents=True)
7069
for i, xfm in enumerate(self.xforms):
71-
with open('%s.%03d' % (filename, i), 'w') as f:
72-
f.write(xfm.to_string())
70+
(output_dir / '.'.join((str(filename), '%03d' % i))).write_text(
71+
xfm.to_string())
7372

7473
def to_ras(self, moving=None, reference=None):
7574
"""Return a nitransforms' internal RAS matrix."""

0 commit comments

Comments
 (0)