|
11 | 11 | from pathlib import Path
|
12 | 12 | import numpy as np
|
13 | 13 | from nibabel.loadsave import load as _nbload
|
| 14 | +from nibabel.arrayproxy import get_obj_dtype |
14 | 15 | from scipy import ndimage as ndi
|
15 | 16 |
|
| 17 | +from nitransforms.linear import Affine, LinearTransformsMapping |
16 | 18 | from nitransforms.base import (
|
17 | 19 | ImageGrid,
|
18 | 20 | TransformError,
|
|
23 | 25 | SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8
|
24 | 26 | """Minimum number of volumes to automatically serialize 4D transforms."""
|
25 | 27 |
|
26 |
| -class NotImplementedWarning(UserWarning): |
27 |
| - """A custom class for warnings.""" |
28 |
| - |
29 | 28 |
|
30 | 29 | def apply(
|
31 | 30 | transform,
|
@@ -99,49 +98,89 @@ def apply(
|
99 | 98 |
|
100 | 99 | data = np.asanyarray(spatialimage.dataobj)
|
101 | 100 | data_nvols = 1 if data.ndim < 4 else data.shape[-1]
|
| 101 | + |
102 | 102 | xfm_nvols = len(transform)
|
103 | 103 |
|
104 | 104 | if data_nvols == 1 and xfm_nvols > 1:
|
105 | 105 | data = data[..., np.newaxis]
|
106 | 106 | elif data_nvols != xfm_nvols:
|
107 | 107 | raise ValueError(
|
108 |
| - "The fourth dimension of the data does not match the tranform's shape." |
| 108 | + "The fourth dimension of the data does not match the transform's shape." |
109 | 109 | )
|
110 | 110 |
|
111 | 111 | serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np.inf
|
112 |
| - serialize_4d = max(data_nvols, xfm_nvols) > serialize_nvols |
| 112 | + serialize_4d = max(data_nvols, xfm_nvols) >= serialize_nvols |
| 113 | + |
113 | 114 | if serialize_4d:
|
114 |
| - warn( |
115 |
| - "4D transforms serialization into 3D+t not implemented", |
116 |
| - NotImplementedWarning, |
117 |
| - stacklevel=2, |
| 115 | + # Avoid opening the data array just yet |
| 116 | + input_dtype = get_obj_dtype(spatialimage.dataobj) |
| 117 | + output_dtype = output_dtype or input_dtype |
| 118 | + |
| 119 | + # Prepare physical coordinates of input (grid, points) |
| 120 | + xcoords = _ref.ndcoords.astype("f4").T |
| 121 | + |
| 122 | + # Invert target's (moving) affine once |
| 123 | + ras2vox = ~Affine(spatialimage.affine) |
| 124 | + dataobj = ( |
| 125 | + np.asanyarray(spatialimage.dataobj, dtype=input_dtype) |
| 126 | + if spatialimage.ndim in (2, 3) |
| 127 | + else None |
118 | 128 | )
|
119 | 129 |
|
120 |
| - # For model-based nonlinear transforms, generate the corresponding dense field |
121 |
| - if hasattr(transform, "to_field") and callable(transform.to_field): |
122 |
| - targets = ImageGrid(spatialimage).index( |
123 |
| - _as_homogeneous( |
124 |
| - transform.to_field(reference=reference).map(_ref.ndcoords.T), |
125 |
| - dim=_ref.ndim, |
126 |
| - ) |
| 130 | + # Order F ensures individual volumes are contiguous in memory |
| 131 | + # Also matches NIfTI, making final save more efficient |
| 132 | + resampled = np.zeros( |
| 133 | + (xcoords.shape[0], len(transform)), dtype=output_dtype, order="F" |
127 | 134 | )
|
| 135 | + |
| 136 | + for t, xfm_t in enumerate(transform): |
| 137 | + # Map the input coordinates on to timepoint t of the target (moving) |
| 138 | + ycoords = xfm_t.map(xcoords)[..., : _ref.ndim] |
| 139 | + |
| 140 | + # Calculate corresponding voxel coordinates |
| 141 | + yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim] |
| 142 | + |
| 143 | + # Interpolate |
| 144 | + resampled[..., t] = ndi.map_coordinates( |
| 145 | + ( |
| 146 | + dataobj |
| 147 | + if dataobj is not None |
| 148 | + else spatialimage.dataobj[..., t].astype(input_dtype, copy=False) |
| 149 | + ), |
| 150 | + yvoxels.T, |
| 151 | + output=output_dtype, |
| 152 | + order=order, |
| 153 | + mode=mode, |
| 154 | + cval=cval, |
| 155 | + prefilter=prefilter, |
| 156 | + ) |
| 157 | + |
128 | 158 | else:
|
129 |
| - targets = ImageGrid(spatialimage).index( # data should be an image |
130 |
| - _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) |
131 |
| - ) |
| 159 | + # For model-based nonlinear transforms, generate the corresponding dense field |
| 160 | + if hasattr(transform, "to_field") and callable(transform.to_field): |
| 161 | + targets = ImageGrid(spatialimage).index( |
| 162 | + _as_homogeneous( |
| 163 | + transform.to_field(reference=reference).map(_ref.ndcoords.T), |
| 164 | + dim=_ref.ndim, |
| 165 | + ) |
| 166 | + ) |
| 167 | + else: |
| 168 | + targets = ImageGrid(spatialimage).index( # data should be an image |
| 169 | + _as_homogeneous(transform.map(_ref.ndcoords.T), dim=_ref.ndim) |
| 170 | + ) |
132 | 171 |
|
133 |
| - if transform.ndim == 4: |
134 |
| - targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T |
135 |
| - |
136 |
| - resampled = ndi.map_coordinates( |
137 |
| - data, |
138 |
| - targets, |
139 |
| - output=output_dtype, |
140 |
| - order=order, |
141 |
| - mode=mode, |
142 |
| - cval=cval, |
143 |
| - prefilter=prefilter, |
144 |
| - ) |
| 172 | + if transform.ndim == 4: |
| 173 | + targets = _as_homogeneous(targets.reshape(-2, targets.shape[0])).T |
| 174 | + |
| 175 | + resampled = ndi.map_coordinates( |
| 176 | + data, |
| 177 | + targets, |
| 178 | + output=output_dtype, |
| 179 | + order=order, |
| 180 | + mode=mode, |
| 181 | + cval=cval, |
| 182 | + prefilter=prefilter, |
| 183 | + ) |
145 | 184 |
|
146 | 185 | if isinstance(_ref, ImageGrid): # If reference is grid, reshape
|
147 | 186 | hdr = None
|
|
0 commit comments