Skip to content

Commit 4cfc8f8

Browse files
committed
fix: fixed error types and increased test coverage
1 parent 07ade2b commit 4cfc8f8

File tree

3 files changed

+56
-14
lines changed

3 files changed

+56
-14
lines changed

nitransforms/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,9 @@ def apply(
267267
self.reference if reference is None else SpatialReference.factory(reference)
268268
)
269269

270+
if _ref is None:
271+
raise TransformError("Cannot apply transform without reference")
272+
270273
if isinstance(spatialimage, (str, Path)):
271274
spatialimage = _nbload(str(spatialimage))
272275

nitransforms/nonlinear.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,15 @@ def __init__(self, coefficients, reference=None, order=3):
153153

154154
def to_field(self, reference=None, dtype="float32"):
155155
"""Generate a displacements deformation field from this B-Spline field."""
156-
reference = _ensure_image(reference)
157-
_ref = self.reference if reference is None else SpatialReference.factory(reference)
156+
_ref = (
157+
self.reference if reference is None else
158+
ImageGrid(_ensure_image(reference))
159+
)
158160
if _ref is None:
159-
raise ValueError("A reference must be defined")
161+
raise TransformError("A reference must be defined")
160162

161163
ndim = self._coeffs.shape[-1]
162164

163-
# If locations to be interpolated are on a grid, use faster tensor-bspline calculation
164165
if self._weights is None:
165166
self._weights = grid_bspline_weights(_ref, self._knots)
166167

@@ -185,21 +186,17 @@ def apply(
185186
):
186187
"""Apply a B-Spline transform on input data."""
187188

188-
if reference is not None:
189-
reference = _ensure_image(reference)
190-
191189
_ref = (
192-
self.reference if reference is None else SpatialReference.factory(reference)
190+
self.reference if reference is None else
191+
SpatialReference.factory(_ensure_image(reference))
193192
)
194-
195-
if isinstance(spatialimage, (str, Path)):
196-
spatialimage = _nbload(str(spatialimage))
193+
spatialimage = _ensure_image(spatialimage)
197194

198195
# If locations to be interpolated are not on a grid, run map()
199196
if not isinstance(_ref, ImageGrid):
200197
return super().apply(
201198
spatialimage,
202-
reference=reference,
199+
reference=_ref,
203200
order=order,
204201
mode=mode,
205202
cval=cval,
@@ -208,7 +205,7 @@ def apply(
208205
)
209206

210207
# If locations to be interpolated are on a grid, generate a displacements field
211-
return self.to_field().apply(
208+
return self.to_field(reference=reference).apply(
212209
spatialimage,
213210
reference=reference,
214211
order=order,

nitransforms/tests/test_nonlinear.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_itk_disp_load(size):
4141
ITKDisplacementsField.from_image(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
4242

4343

44-
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 2, 3)])
44+
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 2, 3), (20, 20, 20, 1, 4)])
4545
def test_displacements_bad_sizes(size):
4646
"""Checks field sizes."""
4747
with pytest.raises(TransformError):
@@ -58,6 +58,48 @@ def test_itk_disp_load_intent():
5858
assert field.header.get_intent()[0] == "vector"
5959

6060

61+
def test_displacements_init():
62+
DisplacementsFieldTransform(
63+
np.zeros((10, 10, 10, 3)),
64+
reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None),
65+
)
66+
67+
with pytest.raises(TransformError):
68+
DisplacementsFieldTransform(np.zeros((10, 10, 10, 3)))
69+
with pytest.raises(TransformError):
70+
DisplacementsFieldTransform(
71+
np.zeros((10, 10, 10, 3)),
72+
reference=np.zeros((10, 10, 10, 3)),
73+
)
74+
75+
76+
def test_bsplines_init():
77+
with pytest.raises(TransformError):
78+
BSplineFieldTransform(
79+
nb.Nifti1Image(np.zeros((10, 10, 10, 4)), np.eye(4), None),
80+
reference=nb.Nifti1Image(np.zeros((10, 10, 10)), np.eye(4), None),
81+
)
82+
83+
84+
def test_bsplines_references(testdata_path):
85+
with pytest.raises(TransformError):
86+
BSplineFieldTransform(
87+
testdata_path / "someones_bspline_coefficients.nii.gz"
88+
).to_field()
89+
90+
with pytest.raises(TransformError):
91+
BSplineFieldTransform(
92+
testdata_path / "someones_bspline_coefficients.nii.gz"
93+
).apply(testdata_path / "someones_anatomy.nii.gz")
94+
95+
BSplineFieldTransform(
96+
testdata_path / "someones_bspline_coefficients.nii.gz"
97+
).apply(
98+
testdata_path / "someones_anatomy.nii.gz",
99+
reference=testdata_path / "someones_anatomy.nii.gz"
100+
)
101+
102+
61103
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
62104
@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
63105
@pytest.mark.parametrize("axis", [0, 1, 2, (0, 1), (1, 2), (0, 1, 2)])

0 commit comments

Comments
 (0)