Skip to content

Commit bf7abdf

Browse files
committed
ENH: More comprehensive implementation of ITK affines I/O
1 parent f39e6fd commit bf7abdf

File tree

6 files changed

+235
-62
lines changed

6 files changed

+235
-62
lines changed

nitransforms/io/base.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,40 @@
11
"""Read/write linear transforms."""
2-
import numpy as np
3-
from nibabel.wrapstruct import LabeledWrapStruct as LWS
2+
from scipy.io.matlab.miobase import get_matfile_version
3+
from scipy.io.matlab.mio4 import MatFile4Reader # , MatFile4Writer
4+
from scipy.io.matlab.mio5 import MatFile5Reader # , MatFile5Writer
45

6+
from ..patched import LabeledWrapStruct
57

6-
class LabeledWrapStruct(LWS):
7-
def __setitem__(self, item, value):
8-
self._structarr[item] = np.asanyarray(value)
8+
9+
class TransformFileError(Exception):
10+
"""A custom exception for transform files."""
911

1012

1113
class StringBasedStruct(LabeledWrapStruct):
14+
"""File data structure from text files."""
15+
1216
def __init__(self,
1317
binaryblock=None,
1418
endianness=None,
1519
check=True):
16-
if binaryblock is not None and getattr(binaryblock, 'dtype',
17-
None) == self.dtype:
20+
"""Create a data structure based off of a string."""
21+
_dtype = getattr(binaryblock, 'dtype', None)
22+
if binaryblock is not None and _dtype == self.dtype:
1823
self._structarr = binaryblock.copy()
1924
return
2025
super(StringBasedStruct, self).__init__(binaryblock, endianness, check)
2126

2227
def __array__(self):
28+
"""Return the internal structure array."""
2329
return self._structarr
30+
31+
32+
def _read_mat(byte_stream):
33+
mjv, _ = get_matfile_version(byte_stream)
34+
if mjv == 0:
35+
reader = MatFile4Reader(byte_stream)
36+
elif mjv == 1:
37+
reader = MatFile5Reader(byte_stream)
38+
elif mjv == 2:
39+
raise TransformFileError('Please use HDF reader for matlab v7.3 files')
40+
return reader.get_variables()

nitransforms/io/itk.py

+128-28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Read/write ITK transforms."""
22
import numpy as np
3-
from .base import StringBasedStruct
3+
4+
from nibabel.affines import from_matvec
5+
from .base import StringBasedStruct, _read_mat
6+
7+
LPS = np.diag([-1, -1, 1, 1])
48

59

610
class ITKLinearTransform(StringBasedStruct):
@@ -13,20 +17,24 @@ class ITKLinearTransform(StringBasedStruct):
1317
('offset', 'f4', 3), # Center of rotation
1418
])
1519
dtype = template_dtype
20+
# files_types = (('string', '.tfm'), ('binary', '.mat'))
21+
# valid_exts = ('.tfm', '.mat')
1622

17-
def __init__(self):
23+
def __init__(self, parameters=None, offset=None):
1824
"""Initialize with default offset and index."""
1925
super().__init__()
20-
self.structarr['offset'] = [0, 0, 0]
2126
self.structarr['index'] = 1
27+
self.structarr['offset'] = offset or [0, 0, 0]
2228
self.structarr['parameters'] = np.eye(4)
29+
if parameters is not None:
30+
self.structarr['parameters'] = parameters
2331

2432
def __str__(self):
2533
"""Generate a string representation."""
2634
sa = self.structarr
2735
lines = [
2836
'#Transform {:d}'.format(sa['index']),
29-
'Transform: MatrixOffsetTransformBase_double_3_3',
37+
'Transform: AffineTransform_float_3_3',
3038
'Parameters: {}'.format(' '.join(
3139
['%g' % p
3240
for p in sa['parameters'][:3, :3].reshape(-1).tolist() +
@@ -36,6 +44,23 @@ def __str__(self):
3644
]
3745
return '\n'.join(lines)
3846

47+
def to_filename(self, filename):
48+
"""Store this transform to a file with the appropriate format."""
49+
if str(filename).endswith('.mat'):
50+
raise NotImplementedError
51+
52+
with open(str(filename), 'w') as f:
53+
f.write(self.to_string())
54+
55+
def to_ras(self):
56+
"""Return a nitransforms' internal RAS matrix."""
57+
sa = self.structarr
58+
matrix = sa['parameters']
59+
offset = sa['offset']
60+
c_neg = from_matvec(np.eye(3), offset * -1.0)
61+
c_pos = from_matvec(np.eye(3), offset)
62+
return LPS.dot(c_pos.dot(matrix.dot(c_neg.dot(LPS))))
63+
3964
def to_string(self, banner=None):
4065
"""Convert to a string directly writeable to file."""
4166
string = '%s'
@@ -48,9 +73,47 @@ def to_string(self, banner=None):
4873
return string % self
4974

5075
@classmethod
51-
def from_string(klass, string):
76+
def from_binary(cls, byte_stream, index=None):
77+
"""Read the struct from a matlab binary file."""
78+
mdict = _read_mat(byte_stream)
79+
return cls.from_matlab_dict(mdict, index=index)
80+
81+
@classmethod
82+
def from_fileobj(cls, fileobj, check=True):
83+
"""Read the struct from a file object."""
84+
if fileobj.name.endswith('.mat'):
85+
return cls.from_binary(fileobj)
86+
return cls.from_string(fileobj.read())
87+
88+
@classmethod
89+
def from_matlab_dict(cls, mdict, index=None):
90+
"""Read the struct from a matlab dictionary."""
91+
tf = cls()
92+
sa = tf.structarr
93+
if index is not None:
94+
raise NotImplementedError
95+
96+
sa['index'] = 1
97+
parameters = np.eye(4, dtype='f4')
98+
parameters[:3, :3] = mdict['AffineTransform_float_3_3'][:-3].reshape((3, 3))
99+
parameters[:3, 3] = mdict['AffineTransform_float_3_3'][-3:].flatten()
100+
sa['parameters'] = parameters
101+
sa['offset'] = mdict['fixed'].flatten()
102+
return tf
103+
104+
@classmethod
105+
def from_ras(cls, ras, index=0):
106+
"""Create an ITK affine from a nitransform's RAS+ matrix."""
107+
tf = cls()
108+
sa = tf.structarr
109+
sa['index'] = index + 1
110+
sa['parameters'] = LPS.dot(ras.dot(LPS))
111+
return tf
112+
113+
@classmethod
114+
def from_string(cls, string):
52115
"""Read the struct from string."""
53-
tf = klass()
116+
tf = cls()
54117
sa = tf.structarr
55118
lines = [l for l in string.splitlines()
56119
if l.strip()]
@@ -61,19 +124,14 @@ def from_string(klass, string):
61124
parameters = np.eye(4, dtype='f4')
62125
sa['index'] = int(lines[0][lines[0].index('T'):].split()[1])
63126
sa['offset'] = np.genfromtxt([lines[3].split(':')[-1].encode()],
64-
dtype=klass.dtype['offset'])
127+
dtype=cls.dtype['offset'])
65128
vals = np.genfromtxt([lines[2].split(':')[-1].encode()],
66129
dtype='f4')
67130
parameters[:3, :3] = vals[:-3].reshape((3, 3))
68131
parameters[:3, 3] = vals[-3:]
69132
sa['parameters'] = parameters
70133
return tf
71134

72-
@classmethod
73-
def from_fileobj(klass, fileobj, check=True):
74-
"""Read the struct from a file object."""
75-
return klass.from_string(fileobj.read())
76-
77135

78136
class ITKLinearTransformArray(StringBasedStruct):
79137
"""A string-based structure for series of ITK linear transforms."""
@@ -89,33 +147,80 @@ def __init__(self,
89147
check=True):
90148
"""Initialize with (optionally) a list of transforms."""
91149
super().__init__(binaryblock, endianness, check)
92-
self._xforms = []
93-
for i, mat in enumerate(xforms or []):
94-
xfm = ITKLinearTransform()
95-
xfm['parameters'] = mat
96-
xfm['index'] = i + 1
97-
self._xforms.append(xfm)
150+
self.xforms = [ITKLinearTransform(parameters=mat)
151+
for mat in xforms or []]
152+
153+
@property
154+
def xforms(self):
155+
"""Get the list of internal ITKLinearTransforms."""
156+
return self._xforms
157+
158+
@xforms.setter
159+
def xforms(self, value):
160+
self._xforms = value
161+
162+
# Update indexes
163+
for i, val in enumerate(self._xforms):
164+
val['index'] = i + 1
98165

99166
def __getitem__(self, idx):
100167
"""Allow dictionary access to the transforms."""
101168
if idx == 'xforms':
102169
return self._xforms
103170
if idx == 'nxforms':
104171
return len(self._xforms)
105-
return super().__getitem__(idx)
172+
raise KeyError(idx)
173+
174+
def to_filename(self, filename):
175+
"""Store this transform to a file with the appropriate format."""
176+
if str(filename).endswith('.mat'):
177+
raise NotImplementedError
178+
179+
with open(str(filename), 'w') as f:
180+
f.write(self.to_string())
181+
182+
def to_ras(self):
183+
"""Return a nitransforms' internal RAS matrix."""
184+
return np.stack([xfm.to_ras() for xfm in self._xforms])
106185

107186
def to_string(self):
108187
"""Convert to a string directly writeable to file."""
109188
strings = []
110-
for i, xfm in enumerate(self._xforms):
189+
for i, xfm in enumerate(self.xforms):
111190
xfm.structarr['index'] = i + 1
112191
strings.append(xfm.to_string())
113192
return '\n'.join(strings)
114193

115194
@classmethod
116-
def from_string(klass, string):
195+
def from_binary(cls, byte_stream):
196+
"""Read the struct from a matlab binary file."""
197+
mdict = _read_mat(byte_stream)
198+
nxforms = mdict['fixed'].shape[0]
199+
200+
_self = cls()
201+
_self.xforms = [ITKLinearTransform.from_matlab_dict(mdict, i)
202+
for i in range(nxforms)]
203+
return _self
204+
205+
@classmethod
206+
def from_fileobj(cls, fileobj, check=True):
207+
"""Read the struct from a file object."""
208+
if fileobj.name.endswith('.mat'):
209+
return cls.from_binary(fileobj)
210+
return cls.from_string(fileobj.read())
211+
212+
@classmethod
213+
def from_ras(cls, ras):
214+
"""Create an ITK affine from a nitransform's RAS+ matrix."""
215+
_self = cls()
216+
_self.xforms = [ITKLinearTransform.from_ras(ras[i, ...], i)
217+
for i in range(ras.shape[0])]
218+
return _self
219+
220+
@classmethod
221+
def from_string(cls, string):
117222
"""Read the struct from string."""
118-
_self = klass()
223+
_self = cls()
119224
lines = [l.strip() for l in string.splitlines()
120225
if l.strip()]
121226

@@ -124,11 +229,6 @@ def from_string(klass, string):
124229

125230
string = '\n'.join(lines[1:])
126231
for xfm in string.split('#')[1:]:
127-
_self._xforms.append(ITKLinearTransform.from_string(
232+
_self.xforms.append(ITKLinearTransform.from_string(
128233
'#%s' % xfm))
129234
return _self
130-
131-
@classmethod
132-
def from_fileobj(klass, fileobj, check=True):
133-
"""Read the struct from a file object."""
134-
return klass.from_string(fileobj.read())

nitransforms/linear.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pathlib import Path
1212

1313
from nibabel.loadsave import load as loadimg
14-
from nibabel.affines import from_matvec, voxel_sizes, obliquity
14+
from nibabel.affines import voxel_sizes, obliquity
1515
from .base import TransformBase, _as_homogeneous, EQUALITY_TOL
1616
from .patched import shape_zoom_affine
1717
from . import io
@@ -78,7 +78,7 @@ def __eq__(self, other):
7878
True
7979
8080
"""
81-
if not self._reference == other._reference:
81+
if not self._reference == other.reference:
8282
return False
8383
return np.allclose(self.matrix, other.matrix, rtol=EQUALITY_TOL)
8484

@@ -133,10 +133,8 @@ def _to_hdf5(self, x5_root):
133133
def to_filename(self, filename, fmt='X5', moving=None):
134134
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
135135
if fmt.lower() in ['itk', 'ants', 'elastix']:
136-
itkobj = io.itk.ITKLinearTransformArray(
137-
xforms=[LPS.dot(m.dot(LPS)) for m in self.matrix])
138-
with open(filename, 'w') as f:
139-
f.write(itkobj.to_string())
136+
itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix)
137+
itkobj.to_filename(filename)
140138
return filename
141139

142140
if fmt.lower() == 'afni':
@@ -228,19 +226,11 @@ def to_filename(self, filename, fmt='X5', moving=None):
228226

229227
def load(filename, fmt='X5', reference=None):
230228
"""Load a linear transform."""
231-
if fmt.lower() in ['itk', 'ants', 'elastix', 'nifty']:
229+
if fmt.lower() in ('itk', 'ants', 'elastix'):
232230
with open(filename) as itkfile:
233231
itkxfm = io.itk.ITKLinearTransformArray.from_fileobj(
234232
itkfile)
235-
236-
matlist = []
237-
for xfm in itkxfm['xforms']:
238-
matrix = xfm['parameters']
239-
offset = xfm['offset']
240-
c_neg = from_matvec(np.eye(3), offset * -1.0)
241-
c_pos = from_matvec(np.eye(3), offset)
242-
matlist.append(LPS.dot(c_pos.dot(matrix.dot(c_neg.dot(LPS)))))
243-
matrix = np.stack(matlist)
233+
matrix = itkxfm.to_ras()
244234
# elif fmt.lower() == 'afni':
245235
# parameters = LPS.dot(self.matrix.dot(LPS))
246236
# parameters = parameters[:3, :].reshape(-1).tolist()

nitransforms/patched.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from nibabel.wrapstruct import LabeledWrapStruct as LWS
23

34

45
def shape_zoom_affine(shape, zooms, x_flip=True, y_flip=False):
@@ -63,3 +64,8 @@ def shape_zoom_affine(shape, zooms, x_flip=True, y_flip=False):
6364
aff[:3, :3] = np.diag(zooms)
6465
aff[:3, -1] = -origin * zooms
6566
return aff
67+
68+
69+
class LabeledWrapStruct(LWS):
70+
def __setitem__(self, item, value):
71+
self._structarr[item] = np.asanyarray(value)

0 commit comments

Comments
 (0)