Skip to content

Commit 129d9e3

Browse files
committed
enh(displacementsfields): early implementation of itk importing and tests
Closes #32
1 parent 04ce632 commit 129d9e3

File tree

3 files changed

+106
-24
lines changed

3 files changed

+106
-24
lines changed

nitransforms/io/itk.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Read/write ITK transforms."""
2+
import warnings
23
import numpy as np
34
from scipy.io import savemat as _save_mat
5+
from nibabel.loadsave import load as loadimg
46
from nibabel.affines import from_matvec
57
from .base import BaseLinearTransformList, LinearParameters, _read_mat, TransformFileError
68

@@ -245,3 +247,37 @@ def from_string(cls, string):
245247
_self.xforms.append(cls._inner_type.from_string(
246248
'#%s' % xfm))
247249
return _self
250+
251+
252+
class ITKDisplacementsField:
253+
"""A data structure representing displacements fields."""
254+
255+
@classmethod
256+
def from_filename(cls, filename):
257+
"""Import a displacements field from a NIfTI file."""
258+
imgobj = loadimg(str(filename))
259+
return cls.from_image(imgobj)
260+
261+
@classmethod
262+
def from_image(cls, imgobj):
263+
"""Import a displacements field from a NIfTI file."""
264+
_hdr = imgobj.header.copy()
265+
_shape = _hdr.get_data_shape()
266+
267+
if (
268+
len(_shape) != 5 or
269+
_shape[-2] != 1 or
270+
not _shape[-1] in (2, 3)
271+
):
272+
raise TransformFileError(
273+
'Displacements field "%s" does not come from ITK.' %
274+
imgobj.file_map['image'].filename)
275+
276+
if _hdr.get_intent()[0] != 'vector':
277+
warnings.warn('Incorrect intent identified.')
278+
_hdr.set_intent('vector')
279+
280+
_field = np.squeeze(np.asanyarray(imgobj.dataobj))
281+
_field[..., (0, 1)] *= -1.0
282+
283+
return imgobj.__class__(_field, imgobj.affine, _hdr)

nitransforms/nonlinear.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ def __init__(self, field, reference=None):
3131
'Number of components of the deformation field does '
3232
'not match the number of dimensions')
3333

34-
if reference is None:
35-
reference = field.__class__(np.zeros(self._field.shape[:-1]),
36-
field.affine, field.header)
37-
self.reference = reference
34+
self.reference = field.__class__(np.zeros(self._field.shape[:-1]),
35+
field.affine, field.header)
3836

3937
def map(self, x, inverse=False, index=0):
4038
r"""

nitransforms/tests/test_nonlinear.py

+68-20
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
import numpy as np
1010
import nibabel as nb
11+
from ..io.base import TransformFileError
1112
from ..nonlinear import DisplacementsFieldTransform
13+
from ..io.itk import ITKDisplacementsField
1214

1315
TESTS_BORDER_TOLERANCE = 0.05
1416
APPLY_NONLINEAR_CMD = {
@@ -19,34 +21,80 @@
1921
}
2022

2123

24+
@pytest.mark.parametrize('size', [(20, 20, 20), (20, 20, 20, 3)])
25+
def test_itk_disp_load(size):
26+
"""Checks field sizes."""
27+
with pytest.raises(TransformFileError):
28+
ITKDisplacementsField.from_image(
29+
nb.Nifti1Image(np.zeros(size), None, None))
30+
31+
32+
def test_itk_disp_load_intent():
33+
"""Checks whether the NIfTI intent is fixed."""
34+
with pytest.warns(UserWarning):
35+
field = ITKDisplacementsField.from_image(
36+
nb.Nifti1Image(np.zeros((20, 20, 20, 1, 3)), None, None))
37+
38+
assert field.header.get_intent()[0] == 'vector'
39+
40+
41+
@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
2242
@pytest.mark.parametrize('sw_tool', ['itk'])
23-
def test_displacements_field(tmp_path, data_path, sw_tool):
43+
@pytest.mark.parametrize('axis', [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
44+
def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool, axis):
45+
"""Check a translation-only field on one or more axes, different image orientations."""
2446
os.chdir(str(tmp_path))
25-
img_fname = str(data_path / 'tpl-OASIS30ANTs_T1w.nii.gz')
26-
xfm_fname = str(
27-
data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz')
28-
ants_warp = nb.load(xfm_fname)
29-
hdr = ants_warp.header.copy()
47+
nii = get_testdata[image_orientation]
48+
nii.to_filename('reference.nii.gz')
49+
fieldmap = np.zeros((*nii.shape[:3], 1, 3), dtype='float32')
50+
fieldmap[..., axis] = -10.0
51+
52+
_hdr = nii.header.copy()
53+
_hdr.set_intent('vector')
54+
_hdr.set_data_dtype('float32')
3055

31-
# fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj))
3256
xfm_fname = 'warp.nii.gz'
33-
nii = nb.load(img_fname)
34-
fieldmap = np.zeros((*nii.shape[:3], 1, 3))
35-
fieldmap[..., 2] = -10.0
36-
# fieldmap = np.flip(np.flip(fieldmap, 1), 0)
37-
ants_warp = nb.Nifti1Image(fieldmap, nii.affine, hdr)
38-
ants_warp.to_filename(xfm_fname)
39-
fieldmap = np.squeeze(np.asanyarray(ants_warp.dataobj))
40-
field = nb.Nifti1Image(
41-
fieldmap,
42-
ants_warp.affine, ants_warp.header
43-
)
44-
45-
xfm = DisplacementsFieldTransform(field)
57+
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
58+
field.to_filename(xfm_fname)
59+
60+
xfm = DisplacementsFieldTransform(
61+
ITKDisplacementsField.from_image(field))
4662

4763
# Then apply the transform and cross-check with software
4864
cmd = APPLY_NONLINEAR_CMD[sw_tool](
4965
transform=os.path.abspath(xfm_fname),
66+
reference=tmp_path / 'reference.nii.gz',
67+
moving=tmp_path / 'reference.nii.gz')
68+
69+
# skip test if command is not available on host
70+
exe = cmd.split(" ", 1)[0]
71+
if not shutil.which(exe):
72+
pytest.skip("Command {} not found on host".format(exe))
73+
74+
exit_code = check_call([cmd], shell=True)
75+
assert exit_code == 0
76+
sw_moved = nb.load('resampled.nii.gz')
77+
78+
nt_moved = xfm.apply(nii, order=0)
79+
nt_moved.to_filename('nt_resampled.nii.gz')
80+
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
81+
# A certain tolerance is necessary because of resampling at borders
82+
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
83+
84+
85+
@pytest.mark.parametrize('sw_tool', ['itk'])
86+
def test_displacements_field2(tmp_path, data_path, sw_tool):
87+
"""Check a translation-only field on one or more axes, different image orientations."""
88+
os.chdir(str(tmp_path))
89+
img_fname = data_path / 'tpl-OASIS30ANTs_T1w.nii.gz'
90+
xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz'
91+
92+
xfm = DisplacementsFieldTransform(
93+
ITKDisplacementsField.from_filename(xfm_fname))
94+
95+
# Then apply the transform and cross-check with software
96+
cmd = APPLY_NONLINEAR_CMD[sw_tool](
97+
transform=xfm_fname,
5098
reference=img_fname,
5199
moving=img_fname)
52100

0 commit comments

Comments
 (0)