Skip to content

Commit 6fbcbc1

Browse files
authored
Merge pull request #138 from poldracklab/enh/bring-bspline
ENH: Base implementation of B-Spline transforms
2 parents f703445 + f232acf commit 6fbcbc1

File tree

7 files changed

+400
-94
lines changed

7 files changed

+400
-94
lines changed

docs/_api/interp.rst

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
=====================
2+
Interpolation methods
3+
=====================
4+
5+
.. automodule:: nitransforms.interp
6+
:members:
7+
8+
-------------
9+
Method groups
10+
-------------
11+
12+
^^^^^^^^^
13+
B-Splines
14+
^^^^^^^^^
15+
16+
.. automodule:: nitransforms.interp.bspline
17+
:members:

docs/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ Information on specific functions, classes, and methods for developers.
1010
_api/linear
1111
_api/manip
1212
_api/nonlinear
13+
_api/interp
1314
_api/patched

nitransforms/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(self, image):
103103
self._shape = image.shape
104104

105105
self._ndim = getattr(image, "ndim", len(image.shape))
106-
if self._ndim == 4:
106+
if self._ndim >= 4:
107107
self._shape = image.shape[:3]
108108
self._ndim = 3
109109

@@ -267,6 +267,9 @@ def apply(
267267
self.reference if reference is None else SpatialReference.factory(reference)
268268
)
269269

270+
if _ref is None:
271+
raise TransformError("Cannot apply transform without reference")
272+
270273
if isinstance(spatialimage, (str, Path)):
271274
spatialimage = _nbload(str(spatialimage))
272275

nitransforms/interp/__init__.py

Whitespace-only changes.

nitransforms/interp/bspline.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
4+
#
5+
# See COPYING file distributed along with the NiBabel package for the
6+
# copyright and license terms.
7+
#
8+
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9+
"""Interpolate with 3D tensor-product B-Spline basis."""
10+
import numpy as np
11+
import nibabel as nb
12+
from scipy.sparse import csr_matrix, kron
13+
14+
15+
def _cubic_bspline(d, order=3):
16+
"""Evaluate the cubic bspline at distance d from the center."""
17+
if order != 3:
18+
raise NotImplementedError
19+
20+
return np.piecewise(
21+
d,
22+
[d < 1.0, d >= 1.0],
23+
[
24+
lambda d: (4.0 - 6.0 * d ** 2 + 3.0 * d ** 3) / 6.0,
25+
lambda d: (2.0 - d) ** 3 / 6.0,
26+
],
27+
)
28+
29+
30+
def grid_bspline_weights(target_grid, ctrl_grid):
31+
r"""
32+
Evaluate tensor-product B-Spline weights on a grid.
33+
34+
For each of the :math:`N` input locations :math:`\mathbf{x} = (x_i, x_j, x_k)`
35+
and :math:`K` control points or *knots* :math:`\mathbf{c} =(c_i, c_j, c_k)`,
36+
the tensor-product cubic B-Spline kernel weights are calculated:
37+
38+
.. math::
39+
\Psi^3(\mathbf{x}, \mathbf{c}) =
40+
\beta^3(x_i - c_i) \cdot \beta^3(x_j - c_j) \cdot \beta^3(x_k - c_k),
41+
\label{eq:bspline_weights}\tag{1}
42+
43+
where each :math:`\beta^3` represents the cubic B-Spline for one dimension.
44+
The 1D B-Spline kernel implementation uses :obj:`numpy.piecewise`, and is based on the
45+
closed-form given by Eq. (6) of [Unser1999]_.
46+
47+
By iterating over dimensions, the data samples that fall outside of the compact
48+
support of the tensor-product kernel associated to each control point can be filtered
49+
out and dismissed to lighten computation.
50+
51+
Finally, the resulting weights matrix :math:`\Psi^3(\mathbf{k}, \mathbf{s})`
52+
can be easily identified in Eq. :math:`\eqref{eq:1}` and used as the design matrix
53+
for approximation of data.
54+
55+
Parameters
56+
----------
57+
target_grid : :obj:`~nitransforms.base.ImageGrid` or :obj:`nibabel.spatialimages`
58+
Regular grid of :math:`N` locations at which tensor B-Spline basis will be evaluated.
59+
ctrl_grid : :obj:`~nitransforms.base.ImageGrid` or :obj:`nibabel.spatialimages`
60+
Regular grid of :math:`K` control points (knot) where B-Spline basis are defined.
61+
62+
Returns
63+
-------
64+
weights : :obj:`numpy.ndarray` (:math:`K \times N`)
65+
A sparse matrix of interpolating weights :math:`\Psi^3(\mathbf{k}, \mathbf{s})`
66+
for the *N* voxels of the target EPI, for each of the total *K* knots.
67+
This sparse matrix can be directly used as design matrix for the fitting
68+
step of approximation/extrapolation.
69+
70+
"""
71+
shape = target_grid.shape[:3]
72+
ctrl_sp = nb.affines.voxel_sizes(ctrl_grid.affine)[:3]
73+
ras2ijk = np.linalg.inv(ctrl_grid.affine)
74+
# IJK index in the control point image of the first index in the target image
75+
origin = nb.affines.apply_affine(ras2ijk, [tuple(target_grid.affine[:3, 3])])[0]
76+
77+
wd = []
78+
for i, (o, n, sp) in enumerate(
79+
zip(origin, shape, nb.affines.voxel_sizes(target_grid.affine)[:3])
80+
):
81+
# Locations of voxels in target image in control point image
82+
locations = np.arange(0, n, dtype="float16") * sp / ctrl_sp[i] + o
83+
knots = np.arange(0, ctrl_grid.shape[i], dtype="float16")
84+
distance = np.abs(locations[np.newaxis, ...] - knots[..., np.newaxis])
85+
86+
within_support = distance < 2.0
87+
d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True)
88+
bs_w = _cubic_bspline(d_vals)
89+
weights = np.zeros_like(distance, dtype="float32")
90+
weights[within_support] = bs_w[d_idxs]
91+
wd.append(csr_matrix(weights))
92+
93+
return kron(kron(wd[0], wd[1]), wd[2])

0 commit comments

Comments
 (0)