Skip to content

Commit c13d62c

Browse files
authored
Rename RandomTranslate to RandomJitter (#4828)
* update * changelog * typo * typo
1 parent d72df66 commit c13d62c

File tree

8 files changed

+21
-15
lines changed

8 files changed

+21
-15
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
3636
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
3737
### Changed
38+
- Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828))
3839
- Do not allow accessing edge types in `HeteroData` with two node types when there exists multiple relations between these types ([#4782](https://github.com/pyg-team/pytorch_geometric/pull/4782))
3940
- Allow `edge_type == rev_edge_type` argument in `RandomLinkSplit` ([#4757](https://github.com/pyg-team/pytorch_geometric/pull/4757))
4041
- Fixed a numerical instability in the `GeneralConv` and `neighbor_sample` tests ([#4754](https://github.com/pyg-team/pytorch_geometric/pull/4754))

docs/source/notes/introduction.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ In addition, we can use the :obj:`transform` argument to randomly augment a :cla
330330
331331
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
332332
pre_transform=T.KNNGraph(k=6),
333-
transform=T.RandomTranslate(0.01))
333+
transform=T.RandomJitter(0.01))
334334
335335
dataset[0]
336336
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

examples/dgcnn_segmentation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
category = 'Airplane' # Pass in `None` to train on all categories.
1414
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
1515
transform = T.Compose([
16-
T.RandomTranslate(0.01),
16+
T.RandomJitter(0.01),
1717
T.RandomRotate(15, axis=0),
1818
T.RandomRotate(15, axis=1),
1919
T.RandomRotate(15, axis=2)

examples/point_transformer_segmentation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
category = 'Airplane' # Pass in `None` to train on all categories.
2323
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
2424
transform = T.Compose([
25-
T.RandomTranslate(0.01),
25+
T.RandomJitter(0.01),
2626
T.RandomRotate(15, axis=0),
2727
T.RandomRotate(15, axis=1),
2828
T.RandomRotate(15, axis=2),

examples/pointnet2_segmentation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
category = 'Airplane' # Pass in `None` to train on all categories.
1515
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
1616
transform = T.Compose([
17-
T.RandomTranslate(0.01),
17+
T.RandomJitter(0.01),
1818
T.RandomRotate(15, axis=0),
1919
T.RandomRotate(15, axis=1),
2020
T.RandomRotate(15, axis=2)

test/transforms/test_random_translate.py test/transforms/test_random_jitter.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
import torch
22

33
from torch_geometric.data import Data
4-
from torch_geometric.transforms import RandomTranslate
4+
from torch_geometric.transforms import RandomJitter
55

66

7-
def test_random_translate():
8-
assert RandomTranslate(0.1).__repr__() == 'RandomTranslate(0.1)'
7+
def test_random_jitter():
8+
assert RandomJitter(0.1).__repr__() == 'RandomJitter(0.1)'
99

1010
pos = torch.Tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
1111

1212
data = Data(pos=pos)
13-
data = RandomTranslate(0)(data)
13+
data = RandomJitter(0)(data)
1414
assert len(data) == 1
1515
assert data.pos.tolist() == pos.tolist()
1616

1717
data = Data(pos=pos)
18-
data = RandomTranslate(0.1)(data)
18+
data = RandomJitter(0.1)(data)
1919
assert len(data) == 1
2020
assert data.pos.min().item() >= -0.1
2121
assert data.pos.max().item() <= 0.1
2222

2323
data = Data(pos=pos)
24-
data = RandomTranslate([0.1, 1])(data)
24+
data = RandomJitter([0.1, 1])(data)
2525
assert len(data) == 1
2626
assert data.pos[:, 0].min().item() >= -0.1
2727
assert data.pos[:, 0].max().item() <= 0.1

torch_geometric/transforms/__init__.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .center import Center
1717
from .normalize_rotation import NormalizeRotation
1818
from .normalize_scale import NormalizeScale
19-
from .random_translate import RandomTranslate
19+
from .random_jitter import RandomJitter
2020
from .random_flip import RandomFlip
2121
from .linear_transformation import LinearTransformation
2222
from .random_scale import RandomScale
@@ -70,7 +70,7 @@
7070
'Center',
7171
'NormalizeRotation',
7272
'NormalizeScale',
73-
'RandomTranslate',
73+
'RandomJitter',
7474
'RandomFlip',
7575
'LinearTransformation',
7676
'RandomScale',
@@ -109,3 +109,8 @@
109109
]
110110

111111
classes = __all__
112+
113+
from torch_geometric.deprecation import deprecated # noqa
114+
115+
RandomTranslate = deprecated("use 'transforms.RandomJitter' instead",
116+
'transforms.RandomTranslate')(RandomJitter)

torch_geometric/transforms/random_translate.py torch_geometric/transforms/random_jitter.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from torch_geometric.transforms import BaseTransform
88

99

10-
@functional_transform('random_translate')
11-
class RandomTranslate(BaseTransform):
10+
@functional_transform('random_jitter')
11+
class RandomJitter(BaseTransform):
1212
r"""Translates node positions by randomly sampled translation values
13-
within a given interval (functional name: :obj:`random_translate`).
13+
within a given interval (functional name: :obj:`random_jitter`).
1414
In contrast to other random transformations,
1515
translation is applied separately at each position
1616

0 commit comments

Comments
 (0)