Skip to content

Commit 8ba34c9

Browse files
authored
Merge pull request #220 from nipy/enh/reenable-parallelization-apply-214-parallel
ENH: Parallelize serialized 3D+t transforms
2 parents 4c06174 + 38bb388 commit 8ba34c9

File tree

1 file changed

+54
-20
lines changed

1 file changed

+54
-20
lines changed

nitransforms/resampling.py

+54-20
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Resampling utilities."""
1010

11+
from os import cpu_count
12+
from concurrent.futures import ProcessPoolExecutor, as_completed
1113
from pathlib import Path
1214
import numpy as np
1315
from nibabel.loadsave import load as _nbload
@@ -25,6 +27,25 @@
2527
"""Minimum number of volumes to automatically serialize 4D transforms."""
2628

2729

30+
def _apply_volume(
31+
index,
32+
data,
33+
targets,
34+
order=3,
35+
mode="constant",
36+
cval=0.0,
37+
prefilter=True,
38+
):
39+
return index, ndi.map_coordinates(
40+
data,
41+
targets,
42+
order=order,
43+
mode=mode,
44+
cval=cval,
45+
prefilter=prefilter,
46+
)
47+
48+
2849
def apply(
2950
transform,
3051
spatialimage,
@@ -135,34 +156,47 @@ def apply(
135156
else None
136157
)
137158

138-
# Order F ensures individual volumes are contiguous in memory
139-
# Also matches NIfTI, making final save more efficient
140-
resampled = np.zeros(
141-
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
142-
)
159+
if njobs is None:
160+
njobs = cpu_count()
143161

144-
for t in range(n_resamplings):
145-
xfm_t = transform if n_resamplings == 1 else transform[t]
162+
with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor:
163+
results = []
164+
for t in range(n_resamplings):
165+
xfm_t = transform if n_resamplings == 1 else transform[t]
146166

147-
if targets is None:
148-
targets = ImageGrid(spatialimage).index( # data should be an image
149-
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
150-
)
167+
if targets is None:
168+
targets = ImageGrid(spatialimage).index( # data should be an image
169+
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
170+
)
151171

152-
# Interpolate
153-
resampled[..., t] = ndi.map_coordinates(
154-
(
172+
data_t = (
155173
data
156174
if data is not None
157175
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
158-
),
159-
targets,
160-
order=order,
161-
mode=mode,
162-
cval=cval,
163-
prefilter=prefilter,
176+
)
177+
178+
results.append(
179+
executor.submit(
180+
_apply_volume,
181+
t,
182+
data_t,
183+
targets,
184+
order=order,
185+
mode=mode,
186+
cval=cval,
187+
prefilter=prefilter,
188+
)
189+
)
190+
191+
# Order F ensures individual volumes are contiguous in memory
192+
# Also matches NIfTI, making final save more efficient
193+
resampled = np.zeros(
194+
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
164195
)
165196

197+
for future in as_completed(results):
198+
t, resampled_t = future.result()
199+
resampled[..., t] = resampled_t
166200
else:
167201
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
168202

0 commit comments

Comments
 (0)