Skip to content

Commit 001c3f9

Browse files
committed
fix: interface internal streamlining toward getting tests to pass
fix: interface internal streamlining toward getting tests to pass
1 parent 851df88 commit 001c3f9

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

sdcflows/interfaces/bspline.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -378,20 +378,30 @@ def _run_interface(self, runtime):
378378
)
379379

380380
# We can now write out the fieldmap
381-
# unwarp.mapped.to_filename(out_field)
382-
# self._results["out_field"] = out_field
381+
self._results["out_field"] = fname_presuffix(
382+
self.inputs.in_data,
383+
suffix="_field",
384+
newpath=runtime.cwd,
385+
)
386+
unwarp.mapped.to_filename(self._results["out_field"])
383387

384388
# HMC matrices are only necessary when reslicing the data (i.e., apply())
385389
# Check the length of in_xfms matches that of in_data
386-
self._results["out_corrected"] = unwarp.apply(
390+
self._results["out_corrected"] = fname_presuffix(
391+
self.inputs.in_data,
392+
suffix="_sdc",
393+
newpath=runtime.cwd,
394+
)
395+
396+
unwarp.apply(
387397
self.inputs.in_data,
388398
self.inputs.pe_dir,
389399
self.inputs.ro_time,
390-
xfms=self.inputs.in_xfms,
400+
xfms=self.inputs.in_xfms if isdefined(self.inputs.in_xfms) else None,
391401
num_threads=(
392402
None if not isdefined(self.inputs.num_threads) else self.inputs.num_threads
393403
),
394-
)
404+
).to_filename(self._results["out_corrected"])
395405
return runtime
396406

397407

sdcflows/transform.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def apply(
261261
cval: float = 0.0,
262262
prefilter: bool = True,
263263
output_dtype: Union[str, np.dtype] = None,
264-
num_threads: int = os.cpu_count(),
264+
num_threads: int = None,
265265
allow_negative: bool = False,
266266
):
267267
"""
@@ -329,12 +329,13 @@ def apply(
329329

330330
# Prepare data
331331
data = np.squeeze(np.asanyarray(moving.dataobj))
332+
ndim = min(data.ndim, 3)
332333
output_dtype = output_dtype or moving.header.get_data_dtype()
333334

334335
# Reference image's voxel coordinates (in voxel units)
335336
voxcoords = nt.linear.Affine(
336337
reference=moving
337-
).reference.ndindex.reshape((data.ndim - 1, *data.shape[:-1])).astype("float32")
338+
).reference.ndindex.reshape((ndim, *data.shape[:ndim])).astype("float32")
338339

339340
# The VSM is just the displacements field given in index coordinates
340341
# voxcoords is the deformation field, i.e., the target position of each voxel
@@ -358,7 +359,7 @@ def apply(
358359
mode=mode,
359360
cval=cval,
360361
prefilter=prefilter,
361-
max_concurrent=num_threads,
362+
max_concurrent=num_threads or min(os.cpu_count(), 12),
362363
))
363364

364365
if not allow_negative:

0 commit comments

Comments
 (0)