From ff7066a4c3468fa10c503760e546ebb3721521ef Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Wed, 15 Jun 2022 00:18:58 +0000 Subject: [PATCH 1/9] init --- test/data/test_data.py | 25 +++++++++++ test/data/test_hetero_data.py | 27 ++++++++++++ torch_geometric/data/data.py | 63 ++++++++++++++++++++++++++- torch_geometric/data/feature_store.py | 2 +- torch_geometric/data/hetero_data.py | 53 +++++++++++++++++++++- 5 files changed, 167 insertions(+), 3 deletions(-) diff --git a/test/data/test_data.py b/test/data/test_data.py index b794364be308..be06d43bed0b 100644 --- a/test/data/test_data.py +++ b/test/data/test_data.py @@ -239,3 +239,28 @@ def my_attr1(self, value): data.my_attr1 = 2 assert 'my_attr1' not in data._store assert data.my_attr1 == 2 + + +# Feature Store ############################################################### + + +def test_basic_feature_store(): + data = Data() + x = torch.randn(20, 20) + + # Put tensor: + assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None) + assert torch.equal(data.x, x) + + # Put (modify) tensor slice: + x[15:] = 0 + data.put_tensor(0, attr_name='x', index=slice(15, None, None)) + + # Get tensor: + out = data.get_tensor(attr_name='x', index=None) + assert torch.equal(x, out) + + # Remove tensor: + assert 'x' in data.__dict__['_store'] + data.remove_tensor(attr_name='x', index=None) + assert 'x' not in data.__dict__['_store'] diff --git a/test/data/test_hetero_data.py b/test/data/test_hetero_data.py index ba5f7a33f389..e526d7ef7a20 100644 --- a/test/data/test_hetero_data.py +++ b/test/data/test_hetero_data.py @@ -400,3 +400,30 @@ def test_hetero_data_to_canonical(): with pytest.raises(TypeError, match="missing 1 required"): data['user', 'product'] + + +# Feature Store ############################################################### + + +def test_basic_feature_store(): + data = HeteroData() + x = torch.randn(20, 20) + + # Put tensor: + assert data.put_tensor(copy.deepcopy(x), group_name='paper', attr_name='x', + index=None) + assert torch.equal(data['paper'].x, x) + + # Put (modify) tensor slice: + x[15:] = 0 + data.put_tensor(0, group_name='paper', attr_name='x', + index=slice(15, None, None)) + + # Get tensor: + out = data.get_tensor(group_name='paper', attr_name='x', index=None) + assert torch.equal(x, out) + + # Remove tensor: + assert 'x' in data.__dict__['_mapping'] + data.remove_tensor(group_name='paper', attr_name='x', index=None) + assert 'x' not in data['paper'].__dict__['_mapping'] diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 580c9bdd3b6e..794b391ebe9f 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -1,5 +1,6 @@ import copy from collections.abc import Mapping, Sequence +from dataclasses import dataclass from typing import ( Any, Callable, @@ -17,6 +18,12 @@ from torch import Tensor from torch_sparse import SparseTensor +from torch_geometric.data.feature_store import ( + FeatureStore, + FeatureTensorType, + TensorAttr, + _field_status, +) from torch_geometric.data.storage import ( BaseStorage, EdgeStorage, @@ -300,7 +307,16 @@ def contains_self_loops(self) -> bool: ############################################################################### -class Data(BaseData): +@dataclass +class DataTensorAttr(TensorAttr): + r"""Attribute class for `Data`, which does not require a `group_name`.""" + def __init__(self, attr_name=_field_status.UNSET, + index=_field_status.UNSET): + # Treat group_name as optional, and move it to the end + super().__init__(None, attr_name, index) + + +class Data(BaseData, FeatureStore): r"""A data object describing a homogeneous graph. The data object can hold node-level, link-level and graph-level attributes. In general, :class:`~torch_geometric.data.Data` tries to mimic the @@ -365,6 +381,10 @@ def __init__(self, x: OptTensor = None, edge_index: OptTensor = None, for key, value in kwargs.items(): setattr(self, key, value) + # `Data` does not support group_nae, so we need to adjust `TensorAttr` + # accordingly here to avoid requiring `group_name` to be set: + FeatureStore.__init__(self, attr_cls=DataTensorAttr) + def __getattr__(self, key: str) -> Any: if '_store' not in self.__dict__: raise RuntimeError( @@ -692,6 +712,47 @@ def num_faces(self) -> Optional[int]: return self.face.size(self.__cat_dim__('face', self.face)) return None + # :obj:`FeatureStore` interface ########################################### + + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: + r"""Stores a feature tensor in node storage.""" + + if not attr.is_set('index'): + attr.index = None + + out = getattr(self, attr.attr_name, None) + if out is not None: + # Attr name exists, handle index: + out[attr.index] = tensor + else: + # No attr nane, just store tensor: + setattr(self, attr.attr_name, tensor) + return True + + def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: + r"""Obtains a feature tensor from node storage.""" + # Retrieve tensor and index accordingly: + tensor = getattr(self, attr.attr_name) + if tensor is not None: + # TODO this behavior is a bit odd, since TensorAttr requires that + # we set `index`. So, we assume here that indexing by `None` is + # equivalent to not indexing at all, which is not in line with + # Python semantics. + return tensor[attr.index] if attr.index is not None else tensor + return None + + def _remove_tensor(self, attr: TensorAttr) -> bool: + r"""Deletes a feature tensor from node storage.""" + # Remove tensor entirely: + delattr(self, attr.attr_name) + return True + + def __len__(self) -> int: + return BaseData.__len__(self) + + def __iter__(self): + raise NotImplementedError + ############################################################################### diff --git a/torch_geometric/data/feature_store.py b/torch_geometric/data/feature_store.py index bc7d10322497..b9c2aa623cc6 100644 --- a/torch_geometric/data/feature_store.py +++ b/torch_geometric/data/feature_store.py @@ -245,7 +245,7 @@ def __init__(self, attr_cls: Any = TensorAttr): attributes by subclassing :class:`TensorAttr` and passing the subclass as :obj:`attr_cls`.""" super().__init__() - self._attr_cls = attr_cls + self.__dict__['_attr_cls'] = attr_cls # Core (CRUD) ############################################################# diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index d4e77c1a80e3..6b5d2fd786a2 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -10,6 +10,11 @@ from torch_sparse import SparseTensor from torch_geometric.data.data import BaseData, Data, size_repr +from torch_geometric.data.feature_store import ( + FeatureStore, + FeatureTensorType, + TensorAttr, +) from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage from torch_geometric.typing import EdgeType, NodeType, QueryType from torch_geometric.utils import bipartite_subgraph, is_undirected @@ -18,7 +23,7 @@ NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage] -class HeteroData(BaseData): +class HeteroData(BaseData, FeatureStore): r"""A data object describing a heterogeneous graph, holding multiple node and/or edge types in disjunct storage objects. Storage objects can hold either node-level, link-level or graph-level @@ -105,6 +110,10 @@ def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs): else: setattr(self, key, value) + # `HeteroData` supports group_name, attr_name, and index, so we can + # initialize the feature store without modification: + FeatureStore.__init__(self) + def __getattr__(self, key: str) -> Any: # `data.*_dict` => Link to node and edge stores. # `data.*` => Link to the `_global_store`. @@ -616,6 +625,48 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]: return data + # :obj:`FeatureStore` interface ########################################### + + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: + r"""Stores a feature tensor in node storage.""" + if not attr.is_set('index'): + attr.index = None + + out = self._node_store_dict.get(attr.group_name, None) + if out: + # Group name exists, handle index: + val = getattr(out, attr.attr_name) + if val is not None: + val[attr.index] = tensor + else: + # No group name, just store tensor in node storage: + setattr(self[attr.group_name], attr.attr_name, tensor) + return True + + def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: + r"""Obtains a feature tensor from node storage.""" + # Retrieve tensor and index accordingly: + tensor = getattr(self[attr.group_name], attr.attr_name) + if tensor is not None: + # TODO this behavior is a bit odd, since TensorAttr requires that + # we set `index`. So, we assume here that indexing by `None` is + # equivalent to not indexing at all, which is not in line with + # Python semantics. + return tensor[attr.index] if attr.index is not None else tensor + return None + + def _remove_tensor(self, attr: TensorAttr) -> bool: + r"""Deletes a feature tensor from node storage.""" + # Remove tensor entirely: + delattr(self[attr.group_name], attr.attr_name) + return True + + def __len__(self) -> int: + return BaseData.__len__(self) + + def __iter__(self): + raise NotImplementedError + # Helper functions ############################################################ From 1688803f73e9bb1e4f84880dad609d243d848e35 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Wed, 15 Jun 2022 00:47:08 +0000 Subject: [PATCH 2/9] update --- torch_geometric/data/data.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 794b391ebe9f..491770f808a7 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -750,9 +750,6 @@ def _remove_tensor(self, attr: TensorAttr) -> bool: def __len__(self) -> int: return BaseData.__len__(self) - def __iter__(self): - raise NotImplementedError - ############################################################################### From 59198b67b4756e02e27b50179ca7c587fe53210d Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Fri, 17 Jun 2022 02:00:58 +0000 Subject: [PATCH 3/9] fix, todo: failing tests --- test/data/test_hetero_data.py | 2 +- torch_geometric/data/batch.py | 6 ++++- torch_geometric/data/data.py | 36 ++++++++++++++++------------- torch_geometric/data/hetero_data.py | 14 ++++++----- 4 files changed, 34 insertions(+), 24 deletions(-) diff --git a/test/data/test_hetero_data.py b/test/data/test_hetero_data.py index e526d7ef7a20..b26832bcb068 100644 --- a/test/data/test_hetero_data.py +++ b/test/data/test_hetero_data.py @@ -424,6 +424,6 @@ def test_basic_feature_store(): assert torch.equal(x, out) # Remove tensor: - assert 'x' in data.__dict__['_mapping'] + assert 'x' in data['paper'].__dict__['_mapping'] data.remove_tensor(group_name='paper', attr_name='x', index=None) assert 'x' not in data['paper'].__dict__['_mapping'] diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index ecaae4d663b3..8afbee3d4ba6 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -23,8 +23,12 @@ def __call__(cls, *args, **kwargs): new_cls = base_cls else: name = f'{base_cls.__name__}{cls.__name__}' + + class MetaResolver(type(cls), type(base_cls)): + pass + if name not in globals(): - globals()[name] = type(name, (cls, base_cls), {}) + globals()[name] = MetaResolver(name, (cls, base_cls), {}) new_cls = globals()[name] params = list(inspect.signature(base_cls.__init__).parameters.items()) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 491770f808a7..c9663fb26ce9 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -364,7 +364,10 @@ class Data(BaseData, FeatureStore): def __init__(self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None, **kwargs): - super().__init__() + # `Data` doesn't support group_name, so we need to adjust `TensorAttr` + # accordingly here to avoid requiring `group_name` to be set: + super(Data, self).__init__(attr_cls=DataTensorAttr) + self.__dict__['_store'] = GlobalStorage(_parent=self) if x is not None: @@ -381,10 +384,6 @@ def __init__(self, x: OptTensor = None, edge_index: OptTensor = None, for key, value in kwargs.items(): setattr(self, key, value) - # `Data` does not support group_nae, so we need to adjust `TensorAttr` - # accordingly here to avoid requiring `group_name` to be set: - FeatureStore.__init__(self, attr_cls=DataTensorAttr) - def __getattr__(self, key: str) -> Any: if '_store' not in self.__dict__: raise RuntimeError( @@ -404,6 +403,9 @@ def __setattr__(self, key: str, value: Any): def __delattr__(self, key: str): delattr(self._store, key) + # TODO consider supporting the feature store interface for + # __getitem__, __setitem__, and __delitem__ so, for example, we + # can accept key: Union[str, TensorAttr] in __getitem__. def __getitem__(self, key: str) -> Any: return self._store[key] @@ -712,40 +714,42 @@ def num_faces(self) -> Optional[int]: return self.face.size(self.__cat_dim__('face', self.face)) return None - # :obj:`FeatureStore` interface ########################################### + # FeatureStore interface ########################################### def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: r"""Stores a feature tensor in node storage.""" - - if not attr.is_set('index'): - attr.index = None - out = getattr(self, attr.attr_name, None) - if out is not None: + if out is not None and attr.index is not None: # Attr name exists, handle index: out[attr.index] = tensor else: - # No attr nane, just store tensor: + # No attr name (or None index), just store tensor: setattr(self, attr.attr_name, tensor) return True def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: r"""Obtains a feature tensor from node storage.""" # Retrieve tensor and index accordingly: - tensor = getattr(self, attr.attr_name) + tensor = getattr(self, attr.attr_name, None) if tensor is not None: # TODO this behavior is a bit odd, since TensorAttr requires that # we set `index`. So, we assume here that indexing by `None` is # equivalent to not indexing at all, which is not in line with # Python semantics. - return tensor[attr.index] if attr.index is not None else tensor + if attr.index is None: + return tensor + + dim = self.__cat_dim__(attr.attr_name, tensor) + return torch.index_select(tensor, attr.index, dim=dim) return None def _remove_tensor(self, attr: TensorAttr) -> bool: r"""Deletes a feature tensor from node storage.""" # Remove tensor entirely: - delattr(self, attr.attr_name) - return True + if hasattr(self, attr.attr_name): + delattr(self, attr.attr_name) + return True + return False def __len__(self) -> int: return BaseData.__len__(self) diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 6b5d2fd786a2..99c8c70e43fb 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -97,6 +97,8 @@ class HeteroData(BaseData, FeatureStore): DEFAULT_REL = 'to' def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs): + super(HeteroData, self).__init__() + self.__dict__['_global_store'] = BaseStorage(_parent=self) self.__dict__['_node_store_dict'] = {} self.__dict__['_edge_store_dict'] = {} @@ -110,10 +112,6 @@ def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs): else: setattr(self, key, value) - # `HeteroData` supports group_name, attr_name, and index, so we can - # initialize the feature store without modification: - FeatureStore.__init__(self) - def __getattr__(self, key: str) -> Any: # `data.*_dict` => Link to node and edge stores. # `data.*` => Link to the `_global_store`. @@ -639,7 +637,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: if val is not None: val[attr.index] = tensor else: - # No group name, just store tensor in node storage: + # No node storage found, just store tensor in new one: setattr(self[attr.group_name], attr.attr_name, tensor) return True @@ -652,7 +650,11 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: # we set `index`. So, we assume here that indexing by `None` is # equivalent to not indexing at all, which is not in line with # Python semantics. - return tensor[attr.index] if attr.index is not None else tensor + if attr.index is None: + return tensor + + dim = self[attr.group_name].__cat_dim__(attr.attr_name, tensor) + return torch.index_select(tensor, attr.index, dim=dim) return None def _remove_tensor(self, attr: TensorAttr) -> bool: From 82a10e06b81dde4ac805ddf871ff69be1c5ce06b Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Fri, 17 Jun 2022 02:50:22 +0000 Subject: [PATCH 4/9] fix to_heterodata --- torch_geometric/data/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index c9663fb26ce9..2307195dc284 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -605,7 +605,7 @@ def to_heterogeneous(self, node_type: Optional[Tensor] = None, data = HeteroData() for i, key in enumerate(node_type_names): - for attr, value in self.items(): + for attr, value in self._store.items(): if attr == 'node_type' or attr == 'edge_type': continue elif isinstance(value, Tensor) and self.is_node_attr(attr): @@ -616,7 +616,7 @@ def to_heterogeneous(self, node_type: Optional[Tensor] = None, for i, key in enumerate(edge_type_names): src, _, dst = key - for attr, value in self.items(): + for attr, value in self._store.items(): if attr == 'node_type' or attr == 'edge_type': continue elif attr == 'edge_index': @@ -629,7 +629,7 @@ def to_heterogeneous(self, node_type: Optional[Tensor] = None, # Add global attributes. keys = set(data.keys) | {'node_type', 'edge_type', 'num_nodes'} - for attr, value in self.items(): + for attr, value in self._store.items(): if attr in keys: continue if len(data.node_stores) == 1: From 996591f47cefa3202a1020ae7a47ee075374c6b9 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Fri, 17 Jun 2022 02:52:13 +0000 Subject: [PATCH 5/9] CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 622e86fbf2b5..22225e0b67fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807)) - Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805)) - Added a `max_sample` argument to `AddMetaPaths` in order to tackle very dense metapath edges ([#4750](https://github.com/pyg-team/pytorch_geometric/pull/4750)) - Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756)) From a2a099fb163bf4a621286bbafa6838d89596375a Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Fri, 17 Jun 2022 02:56:28 +0000 Subject: [PATCH 6/9] fix --- torch_geometric/data/data.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 2307195dc284..90ef8fed81a0 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -605,7 +605,7 @@ def to_heterogeneous(self, node_type: Optional[Tensor] = None, data = HeteroData() for i, key in enumerate(node_type_names): - for attr, value in self._store.items(): + for attr, value in self.items(): if attr == 'node_type' or attr == 'edge_type': continue elif isinstance(value, Tensor) and self.is_node_attr(attr): @@ -616,7 +616,7 @@ def to_heterogeneous(self, node_type: Optional[Tensor] = None, for i, key in enumerate(edge_type_names): src, _, dst = key - for attr, value in self._store.items(): + for attr, value in self.items(): if attr == 'node_type' or attr == 'edge_type': continue elif attr == 'edge_index': @@ -629,7 +629,7 @@ def to_heterogeneous(self, node_type: Optional[Tensor] = None, # Add global attributes. keys = set(data.keys) | {'node_type', 'edge_type', 'num_nodes'} - for attr, value in self._store.items(): + for attr, value in self.items(): if attr in keys: continue if len(data.node_stores) == 1: @@ -716,6 +716,9 @@ def num_faces(self) -> Optional[int]: # FeatureStore interface ########################################### + def items(self): + return self._store.items() + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: r"""Stores a feature tensor in node storage.""" out = getattr(self, attr.attr_name, None) From cdd1a72219a398bdba44c136bb8d1811bd8c843f Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Fri, 17 Jun 2022 15:11:29 +0000 Subject: [PATCH 7/9] update --- torch_geometric/data/data.py | 6 +----- torch_geometric/data/hetero_data.py | 13 ++++++------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 90ef8fed81a0..d62894c543ae 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -739,11 +739,7 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: # we set `index`. So, we assume here that indexing by `None` is # equivalent to not indexing at all, which is not in line with # Python semantics. - if attr.index is None: - return tensor - - dim = self.__cat_dim__(attr.attr_name, tensor) - return torch.index_select(tensor, attr.index, dim=dim) + return tensor[attr.index] if attr.index is not None else tensor return None def _remove_tensor(self, attr: TensorAttr) -> bool: diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 99c8c70e43fb..0afc8d5469fe 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -650,18 +650,17 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: # we set `index`. So, we assume here that indexing by `None` is # equivalent to not indexing at all, which is not in line with # Python semantics. - if attr.index is None: - return tensor - - dim = self[attr.group_name].__cat_dim__(attr.attr_name, tensor) - return torch.index_select(tensor, attr.index, dim=dim) + return tensor[attr.index] if attr.index is not None else tensor return None def _remove_tensor(self, attr: TensorAttr) -> bool: r"""Deletes a feature tensor from node storage.""" # Remove tensor entirely: - delattr(self[attr.group_name], attr.attr_name) - return True + if (hasattr(self, attr.group_name) + and hasattr(self[attr.group_name], attr.attr_name)): + delattr(self[attr.group_name], attr.attr_name) + return True + return False def __len__(self) -> int: return BaseData.__len__(self) From 2991fbc65636753a670cb93d7be492d4e2f45ec8 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Fri, 17 Jun 2022 15:27:01 +0000 Subject: [PATCH 8/9] fix --- torch_geometric/data/hetero_data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 0afc8d5469fe..593f1dcaf279 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -656,8 +656,7 @@ def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: def _remove_tensor(self, attr: TensorAttr) -> bool: r"""Deletes a feature tensor from node storage.""" # Remove tensor entirely: - if (hasattr(self, attr.group_name) - and hasattr(self[attr.group_name], attr.attr_name)): + if hasattr(self[attr.group_name], attr.attr_name): delattr(self[attr.group_name], attr.attr_name) return True return False From 7504407dce86c02056ffaf5621fc69708967b530 Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Mon, 20 Jun 2022 16:59:21 +0000 Subject: [PATCH 9/9] comments --- torch_geometric/data/batch.py | 4 ++++ torch_geometric/data/data.py | 4 +++- torch_geometric/data/hetero_data.py | 8 +++++--- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 8afbee3d4ba6..43e553ab1097 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -24,6 +24,10 @@ def __call__(cls, *args, **kwargs): else: name = f'{base_cls.__name__}{cls.__name__}' + # NOTE `MetaResolver` is necessary to resolve metaclass conflict + # problems between `DynamicInheritance` and the metaclass of + # `base_cls`. In particular, it creates a new common metaclass + # from the defined metaclasses. class MetaResolver(type(cls), type(base_cls)): pass diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index d62894c543ae..3a222246b44e 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -366,7 +366,7 @@ def __init__(self, x: OptTensor = None, edge_index: OptTensor = None, pos: OptTensor = None, **kwargs): # `Data` doesn't support group_name, so we need to adjust `TensorAttr` # accordingly here to avoid requiring `group_name` to be set: - super(Data, self).__init__(attr_cls=DataTensorAttr) + super().__init__(attr_cls=DataTensorAttr) self.__dict__['_store'] = GlobalStorage(_parent=self) @@ -717,6 +717,8 @@ def num_faces(self) -> Optional[int]: # FeatureStore interface ########################################### def items(self): + r"""Returns an `ItemsView` over the stored attributes in the `Data` + object.""" return self._store.items() def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: diff --git a/torch_geometric/data/hetero_data.py b/torch_geometric/data/hetero_data.py index 593f1dcaf279..051833a36371 100644 --- a/torch_geometric/data/hetero_data.py +++ b/torch_geometric/data/hetero_data.py @@ -97,7 +97,7 @@ class HeteroData(BaseData, FeatureStore): DEFAULT_REL = 'to' def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs): - super(HeteroData, self).__init__() + super().__init__() self.__dict__['_global_store'] = BaseStorage(_parent=self) self.__dict__['_node_store_dict'] = {} @@ -632,10 +632,12 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: out = self._node_store_dict.get(attr.group_name, None) if out: - # Group name exists, handle index: + # Group name exists, handle index or create new attribute name: val = getattr(out, attr.attr_name) if val is not None: val[attr.index] = tensor + else: + setattr(self[attr.group_name], attr.attr_name, tensor) else: # No node storage found, just store tensor in new one: setattr(self[attr.group_name], attr.attr_name, tensor) @@ -644,7 +646,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: r"""Obtains a feature tensor from node storage.""" # Retrieve tensor and index accordingly: - tensor = getattr(self[attr.group_name], attr.attr_name) + tensor = getattr(self[attr.group_name], attr.attr_name, None) if tensor is not None: # TODO this behavior is a bit odd, since TensorAttr requires that # we set `index`. So, we assume here that indexing by `None` is