Skip to content

Commit 3138de0

Browse files
committed
fix: ensure input dtype is kept after resampling
Resolves: #152.
1 parent 26c2b2e commit 3138de0

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

nitransforms/base.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,10 @@ def apply(
270270
if isinstance(spatialimage, (str, Path)):
271271
spatialimage = _nbload(str(spatialimage))
272272

273-
data = np.asanyarray(spatialimage.dataobj)
273+
data = np.asanyarray(
274+
spatialimage.dataobj,
275+
dtype=spatialimage.get_data_dtype()
276+
)
274277
output_dtype = output_dtype or data.dtype
275278
targets = ImageGrid(spatialimage).index( # data should be an image
276279
_as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim)
@@ -288,9 +291,11 @@ def apply(
288291

289292
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
290293
moved = spatialimage.__class__(
291-
resampled.reshape(_ref.shape), _ref.affine, spatialimage.header
294+
resampled.reshape(_ref.shape).astype(output_dtype),
295+
_ref.affine,
296+
spatialimage.header
292297
)
293-
moved.header.set_data_dtype(output_dtype)
298+
moved.set_data_dtype(output_dtype)
294299
return moved
295300

296301
return resampled

0 commit comments

Comments
 (0)