@@ -92,7 +92,7 @@ def shape(self):
92
92
class ImageGrid (SampledSpatialData ):
93
93
"""Class to represent spaces of gridded data (images)."""
94
94
95
- __slots__ = ["_affine" , "_inverse" , "_ndindex" ]
95
+ __slots__ = ["_affine" , "_inverse" , "_ndindex" , "_header" ]
96
96
97
97
def __init__ (self , image ):
98
98
"""Create a gridded sampling reference."""
@@ -101,6 +101,7 @@ def __init__(self, image):
101
101
102
102
self ._affine = image .affine
103
103
self ._shape = image .shape
104
+ self ._header = getattr (image , "header" , None )
104
105
105
106
self ._ndim = getattr (image , "ndim" , len (image .shape ))
106
107
if self ._ndim >= 4 :
@@ -117,6 +118,11 @@ def affine(self):
117
118
"""Access the indexes-to-RAS affine."""
118
119
return self ._affine
119
120
121
+ @property
122
+ def header (self ):
123
+ """Access the original reference's header."""
124
+ return self ._header
125
+
120
126
@property
121
127
def inverse (self ):
122
128
"""Access the RAS-to-indexes affine."""
@@ -293,12 +299,15 @@ def apply(
293
299
)
294
300
295
301
if isinstance (_ref , ImageGrid ): # If reference is grid, reshape
302
+ hdr = None
303
+ if _ref .header is not None :
304
+ hdr = _ref .header .copy ()
305
+ hdr .set_data_dtype (output_dtype )
296
306
moved = spatialimage .__class__ (
297
307
resampled .reshape (_ref .shape ).astype (output_dtype ),
298
308
_ref .affine ,
299
- spatialimage . header
309
+ hdr ,
300
310
)
301
- moved .set_data_dtype (output_dtype )
302
311
return moved
303
312
304
313
return resampled
0 commit comments