|
1 | 1 | import copy
|
2 | 2 | from collections.abc import Mapping, Sequence
|
| 3 | +from dataclasses import dataclass |
3 | 4 | from typing import (
|
4 | 5 | Any,
|
5 | 6 | Callable,
|
|
17 | 18 | from torch import Tensor
|
18 | 19 | from torch_sparse import SparseTensor
|
19 | 20 |
|
| 21 | +from torch_geometric.data.feature_store import ( |
| 22 | + FeatureStore, |
| 23 | + FeatureTensorType, |
| 24 | + TensorAttr, |
| 25 | + _field_status, |
| 26 | +) |
20 | 27 | from torch_geometric.data.storage import (
|
21 | 28 | BaseStorage,
|
22 | 29 | EdgeStorage,
|
@@ -300,7 +307,16 @@ def contains_self_loops(self) -> bool:
|
300 | 307 | ###############################################################################
|
301 | 308 |
|
302 | 309 |
|
303 |
| -class Data(BaseData): |
| 310 | +@dataclass |
| 311 | +class DataTensorAttr(TensorAttr): |
| 312 | + r"""Attribute class for `Data`, which does not require a `group_name`.""" |
| 313 | + def __init__(self, attr_name=_field_status.UNSET, |
| 314 | + index=_field_status.UNSET): |
| 315 | + # Treat group_name as optional, and move it to the end |
| 316 | + super().__init__(None, attr_name, index) |
| 317 | + |
| 318 | + |
| 319 | +class Data(BaseData, FeatureStore): |
304 | 320 | r"""A data object describing a homogeneous graph.
|
305 | 321 | The data object can hold node-level, link-level and graph-level attributes.
|
306 | 322 | In general, :class:`~torch_geometric.data.Data` tries to mimic the
|
@@ -348,7 +364,10 @@ class Data(BaseData):
|
348 | 364 | def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
|
349 | 365 | edge_attr: OptTensor = None, y: OptTensor = None,
|
350 | 366 | pos: OptTensor = None, **kwargs):
|
351 |
| - super().__init__() |
| 367 | + # `Data` doesn't support group_name, so we need to adjust `TensorAttr` |
| 368 | + # accordingly here to avoid requiring `group_name` to be set: |
| 369 | + super().__init__(attr_cls=DataTensorAttr) |
| 370 | + |
352 | 371 | self.__dict__['_store'] = GlobalStorage(_parent=self)
|
353 | 372 |
|
354 | 373 | if x is not None:
|
@@ -384,6 +403,9 @@ def __setattr__(self, key: str, value: Any):
|
384 | 403 | def __delattr__(self, key: str):
|
385 | 404 | delattr(self._store, key)
|
386 | 405 |
|
| 406 | + # TODO consider supporting the feature store interface for |
| 407 | + # __getitem__, __setitem__, and __delitem__ so, for example, we |
| 408 | + # can accept key: Union[str, TensorAttr] in __getitem__. |
387 | 409 | def __getitem__(self, key: str) -> Any:
|
388 | 410 | return self._store[key]
|
389 | 411 |
|
@@ -692,6 +714,47 @@ def num_faces(self) -> Optional[int]:
|
692 | 714 | return self.face.size(self.__cat_dim__('face', self.face))
|
693 | 715 | return None
|
694 | 716 |
|
| 717 | + # FeatureStore interface ########################################### |
| 718 | + |
| 719 | + def items(self): |
| 720 | + r"""Returns an `ItemsView` over the stored attributes in the `Data` |
| 721 | + object.""" |
| 722 | + return self._store.items() |
| 723 | + |
| 724 | + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: |
| 725 | + r"""Stores a feature tensor in node storage.""" |
| 726 | + out = getattr(self, attr.attr_name, None) |
| 727 | + if out is not None and attr.index is not None: |
| 728 | + # Attr name exists, handle index: |
| 729 | + out[attr.index] = tensor |
| 730 | + else: |
| 731 | + # No attr name (or None index), just store tensor: |
| 732 | + setattr(self, attr.attr_name, tensor) |
| 733 | + return True |
| 734 | + |
| 735 | + def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: |
| 736 | + r"""Obtains a feature tensor from node storage.""" |
| 737 | + # Retrieve tensor and index accordingly: |
| 738 | + tensor = getattr(self, attr.attr_name, None) |
| 739 | + if tensor is not None: |
| 740 | + # TODO this behavior is a bit odd, since TensorAttr requires that |
| 741 | + # we set `index`. So, we assume here that indexing by `None` is |
| 742 | + # equivalent to not indexing at all, which is not in line with |
| 743 | + # Python semantics. |
| 744 | + return tensor[attr.index] if attr.index is not None else tensor |
| 745 | + return None |
| 746 | + |
| 747 | + def _remove_tensor(self, attr: TensorAttr) -> bool: |
| 748 | + r"""Deletes a feature tensor from node storage.""" |
| 749 | + # Remove tensor entirely: |
| 750 | + if hasattr(self, attr.attr_name): |
| 751 | + delattr(self, attr.attr_name) |
| 752 | + return True |
| 753 | + return False |
| 754 | + |
| 755 | + def __len__(self) -> int: |
| 756 | + return BaseData.__len__(self) |
| 757 | + |
695 | 758 |
|
696 | 759 | ###############################################################################
|
697 | 760 |
|
|
0 commit comments