12
12
import numpy as np
13
13
import h5py
14
14
import warnings
15
- from nibabel .loadsave import load
15
+ from nibabel .loadsave import load as _nbload
16
+ from nibabel import funcs as _nbfuncs
16
17
from nibabel .nifti1 import intent_codes as INTENT_CODES
17
18
from nibabel .cifti2 import Cifti2Image
18
19
from scipy import ndimage as ndi
19
20
20
21
EQUALITY_TOL = 1e-5
21
22
22
23
23
- class TransformError (ValueError ):
24
+ class TransformError (TypeError ):
24
25
"""A custom exception for transforms."""
25
26
26
27
@@ -51,7 +52,7 @@ def __init__(self, dataset):
51
52
return
52
53
53
54
if isinstance (dataset , (str , Path )):
54
- dataset = load (str (dataset ))
55
+ dataset = _nbload (str (dataset ))
55
56
56
57
if hasattr (dataset , 'numDA' ): # Looks like a Gifti file
57
58
_das = dataset .get_arrays_from_intent (INTENT_CODES ['pointset' ])
@@ -96,14 +97,18 @@ class ImageGrid(SampledSpatialData):
96
97
def __init__ (self , image ):
97
98
"""Create a gridded sampling reference."""
98
99
if isinstance (image , (str , Path )):
99
- image = load ( str (image ))
100
+ image = _nbfuncs . squeeze_image ( _nbload ( str (image ) ))
100
101
101
102
self ._affine = image .affine
102
103
self ._shape = image .shape
104
+
103
105
self ._ndim = getattr (image , 'ndim' , len (image .shape ))
106
+ if self ._ndim == 4 :
107
+ self ._shape = image .shape [:3 ]
108
+ self ._ndim = 3
104
109
105
110
self ._npoints = getattr (image , 'npoints' ,
106
- np .prod (image . shape ))
111
+ np .prod (self . _shape ))
107
112
self ._ndindex = None
108
113
self ._coords = None
109
114
self ._inverse = getattr (image , 'inverse' ,
@@ -168,13 +173,15 @@ class TransformBase(object):
168
173
169
174
__slots__ = ['_reference' ]
170
175
171
- def __init__ (self ):
176
+ def __init__ (self , reference = None ):
172
177
"""Instantiate a transform."""
173
178
self ._reference = None
179
+ if reference :
180
+ self .reference = reference
174
181
175
- def __call__ (self , x , inverse = False , index = 0 ):
182
+ def __call__ (self , x , inverse = False ):
176
183
"""Apply y = f(x)."""
177
- return self .map (x , inverse = inverse , index = index )
184
+ return self .map (x , inverse = inverse )
178
185
179
186
def __add__ (self , b ):
180
187
"""
@@ -246,13 +253,13 @@ def apply(self, spatialimage, reference=None,
246
253
247
254
"""
248
255
if reference is not None and isinstance (reference , (str , Path )):
249
- reference = load (str (reference ))
256
+ reference = _nbload (str (reference ))
250
257
251
258
_ref = self .reference if reference is None \
252
259
else SpatialReference .factory (reference )
253
260
254
261
if isinstance (spatialimage , (str , Path )):
255
- spatialimage = load (str (spatialimage ))
262
+ spatialimage = _nbload (str (spatialimage ))
256
263
257
264
data = np .asanyarray (spatialimage .dataobj )
258
265
output_dtype = output_dtype or data .dtype
@@ -279,7 +286,7 @@ def apply(self, spatialimage, reference=None,
279
286
280
287
return resampled
281
288
282
- def map (self , x , inverse = False , index = 0 ):
289
+ def map (self , x , inverse = False ):
283
290
r"""
284
291
Apply :math:`y = f(x)`.
285
292
@@ -291,8 +298,6 @@ def map(self, x, inverse=False, index=0):
291
298
Input RAS+ coordinates (i.e., physical coordinates).
292
299
inverse : bool
293
300
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
294
- index : int, optional
295
- Transformation index
296
301
297
302
Returns
298
303
-------
@@ -407,7 +412,7 @@ def insert(self, i, x):
407
412
"""
408
413
self .transforms = self .transforms [:i ] + _as_chain (x ) + self .transforms [i :]
409
414
410
- def map (self , x , inverse = False , index = 0 ):
415
+ def map (self , x , inverse = False ):
411
416
"""
412
417
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
413
418
0 commit comments