diff --git a/nitransforms/base.py b/nitransforms/base.py index 81ed1a5e..4bccf4eb 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -7,6 +7,7 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Common interface for transforms.""" + from pathlib import Path import numpy as np import h5py @@ -146,13 +147,13 @@ def from_arrays(cls, coordinates, triangles): darrays = [ nb.gifti.GiftiDataArray( coordinates.astype(np.float32), - intent=nb.nifti1.intent_codes['NIFTI_INTENT_POINTSET'], - datatype=nb.nifti1.data_type_codes['NIFTI_TYPE_FLOAT32'], + intent=nb.nifti1.intent_codes["NIFTI_INTENT_POINTSET"], + datatype=nb.nifti1.data_type_codes["NIFTI_TYPE_FLOAT32"], ), nb.gifti.GiftiDataArray( triangles.astype(np.int32), - intent=nb.nifti1.intent_codes['NIFTI_INTENT_TRIANGLE'], - datatype=nb.nifti1.data_type_codes['NIFTI_TYPE_INT32'], + intent=nb.nifti1.intent_codes["NIFTI_INTENT_TRIANGLE"], + datatype=nb.nifti1.data_type_codes["NIFTI_TYPE_INT32"], ), ] gii = nb.gifti.GiftiImage(darrays=darrays) @@ -251,14 +252,57 @@ class TransformBase: __slots__ = ( "_reference", "_ndim", + "_affine", + "_shape", + "_header", + "_grid", + "_mapping", + "_hdf5_dct", + "_x5_dct", ) - def __init__(self, reference=None): + x5_struct = { + "TransformGroup/0": { + "Type": None, + "Transform": None, + "Metadata": None, + "Inverse": None, + }, + "TransformGroup/0/Domain": {"Grid": None, "Size": None, "Mapping": None}, + "TransformGroup/1": {}, + "TransformChain": {}, + } + + def __init__( + self, + x5=None, + hdf5=None, + nifti=None, + shape=None, + affine=None, + header=None, + reference=None, + ): """Instantiate a transform.""" + self._reference = None if reference: self.reference = reference + if nifti is not None: + self._x5_dct = self.init_x5_structure(nifti) + elif hdf5: + self.update_x5_structure(hdf5) + elif x5: + self.update_x5_structure(x5) + self._shape = shape + self._affine = affine + self._header = header + + # TO-DO + self._grid = None + self._mapping = None + def __call__(self, x, inverse=False): """Apply y = f(x).""" return self.map(x, inverse=inverse) @@ -295,6 +339,12 @@ def ndim(self): """Access the dimensions of the reference space.""" raise TypeError("TransformBase has no dimensions") + def init_x5_structure(self, xfm_data=None): + self.x5_struct["TransformGroup/0/Transform"] = xfm_data + + def update_x5_structure(self, hdf5_struct=None): + self.x5_struct.update(hdf5_struct) + def map(self, x, inverse=False): r""" Apply :math:`y = f(x)`. @@ -316,33 +366,68 @@ def map(self, x, inverse=False): """ return x - def to_filename(self, filename, fmt="X5"): - """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" - with h5py.File(filename, "w") as out_file: - out_file.attrs["Format"] = "X5" - out_file.attrs["Version"] = np.uint16(1) - root = out_file.create_group("/0") - self._to_hdf5(root) - - return filename - - def _to_hdf5(self, x5_root): - """Serialize this object into the x5 file format.""" - raise NotImplementedError - def apply(self, *args, **kwargs): """Apply the transform to a dataset. Deprecated. Please use ``nitransforms.resampling.apply`` instead. """ - message = ( - "The `apply` method is deprecated. Please use `nitransforms.resampling.apply` instead." - ) + message = "The `apply` method is deprecated. Please use `nitransforms.resampling.apply` instead." warnings.warn(message, DeprecationWarning, stacklevel=2) from .resampling import apply return apply(self, *args, **kwargs) + def _to_hdf5(self, x5_root): + """Serialize this object into the x5 file format.""" + transform_group = x5_root.create_group("TransformGroup") + + """Group '0' containing Affine transform""" + transform_0 = transform_group.create_group("0") + + transform_0.attrs["Type"] = "Affine" + transform_0.create_dataset("Transform", data=self._matrix) + transform_0.create_dataset("Inverse", data=np.linalg.inv(self._matrix)) + + metadata = {"key": "value"} + transform_0.attrs["Metadata"] = str(metadata) + + """sub-group 'Domain' contained within group '0' """ + domain_group = transform_0.create_group("Domain") + domain_group.attrs["Grid"] = self.grid + domain_group.create_dataset("Size", data=_as_homogeneous(self._reference.shape)) + domain_group.create_dataset("Mapping", data=self.map) + + raise NotImplementedError + + def read_x5(self, x5_root): + variables = {} + with h5py.File(x5_root, "r") as f: + f.visititems( + lambda filename, x5_root: self._from_hdf5(filename, x5_root, variables) + ) + + _transform = variables["TransformGroup/0/Transform"] + _inverse = variables["TransformGroup/0/Inverse"] + _size = variables["TransformGroup/0/Domain/Size"] + _map = variables["TransformGroup/0/Domain/Mapping"] + + return _transform, _inverse, _size, _map + + def _from_hdf5(self, name, x5_root, storage): + if isinstance(x5_root, h5py.Dataset): + storage[name] = { + "type": "dataset", + "attrs": dict(x5_root.attrs), + "shape": x5_root.shape, + "data": x5_root[()], # Read the data + } + elif isinstance(x5_root, h5py.Group): + storage[name] = { + "type": "group", + "attrs": dict(x5_root.attrs), + "members": {}, + } + def _as_homogeneous(xyz, dtype="float32", dim=3): """ diff --git a/nitransforms/cli.py b/nitransforms/cli.py index 8f8f5ce0..7fb5e468 100644 --- a/nitransforms/cli.py +++ b/nitransforms/cli.py @@ -2,11 +2,13 @@ import os from textwrap import dedent +from nitransforms.base import TransformBase +from nitransforms.io.base import xfm_loader +from nitransforms.linear import load as linload +from nitransforms.nonlinear import load as nlinload +from nitransforms.resampling import apply -from .linear import load as linload -from .nonlinear import load as nlinload -from .resampling import apply - +import pprint def cli_apply(pargs): """ @@ -32,8 +34,8 @@ def cli_apply(pargs): xfm = ( nlinload(pargs.transform, fmt=fmt) - if pargs.nonlinear else - linload(pargs.transform, fmt=fmt) + if pargs.nonlinear + else linload(pargs.transform, fmt=fmt) ) # ensure a reference is set @@ -47,8 +49,43 @@ def cli_apply(pargs): cval=pargs.cval, prefilter=pargs.prefilter, ) - moved.to_filename(pargs.out or f"nt_{os.path.basename(pargs.moving)}") + # moved.to_filename(pargs.out or f"nt_{os.path.basename(pargs.moving)}") + + +def cli_xfm_util(pargs): + """ """ + + xfm_data = xfm_loader(pargs.transform) + xfm_x5 = TransformBase(**xfm_data) + + if pargs.info: + pprint.pprint(xfm_x5.x5_struct) + print(f"Shape:\n{xfm_x5._shape}") + print(f"Affine:\n{xfm_x5._affine}") + + if pargs.x5: + filename = f"{os.path.basename(pargs.transform).split('.')[0]}.x5" + xfm_x5.to_filename(filename) + print(f"Writing out {filename}") + + +def cli_xfm_util(pargs): + """ + """ + + xfm_data = xfm_loader(pargs.transform) + xfm_x5 = TransformBase(**xfm_data) + + if pargs.info: + pprint.pprint(xfm_x5.x5_struct) + print(f"Shape:\n{xfm_x5._shape}") + print(f"Affine:\n{xfm_x5._affine}") + if pargs.x5: + filename = f"{os.path.basename(pargs.transform).split('.')[0]}.x5" + xfm_x5.to_filename(filename) + print(f"Writing out {filename}") + def get_parser(): desc = dedent( @@ -58,6 +95,7 @@ def get_parser(): Commands: apply Apply a transformation to an image + xfm_util Assorted transform utilities For command specific information, use 'nt -h'. """ @@ -122,6 +160,17 @@ def _add_subparser(name, description): help="Determines if the image's data array is prefiltered with a spline filter before " "interpolation (default: True)", ) + + xfm_util = _add_subparser("xfm_util", cli_xfm_util.__doc__) + xfm_util.set_defaults(func=cli_xfm_util) + xfm_util.add_argument("transform", help="The transform file") + xfm_util.add_argument( + "--info", action="store_true", help="Get information about the transform" + ) + xfm_util.add_argument( + "--x5", action="store_true", help="Convert transform to .x5 file format." + ) + return parser, subparsers @@ -135,3 +184,7 @@ def main(pargs=None): subparser = subparsers.choices[pargs.command] subparser.print_help() raise (e) + + +if __name__ == "__main__": + main() diff --git a/nitransforms/io/__init__.py b/nitransforms/io/__init__.py index c38d11c2..0b6117be 100644 --- a/nitransforms/io/__init__.py +++ b/nitransforms/io/__init__.py @@ -1,7 +1,7 @@ # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Read and write transforms.""" -from nitransforms.io import afni, fsl, itk, lta +from nitransforms.io import afni, fsl, itk, lta, x5 from nitransforms.io.base import TransformIOError, TransformFileError __all__ = [ @@ -22,6 +22,7 @@ "fs": (lta, "FSLinearTransform"), "fsl": (fsl, "FSLLinearTransform"), "afni": (afni, "AFNILinearTransform"), + "x5": (x5, "X5Transform"), } diff --git a/nitransforms/io/base.py b/nitransforms/io/base.py index 3c923426..3c805bbe 100644 --- a/nitransforms/io/base.py +++ b/nitransforms/io/base.py @@ -1,11 +1,142 @@ """Read/write linear transforms.""" + from pathlib import Path import numpy as np +import nibabel as nb from nibabel import load as loadimg +import h5py + from ..patched import LabeledWrapStruct +def get_xfm_filetype(xfm_file): + path = Path(xfm_file) + ext = path.suffix + if ext == ".gz" and path.name.endswith(".nii.gz"): + return "nifti" + + file_types = { + ".nii": "nifti", + ".h5": "hdf5", + ".x5": "x5", + ".txt": "txt", + ".mat": "txt", + } + return file_types.get(ext, "unknown") + + +def gather_fields(x5=None, hdf5=None, nifti=None, shape=None, affine=None, header=None): + xfm_fields = { + "x5": x5, + "hdf5": hdf5, + "nifti": nifti, + "header": header, + "shape": shape, + "affine": affine, + } + return xfm_fields + + +def load_nifti(nifti_file): + nifti_xfm = nb.load(nifti_file) + xfm_data = nifti_xfm.get_fdata() + shape = nifti_xfm.shape + affine = nifti_xfm.affine + header = getattr(nifti_xfm, "header", None) + return gather_fields(nifti=xfm_data, shape=shape, affine=affine, header=header) + +def load_hdf5(hdf5_file): + storage = {} + + def get_hdf5_items(name, x5_root): + if isinstance(x5_root, h5py.Dataset): + storage[name] = { + "type": "dataset", + "attrs": dict(x5_root.attrs), + "shape": x5_root.shape, + "data": x5_root[()], + } + elif isinstance(x5_root, h5py.Group): + storage[name] = { + "type": "group", + "attrs": dict(x5_root.attrs), + "members": {}, + } + + with h5py.File(hdf5_file, "r") as f: + f.visititems(get_hdf5_items) + if storage: + hdf5_storage = {"hdf5": storage} + return hdf5_storage + + +def load_x5(x5_file): + load_hdf5(x5_file) + + +def load_mat(mat_file): + affine_matrix = np.loadtxt(mat_file) + affine = nb.affines.from_matvec(affine_matrix[:, :3], affine_matrix[:, 3]) + return gather_fields(affine=affine) + + +def xfm_loader(xfm_file): + loaders = { + "nifti": load_nifti, + "hdf5": load_hdf5, + "x5": load_x5, + "txt": load_mat, + "mat": load_mat, + } + xfm_filetype = get_xfm_filetype(xfm_file) + loader = loaders.get(xfm_filetype) + if loader is None: + raise ValueError(f"Unsupported file type: {xfm_filetype}") + return loader(xfm_file) + +def to_filename(self, filename, fmt="X5"): + """Store the transform in BIDS-Transforms HDF5 file format (.x5).""" + with h5py.File(filename, "w") as out_file: + out_file.attrs["Format"] = "X5" + out_file.attrs["Version"] = np.uint16(1) + root = out_file.create_group("/0") + self._to_hdf5(root) + + return filename + +def _to_hdf5(self, x5_root): + """Serialize this object into the x5 file format.""" + transform_group = x5_root.create_group("TransformGroup") + + """Group '0' containing Affine transform""" + transform_0 = transform_group.create_group("0") + transform_0.attrs["Type"] = "Affine" + transform_0.create_dataset("Transform", data=self._affine) + transform_0.create_dataset("Inverse", data=np.linalg.inv(self._affine)) + + metadata = {"key": "value"} + transform_0.attrs["Metadata"] = str(metadata) + + """sub-group 'Domain' contained within group '0' """ + domain_group = transform_0.create_group("Domain") + # domain_group.attrs["Grid"] = self._grid + # domain_group.create_dataset("Size", data=_as_homogeneous(self._reference.shape)) + # domain_group.create_dataset("Mapping", data=self.mapping) + + +def _from_x5(self, x5_root): + variables = {} + + x5_root.visititems(lambda name, x5_root: loader(name, x5_root, variables)) + + _transform = variables["TransformGroup/0/Transform"] + _inverse = variables["TransformGroup/0/Inverse"] + _size = variables["TransformGroup/0/Domain/Size"] + _mapping = variables["TransformGroup/0/Domain/Mapping"] + + return _transform, _inverse, _size, _map + class TransformIOError(IOError): """General I/O exception while reading/writing transforms.""" diff --git a/nitransforms/io/x5.py b/nitransforms/io/x5.py new file mode 100644 index 00000000..24331db7 --- /dev/null +++ b/nitransforms/io/x5.py @@ -0,0 +1,48 @@ +"""Read/write x5 transforms.""" + +from h5py import File as H5File +from nitransforms.io.base import ( + BaseLinearTransformList, +) + + +class X5Transform: + """A string-based structure for X5 linear transforms.""" + + _transform = None + + def __init__(self, parameters=None, offset=None): + return + + def __str__(self): + return + + @classmethod + def from_filename(cls, filename): + """Read the struct from a X5 file given its path.""" + if str(filename).endswith(".h5"): + with H5File(str(filename), "r") as hdf: + return cls.from_h5obj(hdf) + + @classmethod + def from_h5obj(cls, h5obj): + """Read the transformations in an X5 file.""" + xfm_list = list(h5obj.keys()) + + xfm = xfm_list["Transform"] + inv = xfm_list["Inverse"] + coords = xfm_list["Size"] + map = xfm_list["Mapping"] + + return xfm, inv, coords, map + + +class X5LinearTransformArray(BaseLinearTransformList): + """A string-based structure for series of X5 linear transforms.""" + + _inner_type = X5Transform + + @property + def xforms(self): + """Get the list of internal X5LinearTransforms.""" + return self._xforms diff --git a/nitransforms/linear.py b/nitransforms/linear.py index 71df6a16..a296f360 100644 --- a/nitransforms/linear.py +++ b/nitransforms/linear.py @@ -7,6 +7,7 @@ # ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## """Linear transforms.""" + import warnings import numpy as np from pathlib import Path @@ -180,15 +181,42 @@ def map(self, x, inverse=False): affine = self._inverse return affine.dot(coords).T[..., :-1] - def _to_hdf5(self, x5_root): + def _to_hdf5(self, x, x5_root): """Serialize this object into the x5 file format.""" - xform = x5_root.create_dataset("Transform", data=[self._matrix]) - xform.attrs["Type"] = "affine" - x5_root.create_dataset("Inverse", data=[(~self).matrix]) + transgrp = x5_root.create_group("TransformGroup") + affine = self._x5group_affine(transgrp) + coords = self._x5group_domain(x, affine) if self._reference: self.reference._to_hdf5(x5_root.create_group("Reference")) + return # nothing? + + def _x5group_affine(self, TransformGroup): + """Create group "0" for affine in x5_root/TransformGroup/ according to x5 file format""" + aff = TransformGroup.create_group("0") + aff.attrs["Type"] = "affine" # Should have shape {scalar} + aff.attrs["Metadata"] = ( + "metadata" # This is a draft for metadata. Should have shape {scalar} + ) + aff.create_dataset("Transform", data=[self._matrix]) # Should have shape {3,4} + aff.create_dataset("Inverse", data=[(~self).matrix]) # Should have shape {4,3} + return aff + + def _x5group_domain(self, x, transform): + """Create group "Domain" in x5_root/TransformGroup/0/ according to x5 file format""" + coords = transform.create_group("Domain") + coords.attrs["Grid"] = ( + "grid" # How do I interpet this 'grid'? Should have shape {scalar} + ) + coords.create_dataset( + "Size", data=_as_homogeneous(x, dim=self._matrix.shape[0] - 1).T + ) # Should have shape {3} + coords.create_dataset( + "Mapping", data=[self.map(self, x)] + ) # Should have shape {4,4} + return coords + def to_filename(self, filename, fmt="X5", moving=None): """Store the transform in the requested output format.""" writer = get_linear_factory(fmt, is_array=False)