Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: First implementation of AFNI displacement fields #50

Merged
merged 10 commits into from
Nov 14, 2019
31 changes: 30 additions & 1 deletion nitransforms/io/afni.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from nibabel.affines import obliquity, voxel_sizes

from ..patched import shape_zoom_affine
from .base import BaseLinearTransformList, LinearParameters, TransformFileError
from .base import (
BaseLinearTransformList,
DisplacementsField,
LinearParameters,
TransformFileError,
)

LPS = np.diag([-1, -1, 1, 1])
OBLIQUITY_THRESHOLD_DEG = 0.01
Expand Down Expand Up @@ -119,5 +124,29 @@ def from_string(cls, string):
return _self


class AFNIDisplacementsField(DisplacementsField):
"""A data structure representing displacements fields."""

@classmethod
def from_image(cls, imgobj):
"""Import a displacements field from a NIfTI file."""
hdr = imgobj.header.copy()
shape = hdr.get_data_shape()

if (
len(shape) != 5 or
shape[-2] != 1 or
not shape[-1] in (2, 3)
):
raise TransformFileError(
'Displacements field "%s" does not come from AFNI.' %
imgobj.file_map['image'].filename)

field = np.squeeze(np.asanyarray(imgobj.dataobj))
field[..., (0, 1)] *= -1.0

return imgobj.__class__(field, imgobj.affine, hdr)


def _is_oblique(affine, thres=OBLIQUITY_THRESHOLD_DEG):
return (obliquity(affine).min() * 180 / pi) > thres
16 changes: 16 additions & 0 deletions nitransforms/io/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Read/write linear transforms."""
import numpy as np
from nibabel import load as loadimg
from scipy.io.matlab.miobase import get_matfile_version
from scipy.io.matlab.mio4 import MatFile4Reader
from scipy.io.matlab.mio5 import MatFile5Reader
Expand Down Expand Up @@ -157,6 +158,21 @@ def from_string(cls, string):
raise NotImplementedError


class DisplacementsField:
"""A data structure representing displacements fields."""

@classmethod
def from_filename(cls, filename):
"""Import a displacements field from a NIfTI file."""
imgobj = loadimg(str(filename))
return cls.from_image(imgobj)

@classmethod
def from_image(cls, imgobj):
"""Import a displacements field from a nibabel image object."""
raise NotImplementedError


def _read_mat(byte_stream):
mjv, _ = get_matfile_version(byte_stream)
if mjv == 0:
Expand Down
36 changes: 18 additions & 18 deletions nitransforms/io/itk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from scipy.io import savemat as _save_mat
from nibabel.loadsave import load as loadimg
from nibabel.affines import from_matvec
from .base import BaseLinearTransformList, LinearParameters, _read_mat, TransformFileError
from .base import (
BaseLinearTransformList,
DisplacementsField,
LinearParameters,
TransformFileError,
_read_mat,
)

LPS = np.diag([-1, -1, 1, 1])

Expand Down Expand Up @@ -249,35 +255,29 @@ def from_string(cls, string):
return _self


class ITKDisplacementsField:
class ITKDisplacementsField(DisplacementsField):
"""A data structure representing displacements fields."""

@classmethod
def from_filename(cls, filename):
"""Import a displacements field from a NIfTI file."""
imgobj = loadimg(str(filename))
return cls.from_image(imgobj)

@classmethod
def from_image(cls, imgobj):
"""Import a displacements field from a NIfTI file."""
_hdr = imgobj.header.copy()
_shape = _hdr.get_data_shape()
hdr = imgobj.header.copy()
shape = hdr.get_data_shape()

if (
len(_shape) != 5 or
_shape[-2] != 1 or
not _shape[-1] in (2, 3)
len(shape) != 5 or
shape[-2] != 1 or
not shape[-1] in (2, 3)
):
raise TransformFileError(
'Displacements field "%s" does not come from ITK.' %
imgobj.file_map['image'].filename)

if _hdr.get_intent()[0] != 'vector':
if hdr.get_intent()[0] != 'vector':
warnings.warn('Incorrect intent identified.')
_hdr.set_intent('vector')
hdr.set_intent('vector')

_field = np.squeeze(np.asanyarray(imgobj.dataobj))
_field[..., (0, 1)] *= -1.0
field = np.squeeze(np.asanyarray(imgobj.dataobj))
field[..., (0, 1)] *= -1.0

return imgobj.__class__(_field, imgobj.affine, _hdr)
return imgobj.__class__(field, imgobj.affine, hdr)
Binary file not shown.
18 changes: 18 additions & 0 deletions nitransforms/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import filecmp
import nibabel as nb
from nibabel.eulerangles import euler2mat
from nibabel.affines import from_matvec
from scipy.io import loadmat, savemat
Expand Down Expand Up @@ -321,3 +322,20 @@ def _mockreturn(arg):
with pytest.raises(TransformFileError):
with open('val.mat', 'rb') as f:
_read_mat(f)

@pytest.mark.parametrize('sw_tool', ['afni'])
def test_Displacements(sw_tool):
"""Test displacements fields."""

if sw_tool == 'afni':
field = nb.Nifti1Image(np.zeros((10, 10, 10)), None, None)
with pytest.raises(TransformFileError):
afni.AFNIDisplacementsField.from_image(field)

field = nb.Nifti1Image(np.zeros((10, 10, 10, 2, 3)), None, None)
with pytest.raises(TransformFileError):
afni.AFNIDisplacementsField.from_image(field)

field = nb.Nifti1Image(np.zeros((10, 10, 10, 1, 4)), None, None)
with pytest.raises(TransformFileError):
afni.AFNIDisplacementsField.from_image(field)
33 changes: 24 additions & 9 deletions nitransforms/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
from ..io.base import TransformFileError
from ..nonlinear import DisplacementsFieldTransform
from ..io.itk import ITKDisplacementsField
from ..io.afni import AFNIDisplacementsField

TESTS_BORDER_TOLERANCE = 0.05
APPLY_NONLINEAR_CMD = {
'itk': """\
antsApplyTransforms -d 3 -r {reference} -i {moving} \
-o resampled.nii.gz -n NearestNeighbor -t {transform} --float\
""".format,
'afni': """\
3dNwarpApply -nwarp {transform} -source {moving} \
-master {reference} -interp NN -prefix resampled.nii.gz
""".format,
}

Expand Down Expand Up @@ -46,8 +51,9 @@ def test_itk_disp_load_intent():
assert field.header.get_intent()[0] == 'vector'


@pytest.mark.xfail(reason="Oblique datasets not fully implemented")
@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
@pytest.mark.parametrize('sw_tool', ['itk'])
@pytest.mark.parametrize('sw_tool', ['itk', 'afni'])
@pytest.mark.parametrize('axis', [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])
def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool, axis):
"""Check a translation-only field on one or more axes, different image orientations."""
Expand All @@ -58,15 +64,20 @@ def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool
fieldmap[..., axis] = -10.0

_hdr = nii.header.copy()
_hdr.set_intent('vector')
if sw_tool in ('itk', ):
_hdr.set_intent('vector')
_hdr.set_data_dtype('float32')

xfm_fname = 'warp.nii.gz'
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
field.to_filename(xfm_fname)

xfm = DisplacementsFieldTransform(
ITKDisplacementsField.from_image(field))
if sw_tool == 'itk':
xfm = DisplacementsFieldTransform(
ITKDisplacementsField.from_image(field))
elif sw_tool == 'afni':
xfm = DisplacementsFieldTransform(
AFNIDisplacementsField.from_image(field))

# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD[sw_tool](
Expand All @@ -90,15 +101,19 @@ def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool
assert (np.abs(diff) > 1e-3).sum() / diff.size < TESTS_BORDER_TOLERANCE


@pytest.mark.parametrize('sw_tool', ['itk'])
@pytest.mark.parametrize('sw_tool', ['itk', 'afni'])
def test_displacements_field2(tmp_path, data_path, sw_tool):
"""Check a translation-only field on one or more axes, different image orientations."""
os.chdir(str(tmp_path))
img_fname = data_path / 'tpl-OASIS30ANTs_T1w.nii.gz'
xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp.nii.gz'

xfm = DisplacementsFieldTransform(
ITKDisplacementsField.from_filename(xfm_fname))
xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp_{}.nii.gz'.format(sw_tool)

if sw_tool == 'itk':
xfm = DisplacementsFieldTransform(
ITKDisplacementsField.from_filename(xfm_fname))
elif sw_tool == 'afni':
xfm = DisplacementsFieldTransform(
AFNIDisplacementsField.from_filename(xfm_fname))

# Then apply the transform and cross-check with software
cmd = APPLY_NONLINEAR_CMD[sw_tool](
Expand Down