Skip to content

Commit 07ade2b

Browse files
committed
fix: final amends to ensure tests pass
1 parent 4eca2fb commit 07ade2b

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

nitransforms/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(self, image):
103103
self._shape = image.shape
104104

105105
self._ndim = getattr(image, "ndim", len(image.shape))
106-
if self._ndim == 4:
106+
if self._ndim >= 4:
107107
self._shape = image.shape[:3]
108108
self._ndim = 3
109109

nitransforms/nonlinear.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def __init__(self, field, reference=None):
4848
)
4949

5050
try:
51-
self.reference = reference if reference is not None else field
51+
self.reference = ImageGrid(
52+
reference if reference is not None else field
53+
)
5254
except AttributeError:
5355
raise TransformError(
5456
"Field must be a spatial image if reference is not provided"
@@ -100,6 +102,7 @@ def map(self, x, inverse=False):
100102
[[-6.5, -36.475167989730835, -19.5], [-1.0, -42.038356602191925, -11.25]]
101103
102104
"""
105+
103106
if inverse is True:
104107
raise NotImplementedError
105108
ijk = self.reference.index(x)

nitransforms/tests/test_nonlinear.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88

99
import numpy as np
1010
import nibabel as nb
11-
from ..io.base import TransformFileError
12-
from ..nonlinear import (
11+
from nitransforms.base import TransformError
12+
from nitransforms.io.base import TransformFileError
13+
from nitransforms.nonlinear import (
1314
BSplineFieldTransform,
1415
DisplacementsFieldTransform,
1516
load as nlload,
@@ -43,7 +44,7 @@ def test_itk_disp_load(size):
4344
@pytest.mark.parametrize("size", [(20, 20, 20), (20, 20, 20, 2, 3)])
4445
def test_displacements_bad_sizes(size):
4546
"""Checks field sizes."""
46-
with pytest.raises(ValueError):
47+
with pytest.raises(TransformError):
4748
DisplacementsFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
4849

4950

0 commit comments

Comments
 (0)