12
12
import h5py
13
13
import warnings
14
14
from nibabel .loadsave import load
15
+ from nibabel .nifti1 import intent_codes as INTENT_CODES
16
+ from nibabel .cifti2 import Cifti2Image
15
17
16
18
from scipy import ndimage as ndi
17
19
18
20
EQUALITY_TOL = 1e-5
19
21
20
22
21
- class ImageGrid (object ):
23
+ class SpatialReference (object ):
24
+ """Factory to create spatial references."""
25
+
26
+ @staticmethod
27
+ def factory (dataset ):
28
+ """Create a reference for spatial transforms."""
29
+ try :
30
+ return SampledSpatialData (dataset )
31
+ except ValueError :
32
+ return ImageGrid (dataset )
33
+
34
+
35
+ class SampledSpatialData (object ):
36
+ """Represent sampled spatial data: regularly gridded (images) and surfaces."""
37
+
38
+ __slots__ = ['_ndim' , '_coords' , '_npoints' , '_shape' ]
39
+
40
+ def __init__ (self , dataset ):
41
+ """Create a sampling reference."""
42
+ self ._shape = None
43
+
44
+ if isinstance (dataset , SampledSpatialData ):
45
+ self ._coords = dataset .ndcoords .copy ()
46
+ self ._npoints , self ._ndim = self ._coords .shape
47
+ return
48
+
49
+ if isinstance (dataset , (str , Path )):
50
+ dataset = load (str (dataset ))
51
+
52
+ if hasattr (dataset , 'numDA' ): # Looks like a Gifti file
53
+ _das = dataset .get_arrays_from_intent (INTENT_CODES ['pointset' ])
54
+ if not _das :
55
+ raise TypeError (
56
+ 'Input Gifti file does not contain reference coordinates.' )
57
+ self ._coords = np .vstack ([da .data for da in _das ])
58
+ self ._npoints , self ._ndim = self ._coords .shape
59
+ return
60
+
61
+ if isinstance (dataset , Cifti2Image ):
62
+ raise NotImplementedError
63
+
64
+ raise ValueError ('Dataset could not be interpreted as an irregular sample.' )
65
+
66
+ @property
67
+ def npoints (self ):
68
+ """Access the total number of voxels."""
69
+ return self ._npoints
70
+
71
+ @property
72
+ def ndim (self ):
73
+ """Access the number of dimensions."""
74
+ return self ._ndim
75
+
76
+ @property
77
+ def ndcoords (self ):
78
+ """List the physical coordinates of this sample."""
79
+ return self ._coords
80
+
81
+ @property
82
+ def shape (self ):
83
+ """Access the space's size of each dimension."""
84
+ return self ._shape
85
+
86
+
87
+ class ImageGrid (SampledSpatialData ):
22
88
"""Class to represent spaces of gridded data (images)."""
23
89
24
- __slots__ = ['_affine' , '_shape' , '_ndim' , '_ndindex' , '_coords' , '_nvox' ,
25
- '_inverse' ]
90
+ __slots__ = ['_affine' , '_inverse' , '_ndindex' ]
26
91
27
92
def __init__ (self , image ):
28
93
"""Create a gridded sampling reference."""
@@ -31,11 +96,14 @@ def __init__(self, image):
31
96
32
97
self ._affine = image .affine
33
98
self ._shape = image .shape
34
- self ._ndim = len (image .shape )
35
- self ._nvox = np .prod (image .shape ) # Do not access data array
99
+ self ._ndim = getattr (image , 'ndim' , len (image .shape ))
100
+
101
+ self ._npoints = getattr (image , 'npoints' ,
102
+ np .prod (image .shape ))
36
103
self ._ndindex = None
37
104
self ._coords = None
38
- self ._inverse = np .linalg .inv (image .affine )
105
+ self ._inverse = getattr (image , 'inverse' ,
106
+ np .linalg .inv (image .affine ))
39
107
40
108
@property
41
109
def affine (self ):
@@ -47,28 +115,13 @@ def inverse(self):
47
115
"""Access the RAS-to-indexes affine."""
48
116
return self ._inverse
49
117
50
- @property
51
- def shape (self ):
52
- """Access the space's size of each dimension."""
53
- return self ._shape
54
-
55
- @property
56
- def ndim (self ):
57
- """Access the number of dimensions."""
58
- return self ._ndim
59
-
60
- @property
61
- def nvox (self ):
62
- """Access the total number of voxels."""
63
- return self ._nvox
64
-
65
118
@property
66
119
def ndindex (self ):
67
120
"""List the indexes corresponding to the space grid."""
68
121
if self ._ndindex is None :
69
122
indexes = tuple ([np .arange (s ) for s in self ._shape ])
70
123
self ._ndindex = np .array (np .meshgrid (
71
- * indexes , indexing = 'ij' )).reshape (self ._ndim , self ._nvox )
124
+ * indexes , indexing = 'ij' )).reshape (self ._ndim , self ._npoints )
72
125
return self ._ndindex
73
126
74
127
@property
@@ -77,7 +130,7 @@ def ndcoords(self):
77
130
if self ._coords is None :
78
131
self ._coords = np .tensordot (
79
132
self ._affine ,
80
- np .vstack ((self .ndindex , np .ones ((1 , self ._nvox )))),
133
+ np .vstack ((self .ndindex , np .ones ((1 , self ._npoints )))),
81
134
axes = 1
82
135
)[:3 , ...]
83
136
return self ._coords
@@ -131,16 +184,19 @@ def ndim(self):
131
184
"""Access the dimensions of the reference space."""
132
185
return self .reference .ndim
133
186
134
- def apply (self , moving , order = 3 , mode = 'constant' , cval = 0.0 , prefilter = True ,
135
- output_dtype = None ):
187
+ def apply (self , spatialimage , reference = None ,
188
+ order = 3 , mode = 'constant' , cval = 0.0 , prefilter = True , output_dtype = None ):
136
189
"""
137
- Resample the moving image in reference space .
190
+ Apply a transformation to an image, resampling on the reference spatial object .
138
191
139
192
Parameters
140
193
----------
141
- moving : `spatialimage`
194
+ spatialimage : `spatialimage`
142
195
The image object containing the data to be resampled in reference
143
196
space
197
+ reference : spatial object
198
+ The image, surface, or combination thereof containing the coordinates
199
+ of samples that will be sampled.
144
200
order : int, optional
145
201
The order of the spline interpolation, default is 3.
146
202
The order has to be in the range 0-5.
@@ -150,7 +206,7 @@ def apply(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
150
206
cval : float, optional
151
207
Constant value for ``mode='constant'``. Default is 0.0.
152
208
prefilter: bool, optional
153
- Determines if the moving image's data array is prefiltered with
209
+ Determines if the image's data array is prefiltered with
154
210
a spline filter before interpolation. The default is ``True``,
155
211
which will create a temporary *float64* array of filtered values
156
212
if *order > 1*. If setting this to ``False``, the output will be
@@ -160,21 +216,27 @@ def apply(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
160
216
161
217
Returns
162
218
-------
163
- moved_image : `spatialimage`
164
- The moving imaged after resampling to reference space.
219
+ resampled : `spatialimage` or ndarray
220
+ The data imaged after resampling to reference space.
165
221
166
222
"""
167
- if isinstance (moving , str ):
168
- moving = load (moving )
223
+ if reference is not None and isinstance (reference , (str , Path )):
224
+ reference = load (reference )
225
+
226
+ _ref = self .reference if reference is None \
227
+ else SpatialReference .factory (reference )
169
228
170
- moving_data = np .asanyarray (moving .dataobj )
171
- output_dtype = output_dtype or moving_data .dtype
172
- targets = ImageGrid (moving ).index (
173
- _as_homogeneous (self .map (self .reference .ndcoords .T ),
174
- dim = self .reference .ndim ))
229
+ if isinstance (spatialimage , str ):
230
+ spatialimage = load (spatialimage )
175
231
176
- moved = ndi .map_coordinates (
177
- moving_data ,
232
+ data = np .asanyarray (spatialimage .dataobj )
233
+ output_dtype = output_dtype or data .dtype
234
+ targets = ImageGrid (spatialimage ).index ( # data should be an image
235
+ _as_homogeneous (self .map (_ref .ndcoords .T ),
236
+ dim = _ref .ndim ))
237
+
238
+ resampled = ndi .map_coordinates (
239
+ data ,
178
240
targets .T ,
179
241
output = output_dtype ,
180
242
order = order ,
@@ -183,10 +245,14 @@ def apply(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
183
245
prefilter = prefilter ,
184
246
)
185
247
186
- moved_image = moving .__class__ (moved .reshape (self .reference .shape ),
187
- self .reference .affine , moving .header )
188
- moved_image .header .set_data_dtype (output_dtype )
189
- return moved_image
248
+ if isinstance (_ref , ImageGrid ): # If reference is grid, reshape
249
+ moved = spatialimage .__class__ (
250
+ resampled .reshape (_ref .shape ),
251
+ _ref .affine , spatialimage .header )
252
+ moved .header .set_data_dtype (output_dtype )
253
+ return moved
254
+
255
+ return resampled
190
256
191
257
def map (self , x , inverse = False , index = 0 ):
192
258
r"""
0 commit comments