Skip to content

Commit 4b30b6d

Browse files
authored
Let Data and HeteroData implement FeatureStore (#4807)
1 parent c13d62c commit 4b30b6d

File tree

7 files changed

+182
-5
lines changed

7 files changed

+182
-5
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807))
89
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
910
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
1011
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))

test/data/test_data.py

+25
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,28 @@ def my_attr1(self, value):
239239
data.my_attr1 = 2
240240
assert 'my_attr1' not in data._store
241241
assert data.my_attr1 == 2
242+
243+
244+
# Feature Store ###############################################################
245+
246+
247+
def test_basic_feature_store():
248+
data = Data()
249+
x = torch.randn(20, 20)
250+
251+
# Put tensor:
252+
assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None)
253+
assert torch.equal(data.x, x)
254+
255+
# Put (modify) tensor slice:
256+
x[15:] = 0
257+
data.put_tensor(0, attr_name='x', index=slice(15, None, None))
258+
259+
# Get tensor:
260+
out = data.get_tensor(attr_name='x', index=None)
261+
assert torch.equal(x, out)
262+
263+
# Remove tensor:
264+
assert 'x' in data.__dict__['_store']
265+
data.remove_tensor(attr_name='x', index=None)
266+
assert 'x' not in data.__dict__['_store']

test/data/test_hetero_data.py

+27
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,30 @@ def test_hetero_data_to_canonical():
400400

401401
with pytest.raises(TypeError, match="missing 1 required"):
402402
data['user', 'product']
403+
404+
405+
# Feature Store ###############################################################
406+
407+
408+
def test_basic_feature_store():
409+
data = HeteroData()
410+
x = torch.randn(20, 20)
411+
412+
# Put tensor:
413+
assert data.put_tensor(copy.deepcopy(x), group_name='paper', attr_name='x',
414+
index=None)
415+
assert torch.equal(data['paper'].x, x)
416+
417+
# Put (modify) tensor slice:
418+
x[15:] = 0
419+
data.put_tensor(0, group_name='paper', attr_name='x',
420+
index=slice(15, None, None))
421+
422+
# Get tensor:
423+
out = data.get_tensor(group_name='paper', attr_name='x', index=None)
424+
assert torch.equal(x, out)
425+
426+
# Remove tensor:
427+
assert 'x' in data['paper'].__dict__['_mapping']
428+
data.remove_tensor(group_name='paper', attr_name='x', index=None)
429+
assert 'x' not in data['paper'].__dict__['_mapping']

torch_geometric/data/batch.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,16 @@ def __call__(cls, *args, **kwargs):
2323
new_cls = base_cls
2424
else:
2525
name = f'{base_cls.__name__}{cls.__name__}'
26+
27+
# NOTE `MetaResolver` is necessary to resolve metaclass conflict
28+
# problems between `DynamicInheritance` and the metaclass of
29+
# `base_cls`. In particular, it creates a new common metaclass
30+
# from the defined metaclasses.
31+
class MetaResolver(type(cls), type(base_cls)):
32+
pass
33+
2634
if name not in globals():
27-
globals()[name] = type(name, (cls, base_cls), {})
35+
globals()[name] = MetaResolver(name, (cls, base_cls), {})
2836
new_cls = globals()[name]
2937

3038
params = list(inspect.signature(base_cls.__init__).parameters.items())

torch_geometric/data/data.py

+65-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
from collections.abc import Mapping, Sequence
3+
from dataclasses import dataclass
34
from typing import (
45
Any,
56
Callable,
@@ -17,6 +18,12 @@
1718
from torch import Tensor
1819
from torch_sparse import SparseTensor
1920

21+
from torch_geometric.data.feature_store import (
22+
FeatureStore,
23+
FeatureTensorType,
24+
TensorAttr,
25+
_field_status,
26+
)
2027
from torch_geometric.data.storage import (
2128
BaseStorage,
2229
EdgeStorage,
@@ -300,7 +307,16 @@ def contains_self_loops(self) -> bool:
300307
###############################################################################
301308

302309

303-
class Data(BaseData):
310+
@dataclass
311+
class DataTensorAttr(TensorAttr):
312+
r"""Attribute class for `Data`, which does not require a `group_name`."""
313+
def __init__(self, attr_name=_field_status.UNSET,
314+
index=_field_status.UNSET):
315+
# Treat group_name as optional, and move it to the end
316+
super().__init__(None, attr_name, index)
317+
318+
319+
class Data(BaseData, FeatureStore):
304320
r"""A data object describing a homogeneous graph.
305321
The data object can hold node-level, link-level and graph-level attributes.
306322
In general, :class:`~torch_geometric.data.Data` tries to mimic the
@@ -348,7 +364,10 @@ class Data(BaseData):
348364
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
349365
edge_attr: OptTensor = None, y: OptTensor = None,
350366
pos: OptTensor = None, **kwargs):
351-
super().__init__()
367+
# `Data` doesn't support group_name, so we need to adjust `TensorAttr`
368+
# accordingly here to avoid requiring `group_name` to be set:
369+
super().__init__(attr_cls=DataTensorAttr)
370+
352371
self.__dict__['_store'] = GlobalStorage(_parent=self)
353372

354373
if x is not None:
@@ -384,6 +403,9 @@ def __setattr__(self, key: str, value: Any):
384403
def __delattr__(self, key: str):
385404
delattr(self._store, key)
386405

406+
# TODO consider supporting the feature store interface for
407+
# __getitem__, __setitem__, and __delitem__ so, for example, we
408+
# can accept key: Union[str, TensorAttr] in __getitem__.
387409
def __getitem__(self, key: str) -> Any:
388410
return self._store[key]
389411

@@ -692,6 +714,47 @@ def num_faces(self) -> Optional[int]:
692714
return self.face.size(self.__cat_dim__('face', self.face))
693715
return None
694716

717+
# FeatureStore interface ###########################################
718+
719+
def items(self):
720+
r"""Returns an `ItemsView` over the stored attributes in the `Data`
721+
object."""
722+
return self._store.items()
723+
724+
def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
725+
r"""Stores a feature tensor in node storage."""
726+
out = getattr(self, attr.attr_name, None)
727+
if out is not None and attr.index is not None:
728+
# Attr name exists, handle index:
729+
out[attr.index] = tensor
730+
else:
731+
# No attr name (or None index), just store tensor:
732+
setattr(self, attr.attr_name, tensor)
733+
return True
734+
735+
def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
736+
r"""Obtains a feature tensor from node storage."""
737+
# Retrieve tensor and index accordingly:
738+
tensor = getattr(self, attr.attr_name, None)
739+
if tensor is not None:
740+
# TODO this behavior is a bit odd, since TensorAttr requires that
741+
# we set `index`. So, we assume here that indexing by `None` is
742+
# equivalent to not indexing at all, which is not in line with
743+
# Python semantics.
744+
return tensor[attr.index] if attr.index is not None else tensor
745+
return None
746+
747+
def _remove_tensor(self, attr: TensorAttr) -> bool:
748+
r"""Deletes a feature tensor from node storage."""
749+
# Remove tensor entirely:
750+
if hasattr(self, attr.attr_name):
751+
delattr(self, attr.attr_name)
752+
return True
753+
return False
754+
755+
def __len__(self) -> int:
756+
return BaseData.__len__(self)
757+
695758

696759
###############################################################################
697760

torch_geometric/data/feature_store.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __init__(self, attr_cls: Any = TensorAttr):
245245
attributes by subclassing :class:`TensorAttr` and passing the subclass
246246
as :obj:`attr_cls`."""
247247
super().__init__()
248-
self._attr_cls = attr_cls
248+
self.__dict__['_attr_cls'] = attr_cls
249249

250250
# Core (CRUD) #############################################################
251251

torch_geometric/data/hetero_data.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
from torch_sparse import SparseTensor
1111

1212
from torch_geometric.data.data import BaseData, Data, size_repr
13+
from torch_geometric.data.feature_store import (
14+
FeatureStore,
15+
FeatureTensorType,
16+
TensorAttr,
17+
)
1318
from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage
1419
from torch_geometric.typing import EdgeType, NodeType, QueryType
1520
from torch_geometric.utils import bipartite_subgraph, is_undirected
@@ -18,7 +23,7 @@
1823
NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]
1924

2025

21-
class HeteroData(BaseData):
26+
class HeteroData(BaseData, FeatureStore):
2227
r"""A data object describing a heterogeneous graph, holding multiple node
2328
and/or edge types in disjunct storage objects.
2429
Storage objects can hold either node-level, link-level or graph-level
@@ -92,6 +97,8 @@ class HeteroData(BaseData):
9297
DEFAULT_REL = 'to'
9398

9499
def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs):
100+
super().__init__()
101+
95102
self.__dict__['_global_store'] = BaseStorage(_parent=self)
96103
self.__dict__['_node_store_dict'] = {}
97104
self.__dict__['_edge_store_dict'] = {}
@@ -616,6 +623,52 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:
616623

617624
return data
618625

626+
# :obj:`FeatureStore` interface ###########################################
627+
628+
def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
629+
r"""Stores a feature tensor in node storage."""
630+
if not attr.is_set('index'):
631+
attr.index = None
632+
633+
out = self._node_store_dict.get(attr.group_name, None)
634+
if out:
635+
# Group name exists, handle index or create new attribute name:
636+
val = getattr(out, attr.attr_name)
637+
if val is not None:
638+
val[attr.index] = tensor
639+
else:
640+
setattr(self[attr.group_name], attr.attr_name, tensor)
641+
else:
642+
# No node storage found, just store tensor in new one:
643+
setattr(self[attr.group_name], attr.attr_name, tensor)
644+
return True
645+
646+
def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
647+
r"""Obtains a feature tensor from node storage."""
648+
# Retrieve tensor and index accordingly:
649+
tensor = getattr(self[attr.group_name], attr.attr_name, None)
650+
if tensor is not None:
651+
# TODO this behavior is a bit odd, since TensorAttr requires that
652+
# we set `index`. So, we assume here that indexing by `None` is
653+
# equivalent to not indexing at all, which is not in line with
654+
# Python semantics.
655+
return tensor[attr.index] if attr.index is not None else tensor
656+
return None
657+
658+
def _remove_tensor(self, attr: TensorAttr) -> bool:
659+
r"""Deletes a feature tensor from node storage."""
660+
# Remove tensor entirely:
661+
if hasattr(self[attr.group_name], attr.attr_name):
662+
delattr(self[attr.group_name], attr.attr_name)
663+
return True
664+
return False
665+
666+
def __len__(self) -> int:
667+
return BaseData.__len__(self)
668+
669+
def __iter__(self):
670+
raise NotImplementedError
671+
619672

620673
# Helper functions ############################################################
621674

0 commit comments

Comments
 (0)