Skip to content

Commit 77a3a3b

Browse files
authored
Merge pull request #42 from oesteban/fix/32
ENH: First implementation of loading and applying ITK displacements fields
2 parents 04ce632 + 96280eb commit 77a3a3b

File tree

3 files changed

+118
-28
lines changed

3 files changed

+118
-28
lines changed

nitransforms/io/itk.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Read/write ITK transforms."""
2+
import warnings
23
import numpy as np
34
from scipy.io import savemat as _save_mat
5+
from nibabel.loadsave import load as loadimg
46
from nibabel.affines import from_matvec
57
from .base import BaseLinearTransformList, LinearParameters, _read_mat, TransformFileError
68

@@ -245,3 +247,37 @@ def from_string(cls, string):
245247
_self.xforms.append(cls._inner_type.from_string(
246248
'#%s' % xfm))
247249
return _self
250+
251+
252+
class ITKDisplacementsField:
253+
"""A data structure representing displacements fields."""
254+
255+
@classmethod
256+
def from_filename(cls, filename):
257+
"""Import a displacements field from a NIfTI file."""
258+
imgobj = loadimg(str(filename))
259+
return cls.from_image(imgobj)
260+
261+
@classmethod
262+
def from_image(cls, imgobj):
263+
"""Import a displacements field from a NIfTI file."""
264+
_hdr = imgobj.header.copy()
265+
_shape = _hdr.get_data_shape()
266+
267+
if (
268+
len(_shape) != 5 or
269+
_shape[-2] != 1 or
270+
not _shape[-1] in (2, 3)
271+
):
272+
raise TransformFileError(
273+
'Displacements field "%s" does not come from ITK.' %
274+
imgobj.file_map['image'].filename)
275+
276+
if _hdr.get_intent()[0] != 'vector':
277+
warnings.warn('Incorrect intent identified.')
278+
_hdr.set_intent('vector')
279+
280+
_field = np.squeeze(np.asanyarray(imgobj.dataobj))
281+
_field[..., (0, 1)] *= -1.0
282+
283+
return imgobj.__class__(_field, imgobj.affine, _hdr)

nitransforms/nonlinear.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,17 @@ class DisplacementsFieldTransform(TransformBase):
2222

2323
def __init__(self, field, reference=None):
2424
"""Create a dense deformation field transform."""
25-
super(DisplacementsFieldTransform, self).__init__()
25+
super().__init__()
2626
self._field = np.asanyarray(field.dataobj)
2727

2828
ndim = self._field.ndim - 1
29-
if len(self._field.shape[:-1]) != ndim:
29+
if self._field.shape[-1] != ndim:
3030
raise ValueError(
31-
'Number of components of the deformation field does '
32-
'not match the number of dimensions')
31+
'The number of components of the displacements (%d) does not '
32+
'the number of dimensions (%d)' % (self._field.shape[-1], ndim))
3333

34-
if reference is None:
35-
reference = field.__class__(np.zeros(self._field.shape[:-1]),
36-
field.affine, field.header)
37-
self.reference = reference
34+
self.reference = field.__class__(np.zeros(self._field.shape[:-1]),
35+
field.affine, field.header)
3836

3937
def map(self, x, inverse=False, index=0):
4038
r"""

nitransforms/tests/test_nonlinear.py

+76-20
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
import numpy as np
1010
import nibabel as nb
11+
from ..io.base import TransformFileError
1112
from ..nonlinear import DisplacementsFieldTransform
13+
from ..io.itk import ITKDisplacementsField
1214

1315
TESTS_BORDER_TOLERANCE = 0.05
1416
APPLY_NONLINEAR_CMD = {
@@ -19,34 +21,88 @@
1921
}
2022

2123

24+
@pytest.mark.parametrize('size', [(20, 20, 20), (20, 20, 20, 3)])
25+
def test_itk_disp_load(size):
26+
"""Checks field sizes."""
27+
with pytest.raises(TransformFileError):
28+
ITKDisplacementsField.from_image(
29+
nb.Nifti1Image(np.zeros(size), None, None))
30+
31+
32+
@pytest.mark.parametrize('size', [(20, 20, 20), (20, 20, 20, 1, 3)])
33+
def test_displacements_bad_sizes(size):
34+
"""Checks field sizes."""
35+
with pytest.raises(ValueError):
36+
DisplacementsFieldTransform(
37+
nb.Nifti1Image(np.zeros(size), None, None))
38+
39+
40+
def test_itk_disp_load_intent():
41+
"""Checks whether the NIfTI intent is fixed."""
42+
with pytest.warns(UserWarning):
43+
field = ITKDisplacementsField.from_image(
44+
nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), None, None))
45+
46+
assert field.header.get_intent()[0] == 'vector'
47+
48+
49+
@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
2250
@pytest.mark.parametrize('sw_tool', ['itk'])
23-
def test_displacements_field(tmp_path, data_path, sw_tool):
51+
@pytest.mark.parametrize('axis', [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
52+
def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool, axis):
53+
"""Check a translation-only field on one or more axes, different image orientations."""
2454
os.chdir(str(tmp_path))
25-
img_fname = str(data_path / 'tpl-OASIS30ANTs_T1w.nii.gz')
26-
xfm_fname = str(
27-
data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz')
28-
ants_warp = nb.load(xfm_fname)
29-
hdr = ants_warp.header.copy()
55+
nii = get_testdata[image_orientation]
56+
nii.to_filename('reference.nii.gz')
57+
fieldmap = np.zeros((*nii.shape[:3], 1, 3), dtype='float32')
58+
fieldmap[..., axis] = -10.0
59+
60+
_hdr = nii.header.copy()
61+
_hdr.set_intent('vector')
62+
_hdr.set_data_dtype('float32')
3063

31-
# fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj))
3264
xfm_fname = 'warp.nii.gz'
33-
nii = nb.load(img_fname)
34-
fieldmap = np.zeros((*nii.shape[:3], 1, 3))
35-
fieldmap[..., 2] = -10.0
36-
# fieldmap = np.flip(np.flip(fieldmap, 1), 0)
37-
ants_warp = nb.Nifti1Image(fieldmap, nii.affine, hdr)
38-
ants_warp.to_filename(xfm_fname)
39-
fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj))
40-
field = nb.Nifti1Image(
41-
fieldmap,
42-
ants_warp.affine, ants_warp.header
43-
)
44-
45-
xfm = DisplacementsFieldTransform(field)
65+
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
66+
field.to_filename(xfm_fname)
67+
68+
xfm = DisplacementsFieldTransform(
69+
ITKDisplacementsField.from_image(field))
4670

4771
# Then apply the transform and cross-check with software
4872
cmd = APPLY_NONLINEAR_CMD[sw_tool](
4973
transform=os.path.abspath(xfm_fname),
74+
reference=tmp_path / 'reference.nii.gz',
75+
moving=tmp_path / 'reference.nii.gz')
76+
77+
# skip test if command is not available on host
78+
exe = cmd.split(" ", 1)[0]
79+
if not shutil.which(exe):
80+
pytest.skip("Command {} not found on host".format(exe))
81+
82+
exit_code = check_call([cmd], shell=True)
83+
assert exit_code == 0
84+
sw_moved = nb.load('resampled.nii.gz')
85+
86+
nt_moved = xfm.apply(nii, order=0)
87+
nt_moved.to_filename('nt_resampled.nii.gz')
88+
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
89+
# A certain tolerance is necessary because of resampling at borders
90+
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
91+
92+
93+
@pytest.mark.parametrize('sw_tool', ['itk'])
94+
def test_displacements_field2(tmp_path, data_path, sw_tool):
95+
"""Check a translation-only field on one or more axes, different image orientations."""
96+
os.chdir(str(tmp_path))
97+
img_fname = data_path / 'tpl-OASIS30ANTs_T1w.nii.gz'
98+
xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz'
99+
100+
xfm = DisplacementsFieldTransform(
101+
ITKDisplacementsField.from_filename(xfm_fname))
102+
103+
# Then apply the transform and cross-check with software
104+
cmd = APPLY_NONLINEAR_CMD[sw_tool](
105+
transform=xfm_fname,
50106
reference=img_fname,
51107
moving=img_fname)
52108

0 commit comments

Comments
 (0)