Skip to content

Commit 6292daf

Browse files
Julien MarabottoJulien Marabotto
Julien Marabotto
authored and
Julien Marabotto
committed
fix: pass tests, serialization implemented
1 parent 1616a35 commit 6292daf

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

nitransforms/resampling.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,7 @@ def apply(
102102
xfm_nvols = len(transform)
103103
else:
104104
xfm_nvols = transform.ndim
105-
"""
106-
if data_nvols == 1 and xfm_nvols > 1:
107-
data = data[..., np.newaxis]
108-
elif data_nvols != xfm_nvols:
109-
raise ValueError(
110-
"The fourth dimension of the data does not match the transform's shape."
111-
)
112-
RESAMPLING FAILS. SUGGEST:
113-
"""
105+
114106
if data.ndim < transform.ndim:
115107
data = data[..., np.newaxis]
116108
elif data_nvols > 1 and data_nvols != xfm_nvols:
@@ -119,26 +111,38 @@ def apply(
119111
)
120112

121113
serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf
122-
serialize_4d = max(data_nvols, xfm_nvols) > serialize_nvols
114+
serialize_4d = max(data_nvols, xfm_nvols) >= serialize_nvols
115+
123116
if serialize_4d:
124-
for t, xfm_t in enumerate(transform):
125-
ras2vox = ~Affine(spatialimage.affine)
126-
input_dtype = get_obj_dtype(spatialimage.dataobj)
127-
output_dtype = output_dtype or input_dtype
117+
# Avoid opening the data array just yet
118+
input_dtype = get_obj_dtype(spatialimage.dataobj)
119+
output_dtype = output_dtype or input_dtype
120+
121+
# Prepare physical coordinates of input (grid, points)
122+
xcoords = _ref.ndcoords.astype("f4").T
123+
124+
# Invert target's (moving) affine once
125+
ras2vox = ~Affine(spatialimage.affine)
126+
dataobj = (
127+
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
128+
if spatialimage.ndim in (2, 3)
129+
else None
130+
)
128131

132+
# Order F ensures individual volumes are contiguous in memory
133+
# Also matches NIfTI, making final save more efficient
134+
resampled = np.zeros(
135+
(xcoords.shape[0], len(transform)), dtype=output_dtype, order="F"
136+
)
137+
138+
for t, xfm_t in enumerate(transform):
129139
# Map the input coordinates on to timepoint t of the target (moving)
130-
xcoords = _ref.ndcoords.astype("f4").T
131140
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]
132141

133142
# Calculate corresponding voxel coordinates
134143
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
135144

136145
# Interpolate
137-
dataobj = (
138-
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
139-
if spatialimage.ndim in (2, 3)
140-
else None
141-
)
142146
resampled[..., t] = ndi.map_coordinates(
143147
(
144148
dataobj

0 commit comments

Comments
 (0)