Skip to content

Commit 9bfa598

Browse files
authored
Merge pull request #54 from mgxd/enh/nlin-load
ENH: Facilitate loading of displacements field transforms
2 parents 0ec1eb3 + 77f983c commit 9bfa598

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

nitransforms/nonlinear.py

+14
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import warnings
1111
import numpy as np
1212
from .base import TransformBase
13+
from . import io
1314

1415
# from .base import ImageGrid
1516
# from nibabel.funcs import four_to_three
@@ -75,6 +76,19 @@ def map(self, x, inverse=False, index=0):
7576
indexes = tuple(tuple(i) for i in indexes.T)
7677
return x + self._field[indexes]
7778

79+
@classmethod
80+
def from_filename(cls, filename, fmt='X5'):
81+
if fmt == 'afni':
82+
_factory = io.afni.AFNIDisplacementsField
83+
elif fmt == 'itk':
84+
_factory = io.itk.ITKDisplacementsField
85+
else:
86+
raise NotImplementedError
87+
88+
return cls(_factory.from_filename(filename))
89+
90+
91+
load = DisplacementsFieldTransform.from_filename
7892

7993
# class BSplineFieldTransform(TransformBase):
8094
# """Represent a nonlinear transform parameterized by BSpline basis."""

nitransforms/tests/test_nonlinear.py

+3-13
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import nibabel as nb
1111
from ..io.base import TransformFileError
12-
from ..nonlinear import DisplacementsFieldTransform
12+
from ..nonlinear import DisplacementsFieldTransform, load as nlload
1313
from ..io.itk import ITKDisplacementsField
1414
from ..io.afni import AFNIDisplacementsField
1515

@@ -72,12 +72,7 @@ def test_displacements_field1(tmp_path, get_testdata, image_orientation, sw_tool
7272
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
7373
field.to_filename(xfm_fname)
7474

75-
if sw_tool == 'itk':
76-
xfm = DisplacementsFieldTransform(
77-
ITKDisplacementsField.from_image(field))
78-
elif sw_tool == 'afni':
79-
xfm = DisplacementsFieldTransform(
80-
AFNIDisplacementsField.from_image(field))
75+
xfm = nlload(xfm_fname, fmt=sw_tool)
8176

8277
# Then apply the transform and cross-check with software
8378
cmd = APPLY_NONLINEAR_CMD[sw_tool](
@@ -108,12 +103,7 @@ def test_displacements_field2(tmp_path, data_path, sw_tool):
108103
img_fname = data_path / 'tpl-OASIS30ANTs_T1w.nii.gz'
109104
xfm_fname = data_path / 'ds-005_sub-01_from-OASIS_to-T1_warp_{}.nii.gz'.format(sw_tool)
110105

111-
if sw_tool == 'itk':
112-
xfm = DisplacementsFieldTransform(
113-
ITKDisplacementsField.from_filename(xfm_fname))
114-
elif sw_tool == 'afni':
115-
xfm = DisplacementsFieldTransform(
116-
AFNIDisplacementsField.from_filename(xfm_fname))
106+
xfm = nlload(xfm_fname, fmt=sw_tool)
117107

118108
# Then apply the transform and cross-check with software
119109
cmd = APPLY_NONLINEAR_CMD[sw_tool](

0 commit comments

Comments
 (0)