diff --git a/nitransforms/io/__init__.py b/nitransforms/io/__init__.py index ab5afeac..c38d11c2 100644 --- a/nitransforms/io/__init__.py +++ b/nitransforms/io/__init__.py @@ -1,11 +1,34 @@ # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Read and write transforms.""" -from . import afni, fsl, itk, lta +from nitransforms.io import afni, fsl, itk, lta +from nitransforms.io.base import TransformIOError, TransformFileError __all__ = [ "afni", "fsl", "itk", "lta", + "get_linear_factory", + "TransformFileError", + "TransformIOError", ] + +_IO_TYPES = { + "itk": (itk, "ITKLinearTransform"), + "ants": (itk, "ITKLinearTransform"), + "elastix": (itk, "ITKLinearTransform"), + "lta": (lta, "FSLinearTransform"), + "fs": (lta, "FSLinearTransform"), + "fsl": (fsl, "FSLLinearTransform"), + "afni": (afni, "AFNILinearTransform"), +} + + +def get_linear_factory(fmt, is_array=True): + """Return the type required by a given format.""" + if fmt.lower() not in _IO_TYPES: + raise TypeError(f"Unsupported transform format <{fmt}>.") + + module, classname = _IO_TYPES[fmt.lower()] + return getattr(module, f"{classname}{'Array' * is_array}") diff --git a/nitransforms/io/afni.py b/nitransforms/io/afni.py index 0b81e9b0..b7fc657b 100644 --- a/nitransforms/io/afni.py +++ b/nitransforms/io/afni.py @@ -95,12 +95,16 @@ def from_string(cls, string): if not lines: raise TransformFileError - parameters = np.vstack( - ( - np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)), - (0.0, 0.0, 0.0, 1.0), + try: + parameters = np.vstack( + ( + np.genfromtxt([lines[0].encode()], dtype="f8").reshape((3, 4)), + (0.0, 0.0, 0.0, 1.0), + ) ) - ) + except ValueError as e: + raise TransformFileError from e + sa["parameters"] = parameters return tf diff --git a/nitransforms/io/base.py b/nitransforms/io/base.py index 284304b6..6d1a7c8e 100644 --- a/nitransforms/io/base.py +++ b/nitransforms/io/base.py @@ -6,8 +6,12 @@ from ..patched import LabeledWrapStruct -class TransformFileError(Exception): - """A custom exception for transform files.""" +class TransformIOError(IOError): + """General I/O exception while reading/writing transforms.""" + + +class TransformFileError(TransformIOError): + """Specific I/O exception when a file does not meet the expected format.""" class StringBasedStruct(LabeledWrapStruct): diff --git a/nitransforms/io/fsl.py b/nitransforms/io/fsl.py index 03252557..3bd4deb1 100644 --- a/nitransforms/io/fsl.py +++ b/nitransforms/io/fsl.py @@ -10,6 +10,7 @@ BaseLinearTransformList, LinearParameters, DisplacementsField, + TransformIOError, TransformFileError, _ensure_image, ) @@ -40,7 +41,7 @@ def from_ras(cls, ras, moving=None, reference=None): moving = reference if reference is None: - raise ValueError("Cannot build FSL linear transform without a reference") + raise TransformIOError("Cannot build FSL linear transform without a reference") reference = _ensure_image(reference) moving = _ensure_image(moving) @@ -77,7 +78,7 @@ def from_string(cls, string): def to_ras(self, moving=None, reference=None): """Return a nitransforms internal RAS+ matrix.""" if reference is None: - raise ValueError("Cannot build FSL linear transform without a reference") + raise TransformIOError("Cannot build FSL linear transform without a reference") if moving is None: warnings.warn( diff --git a/nitransforms/io/itk.py b/nitransforms/io/itk.py index b45a84de..4ab80f82 100644 --- a/nitransforms/io/itk.py +++ b/nitransforms/io/itk.py @@ -8,6 +8,7 @@ BaseLinearTransformList, DisplacementsField, LinearParameters, + TransformIOError, TransformFileError, ) @@ -306,7 +307,7 @@ def from_filename(cls, filename): from h5py import File as H5File if not str(filename).endswith(".h5"): - raise RuntimeError("Extension is not .h5") + raise TransformFileError("Extension is not .h5") with H5File(str(filename)) as f: return cls.from_h5obj(f) @@ -355,7 +356,7 @@ def from_h5obj(cls, fileobj, check=True): ) continue - raise NotImplementedError( + raise TransformIOError( f"Unsupported transform type {xfm['TransformType'][0]}" ) diff --git a/nitransforms/linear.py b/nitransforms/linear.py index e81ead61..00acfafb 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -14,14 +14,14 @@ from nibabel.loadsave import load as _nbload -from .base import ( +from nitransforms.base import ( ImageGrid, TransformBase, SpatialReference, _as_homogeneous, EQUALITY_TOL, ) -from . import io +from nitransforms.io import get_linear_factory, TransformFileError class Affine(TransformBase): @@ -183,51 +183,40 @@ def _to_hdf5(self, x5_root): self.reference._to_hdf5(x5_root.create_group("Reference")) def to_filename(self, filename, fmt="X5", moving=None): - """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" - if fmt.lower() in ["itk", "ants", "elastix"]: - itkobj = io.itk.ITKLinearTransform.from_ras(self.matrix) - itkobj.to_filename(filename) - return filename - - # Rest of the formats peek into moving and reference image grids - moving = ImageGrid(moving) if moving is not None else self.reference - - _factory = { - "afni": io.afni.AFNILinearTransform, - "fsl": io.fsl.FSLLinearTransform, - "lta": io.lta.FSLinearTransform, - "fs": io.lta.FSLinearTransform, - } - - if fmt not in _factory: - raise NotImplementedError(f"Unsupported format <{fmt}>") - - _factory[fmt].from_ras( - self.matrix, moving=moving, reference=self.reference - ).to_filename(filename) - return filename + """Store the transform in the requested output format.""" + writer = get_linear_factory(fmt, is_array=False) - @classmethod - def from_filename(cls, filename, fmt="X5", reference=None, moving=None): - """Create an affine from a transform file.""" if fmt.lower() in ("itk", "ants", "elastix"): - _factory = io.itk.ITKLinearTransformArray - elif fmt.lower() in ("lta", "fs"): - _factory = io.lta.FSLinearTransformArray - elif fmt.lower() == "fsl": - _factory = io.fsl.FSLLinearTransformArray - elif fmt.lower() == "afni": - _factory = io.afni.AFNILinearTransformArray + writer.from_ras(self.matrix).to_filename(filename) else: - raise NotImplementedError + # Rest of the formats peek into moving and reference image grids + writer.from_ras( + self.matrix, + reference=self.reference, + moving=ImageGrid(moving) if moving is not None else self.reference, + ).to_filename(filename) + return filename - struct = _factory.from_filename(filename) - matrix = struct.to_ras(reference=reference, moving=moving) - if cls == Affine: - if np.shape(matrix)[0] != 1: - raise TypeError("Cannot load transform array '%s'" % filename) - matrix = matrix[0] - return cls(matrix, reference=reference) + @classmethod + def from_filename(cls, filename, fmt=None, reference=None, moving=None): + """Create an affine from a transform file.""" + fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl") + + for potential_fmt in fmtlist: + try: + struct = get_linear_factory(potential_fmt).from_filename(filename) + matrix = struct.to_ras(reference=reference, moving=moving) + if cls == Affine: + if np.shape(matrix)[0] != 1: + raise TypeError("Cannot load transform array '%s'" % filename) + matrix = matrix[0] + return cls(matrix, reference=reference) + except (TransformFileError, FileNotFoundError): + continue + + raise TransformFileError( + f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})." + ) def __repr__(self): """ @@ -353,31 +342,18 @@ def map(self, x, inverse=False): return np.swapaxes(affine.dot(coords), 1, 2) def to_filename(self, filename, fmt="X5", moving=None): - """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" - if fmt.lower() in ("itk", "ants", "elastix"): - itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix) - itkobj.to_filename(filename) - return filename + """Store the transform in the requested output format.""" + writer = get_linear_factory(fmt, is_array=True) - # Rest of the formats peek into moving and reference image grids - if moving is not None: - moving = ImageGrid(moving) + if fmt.lower() in ("itk", "ants", "elastix"): + writer.from_ras(self.matrix).to_filename(filename) else: - moving = self.reference - - _factory = { - "afni": io.afni.AFNILinearTransformArray, - "fsl": io.fsl.FSLLinearTransformArray, - "lta": io.lta.FSLinearTransformArray, - "fs": io.lta.FSLinearTransformArray, - } - - if fmt not in _factory: - raise NotImplementedError(f"Unsupported format <{fmt}>") - - _factory[fmt].from_ras( - self.matrix, moving=moving, reference=self.reference - ).to_filename(filename) + # Rest of the formats peek into moving and reference image grids + writer.from_ras( + self.matrix, + reference=self.reference, + moving=ImageGrid(moving) if moving is not None else self.reference, + ).to_filename(filename) return filename def apply( @@ -486,17 +462,17 @@ def apply( return resampled -def load(filename, fmt="X5", reference=None, moving=None): +def load(filename, fmt=None, reference=None, moving=None): """ Load a linear transform file. Examples -------- - >>> xfm = load(regress_dir / "affine-LAS.itk.tfm", fmt="itk") + >>> xfm = load(regress_dir / "affine-LAS.itk.tfm") >>> isinstance(xfm, Affine) True - >>> xfm = load(regress_dir / "itktflist.tfm", fmt="itk") + >>> xfm = load(regress_dir / "itktflist.tfm") >>> isinstance(xfm, LinearTransformsMapping) True diff --git a/nitransforms/tests/test_io.py b/nitransforms/tests/test_io.py index a2b9eaaf..6c153828 100644 --- a/nitransforms/tests/test_io.py +++ b/nitransforms/tests/test_io.py @@ -1,12 +1,14 @@ # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """I/O test cases.""" +import os from subprocess import check_call from io import StringIO import filecmp import shutil import numpy as np import pytest +from h5py import File as H5File import nibabel as nb from nibabel.eulerangles import euler2mat @@ -24,7 +26,7 @@ FSLinearTransform as LT, FSLinearTransformArray as LTA, ) -from ..io.base import LinearParameters, TransformFileError +from ..io.base import LinearParameters, TransformIOError, TransformFileError LPS = np.diag([-1, -1, 1, 1]) ITK_MAT = LPS.dot(np.ones((4, 4)).dot(LPS)) @@ -224,7 +226,7 @@ def test_Linear_common(tmpdir, data_path, sw, image_orientation, get_testdata): # Test without images if sw == "fsl": - with pytest.raises(ValueError): + with pytest.raises(TransformIOError): factory.from_ras(RAS) else: xfm = factory.from_ras(RAS) @@ -408,7 +410,7 @@ def test_afni_Displacements(): afni.AFNIDisplacementsField.from_image(field) -def test_itk_h5(testdata_path): +def test_itk_h5(tmpdir, testdata_path): """Test displacements fields.""" assert ( len( @@ -422,7 +424,7 @@ def test_itk_h5(testdata_path): == 2 ) - with pytest.raises(RuntimeError): + with pytest.raises(TransformFileError): list( itk.ITKCompositeH5.from_filename( testdata_path @@ -430,6 +432,21 @@ def test_itk_h5(testdata_path): ) ) + tmpdir.chdir() + shutil.copy( + testdata_path / "ds-005_sub-01_from-T1w_to-MNI152NLin2009cAsym_mode-image_xfm.h5", + "test.h5", + ) + os.chmod("test.h5", 0o666) + + with H5File("test.h5", "r+") as h5file: + h5group = h5file["TransformGroup"] + xfm = h5group[list(h5group.keys())[1]] + xfm["TransformType"][0] = b"InventTransform" + + with pytest.raises(TransformIOError): + itk.ITKCompositeH5.from_filename("test.h5") + @pytest.mark.parametrize( "file_type, test_file", [(LTA, "from-fsnative_to-scanner_mode-image.lta")] diff --git a/nitransforms/tests/test_linear.py b/nitransforms/tests/test_linear.py index 4462b212..aed4a148 100644 --- a/nitransforms/tests/test_linear.py +++ b/nitransforms/tests/test_linear.py @@ -11,7 +11,8 @@ import nibabel as nb from nibabel.eulerangles import euler2mat from nibabel.affines import from_matvec -from .. import linear as nitl +from nitransforms import linear as nitl +from nitransforms import io from .utils import assert_affines_by_filename RMSE_TOL = 0.1 @@ -53,6 +54,18 @@ def test_linear_valueerror(): nitl.Affine(np.ones((4, 4))) +def test_linear_load_unsupported(data_path): + """Exercise loading transform without I/O implementation.""" + with pytest.raises(TypeError): + nitl.load(data_path / "itktflist2.tfm", fmt="X5") + + +def test_linear_load_mistaken(data_path): + """Exercise loading transform without I/O implementation.""" + with pytest.raises(io.TransformFileError): + nitl.load(data_path / "itktflist2.tfm", fmt="afni") + + def test_loadsave_itk(tmp_path, data_path, testdata_path): """Test idempotency.""" ref_file = testdata_path / "someones_anatomy.nii.gz" @@ -72,9 +85,13 @@ def test_loadsave_itk(tmp_path, data_path, testdata_path): ) +@pytest.mark.parametrize("autofmt", (False, True)) @pytest.mark.parametrize("fmt", ["itk", "fsl", "afni", "lta"]) -def test_loadsave(tmp_path, data_path, testdata_path, fmt): +def test_loadsave(tmp_path, data_path, testdata_path, autofmt, fmt): """Test idempotency.""" + supplied_fmt = None if autofmt else fmt + + # Load reference transform ref_file = testdata_path / "someones_anatomy.nii.gz" xfm = nitl.load(data_path / "itktflist2.tfm", fmt="itk") xfm.reference = ref_file @@ -84,33 +101,33 @@ def test_loadsave(tmp_path, data_path, testdata_path, fmt): if fmt == "fsl": # FSL should not read a transform without reference - with pytest.raises(ValueError): - nitl.load(fname, fmt=fmt) - nitl.load(fname, fmt=fmt, moving=ref_file) + with pytest.raises(io.TransformIOError): + nitl.load(fname, fmt=supplied_fmt) + nitl.load(fname, fmt=supplied_fmt, moving=ref_file) with pytest.warns(UserWarning): assert np.allclose( xfm.matrix, - nitl.load(fname, fmt=fmt, reference=ref_file).matrix, + nitl.load(fname, fmt=supplied_fmt, reference=ref_file).matrix, ) assert np.allclose( xfm.matrix, - nitl.load(fname, fmt=fmt, reference=ref_file, moving=ref_file).matrix, + nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix, ) else: - assert xfm == nitl.load(fname, fmt=fmt, reference=ref_file) + assert xfm == nitl.load(fname, fmt=supplied_fmt, reference=ref_file) xfm.to_filename(fname, fmt=fmt, moving=ref_file) if fmt == "fsl": assert np.allclose( xfm.matrix, - nitl.load(fname, fmt=fmt, reference=ref_file, moving=ref_file).matrix, + nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix, rtol=1e-2, # FSL incurs into large errors due to rounding ) else: - assert xfm == nitl.load(fname, fmt=fmt, reference=ref_file) + assert xfm == nitl.load(fname, fmt=supplied_fmt, reference=ref_file) ref_file = testdata_path / "someones_anatomy.nii.gz" xfm = nitl.load(data_path / "affine-LAS.itk.tfm", fmt="itk") @@ -120,21 +137,21 @@ def test_loadsave(tmp_path, data_path, testdata_path, fmt): if fmt == "fsl": assert np.allclose( xfm.matrix, - nitl.load(fname, fmt=fmt, reference=ref_file, moving=ref_file).matrix, + nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix, rtol=1e-2, # FSL incurs into large errors due to rounding ) else: - assert xfm == nitl.load(fname, fmt=fmt, reference=ref_file) + assert xfm == nitl.load(fname, fmt=supplied_fmt, reference=ref_file) xfm.to_filename(fname, fmt=fmt, moving=ref_file) if fmt == "fsl": assert np.allclose( xfm.matrix, - nitl.load(fname, fmt=fmt, reference=ref_file, moving=ref_file).matrix, + nitl.load(fname, fmt=supplied_fmt, reference=ref_file, moving=ref_file).matrix, rtol=1e-2, # FSL incurs into large errors due to rounding ) else: - assert xfm == nitl.load(fname, fmt=fmt, reference=ref_file) + assert xfm == nitl.load(fname, fmt=supplied_fmt, reference=ref_file) @pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"]) @@ -152,7 +169,7 @@ def test_linear_save(tmpdir, data_path, get_testdata, image_orientation, sw_tool xfm = ( nitl.Affine(T) if (sw_tool, image_orientation) != ("afni", "oblique") else # AFNI is special when moving or reference are oblique - let io do the magic - nitl.Affine(nitl.io.afni.AFNILinearTransform.from_ras(T).to_ras( + nitl.Affine(io.afni.AFNILinearTransform.from_ras(T).to_ras( reference=img, moving=img, )) @@ -199,7 +216,7 @@ def test_apply_linear_transform(tmpdir, get_testdata, get_testmask, image_orient xfm_fname = "M.%s%s" % (sw_tool, ext) # Change reference dataset for AFNI & oblique if (sw_tool, image_orientation) == ("afni", "oblique"): - nitl.io.afni.AFNILinearTransform.from_ras( + io.afni.AFNILinearTransform.from_ras( T, moving=img, reference=img,