diff --git a/nitransforms/base.py b/nitransforms/base.py index 9a1600a0..25fd88e0 100644 --- a/nitransforms/base.py +++ b/nitransforms/base.py @@ -92,7 +92,7 @@ def shape(self): class ImageGrid(SampledSpatialData): """Class to represent spaces of gridded data (images).""" - __slots__ = ["_affine", "_inverse", "_ndindex"] + __slots__ = ["_affine", "_inverse", "_ndindex", "_header"] def __init__(self, image): """Create a gridded sampling reference.""" @@ -101,6 +101,7 @@ def __init__(self, image): self._affine = image.affine self._shape = image.shape + self._header = getattr(image, "header", None) self._ndim = getattr(image, "ndim", len(image.shape)) if self._ndim >= 4: @@ -117,6 +118,11 @@ def affine(self): """Access the indexes-to-RAS affine.""" return self._affine + @property + def header(self): + """Access the original reference's header.""" + return self._header + @property def inverse(self): """Access the RAS-to-indexes affine.""" @@ -293,12 +299,15 @@ def apply( ) if isinstance(_ref, ImageGrid): # If reference is grid, reshape + hdr = None + if _ref.header is not None: + hdr = _ref.header.copy() + hdr.set_data_dtype(output_dtype) moved = spatialimage.__class__( resampled.reshape(_ref.shape).astype(output_dtype), _ref.affine, - spatialimage.header + hdr, ) - moved.set_data_dtype(output_dtype) return moved return resampled