21
21
# https://www.nipreps.org/community/licensing/
22
22
#
23
23
"""The :math:`B_0` unwarping transform formalism."""
24
+ import os
25
+ from functools import partial
26
+ import asyncio
24
27
from pathlib import Path
25
- from typing import Sequence , Union
28
+ from typing import Callable , List , Sequence , Union
26
29
27
30
import attr
28
31
import numpy as np
39
42
from niworkflows .interfaces .nibabel import reorient_image
40
43
41
44
45
+ async def worker (data : np .ndarray , coordinates : np .ndarray , func : Callable ) -> np .ndarray :
46
+ loop = asyncio .get_running_loop ()
47
+ result = await loop .run_in_executor (None , func , data , coordinates )
48
+ return result
49
+
50
+
51
+ async def map_coordinates_thread_pool (
52
+ fulldataset : np .ndarray ,
53
+ coordinates : np .ndarray ,
54
+ num_workers : int ,
55
+ func : Callable = ndi .map_coordinates ,
56
+ ) -> List [np .ndarray ]:
57
+ results = []
58
+ tasks = []
59
+
60
+ out_shape = fulldataset .shape [:- 1 ]
61
+ out_dtype = fulldataset .dtype
62
+
63
+ # Create a worker task for each chunk
64
+ for volume in np .rollaxis (fulldataset , 3 , 0 ):
65
+ task = asyncio .create_task (worker (volume , coordinates , func ))
66
+ tasks .append (task )
67
+
68
+ # Wait for all tasks to complete
69
+ await asyncio .gather (* tasks )
70
+
71
+ # Collect the results an
72
+ results = np .rollaxis (np .array ([
73
+ np .array (task .result (), dtype = out_dtype ).reshape (out_shape )
74
+ for task in tasks
75
+ ]), 0 , 4 )
76
+
77
+ return results
78
+
79
+
42
80
@attr .s (slots = True )
43
81
class B0FieldTransform :
44
82
"""Represents and applies the transform to correct for susceptibility distortions."""
@@ -164,7 +202,8 @@ def apply(
164
202
cval : float = 0.0 ,
165
203
prefilter : bool = True ,
166
204
output_dtype : Union [str , np .dtype ] = None ,
167
- # num_threads: int = None,
205
+ num_threads : int = os .cpu_count (),
206
+ allow_negative : bool = False ,
168
207
):
169
208
"""
170
209
Apply a transformation to an image, resampling on the reference spatial object.
@@ -215,25 +254,29 @@ def apply(
215
254
if isinstance (moving , (str , bytes , Path )):
216
255
moving = nb .load (moving )
217
256
218
- # TODO: not sure this is necessary - instead check it matches self.mapped.
257
+ # Make sure the data array has all cosines positive (i.e., no axes are flipped)
219
258
moving , axcodes = ensure_positive_cosines (moving )
220
259
221
260
self .fit (moving )
222
261
fmap = self .mapped .get_fdata ().copy ()
223
262
224
- # Reverse mapped if reversed blips
225
- if pe_dir .endswith ("-" ):
226
- fmap *= - 1.0
227
-
228
263
# Generate warp field
229
264
pe_axis = "ijk" .index (pe_dir [0 ])
230
265
266
+ axis_flip = axcodes [pe_axis ] in ("LPI" )
267
+ pe_flip = pe_dir .endswith ("-" )
268
+
269
+ # Displacements are reversed if either is true (after ensuring positive cosines)
270
+ if axis_flip ^ pe_flip :
271
+ fmap *= - 1.0
272
+
231
273
# Map voxel coordinates applying the VSM
232
274
ijk_axis = tuple ([np .arange (s ) for s in fmap .shape ])
233
275
voxcoords = np .array (
234
276
np .meshgrid (* ijk_axis , indexing = "ij" ), dtype = "float32"
235
277
).reshape (3 , - 1 )
236
278
279
+ # TODO: we probably want to do this within each resampling thread
237
280
if xfms is not None :
238
281
mov_ras2vox = np .linalg .inv (moving .affine )
239
282
# Map coordinates from reference to time-step
@@ -252,19 +295,35 @@ def apply(
252
295
253
296
# Prepare data
254
297
data = np .squeeze (np .asanyarray (moving .dataobj ))
255
- output_dtype = output_dtype or data .dtype
298
+
299
+ if data .ndim == 3 :
300
+ data = data [..., np .newaxis ]
301
+
302
+ output_dtype = output_dtype or moving .header .get_data_dtype ()
256
303
257
304
# Resample
258
- resampled = ndi .map_coordinates (
259
- data ,
260
- voxcoords ,
305
+ map_coordinates = partial (
306
+ ndi .map_coordinates ,
261
307
output = output_dtype ,
262
308
order = order ,
263
309
mode = mode ,
264
310
cval = cval ,
265
311
prefilter = prefilter ,
312
+ )
313
+
314
+ resampled = np .array (
315
+ asyncio .run (map_coordinates_thread_pool (
316
+ data ,
317
+ voxcoords ,
318
+ num_threads ,
319
+ map_coordinates
320
+ )),
321
+ dtype = output_dtype ,
266
322
).reshape (moving .shape )
267
323
324
+ if not allow_negative :
325
+ resampled [resampled < 0 ] = cval
326
+
268
327
moved = moving .__class__ (resampled , moving .affine , moving .header )
269
328
moved .header .set_data_dtype (output_dtype )
270
329
return reorient_image (moved , axcodes )
0 commit comments