Skip to content

Commit b9777b9

Browse files
committed
enh: added a couple of minor tests for ITK-arrays' to_filename
1 parent 55c98b6 commit b9777b9

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

nitransforms/io/itk.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def to_string(self, banner=None):
8383
return string % self
8484

8585
@classmethod
86-
def from_binary(cls, byte_stream, index=None):
86+
def from_binary(cls, byte_stream, index=0):
8787
"""Read the struct from a matlab binary file."""
8888
mdict = _read_mat(byte_stream)
8989
return cls.from_matlab_dict(mdict, index=index)
@@ -96,14 +96,12 @@ def from_fileobj(cls, fileobj, check=True):
9696
return cls.from_string(fileobj.read())
9797

9898
@classmethod
99-
def from_matlab_dict(cls, mdict, index=None):
99+
def from_matlab_dict(cls, mdict, index=0):
100100
"""Read the struct from a matlab dictionary."""
101101
tf = cls()
102102
sa = tf.structarr
103-
if index is not None:
104-
raise NotImplementedError
105103

106-
sa['index'] = 0
104+
sa['index'] = index
107105
parameters = np.eye(4, dtype='f4')
108106
parameters[:3, :3] = mdict['AffineTransform_float_3_3'][:-3].reshape((3, 3))
109107
parameters[:3, 3] = mdict['AffineTransform_float_3_3'][-3:].flatten()
@@ -184,7 +182,7 @@ def __getitem__(self, idx):
184182
def to_filename(self, filename):
185183
"""Store this transform to a file with the appropriate format."""
186184
if str(filename).endswith('.mat'):
187-
raise NotImplementedError
185+
raise TransformFileError('Please use the ITK\'s new .h5 format.')
188186

189187
with open(str(filename), 'w') as f:
190188
f.write(self.to_string())

nitransforms/tests/test_io.py

+10
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def test_ITKLinearTransformArray(tmpdir, data_path):
108108
itk.ITKLinearTransformArray.from_string(
109109
'\n'.join(text.splitlines()[1:]))
110110

111+
itklist.to_filename('copy.tfm')
112+
with open('copy.tfm') as f:
113+
copytext = f.read()
114+
assert text == copytext
115+
111116
itklist = itk.ITKLinearTransformArray(
112117
xforms=[np.random.normal(size=(4, 4))
113118
for _ in range(4)])
@@ -128,6 +133,11 @@ def test_ITKLinearTransformArray(tmpdir, data_path):
128133
assert np.allclose(xfm.structarr['parameters'][:3, ...],
129134
xfm2.structarr['parameters'][:3, ...])
130135

136+
with pytest.raises(TransformFileError):
137+
itklist.to_filename('matlablist.mat')
138+
139+
140+
131141

132142
@pytest.mark.parametrize('matlab_ver', ['4', '5'])
133143
def test_read_mat1(tmpdir, matlab_ver):

0 commit comments

Comments
 (0)