20
20
ImageGrid ,
21
21
SpatialReference ,
22
22
_as_homogeneous ,
23
+ EQUALITY_TOL ,
23
24
)
24
25
25
26
26
- class DisplacementsFieldTransform (TransformBase ):
27
- """Represents a dense field of displacements (one vector per voxel)."""
27
+ class DeformationFieldTransform (TransformBase ):
28
+ """Represents a dense field of deformed locations (corresponding to each voxel)."""
28
29
29
30
__slots__ = ["_field" ]
30
31
@@ -34,8 +35,8 @@ def __init__(self, field, reference=None):
34
35
35
36
Example
36
37
-------
37
- >>> DisplacementsFieldTransform (test_dir / "someones_displacement_field.nii.gz")
38
- <DisplacementFieldTransform [3D] (57, 67, 56)>
38
+ >>> DeformationFieldTransform (test_dir / "someones_displacement_field.nii.gz")
39
+ <DeformationFieldTransform [3D] (57, 67, 56)>
39
40
40
41
"""
41
42
super ().__init__ ()
@@ -59,13 +60,13 @@ def __init__(self, field, reference=None):
59
60
ndim = self ._field .ndim - 1
60
61
if self ._field .shape [- 1 ] != ndim :
61
62
raise TransformError (
62
- "The number of components of the displacements (%d) does not "
63
+ "The number of components of the displacements (%d) does not match "
63
64
"the number of dimensions (%d)" % (self ._field .shape [- 1 ], ndim )
64
65
)
65
66
66
67
def __repr__ (self ):
67
68
"""Beautify the python representation."""
68
- return f"<DisplacementFieldTransform [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
69
+ return f"<{ self . __class__ . __name__ } [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
69
70
70
71
def map (self , x , inverse = False ):
71
72
r"""
@@ -92,12 +93,12 @@ def map(self, x, inverse=False):
92
93
93
94
Examples
94
95
--------
95
- >>> xfm = DisplacementsFieldTransform (test_dir / "someones_displacement_field.nii.gz")
96
+ >>> xfm = DeformationFieldTransform (test_dir / "someones_displacement_field.nii.gz")
96
97
>>> xfm.map([-6.5, -36., -19.5]).tolist()
97
- [[-6.5 , -36.475167989730835, -19.5 ]]
98
+ [[0.0 , -0.47516798973083496, 0.0 ]]
98
99
99
100
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
100
- [[-6.5 , -36.475167989730835, -19.5 ], [-1 .0, -42.038356602191925, -11.25 ]]
101
+ [[0.0 , -0.47516798973083496, 0.0 ], [0 .0, -0.538356602191925, 0.0 ]]
101
102
102
103
"""
103
104
@@ -108,7 +109,76 @@ def map(self, x, inverse=False):
108
109
if np .any (np .abs (ijk - indexes ) > 0.05 ):
109
110
warnings .warn ("Some coordinates are off-grid of the displacements field." )
110
111
indexes = tuple (tuple (i ) for i in indexes .T )
111
- return x + self ._field [indexes ]
112
+ return self ._field [indexes ]
113
+
114
+ def __matmul__ (self , b ):
115
+ """
116
+ Compose with a transform on the right.
117
+
118
+ Examples
119
+ --------
120
+ >>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
121
+ >>> xfm2 = xfm @ TransformBase()
122
+ >>> xfm == xfm2
123
+ True
124
+
125
+ """
126
+ retval = b .map (
127
+ self ._field .reshape ((- 1 , self ._field .shape [- 1 ]))
128
+ ).reshape (self ._field .shape )
129
+ return DeformationFieldTransform (retval , reference = self .reference )
130
+
131
+ def __eq__ (self , other ):
132
+ """
133
+ Overload equals operator.
134
+
135
+ Examples
136
+ --------
137
+ >>> xfm1 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
138
+ >>> xfm2 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
139
+ >>> xfm1 == xfm2
140
+ True
141
+
142
+ """
143
+ _eq = np .allclose (self ._field , other ._field , rtol = EQUALITY_TOL )
144
+ if _eq and self ._reference != other ._reference :
145
+ warnings .warn ("Fields are equal, but references do not match." )
146
+ return _eq
147
+
148
+
149
+ class DisplacementsFieldTransform (DeformationFieldTransform ):
150
+ """
151
+ Represents a dense field of displacements (one vector per voxel).
152
+
153
+ Converting to a field of deformations is straightforward by just adding the corresponding
154
+ displacement to the :math:`(x, y, z)` coordinates of each voxel.
155
+ Numerically, deformation fields are less susceptible to rounding errors
156
+ than displacements fields.
157
+ SPM generally prefers deformations for that reason.
158
+
159
+ """
160
+
161
+ __slots__ = ["_displacements" ]
162
+
163
+ def __init__ (self , field , reference = None ):
164
+ """
165
+ Create a transform supported by a field of voxel-wise displacements.
166
+
167
+ Example
168
+ -------
169
+ >>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
170
+ >>> xfm
171
+ <DisplacementsFieldTransform[3D] (57, 67, 56)>
172
+
173
+ >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
174
+ [[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
175
+
176
+ """
177
+ super ().__init__ (field , reference = reference )
178
+ self ._displacements = self ._field
179
+ # Convert from displacements to deformations fields
180
+ # (just add the origin to the displacements vector)
181
+ self ._field += self .reference .ndcoords .T .reshape (self ._field .shape )
112
182
113
183
@classmethod
114
184
def from_filename (cls , filename , fmt = "X5" ):
0 commit comments