Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Rename resample() with apply() #30

Merged
merged 4 commits into from
Oct 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 109 additions & 43 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,82 @@
import h5py
import warnings
from nibabel.loadsave import load
from nibabel.nifti1 import intent_codes as INTENT_CODES
from nibabel.cifti2 import Cifti2Image

from scipy import ndimage as ndi

EQUALITY_TOL = 1e-5


class ImageGrid(object):
class SpatialReference:
"""Factory to create spatial references."""

@staticmethod
def factory(dataset):
"""Create a reference for spatial transforms."""
try:
return SampledSpatialData(dataset)
except ValueError:
return ImageGrid(dataset)


class SampledSpatialData:
"""Represent sampled spatial data: regularly gridded (images) and surfaces."""

__slots__ = ['_ndim', '_coords', '_npoints', '_shape']

def __init__(self, dataset):
"""Create a sampling reference."""
self._shape = None

if isinstance(dataset, SampledSpatialData):
self._coords = dataset.ndcoords.copy()
self._npoints, self._ndim = self._coords.shape
return

if isinstance(dataset, (str, Path)):
dataset = load(str(dataset))

if hasattr(dataset, 'numDA'): # Looks like a Gifti file
_das = dataset.get_arrays_from_intent(INTENT_CODES['pointset'])
if not _das:
raise TypeError(
'Input Gifti file does not contain reference coordinates.')
self._coords = np.vstack([da.data for da in _das])
self._npoints, self._ndim = self._coords.shape
return

if isinstance(dataset, Cifti2Image):
raise NotImplementedError

raise ValueError('Dataset could not be interpreted as an irregular sample.')

@property
def npoints(self):
"""Access the total number of voxels."""
return self._npoints

@property
def ndim(self):
"""Access the number of dimensions."""
return self._ndim

@property
def ndcoords(self):
"""List the physical coordinates of this sample."""
return self._coords

@property
def shape(self):
"""Access the space's size of each dimension."""
return self._shape


class ImageGrid(SampledSpatialData):
"""Class to represent spaces of gridded data (images)."""

__slots__ = ['_affine', '_shape', '_ndim', '_ndindex', '_coords', '_nvox',
'_inverse']
__slots__ = ['_affine', '_inverse', '_ndindex']

def __init__(self, image):
"""Create a gridded sampling reference."""
Expand All @@ -31,11 +96,14 @@ def __init__(self, image):

self._affine = image.affine
self._shape = image.shape
self._ndim = len(image.shape)
self._nvox = np.prod(image.shape) # Do not access data array
self._ndim = getattr(image, 'ndim', len(image.shape))

self._npoints = getattr(image, 'npoints',
np.prod(image.shape))
self._ndindex = None
self._coords = None
self._inverse = np.linalg.inv(image.affine)
self._inverse = getattr(image, 'inverse',
np.linalg.inv(image.affine))

@property
def affine(self):
Expand All @@ -47,28 +115,13 @@ def inverse(self):
"""Access the RAS-to-indexes affine."""
return self._inverse

@property
def shape(self):
"""Access the space's size of each dimension."""
return self._shape

@property
def ndim(self):
"""Access the number of dimensions."""
return self._ndim

@property
def nvox(self):
"""Access the total number of voxels."""
return self._nvox

@property
def ndindex(self):
"""List the indexes corresponding to the space grid."""
if self._ndindex is None:
indexes = tuple([np.arange(s) for s in self._shape])
self._ndindex = np.array(np.meshgrid(
*indexes, indexing='ij')).reshape(self._ndim, self._nvox)
*indexes, indexing='ij')).reshape(self._ndim, self._npoints)
return self._ndindex

@property
Expand All @@ -77,7 +130,7 @@ def ndcoords(self):
if self._coords is None:
self._coords = np.tensordot(
self._affine,
np.vstack((self.ndindex, np.ones((1, self._nvox)))),
np.vstack((self.ndindex, np.ones((1, self._npoints)))),
axes=1
)[:3, ...]
return self._coords
Expand Down Expand Up @@ -131,16 +184,19 @@ def ndim(self):
"""Access the dimensions of the reference space."""
return self.reference.ndim

def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
output_dtype=None):
def apply(self, spatialimage, reference=None,
order=3, mode='constant', cval=0.0, prefilter=True, output_dtype=None):
"""
Resample the moving image in reference space.
Apply a transformation to an image, resampling on the reference spatial object.

Parameters
----------
moving : `spatialimage`
spatialimage : `spatialimage`
The image object containing the data to be resampled in reference
space
reference : spatial object, optional
The image, surface, or combination thereof containing the coordinates
of samples that will be sampled.
order : int, optional
The order of the spline interpolation, default is 3.
The order has to be in the range 0-5.
Expand All @@ -150,7 +206,7 @@ def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
cval : float, optional
Constant value for ``mode='constant'``. Default is 0.0.
prefilter: bool, optional
Determines if the moving image's data array is prefiltered with
Determines if the image's data array is prefiltered with
a spline filter before interpolation. The default is ``True``,
which will create a temporary *float64* array of filtered values
if *order > 1*. If setting this to ``False``, the output will be
Expand All @@ -160,21 +216,27 @@ def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,

Returns
-------
moved_image : `spatialimage`
The moving imaged after resampling to reference space.
resampled : `spatialimage` or ndarray
The data imaged after resampling to reference space.

"""
if isinstance(moving, str):
moving = load(moving)
if reference is not None and isinstance(reference, (str, Path)):
reference = load(reference)

_ref = self.reference if reference is None \
else SpatialReference.factory(reference)

moving_data = np.asanyarray(moving.dataobj)
output_dtype = output_dtype or moving_data.dtype
targets = ImageGrid(moving).index(
_as_homogeneous(self.map(self.reference.ndcoords.T),
dim=self.reference.ndim))
if isinstance(spatialimage, (str, Path)):
spatialimage = load(spatialimage)

moved = ndi.map_coordinates(
moving_data,
data = np.asanyarray(spatialimage.dataobj)
output_dtype = output_dtype or data.dtype
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(self.map(_ref.ndcoords.T),
dim=_ref.ndim))

resampled = ndi.map_coordinates(
data,
targets.T,
output=output_dtype,
order=order,
Expand All @@ -183,10 +245,14 @@ def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
prefilter=prefilter,
)

moved_image = moving.__class__(moved.reshape(self.reference.shape),
self.reference.affine, moving.header)
moved_image.header.set_data_dtype(output_dtype)
return moved_image
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
moved = spatialimage.__class__(
resampled.reshape(_ref.shape),
_ref.affine, spatialimage.header)
moved.header.set_data_dtype(output_dtype)
return moved

return resampled

def map(self, x, inverse=False, index=0):
r"""
Expand Down
23 changes: 0 additions & 23 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,26 +145,3 @@ def map(self, x, inverse=False, index=0):
# def _map_voxel(self, index, moving=None):
# """Apply ijk' = f_ijk((i, j, k)), equivalent to the above with indexes."""
# return tuple(self._moving[index + self.__s])

# def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
# output_dtype=None):
# """
# Resample the ``moving`` image applying the deformation field.

# Examples
# --------
# >>> ref = nb.load(os.path.join(datadir, 'someones_anatomy.nii.gz'))
# >>> coeffs = np.zeros((6, 6, 6, 3))
# >>> coeffs[2, 2, 2, ...] = [10.0, -20.0, 0]
# >>> aff = ref.affine
# >>> aff[:3, :3] = aff[:3, :3].dot(np.eye(3) * np.array(
# ... ref.header.get_zooms()[:3]) / 6.0
# ... )
# >>> coeffsimg = nb.Nifti1Image(coeffs, ref.affine, ref.header)
# >>> xfm = BSplineFieldTransform(ref, coeffsimg) # doctest: +SKIP
# >>> new = xfm.resample(ref) # doctest: +SKIP

# """
# self._cache_moving()
# return super(BSplineFieldTransform, self).resample(
# moving, order=order, mode=mode, cval=cval, prefilter=prefilter)
9 changes: 9 additions & 0 deletions nitransforms/tests/data/sub-200148_hemi-R_pial.surf.gii

Large diffs are not rendered by default.

66 changes: 62 additions & 4 deletions nitransforms/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,20 @@
import pytest
import h5py

from ..base import ImageGrid, TransformBase
from ..base import SpatialReference, SampledSpatialData, ImageGrid, TransformBase


def test_SpatialReference(data_path):
"""Ensure the reference factory is working properly."""
obj1 = data_path / 'someones_anatomy.nii.gz'
obj2 = data_path / 'sub-200148_hemi-R_pial.surf.gii'

assert isinstance(SpatialReference.factory(obj1), ImageGrid)
assert isinstance(SpatialReference.factory(str(obj1)), ImageGrid)
assert isinstance(SpatialReference.factory(nb.load(str(obj1))), ImageGrid)
assert isinstance(SpatialReference.factory(obj2), SampledSpatialData)
assert isinstance(SpatialReference.factory(str(obj2)), SampledSpatialData)
assert isinstance(SpatialReference.factory(nb.load(str(obj2))), SampledSpatialData)


@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
Expand All @@ -29,7 +42,10 @@ def test_ImageGrid(get_testdata, image_orientation):
coords = img.ndcoords
assert len(idxs.shape) == len(coords.shape) == 2
assert idxs.shape[0] == coords.shape[0] == img.ndim == 3
assert idxs.shape[1] == coords.shape[1] == img.nvox == np.prod(im.shape)
assert idxs.shape[1] == coords.shape[1] == img.npoints == np.prod(im.shape)

img2 = ImageGrid(img)
assert img2 == img


def test_ImageGrid_utils(tmpdir, data_path, get_testdata):
Expand All @@ -39,7 +55,9 @@ def test_ImageGrid_utils(tmpdir, data_path, get_testdata):
im1 = get_testdata['RAS']
im2 = data_path / 'someones_anatomy.nii.gz'

assert ImageGrid(im1) == ImageGrid(im2)
ig = ImageGrid(im1)
assert ig == ImageGrid(im2)
assert ig.shape is not None

with h5py.File('xfm.x5', 'w') as f:
ImageGrid(im1)._to_hdf5(f.create_group('Reference'))
Expand All @@ -59,10 +77,50 @@ def _to_hdf5(klass, x5_root):
monkeypatch.setattr(TransformBase, '_to_hdf5', _to_hdf5)
fname = str(data_path / 'someones_anatomy.nii.gz')

# Test identity transform
xfm = TransformBase()
xfm.reference = fname
assert xfm.ndim == 3
moved = xfm.resample(fname, order=0)
moved = xfm.apply(fname, order=0)
assert np.all(nb.load(fname).get_fdata() == moved.get_fdata())

# Test identity transform - setting reference
xfm = TransformBase()
xfm.reference = fname
assert xfm.ndim == 3
moved = xfm.apply(fname, reference=fname, order=0)
assert np.all(nb.load(fname).get_fdata() == moved.get_fdata())

# Test applying to Gifti
gii = nb.gifti.GiftiImage(darrays=[
nb.gifti.GiftiDataArray(
data=xfm.reference.ndcoords,
intent=nb.nifti1.intent_codes['pointset'])]
)
giimoved = xfm.apply(fname, reference=gii, order=0)
assert np.allclose(giimoved.reshape(xfm.reference.shape), moved.get_fdata())

# Test to_filename
xfm.to_filename('data.x5')


def test_SampledSpatialData(data_path):
"""Check the reference generated by cifti files."""
gii = data_path / 'sub-200148_hemi-R_pial.surf.gii'

ssd = SampledSpatialData(gii)
assert ssd.npoints == 249277
assert ssd.ndim == 3
assert ssd.ndcoords.shape == (249277, 3)
assert ssd.shape is None

ssd2 = SampledSpatialData(ssd)
assert ssd2.npoints == 249277
assert ssd2.ndim == 3
assert ssd2.ndcoords.shape == (249277, 3)
assert ssd2.shape is None

# check what happens with an empty gifti
with pytest.raises(TypeError):
gii = nb.gifti.GiftiImage()
SampledSpatialData(gii)
2 changes: 1 addition & 1 deletion nitransforms/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_apply_linear_transform(
assert exit_code == 0
sw_moved = nb.load('resampled.nii.gz')

nt_moved = xfm.resample(img, order=0)
nt_moved = xfm.apply(img, order=0)
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
# A certain tolerance is necessary because of resampling at borders
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
2 changes: 1 addition & 1 deletion nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_displacements_field(tmp_path, data_path, sw_tool):
assert exit_code == 0
sw_moved = nb.load('resampled.nii.gz')

nt_moved = xfm.resample(img_fname, order=0)
nt_moved = xfm.apply(img_fname, order=0)
nt_moved.to_filename('nt_resampled.nii.gz')
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
# A certain tolerance is necessary because of resampling at borders
Expand Down