Skip to content

Commit d2e08da

Browse files
authored
Merge pull request #13 from matthew-brett/refactor-create-empty-header
MRG: refactoring for tractogram headers
2 parents 0d0ece1 + 0bacbf5 commit d2e08da

File tree

4 files changed

+41
-36
lines changed

4 files changed

+41
-36
lines changed

nibabel/streamlines/tck.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ def __init__(self, tractogram, header=None):
7272
This is in contrast with TRK's internal convention where it would
7373
have referred to a corner.
7474
"""
75-
if header is None:
76-
header = self.create_empty_header()
77-
7875
super(TckFile, self).__init__(tractogram, header)
7976

8077
@classmethod
@@ -103,7 +100,7 @@ def is_correct_format(cls, fileobj):
103100

104101
@classmethod
105102
def create_empty_header(cls):
106-
""" Return an empty compliant TCK header. """
103+
""" Return an empty compliant TCK header as dict """
107104
header = {}
108105

109106
# Default values

nibabel/streamlines/tests/test_tractogram_file.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from nose.tools import assert_raises
1+
""" Test tractogramFile base class
2+
"""
23

34
from ..tractogram import Tractogram
45
from ..tractogram_file import TractogramFile
56

7+
from nose.tools import assert_raises, assert_equal
8+
69

710
def test_subclassing_tractogram_file():
811

@@ -37,7 +40,7 @@ def create_empty_header(cls):
3740

3841
assert_raises(TypeError, DummyTractogramFile, Tractogram())
3942

40-
# Missing 'create_empty_header' method
43+
# Now we have everything required.
4144
class DummyTractogramFile(TractogramFile):
4245
@classmethod
4346
def is_correct_format(cls, fileobj):
@@ -50,13 +53,16 @@ def load(cls, fileobj, lazy_load=True):
5053
def save(self, fileobj):
5154
pass
5255

53-
assert_raises(TypeError, DummyTractogramFile, Tractogram())
56+
# No error
57+
dtf = DummyTractogramFile(Tractogram())
58+
59+
# Default create_empty_header is empty dict
60+
assert_equal(dtf.header, {})
5461

5562

5663
def test_tractogram_file():
5764
assert_raises(NotImplementedError, TractogramFile.is_correct_format, "")
5865
assert_raises(NotImplementedError, TractogramFile.load, "")
59-
assert_raises(NotImplementedError, TractogramFile.create_empty_header)
6066

6167
# Testing calling the 'save' method of `TractogramFile` object.
6268
class DummyTractogramFile(TractogramFile):
@@ -72,10 +78,6 @@ def load(cls, fileobj, lazy_load=True):
7278
def save(self, fileobj):
7379
pass
7480

75-
@classmethod
76-
def create_empty_header(cls):
77-
return None
78-
7981
assert_raises(NotImplementedError,
8082
super(DummyTractogramFile,
8183
DummyTractogramFile(Tractogram)).save, "")

nibabel/streamlines/tractogram_file.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class TractogramFile(with_metaclass(ABCMeta)):
3939

4040
def __init__(self, tractogram, header=None):
4141
self._tractogram = tractogram
42-
self._header = {} if header is None else header
42+
self._header = self.create_empty_header() if header is None else header
4343

4444
@property
4545
def tractogram(self):
@@ -77,10 +77,10 @@ def is_correct_format(cls, fileobj):
7777
"""
7878
raise NotImplementedError()
7979

80-
@abstractclassmethod
80+
@classmethod
8181
def create_empty_header(cls):
8282
""" Returns an empty header for this streamlines file format. """
83-
raise NotImplementedError()
83+
return {}
8484

8585
@abstractclassmethod
8686
def load(cls, fileobj, lazy_load=True):

nibabel/streamlines/trk.py

+27-21
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def get_affine_trackvis_to_rasmm(header):
7272
7373
Parameters
7474
----------
75-
header : dict
76-
Dict containing trackvis header.
75+
header : dict or ndarray
76+
Dict or numpy structured array containing trackvis header.
7777
7878
Returns
7979
-------
@@ -101,9 +101,12 @@ def get_affine_trackvis_to_rasmm(header):
101101
# If the voxel order implied by the affine does not match the voxel
102102
# order in the TRK header, change the orientation.
103103
# voxel (header) -> voxel (affine)
104-
header_ornt = asstr(header[Field.VOXEL_ORDER])
104+
vox_order = header[Field.VOXEL_ORDER]
105+
# Input header can be dict or structured array
106+
if hasattr(vox_order, 'item'): # structured array
107+
vox_order = header[Field.VOXEL_ORDER].item()
105108
affine_ornt = "".join(aff2axcodes(header[Field.VOXEL_TO_RASMM]))
106-
header_ornt = axcodes2ornt(header_ornt)
109+
header_ornt = axcodes2ornt(vox_order.decode('latin1'))
107110
affine_ornt = axcodes2ornt(affine_ornt)
108111
ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt)
109112
M = nib.orientations.inv_ornt_aff(ornt, header[Field.DIMENSIONS])
@@ -235,10 +238,6 @@ def __init__(self, tractogram, header=None):
235238
and *mm* space where coordinate (0,0,0) refers to the center
236239
of the voxel.
237240
"""
238-
if header is None:
239-
header_rec = self.create_empty_header()
240-
header = dict(zip(header_rec.dtype.names, header_rec[0]))
241-
242241
super(TrkFile, self).__init__(tractogram, header)
243242

244243
@classmethod
@@ -266,20 +265,28 @@ def is_correct_format(cls, fileobj):
266265
return magic_number == cls.MAGIC_NUMBER
267266

268267
@classmethod
269-
def create_empty_header(cls):
270-
""" Return an empty compliant TRK header. """
271-
header = np.zeros(1, dtype=header_2_dtype)
268+
def _default_structarr(cls):
269+
""" Return an empty compliant TRK header as numpy structured array
270+
"""
271+
st_arr = np.zeros((), dtype=header_2_dtype)
272272

273273
# Default values
274-
header[Field.MAGIC_NUMBER] = cls.MAGIC_NUMBER
275-
header[Field.VOXEL_SIZES] = np.array((1, 1, 1), dtype="f4")
276-
header[Field.DIMENSIONS] = np.array((1, 1, 1), dtype="h")
277-
header[Field.VOXEL_TO_RASMM] = np.eye(4, dtype="f4")
278-
header[Field.VOXEL_ORDER] = b"RAS"
279-
header['version'] = 2
280-
header['hdr_size'] = cls.HEADER_SIZE
274+
st_arr[Field.MAGIC_NUMBER] = cls.MAGIC_NUMBER
275+
st_arr[Field.VOXEL_SIZES] = np.array((1, 1, 1), dtype="f4")
276+
st_arr[Field.DIMENSIONS] = np.array((1, 1, 1), dtype="h")
277+
st_arr[Field.VOXEL_TO_RASMM] = np.eye(4, dtype="f4")
278+
st_arr[Field.VOXEL_ORDER] = b"RAS"
279+
st_arr['version'] = 2
280+
st_arr['hdr_size'] = cls.HEADER_SIZE
281281

282-
return header
282+
return st_arr
283+
284+
@classmethod
285+
def create_empty_header(cls):
286+
""" Return an empty compliant TRK header as dict
287+
"""
288+
st_arr = cls._default_structarr()
289+
return dict(zip(st_arr.dtype.names, st_arr.tolist()))
283290

284291
@classmethod
285292
def load(cls, fileobj, lazy_load=False):
@@ -388,7 +395,7 @@ def save(self, fileobj):
388395
of the TRK header data).
389396
"""
390397
# Enforce little-endian byte order for header
391-
header = self.create_empty_header().newbyteorder('<')
398+
header = self._default_structarr().newbyteorder('<')
392399

393400
# Override hdr's fields by those contained in `header`.
394401
for k, v in self.header.items():
@@ -406,7 +413,6 @@ def save(self, fileobj):
406413
nb_scalars = 0
407414
nb_properties = 0
408415

409-
header = header[0]
410416
with Opener(fileobj, mode="wb") as f:
411417
# Keep track of the beginning of the header.
412418
beginning = f.tell()

0 commit comments

Comments
 (0)