Skip to content

Commit ea514d8

Browse files
committed
MAINT: Add new test to check DisplacementsFieldTransforms
1 parent ead87dd commit ea514d8

10 files changed

+67
-8
lines changed

nitransforms/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
transform
1818
"""
1919
from .linear import Affine
20-
from .nonlinear import DeformationFieldTransform
20+
from .nonlinear import DisplacementsFieldTransform
2121

2222

23-
__all__ = ['Affine', 'DeformationFieldTransform']
23+
__all__ = ['Affine', 'DisplacementsFieldTransform']

nitransforms/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""Common interface for transforms."""
1010
import numpy as np
1111
import h5py
12+
from nibabel.loadsave import load
1213

1314
from scipy import ndimage as ndi
1415

@@ -168,6 +169,9 @@ def resample(self, moving, order=3, mode='constant', cval=0.0, prefilter=True,
168169
The moving imaged after resampling to reference space.
169170
170171
"""
172+
if isinstance(moving, str):
173+
moving = load(moving)
174+
171175
moving_data = np.asanyarray(moving.dataobj)
172176
if output_dtype is None:
173177
output_dtype = moving_data.dtype

nitransforms/nonlinear.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""Nonlinear transforms."""
1010
from warnings import warn
1111
import numpy as np
12-
from scipy import ndimage as ndi
12+
# from scipy import ndimage as ndi
1313
# from gridbspline.maths import cubic
1414

1515
from .base import ImageGrid, TransformBase
@@ -18,14 +18,14 @@
1818
# vbspl = np.vectorize(cubic)
1919

2020

21-
class DeformationFieldTransform(TransformBase):
21+
class DisplacementsFieldTransform(TransformBase):
2222
"""Represents a dense field of displacements (one vector per voxel)."""
2323

2424
__slots__ = ['_field']
2525

2626
def __init__(self, field, reference=None):
2727
"""Create a dense deformation field transform."""
28-
super(DeformationFieldTransform, self).__init__()
28+
super(DisplacementsFieldTransform, self).__init__()
2929
self._field = np.asanyarray(field.dataobj)
3030

3131
ndim = self._field.ndim - 1
@@ -38,7 +38,8 @@ def __init__(self, field, reference=None):
3838
# displacement vector per voxel in output (reference)
3939
# space
4040
if reference is None:
41-
reference = four_to_three(field)[0]
41+
reference = field.__class__(np.zeros(self._field.shape[:-1]),
42+
field.affine, field.header)
4243
elif reference.shape[:ndim] != field.shape[:ndim]:
4344
raise ValueError(
4445
'Reference ({}) and field ({}) must have the same '
@@ -71,7 +72,7 @@ def map(self, x, inverse=False, index=0):
7172
>>> field = np.zeros((10, 10, 10, 3))
7273
>>> field[..., 0] = 4.0
7374
>>> fieldimg = nb.Nifti1Image(field, np.diag([2., 2., 2., 1.]))
74-
>>> xfm = DeformationFieldTransform(fieldimg)
75+
>>> xfm = DisplacementsFieldTransform(fieldimg)
7576
>>> xfm([4.0, 4.0, 4.0]).tolist()
7677
[[8.0, 4.0, 4.0]]
7778
18.8 MB
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

nitransforms/tests/test_transform.py nitransforms/tests/test_linear.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
3-
"""Tests of the transform module."""
3+
"""Tests of linear transforms."""
44
import os
55
import pytest
66
import numpy as np

nitransforms/tests/test_nonlinear.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
"""Tests of nonlinear transforms."""
4+
import os
5+
import shutil
6+
from subprocess import check_call
7+
import pytest
8+
9+
import numpy as np
10+
import nibabel as nb
11+
from ..nonlinear import DisplacementsFieldTransform
12+
13+
TESTS_BORDER_TOLERANCE = 0.05
14+
APPLY_NONLINEAR_CMD = {
15+
'itk': """\
16+
antsApplyTransforms -d 3 -r {reference} -i {moving} \
17+
-o resampled.nii.gz -n NearestNeighbor -t {transform} --float\
18+
""".format,
19+
}
20+
21+
22+
@pytest.mark.parametrize('sw_tool', ['itk'])
23+
def test_displacements_field(tmp_path, data_path, sw_tool):
24+
os.chdir(str(tmp_path))
25+
img_fname = os.path.join(data_path, 'tpl-OASIS30ANTs_T1w.nii.gz')
26+
xfm_fname = os.path.join(
27+
data_path, 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz')
28+
ants_warp = nb.load(xfm_fname)
29+
field = nb.Nifti1Image(
30+
np.squeeze(np.asanyarray(ants_warp.dataobj)),
31+
ants_warp.affine, ants_warp.header
32+
)
33+
34+
xfm = DisplacementsFieldTransform(field)
35+
36+
# Then apply the transform and cross-check with software
37+
cmd = APPLY_NONLINEAR_CMD[sw_tool](
38+
transform=os.path.abspath(xfm_fname),
39+
reference=img_fname,
40+
moving=img_fname)
41+
42+
# skip test if command is not available on host
43+
exe = cmd.split(" ", 1)[0]
44+
if not shutil.which(exe):
45+
pytest.skip("Command {} not found on host".format(exe))
46+
47+
exit_code = check_call([cmd], shell=True)
48+
assert exit_code == 0
49+
sw_moved = nb.load('resampled.nii.gz')
50+
51+
nt_moved = xfm.resample(img_fname, order=0)
52+
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
53+
# A certain tolerance is necessary because of resampling at borders
54+
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE

0 commit comments

Comments
 (0)