diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c8426c8e00d..17df37f2fb51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Allow for `setter` properties in `Data` and `HeteroData` ([#4682](https://github.com/pyg-team/pytorch_geometric/pull/4682)) - Allow for optional `edge_weight` in `GCN2Conv` ([#4670](https://github.com/pyg-team/pytorch_geometric/pull/4670)) - Fixed the interplay between `TUDataset` and `pre_transform` that modify node features ([#4669](https://github.com/pyg-team/pytorch_geometric/pull/4669)) - Make use of the `pyg_sphinx_theme` documentation template ([#4664](https://github.com/pyg-team/pyg-lib/pull/4664), [#4667](https://github.com/pyg-team/pyg-lib/pull/4667)) diff --git a/test/data/test_data.py b/test/data/test_data.py index 3733072cc508..b794364be308 100644 --- a/test/data/test_data.py +++ b/test/data/test_data.py @@ -213,3 +213,29 @@ def test_data_share_memory(): for data in data_list: assert data.x.is_shared() assert torch.allclose(data.x, torch.full((8, ), 4.)) + + +def test_data_setter_properties(): + class MyData(Data): + def __init__(self): + super().__init__() + self.my_attr1 = 1 + self.my_attr2 = 2 + + @property + def my_attr1(self): + return self._my_attr1 + + @my_attr1.setter + def my_attr1(self, value): + self._my_attr1 = value + + data = MyData() + assert data.my_attr2 == 2 + + assert 'my_attr1' not in data._store + assert data.my_attr1 == 1 + + data.my_attr1 = 2 + assert 'my_attr1' not in data._store + assert data.my_attr1 == 2 diff --git a/torch_geometric/data/data.py b/torch_geometric/data/data.py index 0a778214a582..a6bc90710d89 100644 --- a/torch_geometric/data/data.py +++ b/torch_geometric/data/data.py @@ -375,7 +375,11 @@ def __getattr__(self, key: str) -> Any: return getattr(self._store, key) def __setattr__(self, key: str, value: Any): - setattr(self._store, key, value) + propobj = getattr(self.__class__, key, None) + if propobj is None or propobj.fset is None: + setattr(self._store, key, value) + else: + propobj.fset(self, value) def __delattr__(self, key: str): delattr(self._store, key)