Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GraphStore definition + Data and HeteroData integration #4816

Merged
merged 33 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ff7066a
init
mananshah99 Jun 15, 2022
1688803
update
mananshah99 Jun 15, 2022
60181f2
Merge branch 'master' of https://github.com/pyg-team/pytorch_geometri…
mananshah99 Jun 17, 2022
59198b6
fix, todo: failing tests
mananshah99 Jun 17, 2022
82a10e0
fix to_heterodata
mananshah99 Jun 17, 2022
996591f
CHANGELOG
mananshah99 Jun 17, 2022
a2a099f
fix
mananshah99 Jun 17, 2022
cdd1a72
update
mananshah99 Jun 17, 2022
2991fbc
fix
mananshah99 Jun 17, 2022
4038892
init
mananshah99 Jun 17, 2022
cc0b39c
init 2
mananshah99 Jun 17, 2022
3918748
lint
mananshah99 Jun 17, 2022
5824d90
CHANGELOG
mananshah99 Jun 17, 2022
396c72d
fix
mananshah99 Jun 17, 2022
a46f60f
update materialized_graph
mananshah99 Jun 17, 2022
403f03f
lint
mananshah99 Jun 18, 2022
4750cb6
merge
mananshah99 Jun 20, 2022
aa46523
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2022
a3f083e
isort
mananshah99 Jun 20, 2022
04ee394
Merge branch 'feature_store_pt2' of https://github.com/pyg-team/pytor…
mananshah99 Jun 20, 2022
385d8a8
MaterializedGraph -> GraphStore
mananshah99 Jun 20, 2022
5e6badb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2022
aa94d8c
update
mananshah99 Jun 20, 2022
d1e6ee0
Merge branch 'feature_store_pt2' of https://github.com/pyg-team/pytor…
mananshah99 Jun 20, 2022
3cf6d95
update
mananshah99 Jun 20, 2022
d857d5c
Union[Tensor, Tensor]:
mananshah99 Jun 21, 2022
b49728e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2022
653dfd8
updates
mananshah99 Jun 21, 2022
92f82cb
Merge branch 'feature_store_pt2' of https://github.com/pyg-team/pytor…
mananshah99 Jun 21, 2022
5c1d38e
adj, adj_t respect is_sorted
mananshah99 Jun 21, 2022
acf6c4a
comments
mananshah99 Jun 21, 2022
471dc5a
view
mananshah99 Jun 21, 2022
f8b8d9e
update
mananshah99 Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807))
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
Expand Down
34 changes: 34 additions & 0 deletions test/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import pytest
import torch
import torch.multiprocessing as mp
import torch_sparse

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.typing import EdgeTensorType


def test_data():
Expand Down Expand Up @@ -264,3 +266,35 @@ def test_basic_feature_store():
assert 'x' in data.__dict__['_store']
data.remove_tensor(attr_name='x', index=None)
assert 'x' not in data.__dict__['_store']


# Graph Store #################################################################


def test_basic_graph_store():
data = Data()

edge_index = torch.LongTensor([[0, 1], [1, 2]])
adj = torch_sparse.SparseTensor(row=edge_index[0], col=edge_index[1])

def assert_edge_tensor_type_equal(expected: EdgeTensorType,
actual: EdgeTensorType):
assert len(expected) == len(actual)
for i in range(len(expected)):
assert torch.equal(expected[i], actual[i])

# We put all three tensor types: COO, CSR, and CSC, and we get them back
# to confirm that `GraphStore` works as intended.
coo = adj.coo()[:-1]
csr = adj.csr()[:-1]
csc = adj.t().csr()[:-1]

# Put:
data.put_edge_index(coo, layout='coo')
data.put_edge_index(csr, layout='csr')
data.put_edge_index(csc, layout='csc')

# Get:
assert_edge_tensor_type_equal(coo, data.get_edge_index('coo'))
assert_edge_tensor_type_equal(csr, data.get_edge_index('csr'))
assert_edge_tensor_type_equal(csc, data.get_edge_index('csc'))
2 changes: 1 addition & 1 deletion test/data/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, attr_name=_field_status.UNSET,
class MyFeatureStoreNoGroupName(MyFeatureStore):
def __init__(self):
super().__init__()
self._attr_cls = MyTensorAttrNoGroupName
self._tensor_attr_cls = MyTensorAttrNoGroupName

@staticmethod
def key(attr: TensorAttr) -> str:
Expand Down
53 changes: 53 additions & 0 deletions test/data/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.data.graph_store import (
EdgeAttr,
EdgeTensorType,
GraphStore,
)


class MyGraphStore(GraphStore):
def __init__(self):
super().__init__()
self.store = {}

@staticmethod
def key(attr: EdgeAttr) -> str:
return (attr.edge_type or '<default>') + str(attr.layout)

def _put_edge_index(self, edge_index: EdgeTensorType,
edge_attr: EdgeAttr) -> bool:
self.store[MyGraphStore.key(edge_attr)] = edge_index

def _get_edge_index(self, edge_attr: EdgeAttr) -> EdgeTensorType:
return self.store.get(MyGraphStore.key(edge_attr), None)


def test_graph_store():
graph_store = MyGraphStore()
edge_index = torch.LongTensor([[0, 1], [1, 2]])
adj = SparseTensor(row=edge_index[0], col=edge_index[1])

def assert_edge_tensor_type_equal(expected: EdgeTensorType,
actual: EdgeTensorType):
assert len(expected) == len(actual)
for i in range(len(expected)):
assert torch.equal(expected[i], actual[i])

# We put all three tensor types: COO, CSR, and CSC, and we get them back
# to confirm that `GraphStore` works as intended.
coo = adj.coo()[:-1]
csr = adj.csr()[:-1]
csc = adj.t().csr()[:-1]

# Put:
graph_store['edge', 'coo'] = coo
graph_store['edge', 'csr'] = csr
graph_store['edge', 'csc'] = csc

# Get:
assert_edge_tensor_type_equal(coo, graph_store['edge', 'coo'])
assert_edge_tensor_type_equal(csr, graph_store['edge', 'csr'])
assert_edge_tensor_type_equal(csc, graph_store['edge', 'csc'])
37 changes: 37 additions & 0 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import pytest
import torch
import torch_sparse

from torch_geometric.data import HeteroData
from torch_geometric.data.storage import EdgeStorage
from torch_geometric.typing import EdgeTensorType

x_paper = torch.randn(10, 16)
x_author = torch.randn(5, 32)
Expand Down Expand Up @@ -427,3 +429,38 @@ def test_basic_feature_store():
assert 'x' in data['paper'].__dict__['_mapping']
data.remove_tensor(group_name='paper', attr_name='x', index=None)
assert 'x' not in data['paper'].__dict__['_mapping']


# Graph Store #################################################################


def test_basic_graph_store():
data = HeteroData()

edge_index = torch.LongTensor([[0, 1], [1, 2]])
adj = torch_sparse.SparseTensor(row=edge_index[0], col=edge_index[1])

def assert_edge_tensor_type_equal(expected: EdgeTensorType,
actual: EdgeTensorType):
assert len(expected) == len(actual)
for i in range(len(expected)):
assert torch.equal(expected[i], actual[i])

# We put all three tensor types: COO, CSR, and CSC, and we get them back
# to confirm that `GraphStore` works as intended.
coo = adj.coo()[:-1]
csr = adj.csr()[:-1]
csc = adj.t().csr()[:-1]

# Put:
data.put_edge_index(coo, layout='coo', edge_type='1')
data.put_edge_index(csr, layout='csr', edge_type='2')
data.put_edge_index(csc, layout='csc', edge_type='3')

# Get:
assert_edge_tensor_type_equal(
coo, data.get_edge_index(layout='coo', edge_type='1'))
assert_edge_tensor_type_equal(
csr, data.get_edge_index(layout='csr', edge_type='2'))
assert_edge_tensor_type_equal(
csc, data.get_edge_index(layout='csc', edge_type='3'))
93 changes: 90 additions & 3 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@
TensorAttr,
_field_status,
)
from torch_geometric.data.graph_store import EdgeAttr, EdgeLayout, GraphStore
from torch_geometric.data.storage import (
BaseStorage,
EdgeStorage,
GlobalStorage,
NodeStorage,
)
from torch_geometric.deprecation import deprecated
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.typing import (
Adj,
EdgeTensorType,
EdgeType,
FeatureTensorType,
NodeType,
OptTensor,
)
from torch_geometric.utils import subgraph


Expand Down Expand Up @@ -316,7 +324,17 @@ def __init__(self, attr_name=_field_status.UNSET,
super().__init__(None, attr_name, index)


class Data(BaseData, FeatureStore):
@dataclass
class DataEdgeAttr(EdgeAttr):
r"""Edge attribute class for `Data`, which does not require a
`edge_type`."""
def __init__(self, layout: EdgeLayout, is_sorted: bool = False,
edge_type: EdgeType = None):
# Treat group_name as optional, and move it to the end
super().__init__(edge_type, layout, is_sorted)


class Data(BaseData, FeatureStore, GraphStore):
r"""A data object describing a homogeneous graph.
The data object can hold node-level, link-level and graph-level attributes.
In general, :class:`~torch_geometric.data.Data` tries to mimic the
Expand Down Expand Up @@ -366,7 +384,11 @@ def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
pos: OptTensor = None, **kwargs):
# `Data` doesn't support group_name, so we need to adjust `TensorAttr`
# accordingly here to avoid requiring `group_name` to be set:
super().__init__(attr_cls=DataTensorAttr)
super().__init__(tensor_attr_cls=DataTensorAttr)

# `Data` doesn't support edge_type, so we need to adjust `EdgeAttr`
# accordingly here to avoid requiring `edge_type` to be set:
GraphStore.__init__(self, edge_attr_cls=DataEdgeAttr)

self.__dict__['_store'] = GlobalStorage(_parent=self)

Expand Down Expand Up @@ -755,10 +777,75 @@ def _remove_tensor(self, attr: TensorAttr) -> bool:
def __len__(self) -> int:
return BaseData.__len__(self)

# GraphStore interface ####################################################

def _put_edge_index(self, edge_index: EdgeTensorType,
edge_attr: EdgeAttr) -> bool:
# Convert the edge index to a recognizable format:
attr_name = edge_layout_to_attr_name(edge_attr.layout)
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
setattr(self, attr_name, attr_val)
return True

def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
# Get the requested format and the Adj tensor associated with it:
attr_name = edge_layout_to_attr_name(edge_attr.layout)
attr_val = getattr(self._store, attr_name, None)
if attr_val is not None:
# Convert from Adj type to Tuple[Tensor, Tensor]
attr_val = adj_type_to_edge_tensor_type(edge_attr.layout, attr_val)
return attr_val


###############################################################################


def edge_layout_to_attr_name(layout: EdgeLayout) -> str:
r"""Maps `EdgeLayout`s to their corresponding PyG key names."""
return {
EdgeLayout.COO: 'edge_index',
EdgeLayout.CSR: 'adj',
EdgeLayout.CSC: 'adj_t',
}[layout]


def edge_tensor_type_to_adj_type(
attr: EdgeAttr,
tensor_tuple: EdgeTensorType,
) -> Adj:
r"""Converts an EdgeTensorType tensor tuple to a PyG Adj tensor."""
if attr.layout == EdgeLayout.COO:
# COO: 2 x n
return torch.stack(tensor_tuple)
elif attr.layout == EdgeLayout.CSR:
# CSR: (rowptr, col)
return SparseTensor(rowptr=tensor_tuple[0], col=tensor_tuple[1],
is_sorted=attr.is_sorted)
elif attr.layout == EdgeLayout.CSC:
# CSC: (colptr, row) this is a transposed adjacency matrix, so rowptr
# is the compressed column and col is the uncompressed row.
return SparseTensor(rowptr=tensor_tuple[0], col=tensor_tuple[1],
is_sorted=attr.is_sorted)
raise ValueError(f"Bad layout: got {attr.layout}")


def adj_type_to_edge_tensor_type(layout: EdgeLayout,
edge_index: Adj) -> EdgeTensorType:
r"""Converts a PyG Adj tensor to an EdgeTensorType equivalent."""
if isinstance(edge_index, Tensor):
return (edge_index[0], edge_index[1])
if layout == EdgeLayout.COO:
row, col, _ = edge_index.coo()
return (row, col)
elif layout == EdgeLayout.CSR:
rowptr, col, _ = edge_index.csr()
return (rowptr, col)
else:
# CSC is just adj_t.csr()
colptr, row, _ = edge_index.csr()
return (colptr, row)


def size_repr(key: Any, value: Any, indent: int = 0) -> str:
pad = ' ' * indent
if isinstance(value, Tensor) and value.dim() == 0:
Expand Down
20 changes: 10 additions & 10 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ def __repr__(self) -> str:


class FeatureStore(MutableMapping):
def __init__(self, attr_cls: Any = TensorAttr):
def __init__(self, tensor_attr_cls: Any = TensorAttr):
r"""Initializes the feature store. Implementor classes can customize
the ordering and required nature of their :class:`TensorAttr` tensor
attributes by subclassing :class:`TensorAttr` and passing the subclass
as :obj:`attr_cls`."""
super().__init__()
self.__dict__['_attr_cls'] = attr_cls
self.__dict__['_tensor_attr_cls'] = tensor_attr_cls

# Core (CRUD) #############################################################

Expand All @@ -270,7 +270,7 @@ def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool:
Returns:
bool: Whether insertion was successful.
"""
attr = self._attr_cls.cast(*args, **kwargs)
attr = self._tensor_attr_cls.cast(*args, **kwargs)
if not attr.is_fully_specified():
raise ValueError(f"The input TensorAttr '{attr}' is not fully "
f"specified. Please fully specify the input by "
Expand Down Expand Up @@ -310,7 +310,7 @@ def to_type(tensor: FeatureTensorType) -> FeatureTensorType:
return tensor.numpy()
return tensor

attr = self._attr_cls.cast(*args, **kwargs)
attr = self._tensor_attr_cls.cast(*args, **kwargs)
if isinstance(attr.index, slice):
if attr.index.start == attr.index.stop == attr.index.step is None:
attr.index = None
Expand Down Expand Up @@ -341,7 +341,7 @@ def remove_tensor(self, *args, **kwargs) -> bool:
Returns:
bool: Whether deletion was succesful.
"""
attr = self._attr_cls.cast(*args, **kwargs)
attr = self._tensor_attr_cls.cast(*args, **kwargs)
if not attr.is_fully_specified():
raise ValueError(f"The input TensorAttr '{attr}' is not fully "
f"specified. Please fully specify the input by "
Expand All @@ -366,7 +366,7 @@ def update_tensor(self, tensor: FeatureTensorType, *args,
Returns:
bool: Whether the update was succesful.
"""
attr = self._attr_cls.cast(*args, **kwargs)
attr = self._tensor_attr_cls.cast(*args, **kwargs)
self.remove_tensor(attr)
return self.put_tensor(tensor, attr)

Expand All @@ -375,7 +375,7 @@ def update_tensor(self, tensor: FeatureTensorType, *args,
def view(self, *args, **kwargs) -> AttrView:
r"""Returns an :class:`AttrView` of the feature store, with the defined
attributes set."""
attr = self._attr_cls.cast(*args, **kwargs)
attr = self._tensor_attr_cls.cast(*args, **kwargs)
return AttrView(self, attr)

# Python built-ins ########################################################
Expand All @@ -384,7 +384,7 @@ def __setitem__(self, key: TensorAttr, value: FeatureTensorType):
r"""Supports store[tensor_attr] = tensor."""
# CastMixin will handle the case of key being a tuple or TensorAttr
# object:
key = self._attr_cls.cast(key)
key = self._tensor_attr_cls.cast(key)
# We need to fully specify the key for __setitem__ as it does not make
# sense to work with a view here:
key.fully_specify()
Expand All @@ -403,7 +403,7 @@ def __getitem__(self, key: TensorAttr) -> Any:
"""
# CastMixin will handle the case of key being a tuple or TensorAttr
# object:
attr = self._attr_cls.cast(key)
attr = self._tensor_attr_cls.cast(key)
if attr.is_fully_specified():
return self.get_tensor(attr)
# If the view is not fully specified, return a :class:`AttrView`:
Expand All @@ -413,7 +413,7 @@ def __delitem__(self, key: TensorAttr):
r"""Supports del store[tensor_attr]."""
# CastMixin will handle the case of key being a tuple or TensorAttr
# object:
key = self._attr_cls.cast(key)
key = self._tensor_attr_cls.cast(key)
key.fully_specify()
self.remove_tensor(key)

Expand Down
Loading