Skip to content

Commit b274fbd

Browse files
authored
GraphStore definition + Data and HeteroData integration (#4816)
1 parent 4b30b6d commit b274fbd

10 files changed

+369
-24
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))
89
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807))
910
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
1011
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))

test/data/test_data.py

+32
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
import torch
55
import torch.multiprocessing as mp
6+
import torch_sparse
67

78
import torch_geometric
89
from torch_geometric.data import Data
@@ -264,3 +265,34 @@ def test_basic_feature_store():
264265
assert 'x' in data.__dict__['_store']
265266
data.remove_tensor(attr_name='x', index=None)
266267
assert 'x' not in data.__dict__['_store']
268+
269+
270+
# Graph Store #################################################################
271+
272+
273+
def test_basic_graph_store():
274+
data = Data()
275+
276+
edge_index = torch.LongTensor([[0, 1], [1, 2]])
277+
adj = torch_sparse.SparseTensor(row=edge_index[0], col=edge_index[1])
278+
279+
def assert_equal_tensor_tuple(expected, actual):
280+
assert len(expected) == len(actual)
281+
for i in range(len(expected)):
282+
assert torch.equal(expected[i], actual[i])
283+
284+
# We put all three tensor types: COO, CSR, and CSC, and we get them back
285+
# to confirm that `GraphStore` works as intended.
286+
coo = adj.coo()[:-1]
287+
csr = adj.csr()[:-1]
288+
csc = adj.csc()[:-1]
289+
290+
# Put:
291+
data.put_edge_index(coo, layout='coo')
292+
data.put_edge_index(csr, layout='csr')
293+
data.put_edge_index(csc, layout='csc')
294+
295+
# Get:
296+
assert_equal_tensor_tuple(coo, data.get_edge_index('coo'))
297+
assert_equal_tensor_tuple(csr, data.get_edge_index('csr'))
298+
assert_equal_tensor_tuple(csc, data.get_edge_index('csc'))

test/data/test_feature_store.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, attr_name=_field_status.UNSET,
6666
class MyFeatureStoreNoGroupName(MyFeatureStore):
6767
def __init__(self):
6868
super().__init__()
69-
self._attr_cls = MyTensorAttrNoGroupName
69+
self._tensor_attr_cls = MyTensorAttrNoGroupName
7070

7171
@staticmethod
7272
def key(attr: TensorAttr) -> str:

test/data/test_graph_store.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch_sparse import SparseTensor
5+
6+
from torch_geometric.data.graph_store import (
7+
EdgeAttr,
8+
EdgeLayout,
9+
EdgeTensorType,
10+
GraphStore,
11+
)
12+
13+
14+
class MyGraphStore(GraphStore):
15+
def __init__(self):
16+
super().__init__()
17+
self.store = {}
18+
19+
@staticmethod
20+
def key(attr: EdgeAttr) -> str:
21+
return f"{attr.edge_type or '<default>'}_{attr.layout}"
22+
23+
def _put_edge_index(self, edge_index: EdgeTensorType,
24+
edge_attr: EdgeAttr) -> bool:
25+
self.store[MyGraphStore.key(edge_attr)] = edge_index
26+
27+
def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
28+
return self.store.get(MyGraphStore.key(edge_attr), None)
29+
30+
31+
def test_graph_store():
32+
graph_store = MyGraphStore()
33+
edge_index = torch.LongTensor([[0, 1], [1, 2]])
34+
adj = SparseTensor(row=edge_index[0], col=edge_index[1])
35+
36+
def assert_equal_tensor_tuple(expected, actual):
37+
assert len(expected) == len(actual)
38+
for i in range(len(expected)):
39+
assert torch.equal(expected[i], actual[i])
40+
41+
# We put all three tensor types: COO, CSR, and CSC, and we get them back
42+
# to confirm that `GraphStore` works as intended.
43+
coo = adj.coo()[:-1]
44+
csr = adj.csr()[:-1]
45+
csc = adj.csc()[:-1]
46+
47+
# Put:
48+
graph_store['edge', EdgeLayout.COO] = coo
49+
graph_store['edge', 'csr'] = csr
50+
graph_store['edge', 'csc'] = csc
51+
52+
# Get:
53+
assert_equal_tensor_tuple(coo, graph_store['edge', 'coo'])
54+
assert_equal_tensor_tuple(csr, graph_store['edge', 'csr'])
55+
assert_equal_tensor_tuple(csc, graph_store['edge', 'csc'])

test/data/test_hetero_data.py

+35
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
import torch
5+
import torch_sparse
56

67
from torch_geometric.data import HeteroData
78
from torch_geometric.data.storage import EdgeStorage
@@ -427,3 +428,37 @@ def test_basic_feature_store():
427428
assert 'x' in data['paper'].__dict__['_mapping']
428429
data.remove_tensor(group_name='paper', attr_name='x', index=None)
429430
assert 'x' not in data['paper'].__dict__['_mapping']
431+
432+
433+
# Graph Store #################################################################
434+
435+
436+
def test_basic_graph_store():
437+
data = HeteroData()
438+
439+
edge_index = torch.LongTensor([[0, 1], [1, 2]])
440+
adj = torch_sparse.SparseTensor(row=edge_index[0], col=edge_index[1])
441+
442+
def assert_equal_tensor_tuple(expected, actual):
443+
assert len(expected) == len(actual)
444+
for i in range(len(expected)):
445+
assert torch.equal(expected[i], actual[i])
446+
447+
# We put all three tensor types: COO, CSR, and CSC, and we get them back
448+
# to confirm that `GraphStore` works as intended.
449+
coo = adj.coo()[:-1]
450+
csr = adj.csr()[:-1]
451+
csc = adj.csc()[:-1]
452+
453+
# Put:
454+
data.put_edge_index(coo, layout='coo', edge_type='1')
455+
data.put_edge_index(csr, layout='csr', edge_type='2')
456+
data.put_edge_index(csc, layout='csc', edge_type='3')
457+
458+
# Get:
459+
assert_equal_tensor_tuple(coo,
460+
data.get_edge_index(layout='coo', edge_type='1'))
461+
assert_equal_tensor_tuple(csr,
462+
data.get_edge_index(layout='csr', edge_type='2'))
463+
assert_equal_tensor_tuple(csc,
464+
data.get_edge_index(layout='csc', edge_type='3'))

torch_geometric/data/data.py

+95-3
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,22 @@
2424
TensorAttr,
2525
_field_status,
2626
)
27+
from torch_geometric.data.graph_store import EdgeAttr, EdgeLayout, GraphStore
2728
from torch_geometric.data.storage import (
2829
BaseStorage,
2930
EdgeStorage,
3031
GlobalStorage,
3132
NodeStorage,
3233
)
3334
from torch_geometric.deprecation import deprecated
34-
from torch_geometric.typing import EdgeType, NodeType, OptTensor
35+
from torch_geometric.typing import (
36+
Adj,
37+
EdgeTensorType,
38+
EdgeType,
39+
FeatureTensorType,
40+
NodeType,
41+
OptTensor,
42+
)
3543
from torch_geometric.utils import subgraph
3644

3745

@@ -316,7 +324,17 @@ def __init__(self, attr_name=_field_status.UNSET,
316324
super().__init__(None, attr_name, index)
317325

318326

319-
class Data(BaseData, FeatureStore):
327+
@dataclass
328+
class DataEdgeAttr(EdgeAttr):
329+
r"""Edge attribute class for `Data`, which does not require a
330+
`edge_type`."""
331+
def __init__(self, layout: EdgeLayout, is_sorted: bool = False,
332+
edge_type: EdgeType = None):
333+
# Treat group_name as optional, and move it to the end
334+
super().__init__(edge_type, layout, is_sorted)
335+
336+
337+
class Data(BaseData, FeatureStore, GraphStore):
320338
r"""A data object describing a homogeneous graph.
321339
The data object can hold node-level, link-level and graph-level attributes.
322340
In general, :class:`~torch_geometric.data.Data` tries to mimic the
@@ -366,7 +384,11 @@ def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
366384
pos: OptTensor = None, **kwargs):
367385
# `Data` doesn't support group_name, so we need to adjust `TensorAttr`
368386
# accordingly here to avoid requiring `group_name` to be set:
369-
super().__init__(attr_cls=DataTensorAttr)
387+
super().__init__(tensor_attr_cls=DataTensorAttr)
388+
389+
# `Data` doesn't support edge_type, so we need to adjust `EdgeAttr`
390+
# accordingly here to avoid requiring `edge_type` to be set:
391+
GraphStore.__init__(self, edge_attr_cls=DataEdgeAttr)
370392

371393
self.__dict__['_store'] = GlobalStorage(_parent=self)
372394

@@ -755,9 +777,79 @@ def _remove_tensor(self, attr: TensorAttr) -> bool:
755777
def __len__(self) -> int:
756778
return BaseData.__len__(self)
757779

780+
# GraphStore interface ####################################################
781+
782+
def _put_edge_index(self, edge_index: EdgeTensorType,
783+
edge_attr: EdgeAttr) -> bool:
784+
# Convert the edge index to a recognizable format:
785+
attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout]
786+
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
787+
setattr(self, attr_name, attr_val)
788+
return True
789+
790+
def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
791+
# Get the requested format and the Adj tensor associated with it:
792+
attr_name = EDGE_LAYOUT_TO_ATTR_NAME[edge_attr.layout]
793+
attr_val = getattr(self._store, attr_name, None)
794+
if attr_val is not None:
795+
# Convert from Adj type to Tuple[Tensor, Tensor]
796+
attr_val = adj_type_to_edge_tensor_type(edge_attr.layout, attr_val)
797+
return attr_val
798+
758799

759800
###############################################################################
760801

802+
EDGE_LAYOUT_TO_ATTR_NAME = {
803+
EdgeLayout.COO: 'edge_index',
804+
EdgeLayout.CSR: 'adj',
805+
EdgeLayout.CSC: 'adj_t',
806+
}
807+
808+
809+
def edge_tensor_type_to_adj_type(
810+
attr: EdgeAttr,
811+
tensor_tuple: EdgeTensorType,
812+
) -> Adj:
813+
r"""Converts an EdgeTensorType tensor tuple to a PyG Adj tensor."""
814+
if attr.layout == EdgeLayout.COO:
815+
# COO: (row, col)
816+
if (tensor_tuple[0].storage().data_ptr() ==
817+
tensor_tuple[1].storage().data_ptr()):
818+
# Do not copy if the tensor tuple is constructed from the same
819+
# storage (instead, return a view):
820+
out = torch.empty(0, dtype=tensor_tuple[0].dtype)
821+
out.set_(tensor_tuple[0].storage(), storage_offset=0,
822+
size=tensor_tuple[0].size() + tensor_tuple[1].size())
823+
return out.view(2, -1)
824+
return torch.stack(tensor_tuple)
825+
elif attr.layout == EdgeLayout.CSR:
826+
# CSR: (rowptr, col)
827+
return SparseTensor(rowptr=tensor_tuple[0], col=tensor_tuple[1],
828+
is_sorted=True)
829+
elif attr.layout == EdgeLayout.CSC:
830+
# CSC: (row, colptr) this is a transposed adjacency matrix, so rowptr
831+
# is the compressed column and col is the uncompressed row.
832+
return SparseTensor(rowptr=tensor_tuple[1], col=tensor_tuple[0],
833+
is_sorted=True)
834+
raise ValueError(f"Bad edge layout (got '{attr.layout}')")
835+
836+
837+
def adj_type_to_edge_tensor_type(layout: EdgeLayout,
838+
edge_index: Adj) -> EdgeTensorType:
839+
r"""Converts a PyG Adj tensor to an EdgeTensorType equivalent."""
840+
if isinstance(edge_index, Tensor):
841+
return (edge_index[0], edge_index[1])
842+
if layout == EdgeLayout.COO:
843+
row, col, _ = edge_index.coo()
844+
return (row, col)
845+
elif layout == EdgeLayout.CSR:
846+
rowptr, col, _ = edge_index.csr()
847+
return (rowptr, col)
848+
else:
849+
# CSC is just adj_t.csr():
850+
colptr, row, _ = edge_index.csr()
851+
return (row, colptr)
852+
761853

762854
def size_repr(key: Any, value: Any, indent: int = 0) -> str:
763855
pad = ' ' * indent

torch_geometric/data/feature_store.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,13 @@ def __repr__(self) -> str:
239239

240240

241241
class FeatureStore(MutableMapping):
242-
def __init__(self, attr_cls: Any = TensorAttr):
242+
def __init__(self, tensor_attr_cls: Any = TensorAttr):
243243
r"""Initializes the feature store. Implementor classes can customize
244244
the ordering and required nature of their :class:`TensorAttr` tensor
245245
attributes by subclassing :class:`TensorAttr` and passing the subclass
246246
as :obj:`attr_cls`."""
247247
super().__init__()
248-
self.__dict__['_attr_cls'] = attr_cls
248+
self.__dict__['_tensor_attr_cls'] = tensor_attr_cls
249249

250250
# Core (CRUD) #############################################################
251251

@@ -270,7 +270,7 @@ def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool:
270270
Returns:
271271
bool: Whether insertion was successful.
272272
"""
273-
attr = self._attr_cls.cast(*args, **kwargs)
273+
attr = self._tensor_attr_cls.cast(*args, **kwargs)
274274
if not attr.is_fully_specified():
275275
raise ValueError(f"The input TensorAttr '{attr}' is not fully "
276276
f"specified. Please fully specify the input by "
@@ -310,7 +310,7 @@ def to_type(tensor: FeatureTensorType) -> FeatureTensorType:
310310
return tensor.numpy()
311311
return tensor
312312

313-
attr = self._attr_cls.cast(*args, **kwargs)
313+
attr = self._tensor_attr_cls.cast(*args, **kwargs)
314314
if isinstance(attr.index, slice):
315315
if attr.index.start == attr.index.stop == attr.index.step is None:
316316
attr.index = None
@@ -341,7 +341,7 @@ def remove_tensor(self, *args, **kwargs) -> bool:
341341
Returns:
342342
bool: Whether deletion was succesful.
343343
"""
344-
attr = self._attr_cls.cast(*args, **kwargs)
344+
attr = self._tensor_attr_cls.cast(*args, **kwargs)
345345
if not attr.is_fully_specified():
346346
raise ValueError(f"The input TensorAttr '{attr}' is not fully "
347347
f"specified. Please fully specify the input by "
@@ -366,7 +366,7 @@ def update_tensor(self, tensor: FeatureTensorType, *args,
366366
Returns:
367367
bool: Whether the update was succesful.
368368
"""
369-
attr = self._attr_cls.cast(*args, **kwargs)
369+
attr = self._tensor_attr_cls.cast(*args, **kwargs)
370370
self.remove_tensor(attr)
371371
return self.put_tensor(tensor, attr)
372372

@@ -375,7 +375,7 @@ def update_tensor(self, tensor: FeatureTensorType, *args,
375375
def view(self, *args, **kwargs) -> AttrView:
376376
r"""Returns an :class:`AttrView` of the feature store, with the defined
377377
attributes set."""
378-
attr = self._attr_cls.cast(*args, **kwargs)
378+
attr = self._tensor_attr_cls.cast(*args, **kwargs)
379379
return AttrView(self, attr)
380380

381381
# Python built-ins ########################################################
@@ -384,7 +384,7 @@ def __setitem__(self, key: TensorAttr, value: FeatureTensorType):
384384
r"""Supports store[tensor_attr] = tensor."""
385385
# CastMixin will handle the case of key being a tuple or TensorAttr
386386
# object:
387-
key = self._attr_cls.cast(key)
387+
key = self._tensor_attr_cls.cast(key)
388388
# We need to fully specify the key for __setitem__ as it does not make
389389
# sense to work with a view here:
390390
key.fully_specify()
@@ -403,7 +403,7 @@ def __getitem__(self, key: TensorAttr) -> Any:
403403
"""
404404
# CastMixin will handle the case of key being a tuple or TensorAttr
405405
# object:
406-
attr = self._attr_cls.cast(key)
406+
attr = self._tensor_attr_cls.cast(key)
407407
if attr.is_fully_specified():
408408
return self.get_tensor(attr)
409409
# If the view is not fully specified, return a :class:`AttrView`:
@@ -413,7 +413,7 @@ def __delitem__(self, key: TensorAttr):
413413
r"""Supports del store[tensor_attr]."""
414414
# CastMixin will handle the case of key being a tuple or TensorAttr
415415
# object:
416-
key = self._attr_cls.cast(key)
416+
key = self._tensor_attr_cls.cast(key)
417417
key.fully_specify()
418418
self.remove_tensor(key)
419419

0 commit comments

Comments
 (0)