diff --git a/nitransforms/io/itk.py b/nitransforms/io/itk.py index 142dc336..bd0c1409 100644 --- a/nitransforms/io/itk.py +++ b/nitransforms/io/itk.py @@ -1,6 +1,8 @@ """Read/write ITK transforms.""" +import warnings import numpy as np from scipy.io import savemat as _save_mat +from nibabel.loadsave import load as loadimg from nibabel.affines import from_matvec from .base import BaseLinearTransformList, LinearParameters, _read_mat, TransformFileError @@ -245,3 +247,37 @@ def from_string(cls, string): _self.xforms.append(cls._inner_type.from_string( '#%s' % xfm)) return _self + + +class ITKDisplacementsField: + """A data structure representing displacements fields.""" + + @classmethod + def from_filename(cls, filename): + """Import a displacements field from a NIfTI file.""" + imgobj = loadimg(str(filename)) + return cls.from_image(imgobj) + + @classmethod + def from_image(cls, imgobj): + """Import a displacements field from a NIfTI file.""" + _hdr = imgobj.header.copy() + _shape = _hdr.get_data_shape() + + if ( + len(_shape) != 5 or + _shape[-2] != 1 or + not _shape[-1] in (2, 3) + ): + raise TransformFileError( + 'Displacements field "%s" does not come from ITK.' % + imgobj.file_map['image'].filename) + + if _hdr.get_intent()[0] != 'vector': + warnings.warn('Incorrect intent identified.') + _hdr.set_intent('vector') + + _field = np.squeeze(np.asanyarray(imgobj.dataobj)) + _field[..., (0, 1)] *= -1.0 + + return imgobj.__class__(_field, imgobj.affine, _hdr) diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 58d9ac4c..22853beb 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -22,19 +22,17 @@ class DisplacementsFieldTransform(TransformBase): def __init__(self, field, reference=None): """Create a dense deformation field transform.""" - super(DisplacementsFieldTransform, self).__init__() + super().__init__() self._field = np.asanyarray(field.dataobj) ndim = self._field.ndim - 1 - if len(self._field.shape[:-1]) != ndim: + if self._field.shape[-1] != ndim: raise ValueError( - 'Number of components of the deformation field does ' - 'not match the number of dimensions') + 'The number of components of the displacements (%d) does not ' + 'the number of dimensions (%d)' % (self._field.shape[-1], ndim)) - if reference is None: - reference = field.__class__(np.zeros(self._field.shape[:-1]), - field.affine, field.header) - self.reference = reference + self.reference = field.__class__(np.zeros(self._field.shape[:-1]), + field.affine, field.header) def map(self, x, inverse=False, index=0): r""" diff --git a/nitransforms/tests/test_nonlinear.py b/nitransforms/tests/test_nonlinear.py index 91818d74..f4b9af71 100644 --- a/nitransforms/tests/test_nonlinear.py +++ b/nitransforms/tests/test_nonlinear.py @@ -8,7 +8,9 @@ import numpy as np import nibabel as nb +from ..io.base import TransformFileError from ..nonlinear import DisplacementsFieldTransform +from ..io.itk import ITKDisplacementsField TESTS_BORDER_TOLERANCE = 0.05 APPLY_NONLINEAR_CMD = { @@ -19,34 +21,88 @@ } +@pytest.mark.parametrize('size', [(20, 20, 20), (20, 20, 20, 3)]) +def test_itk_disp_load(size): + """Checks field sizes.""" + with pytest.raises(TransformFileError): + ITKDisplacementsField.from_image( + nb.Nifti1Image(np.zeros(size), None, None)) + + +@pytest.mark.parametrize('size', [(20, 20, 20), (20, 20, 20, 1, 3)]) +def test_displacements_bad_sizes(size): + """Checks field sizes.""" + with pytest.raises(ValueError): + DisplacementsFieldTransform( + nb.Nifti1Image(np.zeros(size), None, None)) + + +def test_itk_disp_load_intent(): + """Checks whether the NIfTI intent is fixed.""" + with pytest.warns(UserWarning): + field = ITKDisplacementsField.from_image( + nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), None, None)) + + assert field.header.get_intent()[0] == 'vector' + + +@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique']) @pytest.mark.parametrize('sw_tool', ['itk']) -def test_displacements_field(tmp_path, data_path, sw_tool): +@pytest.mark.parametrize('axis', [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)]) +def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool, axis): + """Check a translation-only field on one or more axes, different image orientations.""" os.chdir(str(tmp_path)) - img_fname = str(data_path / 'tpl-OASIS30ANTs_T1w.nii.gz') - xfm_fname = str( - data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz') - ants_warp = nb.load(xfm_fname) - hdr = ants_warp.header.copy() + nii = get_testdata[image_orientation] + nii.to_filename('reference.nii.gz') + fieldmap = np.zeros((*nii.shape[:3], 1, 3), dtype='float32') + fieldmap[..., axis] = -10.0 + + _hdr = nii.header.copy() + _hdr.set_intent('vector') + _hdr.set_data_dtype('float32') - # fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj)) xfm_fname = 'warp.nii.gz' - nii = nb.load(img_fname) - fieldmap = np.zeros((*nii.shape[:3], 1, 3)) - fieldmap[..., 2] = -10.0 - # fieldmap = np.flip(np.flip(fieldmap, 1), 0) - ants_warp = nb.Nifti1Image(fieldmap, nii.affine, hdr) - ants_warp.to_filename(xfm_fname) - fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj)) - field = nb.Nifti1Image( - fieldmap, - ants_warp.affine, ants_warp.header - ) - - xfm = DisplacementsFieldTransform(field) + field = nb.Nifti1Image(fieldmap, nii.affine, _hdr) + field.to_filename(xfm_fname) + + xfm = DisplacementsFieldTransform( + ITKDisplacementsField.from_image(field)) # Then apply the transform and cross-check with software cmd = APPLY_NONLINEAR_CMD[sw_tool]( transform=os.path.abspath(xfm_fname), + reference=tmp_path / 'reference.nii.gz', + moving=tmp_path / 'reference.nii.gz') + + # skip test if command is not available on host + exe = cmd.split(" ", 1)[0] + if not shutil.which(exe): + pytest.skip("Command {} not found on host".format(exe)) + + exit_code = check_call([cmd], shell=True) + assert exit_code == 0 + sw_moved = nb.load('resampled.nii.gz') + + nt_moved = xfm.apply(nii, order=0) + nt_moved.to_filename('nt_resampled.nii.gz') + diff = sw_moved.get_fdata() - nt_moved.get_fdata() + # A certain tolerance is necessary because of resampling at borders + assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE + + +@pytest.mark.parametrize('sw_tool', ['itk']) +def test_displacements_field2(tmp_path, data_path, sw_tool): + """Check a translation-only field on one or more axes, different image orientations.""" + os.chdir(str(tmp_path)) + img_fname = data_path / 'tpl-OASIS30ANTs_T1w.nii.gz' + xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz' + + xfm = DisplacementsFieldTransform( + ITKDisplacementsField.from_filename(xfm_fname)) + + # Then apply the transform and cross-check with software + cmd = APPLY_NONLINEAR_CMD[sw_tool]( + transform=xfm_fname, reference=img_fname, moving=img_fname)