Skip to content

Commit 23daabb

Browse files
committed
Merge remote-tracking branch 'jmarabotto/patch/oesteban-pr' into enh/reenable-parallelization-apply-214
2 parents 86b3d11 + 6292daf commit 23daabb

File tree

1 file changed

+70
-31
lines changed

1 file changed

+70
-31
lines changed

nitransforms/resampling.py

+70-31
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from pathlib import Path
1212
import numpy as np
1313
from nibabel.loadsave import load as _nbload
14+
from nibabel.arrayproxy import get_obj_dtype
1415
from scipy import ndimage as ndi
1516

17+
from nitransforms.linear import Affine, LinearTransformsMapping
1618
from nitransforms.base import (
1719
ImageGrid,
1820
TransformError,
@@ -23,9 +25,6 @@
2325
SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8
2426
"""Minimum number of volumes to automatically serialize 4D transforms."""
2527

26-
class NotImplementedWarning(UserWarning):
27-
"""A custom class for warnings."""
28-
2928

3029
def apply(
3130
transform,
@@ -99,49 +98,89 @@ def apply(
9998

10099
data = np.asanyarray(spatialimage.dataobj)
101100
data_nvols = 1 if data.ndim < 4 else data.shape[-1]
101+
102102
xfm_nvols = len(transform)
103103

104104
if data_nvols == 1 and xfm_nvols > 1:
105105
data = data[..., np.newaxis]
106106
elif data_nvols != xfm_nvols:
107107
raise ValueError(
108-
"The fourth dimension of the data does not match the tranform's shape."
108+
"The fourth dimension of the data does not match the transform's shape."
109109
)
110110

111111
serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf
112-
serialize_4d = max(data_nvols, xfm_nvols) > serialize_nvols
112+
serialize_4d = max(data_nvols, xfm_nvols) >= serialize_nvols
113+
113114
if serialize_4d:
114-
warn(
115-
"4D transforms serialization into 3D+t not implemented",
116-
NotImplementedWarning,
117-
stacklevel=2,
115+
# Avoid opening the data array just yet
116+
input_dtype = get_obj_dtype(spatialimage.dataobj)
117+
output_dtype = output_dtype or input_dtype
118+
119+
# Prepare physical coordinates of input (grid, points)
120+
xcoords = _ref.ndcoords.astype("f4").T
121+
122+
# Invert target's (moving) affine once
123+
ras2vox = ~Affine(spatialimage.affine)
124+
dataobj = (
125+
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
126+
if spatialimage.ndim in (2, 3)
127+
else None
118128
)
119129

120-
# For model-based nonlinear transforms, generate the corresponding dense field
121-
if hasattr(transform, "to_field") and callable(transform.to_field):
122-
targets = ImageGrid(spatialimage).index(
123-
_as_homogeneous(
124-
transform.to_field(reference=reference).map(_ref.ndcoords.T),
125-
dim=_ref.ndim,
126-
)
130+
# Order F ensures individual volumes are contiguous in memory
131+
# Also matches NIfTI, making final save more efficient
132+
resampled = np.zeros(
133+
(xcoords.shape[0], len(transform)), dtype=output_dtype, order="F"
127134
)
135+
136+
for t, xfm_t in enumerate(transform):
137+
# Map the input coordinates on to timepoint t of the target (moving)
138+
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]
139+
140+
# Calculate corresponding voxel coordinates
141+
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
142+
143+
# Interpolate
144+
resampled[..., t] = ndi.map_coordinates(
145+
(
146+
dataobj
147+
if dataobj is not None
148+
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
149+
),
150+
yvoxels.T,
151+
output=output_dtype,
152+
order=order,
153+
mode=mode,
154+
cval=cval,
155+
prefilter=prefilter,
156+
)
157+
128158
else:
129-
targets = ImageGrid(spatialimage).index( # data should be an image
130-
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
131-
)
159+
# For model-based nonlinear transforms, generate the corresponding dense field
160+
if hasattr(transform, "to_field") and callable(transform.to_field):
161+
targets = ImageGrid(spatialimage).index(
162+
_as_homogeneous(
163+
transform.to_field(reference=reference).map(_ref.ndcoords.T),
164+
dim=_ref.ndim,
165+
)
166+
)
167+
else:
168+
targets = ImageGrid(spatialimage).index( # data should be an image
169+
_as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim)
170+
)
132171

133-
if transform.ndim == 4:
134-
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
135-
136-
resampled = ndi.map_coordinates(
137-
data,
138-
targets,
139-
output=output_dtype,
140-
order=order,
141-
mode=mode,
142-
cval=cval,
143-
prefilter=prefilter,
144-
)
172+
if transform.ndim == 4:
173+
targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T
174+
175+
resampled = ndi.map_coordinates(
176+
data,
177+
targets,
178+
output=output_dtype,
179+
order=order,
180+
mode=mode,
181+
cval=cval,
182+
prefilter=prefilter,
183+
)
145184

146185
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
147186
hdr = None

0 commit comments

Comments
 (0)