diff --git a/nibabel/cifti2/caretspec.py b/nibabel/cifti2/caretspec.py new file mode 100644 index 000000000..6e32fb1d7 --- /dev/null +++ b/nibabel/cifti2/caretspec.py @@ -0,0 +1,217 @@ +# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +# +# See COPYING file distributed along with the NiBabel package for the +# copyright and license terms. +# +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +"""Read / write access to CaretSpecFile format + +The format of CaretSpecFiles does not seem to have any independent +documentation. + +Code can be found here [0], and a DTD was worked out in this email thread [1]. + +[0]: https://github.com/Washington-University/workbench/tree/master/src/Files +[1]: https://groups.google.com/a/humanconnectome.org/g/hcp-users/c/EGuwdaTVFuw/m/tg7a_-7mAQAJ +""" +import xml.etree.ElementTree as et +from urllib.parse import urlparse + +import nibabel as nb +from nibabel import pointset as ps +from nibabel import xmlutils as xml +from nibabel.caret import CaretMetaData + + +class CaretSpecDataFile(xml.XmlSerializable): + """DataFile + + * Attributes + + * Structure - A string from the BrainStructure list to identify + what structure this element refers to (usually left cortex, + right cortex, or cerebellum). + * DataFileType - A string from the DataFileType list + * Selected - A boolean + + * Child Elements: [NA] + * Text Content: A URI + * Parent Element - CaretSpecFile + + Attributes + ---------- + structure : str + Name of brain structure + data_file_type : str + Type of data file + selected : bool + Used for workbench internals + uri : str + URI of data file + """ + + def __init__(self, structure=None, data_file_type=None, selected=None, uri=None): + super().__init__() + self.structure = structure + self.data_file_type = data_file_type + self.selected = selected + self.uri = uri + + if data_file_type == 'SURFACE': + self.__class__ = SurfaceDataFile + + def _to_xml_element(self): + data_file = xml.Element('DataFile') + data_file.attrib['Structure'] = str(self.structure) + data_file.attrib['DataFileType'] = str(self.data_file_type) + data_file.attrib['Selected'] = 'true' if self.selected else 'false' + data_file.text = self.uri + return data_file + + def __repr__(self): + return self.to_xml().decode() + + +class SurfaceDataFile(ps.TriangularMesh, CaretSpecDataFile): + _gifti = None + _coords = None + _triangles = None + + def _get_gifti(self): + if self._gifti is None: + parts = urlparse(self.uri) + if parts.scheme == 'file': + self._gifti = nb.load(parts.path) + elif parts.scheme == '': + self._gifti = nb.load(self.uri) + else: + self._gifti = nb.GiftiImage.from_url(self.uri) + return self._gifti + + def get_triangles(self, name=None): + if self._triangles is None: + gifti = self._get_gifti() + self._triangles = gifti.agg_data('triangle') + return self._triangles + + def get_coords(self, name=None): + if self._coords is None: + gifti = self._get_gifti() + self._coords = gifti.agg_data('pointset') + return self._coords + + +class CaretSpecFile(xml.XmlSerializable): + """Class for CaretSpecFile XML documents + + These are used to identify related surfaces and volumes for use with CIFTI-2 + data files. + """ + + def __init__(self, metadata=None, data_files=(), version='1.0'): + super().__init__() + if metadata is not None: + metadata = CaretMetaData(metadata) + self.metadata = metadata + self.data_files = list(data_files) + self.version = version + + def _to_xml_element(self): + caret_spec = xml.Element('CaretSpecFile') + caret_spec.attrib['Version'] = str(self.version) + if self.metadata is not None: + caret_spec.append(self.metadata._to_xml_element()) + for data_file in self.data_files: + caret_spec.append(data_file._to_xml_element()) + return caret_spec + + def to_xml(self, enc='UTF-8', **kwargs): + ele = self._to_xml_element() + et.indent(ele, ' ') + return et.tostring(ele, enc, xml_declaration=True, short_empty_elements=False, **kwargs) + + def __eq__(self, other): + return self.to_xml() == other.to_xml() + + @classmethod + def from_filename(klass, fname, **kwargs): + parser = CaretSpecParser(**kwargs) + with open(fname, 'rb') as fobj: + parser.parse(fptr=fobj) + return parser.caret_spec + + +class CaretSpecParser(xml.XmlParser): + def __init__(self, encoding=None, buffer_size=3500000, verbose=0): + super().__init__(encoding=encoding, buffer_size=buffer_size, verbose=verbose) + self.struct_state = [] + + self.caret_spec = None + + # where to write CDATA: + self.write_to = None + + # Collecting char buffer fragments + self._char_blocks = [] + + def StartElementHandler(self, name, attrs): + self.flush_chardata() + if name == 'CaretSpecFile': + self.caret_spec = CaretSpecFile(version=attrs['Version']) + elif name == 'MetaData': + self.caret_spec.metadata = CaretMetaData() + elif name == 'MD': + self.struct_state.append({}) + elif name in ('Name', 'Value'): + self.write_to = name + elif name == 'DataFile': + selected_map = {'true': True, 'false': False} + data_file = CaretSpecDataFile( + structure=attrs['Structure'], + data_file_type=attrs['DataFileType'], + selected=selected_map[attrs['Selected']], + ) + self.caret_spec.data_files.append(data_file) + self.struct_state.append(data_file) + self.write_to = 'DataFile' + + def EndElementHandler(self, name): + self.flush_chardata() + if name == 'MD': + MD = self.struct_state.pop() + self.caret_spec.metadata[MD['Name']] = MD['Value'] + elif name in ('Name', 'Value'): + self.write_to = None + elif name == 'DataFile': + self.struct_state.pop() + self.write_to = None + + def CharacterDataHandler(self, data): + """Collect character data chunks pending collation + + The parser breaks the data up into chunks of size depending on the + buffer_size of the parser. A large bit of character data, with standard + parser buffer_size (such as 8K) can easily span many calls to this + function. We thus collect the chunks and process them when we hit start + or end tags. + """ + if self._char_blocks is None: + self._char_blocks = [] + self._char_blocks.append(data) + + def flush_chardata(self): + """Collate and process collected character data""" + if self._char_blocks is None: + return + + data = ''.join(self._char_blocks).strip() + # Reset the char collector + self._char_blocks = None + # Process data + if self.write_to in ('Name', 'Value'): + self.struct_state[-1][self.write_to] = data + + elif self.write_to == 'DataFile': + self.struct_state[-1].uri = data diff --git a/nibabel/cifti2/tests/test_caretspec.py b/nibabel/cifti2/tests/test_caretspec.py new file mode 100644 index 000000000..604808c3f --- /dev/null +++ b/nibabel/cifti2/tests/test_caretspec.py @@ -0,0 +1,34 @@ +import unittest +from pathlib import Path + +from nibabel.cifti2.caretspec import * +from nibabel.optpkg import optional_package +from nibabel.testing import data_path + +requests, has_requests, _ = optional_package('requests') + + +def test_CaretSpecFile(): + fsLR = CaretSpecFile.from_filename(Path(data_path) / 'fsLR.wb.spec') + + assert fsLR.metadata == {} + assert fsLR.version == '1.0' + assert len(fsLR.data_files) == 5 + + for df in fsLR.data_files: + assert isinstance(df, CaretSpecDataFile) + if df.data_file_type == 'SURFACE': + assert isinstance(df, SurfaceDataFile) + + +@unittest.skipUnless(has_requests, reason='Test fetches from URL') +def test_SurfaceDataFile(): + fsLR = CaretSpecFile.from_filename(Path(data_path) / 'fsLR.wb.spec') + df = fsLR.data_files[0] + assert df.data_file_type == 'SURFACE' + try: + coords, triangles = df.get_mesh() + except IOError: + raise unittest.SkipTest(reason='Broken URL') + assert coords.shape == (32492, 3) + assert triangles.shape == (64980, 3) diff --git a/nibabel/coordimage.py b/nibabel/coordimage.py new file mode 100644 index 000000000..d418d4743 --- /dev/null +++ b/nibabel/coordimage.py @@ -0,0 +1,185 @@ +import numpy as np + +import nibabel as nib +import nibabel.pointset as ps +from nibabel.fileslice import fill_slicer + + +class CoordinateImage: + """ + Attributes + ---------- + header : a file-specific header + coordaxis : ``CoordinateAxis`` + dataobj : array-like + """ + + def __init__(self, data, coordaxis, header=None): + self.data = data + self.coordaxis = coordaxis + self.header = header + + @property + def shape(self): + return self.data.shape + + def __getitem__(self, slicer): + if isinstance(slicer, str): + slicer = self.coordaxis.get_indices(slicer) + elif isinstance(slicer, list): + slicer = np.hstack([self.coordaxis.get_indices(sub) for sub in slicer]) + + if isinstance(slicer, range): + slicer = slice(slicer.start, slicer.stop, slicer.step) + + data = self.data + if not isinstance(slicer, slice): + data = np.asanyarray(data) + return self.__class__(data[slicer], self.coordaxis[slicer], header=self.header.copy()) + + @classmethod + def from_image(klass, img): + coordaxis = CoordinateAxis.from_header(img.header) + if isinstance(img, nib.Cifti2Image): + if img.ndim != 2: + raise ValueError('Can only interpret 2D images') + for i in img.header.mapped_indices: + if isinstance(img.header.get_axis(i), nib.cifti2.BrainModelAxis): + break + # Reinterpret data ordering based on location of coordinate axis + data = img.dataobj.copy() + data.order = ['F', 'C'][i] + if i == 1: + data._shape = data._shape[::-1] + return klass(data, coordaxis, img.header) + + +class CoordinateAxis: + """ + Attributes + ---------- + parcels : list of ``Parcel`` objects + """ + + def __init__(self, parcels): + self.parcels = parcels + + def load_structures(self, mapping): + """ + Associate parcels to ``Pointset`` structures + """ + raise NotImplementedError + + def __getitem__(self, slicer): + """ + Return a sub-sampled CoordinateAxis containing structures + matching the indices provided. + """ + if slicer is Ellipsis or isinstance(slicer, slice) and slicer == slice(None): + return self + elif isinstance(slicer, slice): + slicer = fill_slicer(slicer, len(self)) + start, stop, step = slicer.start, slicer.stop, slicer.step + else: + raise TypeError(f'Indexing type not supported: {type(slicer)}') + + subparcels = [] + pstop = 0 + for parcel in self.parcels: + pstart, pstop = pstop, pstop + len(parcel) + if pstop < start: + continue + if pstart >= stop: + break + if start < pstart: + substart = (start - pstart) % step + else: + substart = start - pstart + subparcels.append(parcel[substart : stop - pstart : step]) + return CoordinateAxis(subparcels) + + def get_indices(self, parcel, indices=None): + """ + Return the indices in the full axis that correspond to the + requested parcel. If indices are provided, further subsample + the requested parcel. + """ + subseqs = [] + idx = 0 + for p in self.parcels: + if p.name == parcel: + subseqs.append(range(idx, idx + len(p))) + idx += len(p) + if not subseqs: + return () + if indices: + return np.hstack(subseqs)[indices] + if len(subseqs) == 1: + return subseqs[0] + return np.hstack(subseqs) + + def __len__(self): + return sum(len(parcel) for parcel in self.parcels) + + # Hacky factory method for now + @classmethod + def from_header(klass, hdr): + parcels = [] + if isinstance(hdr, nib.Cifti2Header): + axes = [hdr.get_axis(i) for i in hdr.mapped_indices] + for ax in axes: + if isinstance(ax, nib.cifti2.BrainModelAxis): + break + else: + raise ValueError('No BrainModelAxis, cannot create CoordinateAxis') + for name, slicer, struct in ax.iter_structures(): + if struct.volume_shape: + substruct = ps.NdGrid(struct.volume_shape, struct.affine) + indices = struct.voxel + else: + substruct = None + indices = struct.vertex + parcels.append(Parcel(name, substruct, indices)) + + return klass(parcels) + + +class Parcel: + """ + Attributes + ---------- + name : str + structure : ``Pointset`` + indices : object that selects a subset of coordinates in structure + """ + + def __init__(self, name, structure, indices): + self.name = name + self.structure = structure + self.indices = indices + + def __repr__(self): + return f'' + + def __len__(self): + return len(self.indices) + + def __getitem__(self, slicer): + return self.__class__(self.name, self.structure, self.indices[slicer]) + + +class GeometryCollection: + """ + Attributes + ---------- + structures : dict + Mapping from structure names to ``Pointset`` + """ + + def __init__(self, structures): + self.structures = structures + + @classmethod + def from_spec(klass, pathlike): + """Load a collection of geometries from a specification.""" + raise NotImplementedError diff --git a/nibabel/pointset.py b/nibabel/pointset.py index 759a0b15e..b30e30caa 100644 --- a/nibabel/pointset.py +++ b/nibabel/pointset.py @@ -47,7 +47,12 @@ def __array__(self, dtype: None = ..., /) -> np.ndarray[ty.Any, np.dtype[ty.Any] def __array__(self, dtype: _DType, /) -> np.ndarray[ty.Any, _DType]: ... -@dataclass +class HasMeshAttrs(ty.Protocol): + coordinates: CoordinateArray + triangles: CoordinateArray + + +@dataclass(init=False) class Pointset: """A collection of points described by coordinates. @@ -64,7 +69,7 @@ class Pointset: coordinates: CoordinateArray affine: np.ndarray - homogeneous: bool = False + homogeneous: bool # Force use of __rmatmul__ with numpy arrays __array_priority__ = 99 @@ -147,6 +152,82 @@ def get_coords(self, *, as_homogeneous: bool = False): return coords +@dataclass(init=False) +class TriangularMesh(Pointset): + triangles: CoordinateArray + + def __init__( + self, + coordinates: CoordinateArray, + triangles: CoordinateArray, + affine: np.ndarray | None = None, + homogeneous: bool = False, + ): + super().__init__(coordinates, affine=affine, homogeneous=homogeneous) + self.triangles = triangles + + @classmethod + def from_tuple( + cls, + mesh: tuple[CoordinateArray, CoordinateArray], + affine: np.ndarray | None = None, + homogeneous: bool = False, + **kwargs, + ) -> Self: + return cls(mesh[0], mesh[1], affine=affine, homogeneous=homogeneous, **kwargs) + + @classmethod + def from_object( + cls, + mesh: HasMeshAttrs, + affine: np.ndarray | None = None, + homogeneous: bool = False, + **kwargs, + ) -> Self: + return cls( + mesh.coordinates, mesh.triangles, affine=affine, homogeneous=homogeneous, **kwargs + ) + + @property + def n_triangles(self): + """Number of faces + + Subclasses should override with more efficient implementations. + """ + return self.triangles.shape[0] + + def get_triangles(self): + """Mx3 array of indices into coordinate table""" + return np.asanyarray(self.triangles) + + def get_mesh(self, *, as_homogeneous: bool = False): + return self.get_coords(as_homogeneous=as_homogeneous), self.get_triangles() + + +class CoordinateFamilyMixin(Pointset): + def __init__(self, *args, name='original', **kwargs): + mapping = kwargs.pop('mapping', {}) + super().__init__(*args, **kwargs) + self._coords = {name: self.coordinates, **mapping} + + def get_names(self): + """List of surface names that can be passed to :meth:`with_name`""" + return list(self._coords) + + def with_name(self, name: str) -> Self: + new_coords = self._coords[name] + if new_coords is self.coordinates: + return self + # Make a copy, preserving all dataclass fields + new = replace(self, coordinates=new_coords) + # Conserve exact _coords mapping + new._coords = self._coords + return new + + def add_coordinates(self, name, coordinates): + self._coords[name] = coordinates + + class Grid(Pointset): r"""A regularly-spaced collection of coordinates diff --git a/nibabel/tests/data/fsLR.wb.spec b/nibabel/tests/data/fsLR.wb.spec new file mode 100644 index 000000000..b7ad95831 --- /dev/null +++ b/nibabel/tests/data/fsLR.wb.spec @@ -0,0 +1,30 @@ + + + + + + https://raw.githubusercontent.com/mgxd/brainplot/master/brainplot/Conte69_Atlas/Conte69.L.midthickness.32k_fs_LR.surf.gii + + + https://raw.githubusercontent.com/mgxd/brainplot/master/brainplot/Conte69_Atlas/Conte69.R.midthickness.32k_fs_LR.surf.gii + + + https://raw.githubusercontent.com/mgxd/brainplot/master/brainplot/Conte69_Atlas/Conte69.L.very_inflated.32k_fs_LR.surf.gii + + + https://raw.githubusercontent.com/mgxd/brainplot/master/brainplot/Conte69_Atlas/Conte69.R.very_inflated.32k_fs_LR.surf.gii + + + https://templateflow.s3.amazonaws.com/tpl-MNI152NLin6Asym/tpl-MNI152NLin6Asym_res-02_T1w.nii.gz + + diff --git a/nibabel/tests/test_coordimage.py b/nibabel/tests/test_coordimage.py new file mode 100644 index 000000000..7318074f9 --- /dev/null +++ b/nibabel/tests/test_coordimage.py @@ -0,0 +1,80 @@ +import os +from pathlib import Path + +import nibabel as nb +from nibabel import coordimage as ci +from nibabel import pointset as ps +from nibabel.tests.nibabel_data import get_nibabel_data + +from .test_pointset import FreeSurferHemisphere + +CIFTI2_DATA = Path(get_nibabel_data()) / 'nitest-cifti2' + + +class FreeSurferSubject(ci.GeometryCollection): + @classmethod + def from_subject(klass, subject_id, subjects_dir=None): + """Load a FreeSurfer subject by ID""" + if subjects_dir is None: + subjects_dir = os.environ['SUBJECTS_DIR'] + return klass.from_spec(Path(subjects_dir) / subject_id) + + @classmethod + def from_spec(klass, pathlike): + """Load a FreeSurfer subject from its directory structure""" + subject_dir = Path(pathlike) + surfs = subject_dir / 'surf' + structures = { + 'lh': FreeSurferHemisphere.from_filename(surfs / 'lh.white'), + 'rh': FreeSurferHemisphere.from_filename(surfs / 'rh.white'), + } + subject = klass(structures) + subject._subject_dir = subject_dir + return subject + + +class CaretSpec(ci.GeometryCollection): + @classmethod + def from_spec(klass, pathlike): + from nibabel.cifti2.caretspec import CaretSpecFile + + csf = CaretSpecFile.from_filename(pathlike) + structures = { + df.structure: df.uri + for df in csf.data_files + if df.selected # Use selected to avoid overloading for now + } + wbspec = klass(structures) + wbspec._specfile = csf + return wbspec + + +def test_Cifti2Image_as_CoordImage(): + ones = nb.load(CIFTI2_DATA / 'ones.dscalar.nii') + assert ones.shape == (1, 91282) + cimg = ci.CoordinateImage.from_image(ones) + assert cimg.shape == (91282, 1) + + caxis = cimg.coordaxis + assert len(caxis) == 91282 + assert caxis[...] is caxis + assert caxis[:] is caxis + + subaxis = caxis[:100] + assert len(subaxis) == 100 + assert len(subaxis.parcels) == 1 + subaxis = caxis[100:] + assert len(subaxis) == len(caxis) - 100 + assert len(subaxis.parcels) == len(caxis.parcels) + subaxis = caxis[100:-100] + assert len(subaxis) == len(caxis) - 200 + assert len(subaxis.parcels) == len(caxis.parcels) + + lh_img = cimg['CIFTI_STRUCTURE_CORTEX_LEFT'] + assert len(lh_img.coordaxis.parcels) == 1 + assert lh_img.shape == (29696, 1) + + # # Not working yet. + # cortex_img = cimg[["CIFTI_STRUCTURE_CORTEX_LEFT", "CIFTI_STRUCTURE_CORTEX_RIGHT"]] + # assert len(cortex_img.coordaxis.parcels) == 2 + # assert cortex_img.shape == (59412, 1) diff --git a/nibabel/tests/test_pointset.py b/nibabel/tests/test_pointset.py index f4f0e4361..850c2869e 100644 --- a/nibabel/tests/test_pointset.py +++ b/nibabel/tests/test_pointset.py @@ -1,3 +1,4 @@ +from collections import namedtuple from math import prod from pathlib import Path @@ -9,7 +10,7 @@ from nibabel.fileslice import strided_scalar from nibabel.optpkg import optional_package from nibabel.spatialimages import SpatialImage -from nibabel.tests.nibabel_data import get_nibabel_data +from nibabel.tests.nibabel_data import get_nibabel_data, needs_nibabel_data h5, has_h5py, _ = optional_package('h5py') @@ -179,3 +180,247 @@ def test_to_mask(self): ], ) assert np.array_equal(mask_img.affine, np.eye(4)) + + +class TestTriangularMeshes: + def test_api(self): + # Tetrahedron + coords = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ] + ) + triangles = np.array( + [ + [0, 2, 1], + [0, 3, 2], + [0, 1, 3], + [1, 2, 3], + ] + ) + + mesh = namedtuple('mesh', ('coordinates', 'triangles'))(coords, triangles) + + tm1 = ps.TriangularMesh(coords, triangles) + tm2 = ps.TriangularMesh.from_tuple(mesh) + tm3 = ps.TriangularMesh.from_object(mesh) + + assert np.allclose(tm1.affine, np.eye(4)) + assert np.allclose(tm2.affine, np.eye(4)) + assert np.allclose(tm3.affine, np.eye(4)) + + assert tm1.homogeneous is False + assert tm2.homogeneous is False + assert tm3.homogeneous is False + + assert (tm1.n_coords, tm1.dim) == (4, 3) + assert (tm2.n_coords, tm2.dim) == (4, 3) + assert (tm3.n_coords, tm3.dim) == (4, 3) + + assert tm1.n_triangles == 4 + assert tm2.n_triangles == 4 + assert tm3.n_triangles == 4 + + out_coords, out_tris = tm1.get_mesh() + # Currently these are the exact arrays, but I don't think we should + # bake that assumption into the tests + assert np.allclose(out_coords, coords) + assert np.allclose(out_tris, triangles) + + +class TestCoordinateFamilyMixin(TestPointsets): + def test_names(self): + coords = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + ] + ) + cfm = ps.CoordinateFamilyMixin(coords) + + assert cfm.get_names() == ['original'] + assert np.allclose(cfm.with_name('original').coordinates, coords) + + cfm.add_coordinates('shifted', coords + 1) + assert set(cfm.get_names()) == {'original', 'shifted'} + shifted = cfm.with_name('shifted') + assert np.allclose(shifted.coordinates, coords + 1) + assert set(shifted.get_names()) == {'original', 'shifted'} + original = shifted.with_name('original') + assert np.allclose(original.coordinates, coords) + + # Avoid duplicating objects + assert original.with_name('original') is original + # But don't try too hard + assert original.with_name('original') is not cfm + + # with_name() preserves the exact coordinate mapping of the source object. + # Modifications of one are immediately available to all others. + # This is currently an implementation detail, and the expectation is that + # a family will be created once and then queried, but this behavior could + # potentially become confusing or relied upon. + # Change with care. + shifted.add_coordinates('shifted-again', coords + 2) + shift2 = shifted.with_name('shifted-again') + shift3 = cfm.with_name('shifted-again') + + +class H5ArrayProxy: + def __init__(self, file_like, dataset_name): + self.file_like = file_like + self.dataset_name = dataset_name + with h5.File(file_like, 'r') as h5f: + arr = h5f[dataset_name] + self._shape = arr.shape + self._dtype = arr.dtype + + @property + def is_proxy(self): + return True + + @property + def shape(self): + return self._shape + + @property + def ndim(self): + return len(self.shape) + + @property + def dtype(self): + return self._dtype + + def __array__(self, dtype=None): + with h5.File(self.file_like, 'r') as h5f: + return np.asanyarray(h5f[self.dataset_name], dtype) + + def __getitem__(self, slicer): + with h5.File(self.file_like, 'r') as h5f: + return h5f[self.dataset_name][slicer] + + +class H5Geometry(ps.CoordinateFamilyMixin, ps.TriangularMesh): + """Simple Geometry file structure that combines a single topology + with one or more coordinate sets + """ + + @classmethod + def from_filename(klass, pathlike): + coords = {} + with h5.File(pathlike, 'r') as h5f: + triangles = H5ArrayProxy(pathlike, '/topology') + for name in h5f['coordinates']: + coords[name] = H5ArrayProxy(pathlike, f'/coordinates/{name}') + self = klass(next(iter(coords.values())), triangles, mapping=coords) + return self + + def to_filename(self, pathlike): + with h5.File(pathlike, 'w') as h5f: + h5f.create_dataset('/topology', data=self.get_triangles()) + for name, coord in self._coords.items(): + h5f.create_dataset(f'/coordinates/{name}', data=coord) + + +class FSGeometryProxy: + def __init__(self, pathlike): + self._file_like = str(Path(pathlike)) + self._offset = None + self._vnum = None + self._fnum = None + + def _peek(self): + from nibabel.freesurfer.io import _fread3 + + with open(self._file_like, 'rb') as fobj: + magic = _fread3(fobj) + if magic != 16777214: + raise NotImplementedError('Triangle files only!') + fobj.readline() + fobj.readline() + self._vnum = np.fromfile(fobj, '>i4', 1)[0] + self._fnum = np.fromfile(fobj, '>i4', 1)[0] + self._offset = fobj.tell() + + @property + def vnum(self): + if self._vnum is None: + self._peek() + return self._vnum + + @property + def fnum(self): + if self._fnum is None: + self._peek() + return self._fnum + + @property + def offset(self): + if self._offset is None: + self._peek() + return self._offset + + @auto_attr + def coordinates(self): + return ArrayProxy(self._file_like, ((self.vnum, 3), '>f4', self.offset), order='C') + + @auto_attr + def triangles(self): + return ArrayProxy( + self._file_like, + ((self.fnum, 3), '>i4', self.offset + 12 * self.vnum), + order='C', + ) + + +class FreeSurferHemisphere(ps.CoordinateFamilyMixin, ps.TriangularMesh): + @classmethod + def from_filename(klass, pathlike): + path = Path(pathlike) + hemi, default = path.name.split('.') + self = klass.from_object(FSGeometryProxy(path), name=default) + mesh_names = ( + 'orig', + 'white', + 'smoothwm', + 'pial', + 'inflated', + 'sphere', + 'midthickness', + 'graymid', + ) # Often created + + for mesh in mesh_names: + if mesh != default: + fpath = path.parent / f'{hemi}.{mesh}' + if fpath.exists(): + self.add_coordinates(mesh, FSGeometryProxy(fpath).coordinates) + return self + + +@needs_nibabel_data('nitest-freesurfer') +def test_FreeSurferHemisphere(): + lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white') + assert lh.n_coords == 163842 + assert lh.n_triangles == 327680 + + +@skipUnless(has_h5py, reason='Test requires h5py') +@needs_nibabel_data('nitest-freesurfer') +def test_make_H5Geometry(tmp_path): + lh = FreeSurferHemisphere.from_filename(FS_DATA / 'fsaverage/surf/lh.white') + h5geo = H5Geometry.from_object(lh) + for name in ('white', 'pial'): + h5geo.add_coordinates(name, lh.with_name(name).coordinates) + h5geo.to_filename(tmp_path / 'geometry.h5') + + rt_h5geo = H5Geometry.from_filename(tmp_path / 'geometry.h5') + assert set(h5geo._coords) == set(rt_h5geo._coords) + assert np.array_equal( + lh.with_name('white').get_coords(), rt_h5geo.with_name('white').get_coords() + ) + assert np.array_equal(lh.get_triangles(), rt_h5geo.get_triangles())