Skip to content

Commit 436ae1d

Browse files
oestebaneffigies
andcommitted
enh: port from process pool into asyncio concurrent
Co-authored-by: Chris Markiewicz <[email protected]>
1 parent 7c7608f commit 436ae1d

File tree

1 file changed

+55
-92
lines changed

1 file changed

+55
-92
lines changed

nitransforms/resampling.py

+55-92
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Resampling utilities."""
1010

11+
import asyncio
1112
from os import cpu_count
12-
from concurrent.futures import ProcessPoolExecutor, as_completed
13+
from functools import partial
1314
from pathlib import Path
14-
from typing import Tuple
15+
from typing import Callable, TypeVar
1516

1617
import numpy as np
1718
from nibabel.loadsave import load as _nbload
@@ -27,65 +28,19 @@
2728
_as_homogeneous,
2829
)
2930

31+
R = TypeVar("R")
32+
3033
SERIALIZE_VOLUME_WINDOW_WIDTH: int = 8
3134
"""Minimum number of volumes to automatically serialize 4D transforms."""
3235

3336

34-
def _apply_volume(
35-
index: int,
36-
data: np.ndarray,
37-
targets: np.ndarray,
38-
order: int = 3,
39-
mode: str = "constant",
40-
cval: float = 0.0,
41-
prefilter: bool = True,
42-
) -> Tuple[int, np.ndarray]:
43-
"""
44-
Decorate :obj:`~scipy.ndimage.map_coordinates` to return an order index for parallelization.
37+
async def worker(job: Callable[[], R], semaphore) -> R:
38+
async with semaphore:
39+
loop = asyncio.get_running_loop()
40+
return await loop.run_in_executor(None, job)
4541

46-
Parameters
47-
----------
48-
index : :obj:`int`
49-
The index of the volume to apply the interpolation to.
50-
data : :obj:`~numpy.ndarray`
51-
The input data array.
52-
targets : :obj:`~numpy.ndarray`
53-
The target coordinates for mapping.
54-
order : :obj:`int`, optional
55-
The order of the spline interpolation, default is 3.
56-
The order has to be in the range 0-5.
57-
mode : :obj:`str`, optional
58-
Determines how the input image is extended when the resamplings overflows
59-
a border. One of ``'constant'``, ``'reflect'``, ``'nearest'``, ``'mirror'``,
60-
or ``'wrap'``. Default is ``'constant'``.
61-
cval : :obj:`float`, optional
62-
Constant value for ``mode='constant'``. Default is 0.0.
63-
prefilter: :obj:`bool`, optional
64-
Determines if the image's data array is prefiltered with
65-
a spline filter before interpolation. The default is ``True``,
66-
which will create a temporary *float64* array of filtered values
67-
if *order > 1*. If setting this to ``False``, the output will be
68-
slightly blurred if *order > 1*, unless the input is prefiltered,
69-
i.e. it is the result of calling the spline filter on the original
70-
input.
71-
72-
Returns
73-
-------
74-
(:obj:`int`, :obj:`~numpy.ndarray`)
75-
The index and the array resulting from the interpolation.
76-
77-
"""
78-
return index, ndi.map_coordinates(
79-
data,
80-
targets,
81-
order=order,
82-
mode=mode,
83-
cval=cval,
84-
prefilter=prefilter,
85-
)
8642

87-
88-
def apply(
43+
async def apply(
8944
transform: TransformBase,
9045
spatialimage: str | Path | SpatialImage,
9146
reference: str | Path | SpatialImage = None,
@@ -94,9 +49,9 @@ def apply(
9449
cval: float = 0.0,
9550
prefilter: bool = True,
9651
output_dtype: np.dtype = None,
97-
serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH,
98-
njobs: int = None,
9952
dtype_width: int = 8,
53+
serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH,
54+
max_concurrent: int = min(cpu_count(), 12),
10055
) -> SpatialImage | np.ndarray:
10156
"""
10257
Apply a transformation to an image, resampling on the reference spatial object.
@@ -118,15 +73,15 @@ def apply(
11873
or ``'wrap'``. Default is ``'constant'``.
11974
cval : :obj:`float`, optional
12075
Constant value for ``mode='constant'``. Default is 0.0.
121-
prefilter: :obj:`bool`, optional
76+
prefilter : :obj:`bool`, optional
12277
Determines if the image's data array is prefiltered with
12378
a spline filter before interpolation. The default is ``True``,
12479
which will create a temporary *float64* array of filtered values
12580
if *order > 1*. If setting this to ``False``, the output will be
12681
slightly blurred if *order > 1*, unless the input is prefiltered,
12782
i.e. it is the result of calling the spline filter on the original
12883
input.
129-
output_dtype: :obj:`~numpy.dtype`, optional
84+
output_dtype : :obj:`~numpy.dtype`, optional
13085
The dtype of the returned array or image, if specified.
13186
If ``None``, the default behavior is to use the effective dtype of
13287
the input image. If slope and/or intercept are defined, the effective
@@ -135,10 +90,17 @@ def apply(
13590
If ``reference`` is defined, then the return value is an image, with
13691
a data array of the effective dtype but with the on-disk dtype set to
13792
the input image's on-disk dtype.
138-
dtype_width: :obj:`int`
93+
dtype_width : :obj:`int`
13994
Cap the width of the input data type to the given number of bytes.
14095
This argument is intended to work as a way to implement lower memory
14196
requirements in resampling.
97+
serialize_nvols : :obj:`int`
98+
Minimum number of volumes in a 3D+t (that is, a series of 3D transformations
99+
independent in time) to resample on a one-by-one basis.
100+
Serialized resampling can be executed concurrently (parallelized) with
101+
the argument ``max_concurrent``.
102+
max_concurrent : :obj:`int`
103+
Maximum number of 3D resamplings to be executed concurrently.
142104
143105
Returns
144106
-------
@@ -201,46 +163,47 @@ def apply(
201163
else None
202164
)
203165

204-
njobs = cpu_count() if njobs is None or njobs < 1 else njobs
166+
# Order F ensures individual volumes are contiguous in memory
167+
# Also matches NIfTI, making final save more efficient
168+
resampled = np.zeros(
169+
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
170+
)
205171

206-
with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor:
207-
results = []
208-
for t in range(n_resamplings):
209-
xfm_t = transform if n_resamplings == 1 else transform[t]
172+
semaphore = asyncio.Semaphore(max_concurrent)
210173

211-
if targets is None:
212-
targets = ImageGrid(spatialimage).index( # data should be an image
213-
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
214-
)
174+
tasks = []
175+
for t in range(n_resamplings):
176+
xfm_t = transform if n_resamplings == 1 else transform[t]
215177

216-
data_t = (
217-
data
218-
if data is not None
219-
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
178+
if targets is None:
179+
targets = ImageGrid(spatialimage).index( # data should be an image
180+
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
220181
)
221182

222-
results.append(
223-
executor.submit(
224-
_apply_volume,
225-
t,
226-
data_t,
227-
targets,
228-
order=order,
229-
mode=mode,
230-
cval=cval,
231-
prefilter=prefilter,
183+
data_t = (
184+
data
185+
if data is not None
186+
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
187+
)
188+
189+
tasks.append(
190+
asyncio.create_task(
191+
worker(
192+
partial(
193+
ndi.map_coordinates,
194+
data_t,
195+
targets,
196+
output=resampled[..., t],
197+
order=order,
198+
mode=mode,
199+
cval=cval,
200+
prefilter=prefilter,
201+
),
202+
semaphore,
232203
)
233204
)
234-
235-
# Order F ensures individual volumes are contiguous in memory
236-
# Also matches NIfTI, making final save more efficient
237-
resampled = np.zeros(
238-
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
239205
)
240-
241-
for future in as_completed(results):
242-
t, resampled_t = future.result()
243-
resampled[..., t] = resampled_t
206+
await asyncio.gather(*tasks)
244207
else:
245208
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
246209

0 commit comments

Comments
 (0)