diff --git a/CHANGELOG.md b/CHANGELOG.md index c9580f3dd638..a349416fc941 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828)) - Do not allow accessing edge types in `HeteroData` with two node types when there exists multiple relations between these types ([#4782](https://github.com/pyg-team/pytorch_geometric/pull/4782)) - Allow `edge_type == rev_edge_type` argument in `RandomLinkSplit` ([#4757](https://github.com/pyg-team/pytorch_geometric/pull/4757)) - Fixed a numerical instability in the `GeneralConv` and `neighbor_sample` tests ([#4754](https://github.com/pyg-team/pytorch_geometric/pull/4754)) diff --git a/docs/source/notes/introduction.rst b/docs/source/notes/introduction.rst index b128dfe3a340..d86c08d12424 100644 --- a/docs/source/notes/introduction.rst +++ b/docs/source/notes/introduction.rst @@ -330,7 +330,7 @@ In addition, we can use the :obj:`transform` argument to randomly augment a :cla dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'], pre_transform=T.KNNGraph(k=6), - transform=T.RandomTranslate(0.01)) + transform=T.RandomJitter(0.01)) dataset[0] >>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518]) diff --git a/examples/dgcnn_segmentation.py b/examples/dgcnn_segmentation.py index 9b0581f98f13..c040fa08189d 100644 --- a/examples/dgcnn_segmentation.py +++ b/examples/dgcnn_segmentation.py @@ -13,7 +13,7 @@ category = 'Airplane' # Pass in `None` to train on all categories. path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') transform = T.Compose([ - T.RandomTranslate(0.01), + T.RandomJitter(0.01), T.RandomRotate(15, axis=0), T.RandomRotate(15, axis=1), T.RandomRotate(15, axis=2) diff --git a/examples/point_transformer_segmentation.py b/examples/point_transformer_segmentation.py index 3ecfa724ed29..86f920283bea 100644 --- a/examples/point_transformer_segmentation.py +++ b/examples/point_transformer_segmentation.py @@ -22,7 +22,7 @@ category = 'Airplane' # Pass in `None` to train on all categories. path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') transform = T.Compose([ - T.RandomTranslate(0.01), + T.RandomJitter(0.01), T.RandomRotate(15, axis=0), T.RandomRotate(15, axis=1), T.RandomRotate(15, axis=2), diff --git a/examples/pointnet2_segmentation.py b/examples/pointnet2_segmentation.py index 15a50723ea47..f316450671cb 100644 --- a/examples/pointnet2_segmentation.py +++ b/examples/pointnet2_segmentation.py @@ -14,7 +14,7 @@ category = 'Airplane' # Pass in `None` to train on all categories. path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') transform = T.Compose([ - T.RandomTranslate(0.01), + T.RandomJitter(0.01), T.RandomRotate(15, axis=0), T.RandomRotate(15, axis=1), T.RandomRotate(15, axis=2) diff --git a/test/transforms/test_random_translate.py b/test/transforms/test_random_jitter.py similarity index 68% rename from test/transforms/test_random_translate.py rename to test/transforms/test_random_jitter.py index 2b335880356c..88e81cb3d08a 100644 --- a/test/transforms/test_random_translate.py +++ b/test/transforms/test_random_jitter.py @@ -1,27 +1,27 @@ import torch from torch_geometric.data import Data -from torch_geometric.transforms import RandomTranslate +from torch_geometric.transforms import RandomJitter -def test_random_translate(): - assert RandomTranslate(0.1).__repr__() == 'RandomTranslate(0.1)' +def test_random_jitter(): + assert RandomJitter(0.1).__repr__() == 'RandomJitter(0.1)' pos = torch.Tensor([[0, 0], [0, 0], [0, 0], [0, 0]]) data = Data(pos=pos) - data = RandomTranslate(0)(data) + data = RandomJitter(0)(data) assert len(data) == 1 assert data.pos.tolist() == pos.tolist() data = Data(pos=pos) - data = RandomTranslate(0.1)(data) + data = RandomJitter(0.1)(data) assert len(data) == 1 assert data.pos.min().item() >= -0.1 assert data.pos.max().item() <= 0.1 data = Data(pos=pos) - data = RandomTranslate([0.1, 1])(data) + data = RandomJitter([0.1, 1])(data) assert len(data) == 1 assert data.pos[:, 0].min().item() >= -0.1 assert data.pos[:, 0].max().item() <= 0.1 diff --git a/torch_geometric/transforms/__init__.py b/torch_geometric/transforms/__init__.py index e9173b6227ae..b9e2c257f930 100644 --- a/torch_geometric/transforms/__init__.py +++ b/torch_geometric/transforms/__init__.py @@ -16,7 +16,7 @@ from .center import Center from .normalize_rotation import NormalizeRotation from .normalize_scale import NormalizeScale -from .random_translate import RandomTranslate +from .random_jitter import RandomJitter from .random_flip import RandomFlip from .linear_transformation import LinearTransformation from .random_scale import RandomScale @@ -70,7 +70,7 @@ 'Center', 'NormalizeRotation', 'NormalizeScale', - 'RandomTranslate', + 'RandomJitter', 'RandomFlip', 'LinearTransformation', 'RandomScale', @@ -109,3 +109,8 @@ ] classes = __all__ + +from torch_geometric.deprecation import deprecated # noqa + +RandomTranslate = deprecated("use 'transforms.RandomJitter' instead", + 'transforms.RandomTranslate')(RandomJitter) diff --git a/torch_geometric/transforms/random_translate.py b/torch_geometric/transforms/random_jitter.py similarity index 89% rename from torch_geometric/transforms/random_translate.py rename to torch_geometric/transforms/random_jitter.py index af123ad1c95f..cbf5d69bd05d 100644 --- a/torch_geometric/transforms/random_translate.py +++ b/torch_geometric/transforms/random_jitter.py @@ -7,10 +7,10 @@ from torch_geometric.transforms import BaseTransform -@functional_transform('random_translate') -class RandomTranslate(BaseTransform): +@functional_transform('random_jitter') +class RandomJitter(BaseTransform): r"""Translates node positions by randomly sampled translation values - within a given interval (functional name: :obj:`random_translate`). + within a given interval (functional name: :obj:`random_jitter`). In contrast to other random transformations, translation is applied separately at each position