|
24 | 24 | TensorAttr,
|
25 | 25 | _field_status,
|
26 | 26 | )
|
| 27 | +from torch_geometric.data.graph_store import EdgeAttr, EdgeLayout, GraphStore |
27 | 28 | from torch_geometric.data.storage import (
|
28 | 29 | BaseStorage,
|
29 | 30 | EdgeStorage,
|
30 | 31 | GlobalStorage,
|
31 | 32 | NodeStorage,
|
32 | 33 | )
|
33 | 34 | from torch_geometric.deprecation import deprecated
|
34 |
| -from torch_geometric.typing import EdgeType, NodeType, OptTensor |
| 35 | +from torch_geometric.typing import ( |
| 36 | + Adj, |
| 37 | + EdgeTensorType, |
| 38 | + EdgeType, |
| 39 | + FeatureTensorType, |
| 40 | + NodeType, |
| 41 | + OptTensor, |
| 42 | +) |
35 | 43 | from torch_geometric.utils import subgraph
|
36 | 44 |
|
37 | 45 |
|
@@ -316,7 +324,17 @@ def __init__(self, attr_name=_field_status.UNSET,
|
316 | 324 | super().__init__(None, attr_name, index)
|
317 | 325 |
|
318 | 326 |
|
319 |
| -class Data(BaseData, FeatureStore): |
| 327 | +@dataclass |
| 328 | +class DataEdgeAttr(EdgeAttr): |
| 329 | + r"""Edge attribute class for `Data`, which does not require a |
| 330 | + `edge_type`.""" |
| 331 | + def __init__(self, layout: EdgeLayout, is_sorted: bool = False, |
| 332 | + edge_type: EdgeType = None): |
| 333 | + # Treat group_name as optional, and move it to the end |
| 334 | + super().__init__(edge_type, layout, is_sorted) |
| 335 | + |
| 336 | + |
| 337 | +class Data(BaseData, FeatureStore, GraphStore): |
320 | 338 | r"""A data object describing a homogeneous graph.
|
321 | 339 | The data object can hold node-level, link-level and graph-level attributes.
|
322 | 340 | In general, :class:`~torch_geometric.data.Data` tries to mimic the
|
@@ -366,7 +384,11 @@ def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
|
366 | 384 | pos: OptTensor = None, **kwargs):
|
367 | 385 | # `Data` doesn't support group_name, so we need to adjust `TensorAttr`
|
368 | 386 | # accordingly here to avoid requiring `group_name` to be set:
|
369 |
| - super().__init__(attr_cls=DataTensorAttr) |
| 387 | + super().__init__(tensor_attr_cls=DataTensorAttr) |
| 388 | + |
| 389 | + # `Data` doesn't support edge_type, so we need to adjust `EdgeAttr` |
| 390 | + # accordingly here to avoid requiring `edge_type` to be set: |
| 391 | + GraphStore.__init__(self, edge_attr_cls=DataEdgeAttr) |
370 | 392 |
|
371 | 393 | self.__dict__['_store'] = GlobalStorage(_parent=self)
|
372 | 394 |
|
@@ -755,9 +777,79 @@ def _remove_tensor(self, attr: TensorAttr) -> bool:
|
755 | 777 | def __len__(self) -> int:
|
756 | 778 | return BaseData.__len__(self)
|
757 | 779 |
|
| 780 | + # GraphStore interface #################################################### |
| 781 | + |
| 782 | + def _put_edge_index(self, edge_index: EdgeTensorType, |
| 783 | + edge_attr: EdgeAttr) -> bool: |
| 784 | + # Convert the edge index to a recognizable format: |
| 785 | + attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout] |
| 786 | + attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index) |
| 787 | + setattr(self, attr_name, attr_val) |
| 788 | + return True |
| 789 | + |
| 790 | + def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]: |
| 791 | + # Get the requested format and the Adj tensor associated with it: |
| 792 | + attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout] |
| 793 | + attr_val = getattr(self._store, attr_name, None) |
| 794 | + if attr_val is not None: |
| 795 | + # Convert from Adj type to Tuple[Tensor, Tensor] |
| 796 | + attr_val = adj_type_to_edge_tensor_type(edge_attr.layout, attr_val) |
| 797 | + return attr_val |
| 798 | + |
758 | 799 |
|
759 | 800 | ###############################################################################
|
760 | 801 |
|
| 802 | +EDGE_LAYOUT_TO_ATTR_NAME = { |
| 803 | + EdgeLayout.COO: 'edge_index', |
| 804 | + EdgeLayout.CSR: 'adj', |
| 805 | + EdgeLayout.CSC: 'adj_t', |
| 806 | +} |
| 807 | + |
| 808 | + |
| 809 | +def edge_tensor_type_to_adj_type( |
| 810 | + attr: EdgeAttr, |
| 811 | + tensor_tuple: EdgeTensorType, |
| 812 | +) -> Adj: |
| 813 | + r"""Converts an EdgeTensorType tensor tuple to a PyG Adj tensor.""" |
| 814 | + if attr.layout == EdgeLayout.COO: |
| 815 | + # COO: (row, col) |
| 816 | + if (tensor_tuple[0].storage().data_ptr() == |
| 817 | + tensor_tuple[1].storage().data_ptr()): |
| 818 | + # Do not copy if the tensor tuple is constructed from the same |
| 819 | + # storage (instead, return a view): |
| 820 | + out = torch.empty(0, dtype=tensor_tuple[0].dtype) |
| 821 | + out.set_(tensor_tuple[0].storage(), storage_offset=0, |
| 822 | + size=tensor_tuple[0].size() + tensor_tuple[1].size()) |
| 823 | + return out.view(2, -1) |
| 824 | + return torch.stack(tensor_tuple) |
| 825 | + elif attr.layout == EdgeLayout.CSR: |
| 826 | + # CSR: (rowptr, col) |
| 827 | + return SparseTensor(rowptr=tensor_tuple[0], col=tensor_tuple[1], |
| 828 | + is_sorted=True) |
| 829 | + elif attr.layout == EdgeLayout.CSC: |
| 830 | + # CSC: (row, colptr) this is a transposed adjacency matrix, so rowptr |
| 831 | + # is the compressed column and col is the uncompressed row. |
| 832 | + return SparseTensor(rowptr=tensor_tuple[1], col=tensor_tuple[0], |
| 833 | + is_sorted=True) |
| 834 | + raise ValueError(f"Bad edge layout (got '{attr.layout}')") |
| 835 | + |
| 836 | + |
| 837 | +def adj_type_to_edge_tensor_type(layout: EdgeLayout, |
| 838 | + edge_index: Adj) -> EdgeTensorType: |
| 839 | + r"""Converts a PyG Adj tensor to an EdgeTensorType equivalent.""" |
| 840 | + if isinstance(edge_index, Tensor): |
| 841 | + return (edge_index[0], edge_index[1]) |
| 842 | + if layout == EdgeLayout.COO: |
| 843 | + row, col, _ = edge_index.coo() |
| 844 | + return (row, col) |
| 845 | + elif layout == EdgeLayout.CSR: |
| 846 | + rowptr, col, _ = edge_index.csr() |
| 847 | + return (rowptr, col) |
| 848 | + else: |
| 849 | + # CSC is just adj_t.csr(): |
| 850 | + colptr, row, _ = edge_index.csr() |
| 851 | + return (row, colptr) |
| 852 | + |
761 | 853 |
|
762 | 854 | def size_repr(key: Any, value: Any, indent: int = 0) -> str:
|
763 | 855 | pad = ' ' * indent
|
|
0 commit comments