Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 426f3c6

Browse files
committedOct 28, 2019
ENH: More comprehensive implementation of ITK affines I/O
1 parent 45a0a2f commit 426f3c6

File tree

6 files changed

+250
-61
lines changed

6 files changed

+250
-61
lines changed
 

‎nitransforms/io/base.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,40 @@
11
"""Read/write linear transforms."""
2-
import numpy as np
3-
from nibabel.wrapstruct import LabeledWrapStruct as LWS
2+
from scipy.io.matlab.miobase import get_matfile_version
3+
from scipy.io.matlab.mio4 import MatFile4Reader # , MatFile4Writer
4+
from scipy.io.matlab.mio5 import MatFile5Reader # , MatFile5Writer
45

6+
from ..patched import LabeledWrapStruct
57

6-
class LabeledWrapStruct(LWS):
7-
def __setitem__(self, item, value):
8-
self._structarr[item] = np.asanyarray(value)
8+
9+
class TransformFileError(Exception):
10+
"""A custom exception for transform files."""
911

1012

1113
class StringBasedStruct(LabeledWrapStruct):
14+
"""File data structure from text files."""
15+
1216
def __init__(self,
1317
binaryblock=None,
1418
endianness=None,
1519
check=True):
16-
if binaryblock is not None and getattr(binaryblock, 'dtype',
17-
None) == self.dtype:
20+
"""Create a data structure based off of a string."""
21+
_dtype = getattr(binaryblock, 'dtype', None)
22+
if binaryblock is not None and _dtype == self.dtype:
1823
self._structarr = binaryblock.copy()
1924
return
2025
super(StringBasedStruct, self).__init__(binaryblock, endianness, check)
2126

2227
def __array__(self):
28+
"""Return the internal structure array."""
2329
return self._structarr
30+
31+
32+
def _read_mat(byte_stream):
33+
mjv, _ = get_matfile_version(byte_stream)
34+
if mjv == 0:
35+
reader = MatFile4Reader(byte_stream)
36+
elif mjv == 1:
37+
reader = MatFile5Reader(byte_stream)
38+
elif mjv == 2:
39+
raise TransformFileError('Please use HDF reader for matlab v7.3 files')
40+
return reader.get_variables()

‎nitransforms/io/itk.py

+138-28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Read/write ITK transforms."""
22
import numpy as np
3-
from .base import StringBasedStruct
3+
from scipy.io import savemat as _save_mat
4+
from nibabel.affines import from_matvec
5+
from .base import StringBasedStruct, _read_mat
6+
7+
LPS = np.diag([-1, -1, 1, 1])
48

59

610
class ITKLinearTransform(StringBasedStruct):
@@ -13,20 +17,24 @@ class ITKLinearTransform(StringBasedStruct):
1317
('offset', 'f4', 3), # Center of rotation
1418
])
1519
dtype = template_dtype
20+
# files_types = (('string', '.tfm'), ('binary', '.mat'))
21+
# valid_exts = ('.tfm', '.mat')
1622

17-
def __init__(self):
23+
def __init__(self, parameters=None, offset=None):
1824
"""Initialize with default offset and index."""
1925
super().__init__()
20-
self.structarr['offset'] = [0, 0, 0]
2126
self.structarr['index'] = 1
27+
self.structarr['offset'] = offset or [0, 0, 0]
2228
self.structarr['parameters'] = np.eye(4)
29+
if parameters is not None:
30+
self.structarr['parameters'] = parameters
2331

2432
def __str__(self):
2533
"""Generate a string representation."""
2634
sa = self.structarr
2735
lines = [
2836
'#Transform {:d}'.format(sa['index']),
29-
'Transform: MatrixOffsetTransformBase_double_3_3',
37+
'Transform: AffineTransform_float_3_3',
3038
'Parameters: {}'.format(' '.join(
3139
['%g' % p
3240
for p in sa['parameters'][:3, :3].reshape(-1).tolist() +
@@ -36,6 +44,33 @@ def __str__(self):
3644
]
3745
return '\n'.join(lines)
3846

47+
def to_filename(self, filename):
48+
"""Store this transform to a file with the appropriate format."""
49+
if str(filename).endswith('.mat'):
50+
sa = self.structarr
51+
affine = np.array(np.hstack((
52+
sa['parameters'][:3, :3].reshape(-1),
53+
sa['parameters'][:3, 3]))[..., np.newaxis], dtype='f4')
54+
fixed = np.array(sa['offset'][..., np.newaxis], dtype='f4')
55+
mdict = {
56+
'AffineTransform_float_3_3': affine,
57+
'fixed': fixed,
58+
}
59+
_save_mat(str(filename), mdict, format='4')
60+
return
61+
62+
with open(str(filename), 'w') as f:
63+
f.write(self.to_string())
64+
65+
def to_ras(self):
66+
"""Return a nitransforms' internal RAS matrix."""
67+
sa = self.structarr
68+
matrix = sa['parameters']
69+
offset = sa['offset']
70+
c_neg = from_matvec(np.eye(3), offset * -1.0)
71+
c_pos = from_matvec(np.eye(3), offset)
72+
return LPS.dot(c_pos.dot(matrix.dot(c_neg.dot(LPS))))
73+
3974
def to_string(self, banner=None):
4075
"""Convert to a string directly writeable to file."""
4176
string = '%s'
@@ -48,9 +83,47 @@ def to_string(self, banner=None):
4883
return string % self
4984

5085
@classmethod
51-
def from_string(klass, string):
86+
def from_binary(cls, byte_stream, index=None):
87+
"""Read the struct from a matlab binary file."""
88+
mdict = _read_mat(byte_stream)
89+
return cls.from_matlab_dict(mdict, index=index)
90+
91+
@classmethod
92+
def from_fileobj(cls, fileobj, check=True):
93+
"""Read the struct from a file object."""
94+
if fileobj.name.endswith('.mat'):
95+
return cls.from_binary(fileobj)
96+
return cls.from_string(fileobj.read())
97+
98+
@classmethod
99+
def from_matlab_dict(cls, mdict, index=None):
100+
"""Read the struct from a matlab dictionary."""
101+
tf = cls()
102+
sa = tf.structarr
103+
if index is not None:
104+
raise NotImplementedError
105+
106+
sa['index'] = 1
107+
parameters = np.eye(4, dtype='f4')
108+
parameters[:3, :3] = mdict['AffineTransform_float_3_3'][:-3].reshape((3, 3))
109+
parameters[:3, 3] = mdict['AffineTransform_float_3_3'][-3:].flatten()
110+
sa['parameters'] = parameters
111+
sa['offset'] = mdict['fixed'].flatten()
112+
return tf
113+
114+
@classmethod
115+
def from_ras(cls, ras, index=0):
116+
"""Create an ITK affine from a nitransform's RAS+ matrix."""
117+
tf = cls()
118+
sa = tf.structarr
119+
sa['index'] = index + 1
120+
sa['parameters'] = LPS.dot(ras.dot(LPS))
121+
return tf
122+
123+
@classmethod
124+
def from_string(cls, string):
52125
"""Read the struct from string."""
53-
tf = klass()
126+
tf = cls()
54127
sa = tf.structarr
55128
lines = [l for l in string.splitlines()
56129
if l.strip()]
@@ -61,19 +134,14 @@ def from_string(klass, string):
61134
parameters = np.eye(4, dtype='f4')
62135
sa['index'] = int(lines[0][lines[0].index('T'):].split()[1])
63136
sa['offset'] = np.genfromtxt([lines[3].split(':')[-1].encode()],
64-
dtype=klass.dtype['offset'])
137+
dtype=cls.dtype['offset'])
65138
vals = np.genfromtxt([lines[2].split(':')[-1].encode()],
66139
dtype='f4')
67140
parameters[:3, :3] = vals[:-3].reshape((3, 3))
68141
parameters[:3, 3] = vals[-3:]
69142
sa['parameters'] = parameters
70143
return tf
71144

72-
@classmethod
73-
def from_fileobj(klass, fileobj, check=True):
74-
"""Read the struct from a file object."""
75-
return klass.from_string(fileobj.read())
76-
77145

78146
class ITKLinearTransformArray(StringBasedStruct):
79147
"""A string-based structure for series of ITK linear transforms."""
@@ -89,33 +157,80 @@ def __init__(self,
89157
check=True):
90158
"""Initialize with (optionally) a list of transforms."""
91159
super().__init__(binaryblock, endianness, check)
92-
self._xforms = []
93-
for i, mat in enumerate(xforms or []):
94-
xfm = ITKLinearTransform()
95-
xfm['parameters'] = mat
96-
xfm['index'] = i + 1
97-
self._xforms.append(xfm)
160+
self.xforms = [ITKLinearTransform(parameters=mat)
161+
for mat in xforms or []]
162+
163+
@property
164+
def xforms(self):
165+
"""Get the list of internal ITKLinearTransforms."""
166+
return self._xforms
167+
168+
@xforms.setter
169+
def xforms(self, value):
170+
self._xforms = value
171+
172+
# Update indexes
173+
for i, val in enumerate(self._xforms):
174+
val['index'] = i + 1
98175

99176
def __getitem__(self, idx):
100177
"""Allow dictionary access to the transforms."""
101178
if idx == 'xforms':
102179
return self._xforms
103180
if idx == 'nxforms':
104181
return len(self._xforms)
105-
return super().__getitem__(idx)
182+
raise KeyError(idx)
183+
184+
def to_filename(self, filename):
185+
"""Store this transform to a file with the appropriate format."""
186+
if str(filename).endswith('.mat'):
187+
raise NotImplementedError
188+
189+
with open(str(filename), 'w') as f:
190+
f.write(self.to_string())
191+
192+
def to_ras(self):
193+
"""Return a nitransforms' internal RAS matrix."""
194+
return np.stack([xfm.to_ras() for xfm in self._xforms])
106195

107196
def to_string(self):
108197
"""Convert to a string directly writeable to file."""
109198
strings = []
110-
for i, xfm in enumerate(self._xforms):
199+
for i, xfm in enumerate(self.xforms):
111200
xfm.structarr['index'] = i + 1
112201
strings.append(xfm.to_string())
113202
return '\n'.join(strings)
114203

115204
@classmethod
116-
def from_string(klass, string):
205+
def from_binary(cls, byte_stream):
206+
"""Read the struct from a matlab binary file."""
207+
mdict = _read_mat(byte_stream)
208+
nxforms = mdict['fixed'].shape[0]
209+
210+
_self = cls()
211+
_self.xforms = [ITKLinearTransform.from_matlab_dict(mdict, i)
212+
for i in range(nxforms)]
213+
return _self
214+
215+
@classmethod
216+
def from_fileobj(cls, fileobj, check=True):
217+
"""Read the struct from a file object."""
218+
if fileobj.name.endswith('.mat'):
219+
return cls.from_binary(fileobj)
220+
return cls.from_string(fileobj.read())
221+
222+
@classmethod
223+
def from_ras(cls, ras):
224+
"""Create an ITK affine from a nitransform's RAS+ matrix."""
225+
_self = cls()
226+
_self.xforms = [ITKLinearTransform.from_ras(ras[i, ...], i)
227+
for i in range(ras.shape[0])]
228+
return _self
229+
230+
@classmethod
231+
def from_string(cls, string):
117232
"""Read the struct from string."""
118-
_self = klass()
233+
_self = cls()
119234
lines = [l.strip() for l in string.splitlines()
120235
if l.strip()]
121236

@@ -124,11 +239,6 @@ def from_string(klass, string):
124239

125240
string = '\n'.join(lines[1:])
126241
for xfm in string.split('#')[1:]:
127-
_self._xforms.append(ITKLinearTransform.from_string(
242+
_self.xforms.append(ITKLinearTransform.from_string(
128243
'#%s' % xfm))
129244
return _self
130-
131-
@classmethod
132-
def from_fileobj(klass, fileobj, check=True):
133-
"""Read the struct from a file object."""
134-
return klass.from_string(fileobj.read())

‎nitransforms/linear.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import numpy as np
1313

1414
from nibabel.loadsave import load as loadimg
15-
from nibabel.affines import from_matvec, voxel_sizes, obliquity
15+
from nibabel.affines import voxel_sizes, obliquity
1616
from .base import TransformBase, _as_homogeneous, EQUALITY_TOL
1717
from .patched import shape_zoom_affine
1818
from . import io
@@ -140,10 +140,8 @@ def _to_hdf5(self, x5_root):
140140
def to_filename(self, filename, fmt='X5', moving=None):
141141
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
142142
if fmt.lower() in ['itk', 'ants', 'elastix']:
143-
itkobj = io.itk.ITKLinearTransformArray(
144-
xforms=[LPS.dot(m.dot(LPS)) for m in self.matrix])
145-
with open(filename, 'w') as f:
146-
f.write(itkobj.to_string())
143+
itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix)
144+
itkobj.to_filename(filename)
147145
return filename
148146

149147
if fmt.lower() == 'afni':
@@ -235,19 +233,11 @@ def to_filename(self, filename, fmt='X5', moving=None):
235233

236234
def load(filename, fmt='X5', reference=None):
237235
"""Load a linear transform."""
238-
if fmt.lower() in ['itk', 'ants', 'elastix', 'nifty']:
236+
if fmt.lower() in ('itk', 'ants', 'elastix'):
239237
with open(filename) as itkfile:
240238
itkxfm = io.itk.ITKLinearTransformArray.from_fileobj(
241239
itkfile)
242-
243-
matlist = []
244-
for xfm in itkxfm['xforms']:
245-
matrix = xfm['parameters']
246-
offset = xfm['offset']
247-
c_neg = from_matvec(np.eye(3), offset * -1.0)
248-
c_pos = from_matvec(np.eye(3), offset)
249-
matlist.append(LPS.dot(c_pos.dot(matrix.dot(c_neg.dot(LPS)))))
250-
matrix = np.stack(matlist)
240+
matrix = itkxfm.to_ras()
251241
# elif fmt.lower() == 'afni':
252242
# parameters = LPS.dot(self.matrix.dot(LPS))
253243
# parameters = parameters[:3, :].reshape(-1).tolist()

‎nitransforms/patched.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from nibabel.wrapstruct import LabeledWrapStruct as LWS
23

34

45
def shape_zoom_affine(shape, zooms, x_flip=True, y_flip=False):
@@ -63,3 +64,8 @@ def shape_zoom_affine(shape, zooms, x_flip=True, y_flip=False):
6364
aff[:3, :3] = np.diag(zooms)
6465
aff[:3, -1] = -origin * zooms
6566
return aff
67+
68+
69+
class LabeledWrapStruct(LWS):
70+
def __setitem__(self, item, value):
71+
self._structarr[item] = np.asanyarray(value)

‎nitransforms/tests/data/itktflist.tfm

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,45 @@
11
#Insight Transform File V1.0
22
#Transform 1
3-
Transform: MatrixOffsetTransformBase_double_3_3
3+
Transform: AffineTransform_float_3_3
44
Parameters: 1 0 0 0 1 0 0 0 1 0 0 0
55
FixedParameters: 10 10 10
66

77
#Transform 2
8-
Transform: MatrixOffsetTransformBase_double_3_3
8+
Transform: AffineTransform_float_3_3
99
Parameters: 1 0 0 0 1 0 0 0 1 0 0 0
1010
FixedParameters: 10 10 10
1111

1212
#Transform 3
13-
Transform: MatrixOffsetTransformBase_double_3_3
13+
Transform: AffineTransform_float_3_3
1414
Parameters: 1 0 0 0 1 0 0 0 1 0 0 0
1515
FixedParameters: 10 10 10
1616

1717
#Transform 4
18-
Transform: MatrixOffsetTransformBase_double_3_3
18+
Transform: AffineTransform_float_3_3
1919
Parameters: 1 0 0 0 1 0 0 0 1 0 0 0
2020
FixedParameters: 10 10 10
2121

2222
#Transform 5
23-
Transform: MatrixOffsetTransformBase_double_3_3
23+
Transform: AffineTransform_float_3_3
2424
Parameters: -1.53626 0.71973 -0.639856 -0.190759 -1.80082 -0.915885 0.502537 1.12532 0.275748 0.393413 1.13855 0.761131
2525
FixedParameters: -0.0993171 0.364984 1.99264
2626

2727
#Transform 6
28-
Transform: MatrixOffsetTransformBase_double_3_3
28+
Transform: AffineTransform_float_3_3
2929
Parameters: -0.130507 -1.03017 2.08189 -1.51723 1.37849 -0.0890962 -0.656323 0.242694 2.15801 -1.26689 0.367131 1.23616
3030
FixedParameters: 0.626607 0.15351 1.24982
3131

3232
#Transform 7
33-
Transform: MatrixOffsetTransformBase_double_3_3
33+
Transform: AffineTransform_float_3_3
3434
Parameters: -1.55395 -0.36383 -0.17749 1.3387 -0.384534 -0.901462 -1.06598 -0.448228 -1.07535 1.92599 0.454696 0.576697
3535
FixedParameters: -0.425602 0.333406 -1.14957
3636

3737
#Transform 8
38-
Transform: MatrixOffsetTransformBase_double_3_3
38+
Transform: AffineTransform_float_3_3
3939
Parameters: 0.723719 -1.05617 -0.800562 -2.47048 -1.76301 -1.4447 -0.749896 1.29774 -1.48893 1.02789 0.65017 -1.48326
4040
FixedParameters: 0.800882 -1.20202 1.25495
4141

4242
#Transform 9
43-
Transform: MatrixOffsetTransformBase_double_3_3
43+
Transform: AffineTransform_float_3_3
4444
Parameters: 1.24025 -0.77628 0.618013 -0.523829 1.09471 1.66921 0.73753 -1.33588 -0.627659 -0.449913 -0.00124181 0.21433
4545
FixedParameters: -0.226504 -0.877893 0.2608

‎nitransforms/tests/test_io.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,22 @@
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""I/O test cases."""
44
import numpy as np
5+
import pytest
56

7+
from nibabel.eulerangles import euler2mat
8+
from nibabel.affines import from_matvec
9+
from scipy.io import loadmat, savemat
610
from ..io import (
711
itk,
812
VolumeGeometry as VG,
913
LinearTransform as LT,
1014
LinearTransformArray as LTA,
1115
)
16+
from ..io.base import _read_mat, TransformFileError
17+
18+
LPS = np.diag([-1, -1, 1, 1])
19+
ITK_MAT = LPS.dot(np.ones((4, 4)).dot(LPS))
20+
1221

1322
def test_VolumeGeometry(tmpdir, get_testdata):
1423
vg = VG()
@@ -23,7 +32,7 @@ def test_VolumeGeometry(tmpdir, get_testdata):
2332
assert len(vg.to_string().split('\n')) == 8
2433

2534

26-
def test_LinearTransform(tmpdir, get_testdata):
35+
def test_LinearTransform(tmpdir):
2736
lt = LT()
2837
assert lt['m_L'].shape == (4, 4)
2938
assert np.all(lt['m_L'] == 0)
@@ -57,6 +66,27 @@ def test_LinearTransformArray(tmpdir, data_path):
5766
assert np.allclose(lta['xforms'][0]['m_L'], lta2['xforms'][0]['m_L'])
5867

5968

69+
def test_ITKLinearTransform(tmpdir, data_path):
70+
tmpdir.chdir()
71+
72+
matlabfile = str(data_path / 'ds-005_sub-01_from-T1_to-OASIS_affine.mat')
73+
mat = loadmat(matlabfile)
74+
with open(matlabfile, 'rb') as f:
75+
itkxfm = itk.ITKLinearTransform.from_fileobj(f)
76+
assert np.allclose(itkxfm['parameters'][:3, :3].flatten(),
77+
mat['AffineTransform_float_3_3'][:-3].flatten())
78+
assert np.allclose(itkxfm['offset'], mat['fixed'].reshape((3, )))
79+
itkxfm.to_filename('copy.mat')
80+
81+
with open('copy.mat', 'rb') as f:
82+
itkxfm2 = itk.ITKLinearTransform.from_fileobj(f)
83+
assert np.all(itkxfm['parameters'] == itkxfm2['parameters'])
84+
85+
rasmat = from_matvec(euler2mat(x=0.9, y=0.001, z=0.001), [4.0, 2.0, -1.0])
86+
itkxfm = itk.ITKLinearTransform.from_ras(rasmat)
87+
assert np.allclose(itkxfm['parameters'], ITK_MAT * rasmat)
88+
89+
6090
def test_ITKLinearTransformArray(tmpdir, data_path):
6191
tmpdir.chdir()
6292

@@ -67,14 +97,20 @@ def test_ITKLinearTransformArray(tmpdir, data_path):
6797

6898
assert itklist['nxforms'] == 9
6999
assert text == itklist.to_string()
100+
with pytest.raises(ValueError):
101+
itk.ITKLinearTransformArray.from_string(
102+
'\n'.join(text.splitlines()[1:]))
70103

71104
itklist = itk.ITKLinearTransformArray(
72-
xforms=[np.around(np.random.normal(size=(4, 4)), decimals=5)
105+
xforms=[np.random.normal(size=(4, 4))
73106
for _ in range(4)])
74107

75108
assert itklist['nxforms'] == 4
76109
assert itklist['xforms'][1].structarr['index'] == 2
77110

111+
with pytest.raises(KeyError):
112+
itklist['invalid_key']
113+
78114
xfm = itklist['xforms'][1]
79115
xfm['index'] = 1
80116
with open('extracted.tfm', 'w') as f:
@@ -84,3 +120,33 @@ def test_ITKLinearTransformArray(tmpdir, data_path):
84120
xfm2 = itk.ITKLinearTransform.from_fileobj(f)
85121
assert np.allclose(xfm.structarr['parameters'][:3, ...],
86122
xfm2.structarr['parameters'][:3, ...])
123+
124+
125+
@pytest.mark.parametrize('matlab_ver', ['4', '5'])
126+
def test_read_mat1(tmpdir, matlab_ver):
127+
"""Test read from matlab."""
128+
tmpdir.chdir()
129+
130+
savemat('val.mat', {'val': np.ones((3,))},
131+
format=matlab_ver)
132+
with open('val.mat', 'rb') as f:
133+
mdict = _read_mat(f)
134+
135+
assert np.all(mdict['val'] == np.ones((3,)))
136+
137+
138+
def test_read_mat2(tmpdir, monkeypatch):
139+
"""Check read matlab raises adequate errors."""
140+
from ..io import base
141+
142+
def _mockreturn(arg):
143+
return (2, 0)
144+
145+
tmpdir.chdir()
146+
savemat('val.mat', {'val': np.ones((3,))})
147+
148+
with monkeypatch.context() as m:
149+
m.setattr(base, 'get_matfile_version', _mockreturn)
150+
with pytest.raises(TransformFileError):
151+
with open('val.mat', 'rb') as f:
152+
_read_mat(f)

0 commit comments

Comments
 (0)
Please sign in to comment.