Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Ensure input dtype is kept after resampling #153

Merged
merged 2 commits into from
Feb 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,10 @@ def apply(
if isinstance(spatialimage, (str, Path)):
spatialimage = _nbload(str(spatialimage))

data = np.asanyarray(spatialimage.dataobj)
data = np.asanyarray(
spatialimage.dataobj,
dtype=spatialimage.get_data_dtype()
)
output_dtype = output_dtype or data.dtype
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim)
Expand All @@ -288,9 +291,11 @@ def apply(

if isinstance(_ref, ImageGrid): # If reference is grid, reshape
moved = spatialimage.__class__(
resampled.reshape(_ref.shape), _ref.affine, spatialimage.header
resampled.reshape(_ref.shape).astype(output_dtype),
_ref.affine,
spatialimage.header
)
moved.header.set_data_dtype(output_dtype)
moved.set_data_dtype(output_dtype)
return moved

return resampled
Expand Down
11 changes: 9 additions & 2 deletions nitransforms/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,26 @@ def _to_hdf5(klass, x5_root):
monkeypatch.setattr(TransformBase, "_to_hdf5", _to_hdf5)
fname = testdata_path / "someones_anatomy.nii.gz"

img = nb.load(fname)
imgdata = np.asanyarray(img.dataobj, dtype=img.get_data_dtype())

# Test identity transform
xfm = TransformBase()
xfm.reference = fname
assert xfm.ndim == 3
moved = xfm.apply(fname, order=0)
assert np.all(nb.load(str(fname)).get_fdata() == moved.get_fdata())
assert np.all(
imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype())
)

# Test identity transform - setting reference
xfm = TransformBase()
xfm.reference = fname
assert xfm.ndim == 3
moved = xfm.apply(str(fname), reference=fname, order=0)
assert np.all(nb.load(str(fname)).get_fdata() == moved.get_fdata())
assert np.all(
imgdata == np.asanyarray(moved.dataobj, dtype=moved.get_data_dtype())
)

# Test applying to Gifti
gii = nb.gifti.GiftiImage(
Expand Down
17 changes: 10 additions & 7 deletions nitransforms/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orient
diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj)

assert np.sqrt((diff ** 2).mean()) < RMSE_TOL
brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool)

cmd = APPLY_LINEAR_CMD[sw_tool](
transform=os.path.abspath(xfm_fname),
Expand All @@ -224,23 +225,25 @@ def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orient
resampled=os.path.abspath("resampled.nii.gz"),
)

brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool)

exit_code = check_call([cmd], shell=True)
assert exit_code == 0
sw_moved = nb.load("resampled.nii.gz")
sw_moved.set_data_dtype(img.get_data_dtype())

nt_moved = xfm.apply(img, order=0)
diff = (sw_moved.get_fdata() - nt_moved.get_fdata())
diff[~brainmask] = 0.0
diff[np.abs(diff) < 1e-3] = 0
diff = (
np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype())
- np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())
)

# A certain tolerance is necessary because of resampling at borders
assert np.sqrt((diff ** 2).mean()) < RMSE_TOL
assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL

nt_moved = xfm.apply("img.nii.gz", order=0)
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
diff = (
np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype())
- np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())
)
# A certain tolerance is necessary because of resampling at borders
assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL

Expand Down
10 changes: 7 additions & 3 deletions nitransforms/tests/test_manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..manip import load as _load, TransformChain
from ..linear import Affine
from .test_nonlinear import (
TESTS_BORDER_TOLERANCE,
RMSE_TOL,
APPLY_NONLINEAR_CMD,
)

Expand All @@ -38,7 +38,11 @@ def test_itk_h5(tmp_path, testdata_path):

# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD["itk"](
transform=xfm_fname, reference=ref_fname, moving=img_fname
transform=xfm_fname,
reference=ref_fname,
moving=img_fname,
output="resampled.nii.gz",
extra="",
)

# skip test if command is not available on host
Expand All @@ -54,7 +58,7 @@ def test_itk_h5(tmp_path, testdata_path):
nt_moved.to_filename("nt_resampled.nii.gz")
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
# A certain tolerance is necessary because of resampling at borders
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
assert (np.abs(diff) > 1e-3).sum() / diff.size < RMSE_TOL


@pytest.mark.parametrize("ext0", ["lta", "tfm"])
Expand Down
68 changes: 56 additions & 12 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
from ..io.itk import ITKDisplacementsField


TESTS_BORDER_TOLERANCE = 0.05
RMSE_TOL = 0.05
APPLY_NONLINEAR_CMD = {
"itk": """\
antsApplyTransforms -d 3 -r {reference} -i {moving} \
-o resampled.nii.gz -n NearestNeighbor -t {transform} --float\
-o {output} -n NearestNeighbor -t {transform} {extra}\
""".format,
"afni": """\
3dNwarpApply -nwarp {transform} -source {moving} \
-master {reference} -interp NN -prefix resampled.nii.gz
-master {reference} -interp NN -prefix {output} {extra}\
""".format,
'fsl': """\
applywarp -i {moving} -r {reference} -o resampled.nii.gz \
applywarp -i {moving} -r {reference} -o {output} {extra}\
-w {transform} --interp=nn""".format,
}

Expand Down Expand Up @@ -56,13 +56,23 @@ def test_itk_disp_load_intent():
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
@pytest.mark.parametrize("axis", [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool, axis):
def test_displacements_field1(
tmp_path,
get_testdata,
get_testmask,
image_orientation,
sw_tool,
axis,
):
"""Check a translation-only field on one or more axes, different image orientations."""
if (image_orientation, sw_tool) == ("oblique", "afni") and axis in ((1, 2), (0, 1, 2)):
pytest.skip("AFNI Deoblique unsupported.")
os.chdir(str(tmp_path))
nii = get_testdata[image_orientation]
msk = get_testmask[image_orientation]
nii.to_filename("reference.nii.gz")
msk.to_filename("mask.nii.gz")

fieldmap = np.zeros(
(*nii.shape[:3], 1, 3) if sw_tool != "fsl" else (*nii.shape[:3], 3),
dtype="float32",
Expand All @@ -83,24 +93,50 @@ def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool
# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD[sw_tool](
transform=os.path.abspath(xfm_fname),
reference=tmp_path / "reference.nii.gz",
moving=tmp_path / "reference.nii.gz",
reference=tmp_path / "mask.nii.gz",
moving=tmp_path / "mask.nii.gz",
output=tmp_path / "resampled_brainmask.nii.gz",
extra="--output-data-type uchar" if sw_tool == "itk" else "",
)

# skip test if command is not available on host
exe = cmd.split(" ", 1)[0]
if not shutil.which(exe):
pytest.skip("Command {} not found on host".format(exe))

# resample mask
exit_code = check_call([cmd], shell=True)
assert exit_code == 0
sw_moved_mask = nb.load("resampled_brainmask.nii.gz")
nt_moved_mask = xfm.apply(msk, order=0)
nt_moved_mask.set_data_dtype(msk.get_data_dtype())
diff = np.asanyarray(sw_moved_mask.dataobj) - np.asanyarray(nt_moved_mask.dataobj)

assert np.sqrt((diff ** 2).mean()) < RMSE_TOL
brainmask = np.asanyarray(nt_moved_mask.dataobj, dtype=bool)

# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD[sw_tool](
transform=os.path.abspath(xfm_fname),
reference=tmp_path / "reference.nii.gz",
moving=tmp_path / "reference.nii.gz",
output=tmp_path / "resampled.nii.gz",
extra="--output-data-type uchar" if sw_tool == "itk" else ""
)

exit_code = check_call([cmd], shell=True)
assert exit_code == 0
sw_moved = nb.load("resampled.nii.gz")

nt_moved = xfm.apply(nii, order=0)
nt_moved.to_filename("nt_resampled.nii.gz")
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
sw_moved.set_data_dtype(nt_moved.get_data_dtype())
diff = (
np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype())
- np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())
)
# A certain tolerance is necessary because of resampling at borders
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL


@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
Expand All @@ -116,7 +152,11 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool):

# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD[sw_tool](
transform=xfm_fname, reference=img_fname, moving=img_fname
transform=xfm_fname,
reference=img_fname,
moving=img_fname,
output="resampled.nii.gz",
extra="",
)

# skip test if command is not available on host
Expand All @@ -130,6 +170,10 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool):

nt_moved = xfm.apply(img_fname, order=0)
nt_moved.to_filename("nt_resampled.nii.gz")
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
sw_moved.set_data_dtype(nt_moved.get_data_dtype())
diff = (
np.asanyarray(sw_moved.dataobj, dtype=sw_moved.get_data_dtype())
- np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())
)
# A certain tolerance is necessary because of resampling at borders
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE
assert np.sqrt((diff ** 2).mean()) < RMSE_TOL