Skip to content

Commit e3e0327

Browse files
committed
enh: create process pool
1 parent d05fae0 commit e3e0327

File tree

1 file changed

+56
-25
lines changed

1 file changed

+56
-25
lines changed

nitransforms/resampling.py

+56-25
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Resampling utilities."""
1010

11-
from concurrent.futures import ProcessPoolExecutor
12-
from functools import partial
11+
from os import cpu_count
12+
from concurrent.futures import ProcessPoolExecutor, as_completed
1313
from pathlib import Path
1414
import numpy as np
1515
from nibabel.loadsave import load as _nbload
@@ -27,6 +27,25 @@
2727
"""Minimum number of volumes to automatically serialize 4D transforms."""
2828

2929

30+
def _apply_volume(
31+
index,
32+
data,
33+
targets,
34+
order=3,
35+
mode="constant",
36+
cval=0.0,
37+
prefilter=True,
38+
):
39+
return index, ndi.map_coordinates(
40+
data,
41+
targets,
42+
order=order,
43+
mode=mode,
44+
cval=cval,
45+
prefilter=prefilter,
46+
)
47+
48+
3049
def apply(
3150
transform,
3251
spatialimage,
@@ -137,35 +156,47 @@ def apply(
137156
else None
138157
)
139158

140-
map_coordinates = partial(
141-
ndi.map_coordinates,
142-
order=order,
143-
mode=mode,
144-
cval=cval,
145-
prefilter=prefilter,
146-
)
159+
if njobs is None:
160+
njobs = cpu_count()
147161

148-
def _apply_volume(index, data, transform, targets=None):
149-
xfm_t = transform if n_resamplings == 1 else transform[index]
162+
with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor:
163+
results = []
164+
for t in range(n_resamplings):
165+
xfm_t = transform if n_resamplings == 1 else transform[t]
150166

151-
if targets is None:
152-
targets = ImageGrid(spatialimage).index( # data should be an image
153-
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
167+
if targets is None:
168+
targets = ImageGrid(spatialimage).index( # data should be an image
169+
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
170+
)
171+
172+
data_t = (
173+
data
174+
if data is not None
175+
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
154176
)
155177

156-
data_t = (
157-
data
158-
if data is not None
159-
else spatialimage.dataobj[..., index].astype(input_dtype, copy=False)
160-
)
161-
return map_coordinates(data_t, targets)
178+
results.append(
179+
executor.submit(
180+
_apply_volume,
181+
t,
182+
data_t,
183+
targets,
184+
order=order,
185+
mode=mode,
186+
cval=cval,
187+
prefilter=prefilter,
188+
)
189+
)
162190

163-
with ProcessPoolExecutor(max_workers=njobs) as executor:
164-
results = executor.map(
165-
_apply_volume,
166-
[(t, data, transform, targets) for t in range(n_resamplings)],
191+
# Order F ensures individual volumes are contiguous in memory
192+
# Also matches NIfTI, making final save more efficient
193+
resampled = np.zeros(
194+
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
167195
)
168-
resampled = np.stack(list(results), -1)
196+
197+
for future in as_completed(results):
198+
t, resampled_t = future.result()
199+
resampled[..., t] = resampled_t
169200
else:
170201
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
171202

0 commit comments

Comments
 (0)