Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GraphStore definition + Data and HeteroData integration #4816

Merged
merged 33 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ff7066a
init
mananshah99 Jun 15, 2022
1688803
update
mananshah99 Jun 15, 2022
60181f2
Merge branch 'master' of https://github.com/pyg-team/pytorch_geometri…
mananshah99 Jun 17, 2022
59198b6
fix, todo: failing tests
mananshah99 Jun 17, 2022
82a10e0
fix to_heterodata
mananshah99 Jun 17, 2022
996591f
CHANGELOG
mananshah99 Jun 17, 2022
a2a099f
fix
mananshah99 Jun 17, 2022
cdd1a72
update
mananshah99 Jun 17, 2022
2991fbc
fix
mananshah99 Jun 17, 2022
4038892
init
mananshah99 Jun 17, 2022
cc0b39c
init 2
mananshah99 Jun 17, 2022
3918748
lint
mananshah99 Jun 17, 2022
5824d90
CHANGELOG
mananshah99 Jun 17, 2022
396c72d
fix
mananshah99 Jun 17, 2022
a46f60f
update materialized_graph
mananshah99 Jun 17, 2022
403f03f
lint
mananshah99 Jun 18, 2022
4750cb6
merge
mananshah99 Jun 20, 2022
aa46523
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2022
a3f083e
isort
mananshah99 Jun 20, 2022
04ee394
Merge branch 'feature_store_pt2' of https://github.com/pyg-team/pytor…
mananshah99 Jun 20, 2022
385d8a8
MaterializedGraph -> GraphStore
mananshah99 Jun 20, 2022
5e6badb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2022
aa94d8c
update
mananshah99 Jun 20, 2022
d1e6ee0
Merge branch 'feature_store_pt2' of https://github.com/pyg-team/pytor…
mananshah99 Jun 20, 2022
3cf6d95
update
mananshah99 Jun 20, 2022
d857d5c
Union[Tensor, Tensor]:
mananshah99 Jun 21, 2022
b49728e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2022
653dfd8
updates
mananshah99 Jun 21, 2022
92f82cb
Merge branch 'feature_store_pt2' of https://github.com/pyg-team/pytor…
mananshah99 Jun 21, 2022
5c1d38e
adj, adj_t respect is_sorted
mananshah99 Jun 21, 2022
acf6c4a
comments
mananshah99 Jun 21, 2022
471dc5a
view
mananshah99 Jun 21, 2022
f8b8d9e
update
mananshah99 Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `MaterializedGraph` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807))
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))
- Added a `max_sample` argument to `AddMetaPaths` in order to tackle very dense metapath edges ([#4750](https://github.com/pyg-team/pytorch_geometric/pull/4750))
- Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756))
Expand Down
49 changes: 49 additions & 0 deletions test/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import pytest
import torch
import torch.multiprocessing as mp
import torch_sparse

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data.materialized_graph import EdgeLayout


def test_data():
Expand Down Expand Up @@ -239,3 +241,50 @@ def my_attr1(self, value):
data.my_attr1 = 2
assert 'my_attr1' not in data._store
assert data.my_attr1 == 2


# Feature Store ###############################################################


def test_basic_feature_store():
data = Data()
x = torch.randn(20, 20)

# Put tensor:
assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None)
assert torch.equal(data.x, x)

# Put (modify) tensor slice:
x[15:] = 0
data.put_tensor(0, attr_name='x', index=slice(15, None, None))

# Get tensor:
out = data.get_tensor(attr_name='x', index=None)
assert torch.equal(x, out)

# Remove tensor:
assert 'x' in data.__dict__['_store']
data.remove_tensor(attr_name='x', index=None)
assert 'x' not in data.__dict__['_store']


# Materialized Graph ##########################################################


def test_basic_materialized_graph():
data = Data()
coo = torch.LongTensor([[1, 2, 3], [2, 3, 1]])
csr = torch_sparse.SparseTensor.from_edge_index(coo)
csc = csr.t()

# COO:
data.put_edge_index(coo, layout=EdgeLayout.COO)
assert torch.equal(data.get_edge_index(layout=EdgeLayout.COO), coo)

# CSR:
data.put_edge_index(csr, layout=EdgeLayout.CSR)
assert data.get_edge_index(layout=EdgeLayout.CSR) == csr

# CSC:
data.put_edge_index(csc, layout=EdgeLayout.CSC)
assert data.get_edge_index(layout=EdgeLayout.CSC) == csc
2 changes: 1 addition & 1 deletion test/data/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, attr_name=_field_status.UNSET,
class MyFeatureStoreNoGroupName(MyFeatureStore):
def __init__(self):
super().__init__()
self._attr_cls = MyTensorAttrNoGroupName
self._tensor_attr_cls = MyTensorAttrNoGroupName

@staticmethod
def key(attr: TensorAttr) -> str:
Expand Down
55 changes: 55 additions & 0 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import pytest
import torch
import torch_sparse

from torch_geometric.data import HeteroData
from torch_geometric.data.materialized_graph import EdgeLayout
from torch_geometric.data.storage import EdgeStorage

x_paper = torch.randn(10, 16)
Expand Down Expand Up @@ -400,3 +402,56 @@ def test_hetero_data_to_canonical():

with pytest.raises(TypeError, match="missing 1 required"):
data['user', 'product']


# Feature Store ###############################################################


def test_basic_feature_store():
data = HeteroData()
x = torch.randn(20, 20)

# Put tensor:
assert data.put_tensor(copy.deepcopy(x), group_name='paper', attr_name='x',
index=None)
assert torch.equal(data['paper'].x, x)

# Put (modify) tensor slice:
x[15:] = 0
data.put_tensor(0, group_name='paper', attr_name='x',
index=slice(15, None, None))

# Get tensor:
out = data.get_tensor(group_name='paper', attr_name='x', index=None)
assert torch.equal(x, out)

# Remove tensor:
assert 'x' in data['paper'].__dict__['_mapping']
data.remove_tensor(group_name='paper', attr_name='x', index=None)
assert 'x' not in data['paper'].__dict__['_mapping']


# Materialized Graph ##########################################################


def test_basic_materialized_graph():
data = HeteroData()
coo = torch.LongTensor([[1, 2, 3], [2, 3, 1]])
csr = torch_sparse.SparseTensor.from_edge_index(coo)
csc = csr.t()

# COO:
data.put_edge_index(coo, layout=EdgeLayout.COO, edge_type=('a', 'to', 'b'))
assert torch.equal(
data.get_edge_index(edge_type=('a', 'to', 'b'), layout=EdgeLayout.COO),
coo)

# CSR:
data.put_edge_index(csr, layout=EdgeLayout.CSR, edge_type=('a', 'to', 'c'))
assert data.get_edge_index(edge_type=('a', 'to', 'c'),
layout=EdgeLayout.CSR) == csr

# CSC:
data.put_edge_index(csc, layout=EdgeLayout.CSC, edge_type=('a', 'to', 'd'))
assert data.get_edge_index(edge_type=('a', 'to', 'd'),
layout=EdgeLayout.CSC) == csc
33 changes: 33 additions & 0 deletions test/data/test_materialized_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch

from torch_geometric.data.materialized_graph import (
EdgeAttr,
EdgeLayout,
EdgeTensorType,
MaterializedGraph,
)


class MyMaterializedGraph(MaterializedGraph):
def __init__(self):
super().__init__()
self.store = {}

@staticmethod
def key(attr: EdgeAttr) -> str:
return attr.edge_type or '<default>'

def _put_edge_index(self, edge_index: EdgeTensorType,
edge_attr: EdgeAttr) -> bool:
self.store[MyMaterializedGraph.key(edge_attr)] = edge_index

def _get_edge_index(self, edge_attr: EdgeAttr) -> EdgeTensorType:
return self.store.get(MyMaterializedGraph.key(edge_attr), None)


def test_materialized_graph():
m = MyMaterializedGraph()
edge_index = torch.LongTensor([[0, 1], [1, 2]])
m.put_edge_index(edge_index, layout=EdgeLayout.COO, edge_type='a')
assert torch.equal(m.get_edge_index(edge_type='a', layout=EdgeLayout.COO),
edge_index)
6 changes: 5 additions & 1 deletion torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ def __call__(cls, *args, **kwargs):
new_cls = base_cls
else:
name = f'{base_cls.__name__}{cls.__name__}'

class MetaResolver(type(cls), type(base_cls)):
pass

if name not in globals():
globals()[name] = type(name, (cls, base_cls), {})
globals()[name] = MetaResolver(name, (cls, base_cls), {})
new_cls = globals()[name]

params = list(inspect.signature(base_cls.__init__).parameters.items())
Expand Down
95 changes: 92 additions & 3 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import (
Any,
Callable,
Expand All @@ -17,14 +18,30 @@
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.data.feature_store import (
FeatureStore,
TensorAttr,
_field_status,
)
from torch_geometric.data.materialized_graph import (
EdgeAttr,
EdgeLayout,
MaterializedGraph,
)
from torch_geometric.data.storage import (
BaseStorage,
EdgeStorage,
GlobalStorage,
NodeStorage,
)
from torch_geometric.deprecation import deprecated
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.typing import (
EdgeTensorType,
EdgeType,
FeatureTensorType,
NodeType,
OptTensor,
)
from torch_geometric.utils import subgraph


Expand Down Expand Up @@ -300,7 +317,16 @@ def contains_self_loops(self) -> bool:
###############################################################################


class Data(BaseData):
@dataclass
class DataTensorAttr(TensorAttr):
r"""Attribute class for `Data`, which does not require a `group_name`."""
def __init__(self, attr_name=_field_status.UNSET,
index=_field_status.UNSET):
# Treat group_name as optional, and move it to the end
super().__init__(None, attr_name, index)


class Data(BaseData, FeatureStore, MaterializedGraph):
r"""A data object describing a homogeneous graph.
The data object can hold node-level, link-level and graph-level attributes.
In general, :class:`~torch_geometric.data.Data` tries to mimic the
Expand Down Expand Up @@ -348,7 +374,10 @@ class Data(BaseData):
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
edge_attr: OptTensor = None, y: OptTensor = None,
pos: OptTensor = None, **kwargs):
super().__init__()
# `Data` doesn't support group_name, so we need to adjust `TensorAttr`
# accordingly here to avoid requiring `group_name` to be set:
super(Data, self).__init__(tensor_attr_cls=DataTensorAttr)

self.__dict__['_store'] = GlobalStorage(_parent=self)

if x is not None:
Expand Down Expand Up @@ -384,6 +413,9 @@ def __setattr__(self, key: str, value: Any):
def __delattr__(self, key: str):
delattr(self._store, key)

# TODO consider supporting the feature store interface for
# __getitem__, __setitem__, and __delitem__ so, for example, we
# can accept key: Union[str, TensorAttr] in __getitem__.
def __getitem__(self, key: str) -> Any:
return self._store[key]

Expand Down Expand Up @@ -692,6 +724,63 @@ def num_faces(self) -> Optional[int]:
return self.face.size(self.__cat_dim__('face', self.face))
return None

# FeatureStore interface ##################################################

def items(self):
return self._store.items()

def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
r"""Stores a feature tensor in node storage."""
out = getattr(self, attr.attr_name, None)
if out is not None and attr.index is not None:
# Attr name exists, handle index:
out[attr.index] = tensor
else:
# No attr name (or None index), just store tensor:
setattr(self, attr.attr_name, tensor)
return True

def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
r"""Obtains a feature tensor from node storage."""
# Retrieve tensor and index accordingly:
tensor = getattr(self, attr.attr_name, None)
if tensor is not None:
# TODO this behavior is a bit odd, since TensorAttr requires that
# we set `index`. So, we assume here that indexing by `None` is
# equivalent to not indexing at all, which is not in line with
# Python semantics.
return tensor[attr.index] if attr.index is not None else tensor
return None

def _remove_tensor(self, attr: TensorAttr) -> bool:
r"""Deletes a feature tensor from node storage."""
# Remove tensor entirely:
if hasattr(self, attr.attr_name):
delattr(self, attr.attr_name)
return True
return False

def __len__(self) -> int:
return BaseData.__len__(self)

# MaterializedGraph interface #############################################

def _layout_to_attr_name(self, layout: EdgeLayout) -> str:
return {
EdgeLayout.COO: 'edge_index',
EdgeLayout.CSR: 'adj',
EdgeLayout.CSC: 'adj_t',
}[layout]

def _put_edge_index(self, edge_index: EdgeTensorType,
edge_attr: EdgeAttr) -> bool:
setattr(self, self._layout_to_attr_name(edge_attr.layout), edge_index)
return True

def _get_edge_index(self, edge_attr: EdgeAttr) -> EdgeTensorType:
return getattr(self._store,
self._layout_to_attr_name(edge_attr.layout))


###############################################################################

Expand Down
Loading