Skip to content

Commit 50a8e64

Browse files
committed
fix: parallelize resampling of 4D images
1 parent fe4faf1 commit 50a8e64

File tree

1 file changed

+70
-11
lines changed

1 file changed

+70
-11
lines changed

sdcflows/transform.py

+70-11
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
# https://www.nipreps.org/community/licensing/
2222
#
2323
"""The :math:`B_0` unwarping transform formalism."""
24+
import os
25+
from functools import partial
26+
import asyncio
2427
from pathlib import Path
25-
from typing import Sequence, Union
28+
from typing import Callable, List, Sequence, Union
2629

2730
import attr
2831
import numpy as np
@@ -39,6 +42,41 @@
3942
from niworkflows.interfaces.nibabel import reorient_image
4043

4144

45+
async def worker(data: np.ndarray, coordinates: np.ndarray, func: Callable) -> np.ndarray:
46+
loop = asyncio.get_running_loop()
47+
result = await loop.run_in_executor(None, func, data, coordinates)
48+
return result
49+
50+
51+
async def map_coordinates_thread_pool(
52+
fulldataset: np.ndarray,
53+
coordinates: np.ndarray,
54+
num_workers: int,
55+
func: Callable = ndi.map_coordinates,
56+
) -> List[np.ndarray]:
57+
results = []
58+
tasks = []
59+
60+
out_shape = fulldataset.shape[:-1]
61+
out_dtype = fulldataset.dtype
62+
63+
# Create a worker task for each chunk
64+
for volume in np.rollaxis(fulldataset, 3, 0):
65+
task = asyncio.create_task(worker(volume, coordinates, func))
66+
tasks.append(task)
67+
68+
# Wait for all tasks to complete
69+
await asyncio.gather(*tasks)
70+
71+
# Collect the results an
72+
results = np.rollaxis(np.array([
73+
np.array(task.result(), dtype=out_dtype).reshape(out_shape)
74+
for task in tasks
75+
]), 0, 4)
76+
77+
return results
78+
79+
4280
@attr.s(slots=True)
4381
class B0FieldTransform:
4482
"""Represents and applies the transform to correct for susceptibility distortions."""
@@ -164,7 +202,8 @@ def apply(
164202
cval: float = 0.0,
165203
prefilter: bool = True,
166204
output_dtype: Union[str, np.dtype] = None,
167-
# num_threads: int = None,
205+
num_threads: int = os.cpu_count(),
206+
allow_negative: bool = False,
168207
):
169208
"""
170209
Apply a transformation to an image, resampling on the reference spatial object.
@@ -215,25 +254,29 @@ def apply(
215254
if isinstance(moving, (str, bytes, Path)):
216255
moving = nb.load(moving)
217256

218-
# TODO: not sure this is necessary - instead check it matches self.mapped.
257+
# Make sure the data array has all cosines positive (i.e., no axes are flipped)
219258
moving, axcodes = ensure_positive_cosines(moving)
220259

221260
self.fit(moving)
222261
fmap = self.mapped.get_fdata().copy()
223262

224-
# Reverse mapped if reversed blips
225-
if pe_dir.endswith("-"):
226-
fmap *= -1.0
227-
228263
# Generate warp field
229264
pe_axis = "ijk".index(pe_dir[0])
230265

266+
axis_flip = axcodes[pe_axis] in ("LPI")
267+
pe_flip = pe_dir.endswith("-")
268+
269+
# Displacements are reversed if either is true (after ensuring positive cosines)
270+
if axis_flip ^ pe_flip:
271+
fmap *= -1.0
272+
231273
# Map voxel coordinates applying the VSM
232274
ijk_axis = tuple([np.arange(s) for s in fmap.shape])
233275
voxcoords = np.array(
234276
np.meshgrid(*ijk_axis, indexing="ij"), dtype="float32"
235277
).reshape(3, -1)
236278

279+
# TODO: we probably want to do this within each resampling thread
237280
if xfms is not None:
238281
mov_ras2vox = np.linalg.inv(moving.affine)
239282
# Map coordinates from reference to time-step
@@ -252,19 +295,35 @@ def apply(
252295

253296
# Prepare data
254297
data = np.squeeze(np.asanyarray(moving.dataobj))
255-
output_dtype = output_dtype or data.dtype
298+
299+
if data.ndim == 3:
300+
data = data[..., np.newaxis]
301+
302+
output_dtype = output_dtype or moving.header.get_data_dtype()
256303

257304
# Resample
258-
resampled = ndi.map_coordinates(
259-
data,
260-
voxcoords,
305+
map_coordinates = partial(
306+
ndi.map_coordinates,
261307
output=output_dtype,
262308
order=order,
263309
mode=mode,
264310
cval=cval,
265311
prefilter=prefilter,
312+
)
313+
314+
resampled = np.array(
315+
asyncio.run(map_coordinates_thread_pool(
316+
data,
317+
voxcoords,
318+
num_threads,
319+
map_coordinates
320+
)),
321+
dtype=output_dtype,
266322
).reshape(moving.shape)
267323

324+
if not allow_negative:
325+
resampled[resampled < 0] = cval
326+
268327
moved = moving.__class__(resampled, moving.affine, moving.header)
269328
moved.header.set_data_dtype(output_dtype)
270329
return reorient_image(moved, axcodes)

0 commit comments

Comments
 (0)