Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: First implementation of loading and applying ITK displacements fields #42

Merged
merged 1 commit into from
Oct 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions nitransforms/io/itk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Read/write ITK transforms."""
import warnings
import numpy as np
from scipy.io import savemat as _save_mat
from nibabel.loadsave import load as loadimg
from nibabel.affines import from_matvec
from .base import BaseLinearTransformList, LinearParameters, _read_mat, TransformFileError

Expand Down Expand Up @@ -245,3 +247,37 @@ def from_string(cls, string):
_self.xforms.append(cls._inner_type.from_string(
'#%s' % xfm))
return _self


class ITKDisplacementsField:
"""A data structure representing displacements fields."""

@classmethod
def from_filename(cls, filename):
"""Import a displacements field from a NIfTI file."""
imgobj = loadimg(str(filename))
return cls.from_image(imgobj)

@classmethod
def from_image(cls, imgobj):
"""Import a displacements field from a NIfTI file."""
_hdr = imgobj.header.copy()
_shape = _hdr.get_data_shape()

if (
len(_shape) != 5 or
_shape[-2] != 1 or
not _shape[-1] in (2, 3)
):
raise TransformFileError(
'Displacements field "%s" does not come from ITK.' %
imgobj.file_map['image'].filename)

if _hdr.get_intent()[0] != 'vector':
warnings.warn('Incorrect intent identified.')
_hdr.set_intent('vector')

_field = np.squeeze(np.asanyarray(imgobj.dataobj))
_field[..., (0, 1)] *= -1.0

return imgobj.__class__(_field, imgobj.affine, _hdr)
14 changes: 6 additions & 8 deletions nitransforms/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,17 @@ class DisplacementsFieldTransform(TransformBase):

def __init__(self, field, reference=None):
"""Create a dense deformation field transform."""
super(DisplacementsFieldTransform, self).__init__()
super().__init__()
self._field = np.asanyarray(field.dataobj)

ndim = self._field.ndim - 1
if len(self._field.shape[:-1]) != ndim:
if self._field.shape[-1] != ndim:
raise ValueError(
'Number of components of the deformation field does '
'not match the number of dimensions')
'The number of components of the displacements (%d) does not '
'the number of dimensions (%d)' % (self._field.shape[-1], ndim))

if reference is None:
reference = field.__class__(np.zeros(self._field.shape[:-1]),
field.affine, field.header)
self.reference = reference
self.reference = field.__class__(np.zeros(self._field.shape[:-1]),
field.affine, field.header)

def map(self, x, inverse=False, index=0):
r"""
Expand Down
96 changes: 76 additions & 20 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import numpy as np
import nibabel as nb
from ..io.base import TransformFileError
from ..nonlinear import DisplacementsFieldTransform
from ..io.itk import ITKDisplacementsField

TESTS_BORDER_TOLERANCE = 0.05
APPLY_NONLINEAR_CMD = {
Expand All @@ -19,34 +21,88 @@
}


@pytest.mark.parametrize('size', [(20, 20, 20), (20, 20, 20, 3)])
def test_itk_disp_load(size):
"""Checks field sizes."""
with pytest.raises(TransformFileError):
ITKDisplacementsField.from_image(
nb.Nifti1Image(np.zeros(size), None, None))


@pytest.mark.parametrize('size', [(20, 20, 20), (20, 20, 20, 1, 3)])
def test_displacements_bad_sizes(size):
"""Checks field sizes."""
with pytest.raises(ValueError):
DisplacementsFieldTransform(
nb.Nifti1Image(np.zeros(size), None, None))


def test_itk_disp_load_intent():
"""Checks whether the NIfTI intent is fixed."""
with pytest.warns(UserWarning):
field = ITKDisplacementsField.from_image(
nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), None, None))

assert field.header.get_intent()[0] == 'vector'


@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
@pytest.mark.parametrize('sw_tool', ['itk'])
def test_displacements_field(tmp_path, data_path, sw_tool):
@pytest.mark.parametrize('axis', [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool, axis):
"""Check a translation-only field on one or more axes, different image orientations."""
os.chdir(str(tmp_path))
img_fname = str(data_path / 'tpl-OASIS30ANTs_T1w.nii.gz')
xfm_fname = str(
data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz')
ants_warp = nb.load(xfm_fname)
hdr = ants_warp.header.copy()
nii = get_testdata[image_orientation]
nii.to_filename('reference.nii.gz')
fieldmap = np.zeros((*nii.shape[:3], 1, 3), dtype='float32')
fieldmap[..., axis] = -10.0

_hdr = nii.header.copy()
_hdr.set_intent('vector')
_hdr.set_data_dtype('float32')

# fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj))
xfm_fname = 'warp.nii.gz'
nii = nb.load(img_fname)
fieldmap = np.zeros((*nii.shape[:3], 1, 3))
fieldmap[..., 2] = -10.0
# fieldmap = np.flip(np.flip(fieldmap, 1), 0)
ants_warp = nb.Nifti1Image(fieldmap, nii.affine, hdr)
ants_warp.to_filename(xfm_fname)
fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj))
field = nb.Nifti1Image(
fieldmap,
ants_warp.affine, ants_warp.header
)

xfm = DisplacementsFieldTransform(field)
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
field.to_filename(xfm_fname)

xfm = DisplacementsFieldTransform(
ITKDisplacementsField.from_image(field))

# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD[sw_tool](
transform=os.path.abspath(xfm_fname),
reference=tmp_path / 'reference.nii.gz',
moving=tmp_path / 'reference.nii.gz')

# skip test if command is not available on host
exe = cmd.split(" ", 1)[0]
if not shutil.which(exe):
pytest.skip("Command {} not found on host".format(exe))

exit_code = check_call([cmd], shell=True)
assert exit_code == 0
sw_moved = nb.load('resampled.nii.gz')

nt_moved = xfm.apply(nii, order=0)
nt_moved.to_filename('nt_resampled.nii.gz')
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
# A certain tolerance is necessary because of resampling at borders
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE


@pytest.mark.parametrize('sw_tool', ['itk'])
def test_displacements_field2(tmp_path, data_path, sw_tool):
"""Check a translation-only field on one or more axes, different image orientations."""
os.chdir(str(tmp_path))
img_fname = data_path / 'tpl-OASIS30ANTs_T1w.nii.gz'
xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz'

xfm = DisplacementsFieldTransform(
ITKDisplacementsField.from_filename(xfm_fname))

# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD[sw_tool](
transform=xfm_fname,
reference=img_fname,
moving=img_fname)

Expand Down