Skip to content

Commit 38bb388

Browse files
committed
enh: create process pool
1 parent 754785f commit 38bb388

File tree

1 file changed

+56
-27
lines changed

1 file changed

+56
-27
lines changed

nitransforms/resampling.py

+56-27
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Resampling utilities."""
1010

11-
from functools import partial
11+
from os import cpu_count
12+
from concurrent.futures import ProcessPoolExecutor, as_completed
1213
from pathlib import Path
1314
import numpy as np
1415
from nibabel.loadsave import load as _nbload
@@ -26,6 +27,25 @@
2627
"""Minimum number of volumes to automatically serialize 4D transforms."""
2728

2829

30+
def _apply_volume(
31+
index,
32+
data,
33+
targets,
34+
order=3,
35+
mode="constant",
36+
cval=0.0,
37+
prefilter=True,
38+
):
39+
return index, ndi.map_coordinates(
40+
data,
41+
targets,
42+
order=order,
43+
mode=mode,
44+
cval=cval,
45+
prefilter=prefilter,
46+
)
47+
48+
2949
def apply(
3050
transform,
3151
spatialimage,
@@ -136,38 +156,47 @@ def apply(
136156
else None
137157
)
138158

139-
map_coordinates = partial(
140-
ndi.map_coordinates,
141-
order=order,
142-
mode=mode,
143-
cval=cval,
144-
prefilter=prefilter,
145-
)
159+
if njobs is None:
160+
njobs = cpu_count()
161+
162+
with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor:
163+
results = []
164+
for t in range(n_resamplings):
165+
xfm_t = transform if n_resamplings == 1 else transform[t]
146166

147-
def _apply_volume(index, data, transform, targets=None):
148-
xfm_t = transform if n_resamplings == 1 else transform[index]
167+
if targets is None:
168+
targets = ImageGrid(spatialimage).index( # data should be an image
169+
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
170+
)
149171

150-
if targets is None:
151-
targets = ImageGrid(spatialimage).index( # data should be an image
152-
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
172+
data_t = (
173+
data
174+
if data is not None
175+
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
153176
)
154177

155-
data_t = (
156-
data
157-
if data is not None
158-
else spatialimage.dataobj[..., index].astype(input_dtype, copy=False)
159-
)
160-
return map_coordinates(data_t, targets)
178+
results.append(
179+
executor.submit(
180+
_apply_volume,
181+
t,
182+
data_t,
183+
targets,
184+
order=order,
185+
mode=mode,
186+
cval=cval,
187+
prefilter=prefilter,
188+
)
189+
)
161190

162-
# Order F ensures individual volumes are contiguous in memory
163-
# Also matches NIfTI, making final save more efficient
164-
resampled = np.zeros(
165-
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
166-
)
167-
for t in range(n_resamplings):
168-
# Interpolate
169-
resampled[..., t] = _apply_volume(t, data, transform, targets=targets)
191+
# Order F ensures individual volumes are contiguous in memory
192+
# Also matches NIfTI, making final save more efficient
193+
resampled = np.zeros(
194+
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
195+
)
170196

197+
for future in as_completed(results):
198+
t, resampled_t = future.result()
199+
resampled[..., t] = resampled_t
171200
else:
172201
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
173202

0 commit comments

Comments
 (0)