Skip to content

Commit e5c8f86

Browse files
committed
Merge pull request #28 from oesteban/enh/add-base-tests
ENH: Write tests pulling up the coverage of base submodule
2 parents 6cceb63 + 21669d1 commit e5c8f86

File tree

3 files changed

+51
-11
lines changed

3 files changed

+51
-11
lines changed

nitransforms/base.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
import numpy as np
1212
import h5py
13+
import warnings
1314
from nibabel.loadsave import load
1415

1516
from scipy import ndimage as ndi
@@ -110,12 +111,6 @@ def __init__(self):
110111
"""Instantiate a transform."""
111112
self._reference = None
112113

113-
def __eq__(self, other):
114-
"""Overload equals operator."""
115-
if not self._reference == other._reference:
116-
return False
117-
return np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL)
118-
119114
def __call__(self, x, inverse=False, index=0):
120115
"""Apply y = f(x)."""
121116
return self.map(x, inverse=inverse, index=index)
@@ -124,7 +119,7 @@ def __call__(self, x, inverse=False, index=0):
124119
def reference(self):
125120
"""Access a reference space where data will be resampled onto."""
126121
if self._reference is None:
127-
raise ValueError('Reference space not set')
122+
warnings.warn('Reference space not set')
128123
return self._reference
129124

130125
@reference.setter

nitransforms/linear.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
#
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Linear transforms."""
10-
import sys
1110
import numpy as np
1211
from pathlib import Path
1312

1413
from nibabel.loadsave import load as loadimg
1514
from nibabel.affines import from_matvec, voxel_sizes, obliquity
16-
from .base import TransformBase, _as_homogeneous
15+
from .base import TransformBase, _as_homogeneous, EQUALITY_TOL
1716
from .patched import shape_zoom_affine
1817
from . import io
1918

@@ -67,6 +66,22 @@ def __init__(self, matrix=None, reference=None):
6766
reference = loadimg(reference)
6867
self.reference = reference
6968

69+
def __eq__(self, other):
70+
"""
71+
Overload equals operator.
72+
73+
Examples
74+
--------
75+
>>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
76+
>>> xfm2 = Affine(xfm1.matrix)
77+
>>> xfm1 == xfm2
78+
True
79+
80+
"""
81+
if not self._reference == other._reference:
82+
return False
83+
return np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL)
84+
7085
@property
7186
def matrix(self):
7287
"""Access the internal representation of this affine."""

nitransforms/tests/test_base.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Tests of the base module."""
22
import numpy as np
3+
import nibabel as nb
34
import pytest
5+
import h5py
46

5-
from ..base import ImageGrid
7+
from ..base import ImageGrid, TransformBase
68

79

810
@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
@@ -30,9 +32,37 @@ def test_ImageGrid(get_testdata, image_orientation):
3032
assert idxs.shape[1] == coords.shape[1] == img.nvox == np.prod(im.shape)
3133

3234

33-
def test_ImageGrid_load(data_path, get_testdata):
35+
def test_ImageGrid_utils(tmpdir, data_path, get_testdata):
3436
"""Check that images can be objects or paths and equality."""
37+
tmpdir.chdir()
38+
3539
im1 = get_testdata['RAS']
3640
im2 = data_path / 'someones_anatomy.nii.gz'
3741

3842
assert ImageGrid(im1) == ImageGrid(im2)
43+
44+
with h5py.File('xfm.x5', 'w') as f:
45+
ImageGrid(im1)._to_hdf5(f.create_group('Reference'))
46+
47+
48+
def test_TransformBase(monkeypatch, data_path, tmpdir):
49+
"""Check the correctness of TransformBase components."""
50+
tmpdir.chdir()
51+
52+
def _fakemap(klass, x, inverse=False, index=0):
53+
return x
54+
55+
def _to_hdf5(klass, x5_root):
56+
return None
57+
58+
monkeypatch.setattr(TransformBase, 'map', _fakemap)
59+
monkeypatch.setattr(TransformBase, '_to_hdf5', _to_hdf5)
60+
fname = str(data_path / 'someones_anatomy.nii.gz')
61+
62+
xfm = TransformBase()
63+
xfm.reference = fname
64+
assert xfm.ndim == 3
65+
moved = xfm.resample(fname, order=0)
66+
assert np.all(nb.load(fname).get_fdata() == moved.get_fdata())
67+
68+
xfm.to_filename('data.x5')

0 commit comments

Comments
 (0)