Skip to content

Commit 6e345f5

Browse files
committed
enh: add support for resampling Gifti files
1 parent f553349 commit 6e345f5

File tree

2 files changed

+165
-46
lines changed

2 files changed

+165
-46
lines changed

nitransforms/base.py

+109-43
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,82 @@
1212
import h5py
1313
import warnings
1414
from nibabel.loadsave import load
15+
from nibabel.nifti1 import intent_codes as INTENT_CODES
16+
from nibabel.cifti2 import Cifti2Image
1517

1618
from scipy import ndimage as ndi
1719

1820
EQUALITY_TOL = 1e-5
1921

2022

21-
class ImageGrid(object):
23+
class SpatialReference(object):
24+
"""Factory to create spatial references."""
25+
26+
@staticmethod
27+
def factory(dataset):
28+
"""Create a reference for spatial transforms."""
29+
try:
30+
return SampledSpatialData(dataset)
31+
except ValueError:
32+
return ImageGrid(dataset)
33+
34+
35+
class SampledSpatialData(object):
36+
"""Represent sampled spatial data: regularly gridded (images) and surfaces."""
37+
38+
__slots__ = ['_ndim', '_coords', '_npoints', '_shape']
39+
40+
def __init__(self, dataset):
41+
"""Create a sampling reference."""
42+
self._shape = None
43+
44+
if isinstance(dataset, SampledSpatialData):
45+
self._coords = dataset.ndcoords.copy()
46+
self._npoints, self._ndim = self._coords.shape
47+
return
48+
49+
if isinstance(dataset, (str, Path)):
50+
dataset = load(str(dataset))
51+
52+
if hasattr(dataset, 'numDA'): # Looks like a Gifti file
53+
_das = dataset.get_arrays_from_intent(INTENT_CODES['pointset'])
54+
if not _das:
55+
raise TypeError(
56+
'Input Gifti file does not contain reference coordinates.')
57+
self._coords = np.vstack([da.data for da in _das])
58+
self._npoints, self._ndim = self._coords.shape
59+
return
60+
61+
if isinstance(dataset, Cifti2Image):
62+
raise NotImplementedError
63+
64+
raise ValueError('Dataset could not be interpreted as an irregular sample.')
65+
66+
@property
67+
def npoints(self):
68+
"""Access the total number of voxels."""
69+
return self._npoints
70+
71+
@property
72+
def ndim(self):
73+
"""Access the number of dimensions."""
74+
return self._ndim
75+
76+
@property
77+
def ndcoords(self):
78+
"""List the physical coordinates of this sample."""
79+
return self._coords
80+
81+
@property
82+
def shape(self):
83+
"""Access the space's size of each dimension."""
84+
return self._shape
85+
86+
87+
class ImageGrid(SampledSpatialData):
2288
"""Class to represent spaces of gridded data (images)."""
2389

24-
__slots__ = ['_affine', '_shape', '_ndim', '_ndindex', '_coords', '_nvox',
25-
'_inverse']
90+
__slots__ = ['_affine', '_inverse', '_ndindex']
2691

2792
def __init__(self, image):
2893
"""Create a gridded sampling reference."""
@@ -31,11 +96,14 @@ def __init__(self, image):
3196

3297
self._affine = image.affine
3398
self._shape = image.shape
34-
self._ndim = len(image.shape)
35-
self._nvox = np.prod(image.shape) # Do not access data array
99+
self._ndim = getattr(image, 'ndim', len(image.shape))
100+
101+
self._npoints = getattr(image, 'npoints',
102+
np.prod(image.shape))
36103
self._ndindex = None
37104
self._coords = None
38-
self._inverse = np.linalg.inv(image.affine)
105+
self._inverse = getattr(image, 'inverse',
106+
np.linalg.inv(image.affine))
39107

40108
@property
41109
def affine(self):
@@ -47,28 +115,13 @@ def inverse(self):
47115
"""Access the RAS-to-indexes affine."""
48116
return self._inverse
49117

50-
@property
51-
def shape(self):
52-
"""Access the space's size of each dimension."""
53-
return self._shape
54-
55-
@property
56-
def ndim(self):
57-
"""Access the number of dimensions."""
58-
return self._ndim
59-
60-
@property
61-
def nvox(self):
62-
"""Access the total number of voxels."""
63-
return self._nvox
64-
65118
@property
66119
def ndindex(self):
67120
"""List the indexes corresponding to the space grid."""
68121
if self._ndindex is None:
69122
indexes = tuple([np.arange(s) for s in self._shape])
70123
self._ndindex = np.array(np.meshgrid(
71-
*indexes, indexing='ij')).reshape(self._ndim, self._nvox)
124+
*indexes, indexing='ij')).reshape(self._ndim, self._npoints)
72125
return self._ndindex
73126

74127
@property
@@ -77,7 +130,7 @@ def ndcoords(self):
77130
if self._coords is None:
78131
self._coords = np.tensordot(
79132
self._affine,
80-
np.vstack((self.ndindex, np.ones((1, self._nvox)))),
133+
np.vstack((self.ndindex, np.ones((1, self._npoints)))),
81134
axes=1
82135
)[:3, ...]
83136
return self._coords
@@ -131,16 +184,19 @@ def ndim(self):
131184
"""Access the dimensions of the reference space."""
132185
return self.reference.ndim
133186

134-
def apply(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
135-
output_dtype=None):
187+
def apply(self, spatialimage, reference=None,
188+
order=3, mode='constant', cval=0.0, prefilter=True, output_dtype=None):
136189
"""
137-
Resample the moving image in reference space.
190+
Apply a transformation to an image, resampling on the reference spatial object.
138191
139192
Parameters
140193
----------
141-
moving : `spatialimage`
194+
spatialimage : `spatialimage`
142195
The image object containing the data to be resampled in reference
143196
space
197+
reference : spatial object
198+
The image, surface, or combination thereof containing the coordinates
199+
of samples that will be sampled.
144200
order : int, optional
145201
The order of the spline interpolation, default is 3.
146202
The order has to be in the range 0-5.
@@ -150,7 +206,7 @@ def apply(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
150206
cval : float, optional
151207
Constant value for ``mode='constant'``. Default is 0.0.
152208
prefilter: bool, optional
153-
Determines if the moving image's data array is prefiltered with
209+
Determines if the image's data array is prefiltered with
154210
a spline filter before interpolation. The default is ``True``,
155211
which will create a temporary *float64* array of filtered values
156212
if *order > 1*. If setting this to ``False``, the output will be
@@ -160,21 +216,27 @@ def apply(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
160216
161217
Returns
162218
-------
163-
moved_image : `spatialimage`
164-
The moving imaged after resampling to reference space.
219+
resampled : `spatialimage` or ndarray
220+
The data imaged after resampling to reference space.
165221
166222
"""
167-
if isinstance(moving, str):
168-
moving = load(moving)
223+
if reference is not None and isinstance(reference, (str, Path)):
224+
reference = load(reference)
225+
226+
_ref = self.reference if reference is None \
227+
else SpatialReference.factory(reference)
169228

170-
moving_data = np.asanyarray(moving.dataobj)
171-
output_dtype = output_dtype or moving_data.dtype
172-
targets = ImageGrid(moving).index(
173-
_as_homogeneous(self.map(self.reference.ndcoords.T),
174-
dim=self.reference.ndim))
229+
if isinstance(spatialimage, str):
230+
spatialimage = load(spatialimage)
175231

176-
moved = ndi.map_coordinates(
177-
moving_data,
232+
data = np.asanyarray(spatialimage.dataobj)
233+
output_dtype = output_dtype or data.dtype
234+
targets = ImageGrid(spatialimage).index( # data should be an image
235+
_as_homogeneous(self.map(_ref.ndcoords.T),
236+
dim=_ref.ndim))
237+
238+
resampled = ndi.map_coordinates(
239+
data,
178240
targets.T,
179241
output=output_dtype,
180242
order=order,
@@ -183,10 +245,14 @@ def apply(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
183245
prefilter=prefilter,
184246
)
185247

186-
moved_image = moving.__class__(moved.reshape(self.reference.shape),
187-
self.reference.affine, moving.header)
188-
moved_image.header.set_data_dtype(output_dtype)
189-
return moved_image
248+
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
249+
moved = spatialimage.__class__(
250+
resampled.reshape(_ref.shape),
251+
_ref.affine, spatialimage.header)
252+
moved.header.set_data_dtype(output_dtype)
253+
return moved
254+
255+
return resampled
190256

191257
def map(self, x, inverse=False, index=0):
192258
r"""

nitransforms/tests/test_base.py

+56-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,20 @@
44
import pytest
55
import h5py
66

7-
from ..base import ImageGrid, TransformBase
7+
from ..base import SpatialReference, SampledSpatialData, ImageGrid, TransformBase
8+
9+
10+
def test_SpatialReference(data_path):
11+
"""Ensure the reference factory is working properly."""
12+
obj1 = data_path / 'someones_anatomy.nii.gz'
13+
obj2 = data_path / 'sub-200148_hemi-R_pial.surf.gii'
14+
15+
assert isinstance(SpatialReference.factory(obj1), ImageGrid)
16+
assert isinstance(SpatialReference.factory(str(obj1)), ImageGrid)
17+
assert isinstance(SpatialReference.factory(nb.load(str(obj1))), ImageGrid)
18+
assert isinstance(SpatialReference.factory(obj2), SampledSpatialData)
19+
assert isinstance(SpatialReference.factory(str(obj2)), SampledSpatialData)
20+
assert isinstance(SpatialReference.factory(nb.load(str(obj2))), SampledSpatialData)
821

922

1023
@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
@@ -29,7 +42,10 @@ def test_ImageGrid(get_testdata, image_orientation):
2942
coords = img.ndcoords
3043
assert len(idxs.shape) == len(coords.shape) == 2
3144
assert idxs.shape[0] == coords.shape[0] == img.ndim == 3
32-
assert idxs.shape[1] == coords.shape[1] == img.nvox == np.prod(im.shape)
45+
assert idxs.shape[1] == coords.shape[1] == img.npoints == np.prod(im.shape)
46+
47+
img2 = ImageGrid(img)
48+
assert img2 == img
3349

3450

3551
def test_ImageGrid_utils(tmpdir, data_path, get_testdata):
@@ -39,7 +55,9 @@ def test_ImageGrid_utils(tmpdir, data_path, get_testdata):
3955
im1 = get_testdata['RAS']
4056
im2 = data_path / 'someones_anatomy.nii.gz'
4157

42-
assert ImageGrid(im1) == ImageGrid(im2)
58+
ig = ImageGrid(im1)
59+
assert ig == ImageGrid(im2)
60+
assert ig.shape is not None
4361

4462
with h5py.File('xfm.x5', 'w') as f:
4563
ImageGrid(im1)._to_hdf5(f.create_group('Reference'))
@@ -59,10 +77,45 @@ def _to_hdf5(klass, x5_root):
5977
monkeypatch.setattr(TransformBase, '_to_hdf5', _to_hdf5)
6078
fname = str(data_path / 'someones_anatomy.nii.gz')
6179

80+
# Test identity transform
6281
xfm = TransformBase()
6382
xfm.reference = fname
6483
assert xfm.ndim == 3
6584
moved = xfm.apply(fname, order=0)
6685
assert np.all(nb.load(fname).get_fdata() == moved.get_fdata())
6786

87+
# Test identity transform - setting reference
88+
xfm = TransformBase()
89+
xfm.reference = fname
90+
assert xfm.ndim == 3
91+
moved = xfm.apply(fname, reference=fname, order=0)
92+
assert np.all(nb.load(fname).get_fdata() == moved.get_fdata())
93+
94+
# Test applying to Gifti
95+
gii = nb.gifti.GiftiImage(darrays=[
96+
nb.gifti.GiftiDataArray(
97+
data=xfm.reference.ndcoords,
98+
intent=nb.nifti1.intent_codes['pointset'])]
99+
)
100+
giimoved = xfm.apply(fname, reference=gii, order=0)
101+
assert np.allclose(giimoved.reshape(xfm.reference.shape), moved.get_fdata())
102+
103+
# Test to_filename
68104
xfm.to_filename('data.x5')
105+
106+
107+
def test_SampledSpatialData(data_path):
108+
"""Check the reference generated by cifti files."""
109+
gii = data_path / 'sub-200148_hemi-R_pial.surf.gii'
110+
111+
ssd = SampledSpatialData(gii)
112+
assert ssd.npoints == 249277
113+
assert ssd.ndim == 3
114+
assert ssd.ndcoords.shape == (249277, 3)
115+
assert ssd.shape is None
116+
117+
ssd2 = SampledSpatialData(ssd)
118+
assert ssd2.npoints == 249277
119+
assert ssd2.ndim == 3
120+
assert ssd2.ndcoords.shape == (249277, 3)
121+
assert ssd2.shape is None

0 commit comments

Comments
 (0)