Skip to content

Commit 8a18581

Browse files
authored
Merge pull request #195 from jmarabotto/sandbox
ENH: Outsource ``apply()`` from transform objects
2 parents f28cd14 + 5b1736b commit 8a18581

9 files changed

+280
-366
lines changed

nitransforms/base.py

+9-98
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from nibabel import funcs as _nbfuncs
1616
from nibabel.nifti1 import intent_codes as INTENT_CODES
1717
from nibabel.cifti2 import Cifti2Image
18-
from scipy import ndimage as ndi
1918

2019
EQUALITY_TOL = 1e-5
2120

@@ -178,7 +177,10 @@ def __ne__(self, other):
178177
class TransformBase:
179178
"""Abstract image class to represent transforms."""
180179

181-
__slots__ = ("_reference", "_ndim",)
180+
__slots__ = (
181+
"_reference",
182+
"_ndim",
183+
)
182184

183185
def __init__(self, reference=None):
184186
"""Instantiate a transform."""
@@ -222,101 +224,6 @@ def ndim(self):
222224
"""Access the dimensions of the reference space."""
223225
raise TypeError("TransformBase has no dimensions")
224226

225-
def apply(
226-
self,
227-
spatialimage,
228-
reference=None,
229-
order=3,
230-
mode="constant",
231-
cval=0.0,
232-
prefilter=True,
233-
output_dtype=None,
234-
):
235-
"""
236-
Apply a transformation to an image, resampling on the reference spatial object.
237-
238-
Parameters
239-
----------
240-
spatialimage : `spatialimage`
241-
The image object containing the data to be resampled in reference
242-
space
243-
reference : spatial object, optional
244-
The image, surface, or combination thereof containing the coordinates
245-
of samples that will be sampled.
246-
order : int, optional
247-
The order of the spline interpolation, default is 3.
248-
The order has to be in the range 0-5.
249-
mode : {'constant', 'reflect', 'nearest', 'mirror', 'wrap'}, optional
250-
Determines how the input image is extended when the resamplings overflows
251-
a border. Default is 'constant'.
252-
cval : float, optional
253-
Constant value for ``mode='constant'``. Default is 0.0.
254-
prefilter: bool, optional
255-
Determines if the image's data array is prefiltered with
256-
a spline filter before interpolation. The default is ``True``,
257-
which will create a temporary *float64* array of filtered values
258-
if *order > 1*. If setting this to ``False``, the output will be
259-
slightly blurred if *order > 1*, unless the input is prefiltered,
260-
i.e. it is the result of calling the spline filter on the original
261-
input.
262-
output_dtype: dtype specifier, optional
263-
The dtype of the returned array or image, if specified.
264-
If ``None``, the default behavior is to use the effective dtype of
265-
the input image. If slope and/or intercept are defined, the effective
266-
dtype is float64, otherwise it is equivalent to the input image's
267-
``get_data_dtype()`` (on-disk type).
268-
If ``reference`` is defined, then the return value is an image, with
269-
a data array of the effective dtype but with the on-disk dtype set to
270-
the input image's on-disk dtype.
271-
272-
Returns
273-
-------
274-
resampled : `spatialimage` or ndarray
275-
The data imaged after resampling to reference space.
276-
277-
"""
278-
if reference is not None and isinstance(reference, (str, Path)):
279-
reference = _nbload(str(reference))
280-
281-
_ref = (
282-
self.reference if reference is None else SpatialReference.factory(reference)
283-
)
284-
285-
if _ref is None:
286-
raise TransformError("Cannot apply transform without reference")
287-
288-
if isinstance(spatialimage, (str, Path)):
289-
spatialimage = _nbload(str(spatialimage))
290-
291-
data = np.asanyarray(spatialimage.dataobj)
292-
targets = ImageGrid(spatialimage).index( # data should be an image
293-
_as_homogeneous(self.map(_ref.ndcoords.T), dim=_ref.ndim)
294-
)
295-
296-
resampled = ndi.map_coordinates(
297-
data,
298-
targets.T,
299-
output=output_dtype,
300-
order=order,
301-
mode=mode,
302-
cval=cval,
303-
prefilter=prefilter,
304-
)
305-
306-
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
307-
hdr = None
308-
if _ref.header is not None:
309-
hdr = _ref.header.copy()
310-
hdr.set_data_dtype(output_dtype or spatialimage.get_data_dtype())
311-
moved = spatialimage.__class__(
312-
resampled.reshape(_ref.shape),
313-
_ref.affine,
314-
hdr,
315-
)
316-
return moved
317-
318-
return resampled
319-
320227
def map(self, x, inverse=False):
321228
r"""
322229
Apply :math:`y = f(x)`.
@@ -382,4 +289,8 @@ def _as_homogeneous(xyz, dtype="float32", dim=3):
382289

383290
def _apply_affine(x, affine, dim):
384291
"""Get the image array's indexes corresponding to coordinates."""
385-
return affine.dot(_as_homogeneous(x, dim=dim).T)[:dim, ...].T
292+
return np.tensordot(
293+
affine,
294+
_as_homogeneous(x, dim=dim).T,
295+
axes=1,
296+
)[:dim, ...]

nitransforms/cli.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .linear import load as linload
77
from .nonlinear import load as nlinload
8+
from .resampling import apply
89

910

1011
def cli_apply(pargs):
@@ -38,7 +39,8 @@ def cli_apply(pargs):
3839
# ensure a reference is set
3940
xfm.reference = pargs.ref or pargs.moving
4041

41-
moved = xfm.apply(
42+
moved = apply(
43+
xfm,
4244
pargs.moving,
4345
order=pargs.order,
4446
mode=pargs.mode,

nitransforms/linear.py

+4-121
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,12 @@
1010
import warnings
1111
import numpy as np
1212
from pathlib import Path
13-
from scipy import ndimage as ndi
1413

15-
from nibabel.loadsave import load as _nbload
1614
from nibabel.affines import from_matvec
17-
from nibabel.arrayproxy import get_obj_dtype
1815

1916
from nitransforms.base import (
2017
ImageGrid,
2118
TransformBase,
22-
SpatialReference,
2319
_as_homogeneous,
2420
EQUALITY_TOL,
2521
)
@@ -113,6 +109,10 @@ def __invert__(self):
113109
"""
114110
return self.__class__(self._inverse)
115111

112+
def __len__(self):
113+
"""Enable using len()."""
114+
return 1 if self._matrix.ndim == 2 else len(self._matrix)
115+
116116
def __matmul__(self, b):
117117
"""
118118
Compose two Affines.
@@ -330,10 +330,6 @@ def __getitem__(self, i):
330330
"""Enable indexed access to the series of matrices."""
331331
return Affine(self.matrix[i, ...], reference=self._reference)
332332

333-
def __len__(self):
334-
"""Enable using len()."""
335-
return len(self._matrix)
336-
337333
def map(self, x, inverse=False):
338334
r"""
339335
Apply :math:`y = f(x)`.
@@ -402,119 +398,6 @@ def to_filename(self, filename, fmt="X5", moving=None):
402398
).to_filename(filename)
403399
return filename
404400

405-
def apply(
406-
self,
407-
spatialimage,
408-
reference=None,
409-
order=3,
410-
mode="constant",
411-
cval=0.0,
412-
prefilter=True,
413-
output_dtype=None,
414-
):
415-
"""
416-
Apply a transformation to an image, resampling on the reference spatial object.
417-
418-
Parameters
419-
----------
420-
spatialimage : `spatialimage`
421-
The image object containing the data to be resampled in reference
422-
space
423-
reference : spatial object, optional
424-
The image, surface, or combination thereof containing the coordinates
425-
of samples that will be sampled.
426-
order : int, optional
427-
The order of the spline interpolation, default is 3.
428-
The order has to be in the range 0-5.
429-
mode : {"constant", "reflect", "nearest", "mirror", "wrap"}, optional
430-
Determines how the input image is extended when the resamplings overflows
431-
a border. Default is "constant".
432-
cval : float, optional
433-
Constant value for ``mode="constant"``. Default is 0.0.
434-
prefilter: bool, optional
435-
Determines if the image's data array is prefiltered with
436-
a spline filter before interpolation. The default is ``True``,
437-
which will create a temporary *float64* array of filtered values
438-
if *order > 1*. If setting this to ``False``, the output will be
439-
slightly blurred if *order > 1*, unless the input is prefiltered,
440-
i.e. it is the result of calling the spline filter on the original
441-
input.
442-
443-
Returns
444-
-------
445-
resampled : `spatialimage` or ndarray
446-
The data imaged after resampling to reference space.
447-
448-
"""
449-
450-
if reference is not None and isinstance(reference, (str, Path)):
451-
reference = _nbload(str(reference))
452-
453-
_ref = (
454-
self.reference if reference is None else SpatialReference.factory(reference)
455-
)
456-
457-
if isinstance(spatialimage, (str, Path)):
458-
spatialimage = _nbload(str(spatialimage))
459-
460-
# Avoid opening the data array just yet
461-
input_dtype = get_obj_dtype(spatialimage.dataobj)
462-
output_dtype = output_dtype or input_dtype
463-
464-
# Prepare physical coordinates of input (grid, points)
465-
xcoords = _ref.ndcoords.astype("f4").T
466-
467-
# Invert target's (moving) affine once
468-
ras2vox = ~Affine(spatialimage.affine)
469-
470-
if spatialimage.ndim == 4 and (len(self) != spatialimage.shape[-1]):
471-
raise ValueError(
472-
"Attempting to apply %d transforms on a file with "
473-
"%d timepoints" % (len(self), spatialimage.shape[-1])
474-
)
475-
476-
# Order F ensures individual volumes are contiguous in memory
477-
# Also matches NIfTI, making final save more efficient
478-
resampled = np.zeros(
479-
(xcoords.shape[0], len(self)), dtype=output_dtype, order="F"
480-
)
481-
482-
dataobj = (
483-
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
484-
if spatialimage.ndim in (2, 3)
485-
else None
486-
)
487-
488-
for t, xfm_t in enumerate(self):
489-
# Map the input coordinates on to timepoint t of the target (moving)
490-
ycoords = xfm_t.map(xcoords)[..., : _ref.ndim]
491-
492-
# Calculate corresponding voxel coordinates
493-
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]
494-
495-
# Interpolate
496-
resampled[..., t] = ndi.map_coordinates(
497-
(
498-
dataobj
499-
if dataobj is not None
500-
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
501-
),
502-
yvoxels.T,
503-
output=output_dtype,
504-
order=order,
505-
mode=mode,
506-
cval=cval,
507-
prefilter=prefilter,
508-
)
509-
510-
if isinstance(_ref, ImageGrid): # If reference is grid, reshape
511-
newdata = resampled.reshape(_ref.shape + (len(self),))
512-
moved = spatialimage.__class__(newdata, _ref.affine, spatialimage.header)
513-
moved.header.set_data_dtype(output_dtype)
514-
return moved
515-
516-
return resampled
517-
518401

519402
def load(filename, fmt=None, reference=None, moving=None):
520403
"""

0 commit comments

Comments
 (0)