Skip to content

Commit 455f337

Browse files
authored
Merge pull request #35 from oesteban/maint/transfrom-io-refactor
ENH: More comprehensive implementation of ITK affines I/O
2 parents 45a0a2f + 474c85b commit 455f337

File tree

6 files changed

+277
-73
lines changed

6 files changed

+277
-73
lines changed

nitransforms/io/base.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,40 @@
11
"""Read/write linear transforms."""
2-
import numpy as np
3-
from nibabel.wrapstruct import LabeledWrapStruct as LWS
2+
from scipy.io.matlab.miobase import get_matfile_version
3+
from scipy.io.matlab.mio4 import MatFile4Reader
4+
from scipy.io.matlab.mio5 import MatFile5Reader
45

6+
from ..patched import LabeledWrapStruct
57

6-
class LabeledWrapStruct(LWS):
7-
def __setitem__(self, item, value):
8-
self._structarr[item] = np.asanyarray(value)
8+
9+
class TransformFileError(Exception):
10+
"""A custom exception for transform files."""
911

1012

1113
class StringBasedStruct(LabeledWrapStruct):
14+
"""File data structure from text files."""
15+
1216
def __init__(self,
1317
binaryblock=None,
1418
endianness=None,
1519
check=True):
16-
if binaryblock is not None and getattr(binaryblock, 'dtype',
17-
None) == self.dtype:
20+
"""Create a data structure based off of a string."""
21+
_dtype = getattr(binaryblock, 'dtype', None)
22+
if binaryblock is not None and _dtype == self.dtype:
1823
self._structarr = binaryblock.copy()
1924
return
2025
super(StringBasedStruct, self).__init__(binaryblock, endianness, check)
2126

2227
def __array__(self):
28+
"""Return the internal structure array."""
2329
return self._structarr
30+
31+
32+
def _read_mat(byte_stream):
33+
mjv, _ = get_matfile_version(byte_stream)
34+
if mjv == 0:
35+
reader = MatFile4Reader(byte_stream)
36+
elif mjv == 1:
37+
reader = MatFile5Reader(byte_stream)
38+
elif mjv == 2:
39+
raise TransformFileError('Please use HDF reader for matlab v7.3 files')
40+
return reader.get_variables()

nitransforms/io/itk.py

+133-31
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Read/write ITK transforms."""
22
import numpy as np
3-
from .base import StringBasedStruct
3+
from scipy.io import savemat as _save_mat
4+
from nibabel.affines import from_matvec
5+
from .base import StringBasedStruct, _read_mat, TransformFileError
6+
7+
LPS = np.diag([-1, -1, 1, 1])
48

59

610
class ITKLinearTransform(StringBasedStruct):
@@ -13,20 +17,24 @@ class ITKLinearTransform(StringBasedStruct):
1317
('offset', 'f4', 3), # Center of rotation
1418
])
1519
dtype = template_dtype
20+
# files_types = (('string', '.tfm'), ('binary', '.mat'))
21+
# valid_exts = ('.tfm', '.mat')
1622

17-
def __init__(self):
23+
def __init__(self, parameters=None, offset=None):
1824
"""Initialize with default offset and index."""
1925
super().__init__()
20-
self.structarr['offset'] = [0, 0, 0]
21-
self.structarr['index'] = 1
26+
self.structarr['index'] = 0
27+
self.structarr['offset'] = offset or [0, 0, 0]
2228
self.structarr['parameters'] = np.eye(4)
29+
if parameters is not None:
30+
self.structarr['parameters'] = parameters
2331

2432
def __str__(self):
2533
"""Generate a string representation."""
2634
sa = self.structarr
2735
lines = [
2836
'#Transform {:d}'.format(sa['index']),
29-
'Transform: MatrixOffsetTransformBase_double_3_3',
37+
'Transform: AffineTransform_float_3_3',
3038
'Parameters: {}'.format(' '.join(
3139
['%g' % p
3240
for p in sa['parameters'][:3, :3].reshape(-1).tolist() +
@@ -36,21 +44,84 @@ def __str__(self):
3644
]
3745
return '\n'.join(lines)
3846

47+
def to_filename(self, filename):
48+
"""Store this transform to a file with the appropriate format."""
49+
if str(filename).endswith('.mat'):
50+
sa = self.structarr
51+
affine = np.array(np.hstack((
52+
sa['parameters'][:3, :3].reshape(-1),
53+
sa['parameters'][:3, 3]))[..., np.newaxis], dtype='f4')
54+
fixed = np.array(sa['offset'][..., np.newaxis], dtype='f4')
55+
mdict = {
56+
'AffineTransform_float_3_3': affine,
57+
'fixed': fixed,
58+
}
59+
_save_mat(str(filename), mdict, format='4')
60+
return
61+
62+
with open(str(filename), 'w') as f:
63+
f.write(self.to_string())
64+
65+
def to_ras(self):
66+
"""Return a nitransforms internal RAS+ matrix."""
67+
sa = self.structarr
68+
matrix = sa['parameters']
69+
offset = sa['offset']
70+
c_neg = from_matvec(np.eye(3), offset * -1.0)
71+
c_pos = from_matvec(np.eye(3), offset)
72+
return LPS.dot(c_pos.dot(matrix.dot(c_neg.dot(LPS))))
73+
3974
def to_string(self, banner=None):
4075
"""Convert to a string directly writeable to file."""
4176
string = '%s'
4277

4378
if banner is None:
44-
banner = self.structarr['index'] == 1
79+
banner = self.structarr['index'] == 0
4580

4681
if banner:
4782
string = '#Insight Transform File V1.0\n%s'
4883
return string % self
4984

5085
@classmethod
51-
def from_string(klass, string):
86+
def from_binary(cls, byte_stream, index=0):
87+
"""Read the struct from a matlab binary file."""
88+
mdict = _read_mat(byte_stream)
89+
return cls.from_matlab_dict(mdict, index=index)
90+
91+
@classmethod
92+
def from_fileobj(cls, fileobj, check=True):
93+
"""Read the struct from a file object."""
94+
if fileobj.name.endswith('.mat'):
95+
return cls.from_binary(fileobj)
96+
return cls.from_string(fileobj.read())
97+
98+
@classmethod
99+
def from_matlab_dict(cls, mdict, index=0):
100+
"""Read the struct from a matlab dictionary."""
101+
tf = cls()
102+
sa = tf.structarr
103+
104+
sa['index'] = index
105+
parameters = np.eye(4, dtype='f4')
106+
parameters[:3, :3] = mdict['AffineTransform_float_3_3'][:-3].reshape((3, 3))
107+
parameters[:3, 3] = mdict['AffineTransform_float_3_3'][-3:].flatten()
108+
sa['parameters'] = parameters
109+
sa['offset'] = mdict['fixed'].flatten()
110+
return tf
111+
112+
@classmethod
113+
def from_ras(cls, ras, index=0):
114+
"""Create an ITK affine from a nitransform's RAS+ matrix."""
115+
tf = cls()
116+
sa = tf.structarr
117+
sa['index'] = index
118+
sa['parameters'] = LPS.dot(ras.dot(LPS))
119+
return tf
120+
121+
@classmethod
122+
def from_string(cls, string):
52123
"""Read the struct from string."""
53-
tf = klass()
124+
tf = cls()
54125
sa = tf.structarr
55126
lines = [l for l in string.splitlines()
56127
if l.strip()]
@@ -61,19 +132,14 @@ def from_string(klass, string):
61132
parameters = np.eye(4, dtype='f4')
62133
sa['index'] = int(lines[0][lines[0].index('T'):].split()[1])
63134
sa['offset'] = np.genfromtxt([lines[3].split(':')[-1].encode()],
64-
dtype=klass.dtype['offset'])
135+
dtype=cls.dtype['offset'])
65136
vals = np.genfromtxt([lines[2].split(':')[-1].encode()],
66137
dtype='f4')
67138
parameters[:3, :3] = vals[:-3].reshape((3, 3))
68139
parameters[:3, 3] = vals[-3:]
69140
sa['parameters'] = parameters
70141
return tf
71142

72-
@classmethod
73-
def from_fileobj(klass, fileobj, check=True):
74-
"""Read the struct from a file object."""
75-
return klass.from_string(fileobj.read())
76-
77143

78144
class ITKLinearTransformArray(StringBasedStruct):
79145
"""A string-based structure for series of ITK linear transforms."""
@@ -89,33 +155,74 @@ def __init__(self,
89155
check=True):
90156
"""Initialize with (optionally) a list of transforms."""
91157
super().__init__(binaryblock, endianness, check)
92-
self._xforms = []
93-
for i, mat in enumerate(xforms or []):
94-
xfm = ITKLinearTransform()
95-
xfm['parameters'] = mat
96-
xfm['index'] = i + 1
97-
self._xforms.append(xfm)
158+
self.xforms = [ITKLinearTransform(parameters=mat)
159+
for mat in xforms or []]
160+
161+
@property
162+
def xforms(self):
163+
"""Get the list of internal ITKLinearTransforms."""
164+
return self._xforms
165+
166+
@xforms.setter
167+
def xforms(self, value):
168+
self._xforms = list(value)
169+
170+
# Update indexes
171+
for i, val in enumerate(self.xforms):
172+
val['index'] = i
98173

99174
def __getitem__(self, idx):
100175
"""Allow dictionary access to the transforms."""
101176
if idx == 'xforms':
102177
return self._xforms
103178
if idx == 'nxforms':
104179
return len(self._xforms)
105-
return super().__getitem__(idx)
180+
raise KeyError(idx)
181+
182+
def to_filename(self, filename):
183+
"""Store this transform to a file with the appropriate format."""
184+
if str(filename).endswith('.mat'):
185+
raise TransformFileError("Please use the ITK's new .h5 format.")
186+
187+
with open(str(filename), 'w') as f:
188+
f.write(self.to_string())
189+
190+
def to_ras(self):
191+
"""Return a nitransforms' internal RAS matrix."""
192+
return np.stack([xfm.to_ras() for xfm in self.xforms])
106193

107194
def to_string(self):
108195
"""Convert to a string directly writeable to file."""
109196
strings = []
110-
for i, xfm in enumerate(self._xforms):
111-
xfm.structarr['index'] = i + 1
197+
for i, xfm in enumerate(self.xforms):
198+
xfm.structarr['index'] = i
112199
strings.append(xfm.to_string())
113200
return '\n'.join(strings)
114201

115202
@classmethod
116-
def from_string(klass, string):
203+
def from_binary(cls, byte_stream):
204+
"""Read the struct from a matlab binary file."""
205+
raise TransformFileError("Please use the ITK's new .h5 format.")
206+
207+
@classmethod
208+
def from_fileobj(cls, fileobj, check=True):
209+
"""Read the struct from a file object."""
210+
if fileobj.name.endswith('.mat'):
211+
return cls.from_binary(fileobj)
212+
return cls.from_string(fileobj.read())
213+
214+
@classmethod
215+
def from_ras(cls, ras):
216+
"""Create an ITK affine from a nitransform's RAS+ matrix."""
217+
_self = cls()
218+
_self.xforms = [ITKLinearTransform.from_ras(ras[i, ...], i)
219+
for i in range(ras.shape[0])]
220+
return _self
221+
222+
@classmethod
223+
def from_string(cls, string):
117224
"""Read the struct from string."""
118-
_self = klass()
225+
_self = cls()
119226
lines = [l.strip() for l in string.splitlines()
120227
if l.strip()]
121228

@@ -124,11 +231,6 @@ def from_string(klass, string):
124231

125232
string = '\n'.join(lines[1:])
126233
for xfm in string.split('#')[1:]:
127-
_self._xforms.append(ITKLinearTransform.from_string(
234+
_self.xforms.append(ITKLinearTransform.from_string(
128235
'#%s' % xfm))
129236
return _self
130-
131-
@classmethod
132-
def from_fileobj(klass, fileobj, check=True):
133-
"""Read the struct from a file object."""
134-
return klass.from_string(fileobj.read())

nitransforms/linear.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313

1414
from nibabel.loadsave import load as loadimg
15-
from nibabel.affines import from_matvec, voxel_sizes, obliquity
15+
from nibabel.affines import voxel_sizes, obliquity
1616
from .base import TransformBase, _as_homogeneous, EQUALITY_TOL
1717
from .patched import shape_zoom_affine
1818
from . import io
@@ -140,10 +140,8 @@ def _to_hdf5(self, x5_root):
140140
def to_filename(self, filename, fmt='X5', moving=None):
141141
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
142142
if fmt.lower() in ['itk', 'ants', 'elastix']:
143-
itkobj = io.itk.ITKLinearTransformArray(
144-
xforms=[LPS.dot(m.dot(LPS)) for m in self.matrix])
145-
with open(filename, 'w') as f:
146-
f.write(itkobj.to_string())
143+
itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix)
144+
itkobj.to_filename(filename)
147145
return filename
148146

149147
if fmt.lower() == 'afni':
@@ -235,19 +233,11 @@ def to_filename(self, filename, fmt='X5', moving=None):
235233

236234
def load(filename, fmt='X5', reference=None):
237235
"""Load a linear transform."""
238-
if fmt.lower() in ['itk', 'ants', 'elastix', 'nifty']:
236+
if fmt.lower() in ('itk', 'ants', 'elastix'):
239237
with open(filename) as itkfile:
240238
itkxfm = io.itk.ITKLinearTransformArray.from_fileobj(
241239
itkfile)
242-
243-
matlist = []
244-
for xfm in itkxfm['xforms']:
245-
matrix = xfm['parameters']
246-
offset = xfm['offset']
247-
c_neg = from_matvec(np.eye(3), offset * -1.0)
248-
c_pos = from_matvec(np.eye(3), offset)
249-
matlist.append(LPS.dot(c_pos.dot(matrix.dot(c_neg.dot(LPS)))))
250-
matrix = np.stack(matlist)
240+
matrix = itkxfm.to_ras()
251241
# elif fmt.lower() == 'afni':
252242
# parameters = LPS.dot(self.matrix.dot(LPS))
253243
# parameters = parameters[:3, :].reshape(-1).tolist()

nitransforms/patched.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from nibabel.wrapstruct import LabeledWrapStruct as LWS
23

34

45
def shape_zoom_affine(shape, zooms, x_flip=True, y_flip=False):
@@ -63,3 +64,8 @@ def shape_zoom_affine(shape, zooms, x_flip=True, y_flip=False):
6364
aff[:3, :3] = np.diag(zooms)
6465
aff[:3, -1] = -origin * zooms
6566
return aff
67+
68+
69+
class LabeledWrapStruct(LWS):
70+
def __setitem__(self, item, value):
71+
self._structarr[item] = np.asanyarray(value)

0 commit comments

Comments
 (0)