diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 4f91c24c..ef92eb6c 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -21,6 +21,7 @@ SpatialReference, _as_homogeneous, ) +from scipy.ndimage import map_coordinates class DenseFieldTransform(TransformBase): @@ -132,6 +133,13 @@ def map(self, x, inverse=False): >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() [[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]] + >>> np.array_str( + ... xfm.map([[-6.7, -36.3, -19.2], [-1., -41.5, -11.25]]), + ... precision=3, + ... suppress_small=True, + ... ) + '[[ 0. -0.482 0. ]\n [ 0. -0.538 0. ]]' + >>> xfm = DenseFieldTransform( ... test_dir / "someones_displacement_field.nii.gz", ... is_deltas=True, @@ -139,16 +147,34 @@ def map(self, x, inverse=False): >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist() [[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]] + >>> np.array_str( + ... xfm.map([[-6.7, -36.3, -19.2], [-1., -41.5, -11.25]]), + ... precision=3, + ... suppress_small=True, + ... ) + '[[ -6.7 -36.782 -19.2 ]\n [ -1. -42.038 -11.25 ]]' + """ if inverse is True: raise NotImplementedError ijk = self.reference.index(x) indexes = np.round(ijk).astype("int") - if np.any(np.abs(ijk - indexes) > 0.05): - warnings.warn("Some coordinates are off-grid of the field.") - indexes = tuple(tuple(i) for i in indexes.T) - return self._field[indexes] + + if np.all(np.abs(ijk - indexes) < 1e-3): + indexes = tuple(tuple(i) for i in indexes.T) + return self._field[indexes] + + return np.vstack(( + map_coordinates( + self._field[..., i], + ijk.T, + order=3, + mode="constant", + cval=0, + prefilter=True, + ) for i in range(self.reference.ndim) + )).T def __matmul__(self, b): """