From 21669d1ab23e1d3f59daccc31879836f90d505af Mon Sep 17 00:00:00 2001
From: oesteban <code@oscaresteban.es>
Date: Wed, 23 Oct 2019 16:59:53 -0700
Subject: [PATCH] ENH: Write tests pulling up the coverage of base submodule

---
 nitransforms/base.py            |  9 ++-------
 nitransforms/linear.py          | 19 +++++++++++++++++--
 nitransforms/tests/test_base.py | 33 +++++++++++++++++++++++++++++++--
 3 files changed, 50 insertions(+), 11 deletions(-)

diff --git a/nitransforms/base.py b/nitransforms/base.py
index 6e6a4d23..154f8e33 100644
--- a/nitransforms/base.py
+++ b/nitransforms/base.py
@@ -10,6 +10,7 @@
 from pathlib import Path
 import numpy as np
 import h5py
+import warnings
 from nibabel.loadsave import load
 
 from scipy import ndimage as ndi
@@ -110,12 +111,6 @@ def __init__(self):
         """Instantiate a transform."""
         self._reference = None
 
-    def __eq__(self, other):
-        """Overload equals operator."""
-        if not self._reference == other._reference:
-            return False
-        return np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL)
-
     def __call__(self, x, inverse=False, index=0):
         """Apply y = f(x)."""
         return self.map(x, inverse=inverse, index=index)
@@ -124,7 +119,7 @@ def __call__(self, x, inverse=False, index=0):
     def reference(self):
         """Access a reference space where data will be resampled onto."""
         if self._reference is None:
-            raise ValueError('Reference space not set')
+            warnings.warn('Reference space not set')
         return self._reference
 
     @reference.setter
diff --git a/nitransforms/linear.py b/nitransforms/linear.py
index df0a2f59..46fe46cc 100644
--- a/nitransforms/linear.py
+++ b/nitransforms/linear.py
@@ -7,13 +7,12 @@
 #
 ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
 """Linear transforms."""
-import sys
 import numpy as np
 from pathlib import Path
 
 from nibabel.loadsave import load as loadimg
 from nibabel.affines import from_matvec, voxel_sizes, obliquity
-from .base import TransformBase, _as_homogeneous
+from .base import TransformBase, _as_homogeneous, EQUALITY_TOL
 from .patched import shape_zoom_affine
 from . import io
 
@@ -67,6 +66,22 @@ def __init__(self, matrix=None, reference=None):
                 reference = loadimg(reference)
             self.reference = reference
 
+    def __eq__(self, other):
+        """
+        Overload equals operator.
+
+        Examples
+        --------
+        >>> xfm1 = Affine([[1, 0, 0, 4], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
+        >>> xfm2 = Affine(xfm1.matrix)
+        >>> xfm1 == xfm2
+        True
+
+        """
+        if not self._reference == other._reference:
+            return False
+        return np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL)
+
     @property
     def matrix(self):
         """Access the internal representation of this affine."""
diff --git a/nitransforms/tests/test_base.py b/nitransforms/tests/test_base.py
index 6035afcb..2856f42e 100644
--- a/nitransforms/tests/test_base.py
+++ b/nitransforms/tests/test_base.py
@@ -1,8 +1,10 @@
 """Tests of the base module."""
 import numpy as np
+import nibabel as nb
 import pytest
+import h5py
 
-from ..base import ImageGrid
+from ..base import ImageGrid, TransformBase
 
 
 @pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
@@ -30,9 +32,36 @@ def test_ImageGrid(get_testdata, image_orientation):
     assert idxs.shape[1] == coords.shape[1] == img.nvox == np.prod(im.shape)
 
 
-def test_ImageGrid_load(data_path, get_testdata):
+def test_ImageGrid_utils(tmpdir, data_path, get_testdata):
     """Check that images can be objects or paths and equality."""
+    tmpdir.chdir()
+
     im1 = get_testdata['RAS']
     im2 = data_path / 'someones_anatomy.nii.gz'
 
     assert ImageGrid(im1) == ImageGrid(im2)
+
+    with h5py.File('xfm.x5', 'w') as f:
+        ImageGrid(im1)._to_hdf5(f.create_group('Reference'))
+
+
+def test_TransformBase(monkeypatch, data_path, tmpdir):
+    """Check the correctness of TransformBase components."""
+    tmpdir.chdir()
+
+    def _fakemap(klass, x, inverse=False, index=0):
+        return x
+
+    def _to_hdf5(klass, x5_root):
+        return None
+
+    monkeypatch.setattr(TransformBase, 'map', _fakemap)
+    monkeypatch.setattr(TransformBase, '_to_hdf5', _to_hdf5)
+    nii = nb.load(str(data_path / 'someones_anatomy.nii.gz'))
+    xfm = TransformBase()
+    xfm.reference = str(data_path / 'someones_anatomy.nii.gz')
+    assert xfm.ndim == 3
+    moved = xfm.resample(nii, order=0)
+    assert np.all(nii.get_fdata() == moved.get_fdata())
+
+    xfm.to_filename('data.x5')