Skip to content

Commit 243b707

Browse files
lightaimerusty1s
andauthored
Add results with 10 runs for RevGNN (#4730)
* Add results with 10 runs * changelog Co-authored-by: Matthias Fey <[email protected]>
1 parent 282c4f9 commit 243b707

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721))
1010
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
1111
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
12-
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715))
12+
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
1313
- Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656), [#4672](https://github.com/pyg-team/pytorch_geometric/pull/4672), [#4676](https://github.com/pyg-team/pytorch_geometric/pull/4676))
1414
- Added `unbatch` functionality ([#4628](https://github.com/pyg-team/pytorch_geometric/pull/4628))
1515
- Confirm that `to_hetero()` works with custom functions, *e.g.*, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653))

examples/rev_gnn.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# Model Paramters: 206,607
2-
# Peak GPU memory usage: 1.57 G
3-
# RevGNN with 7 layers and 160 channels reaches around 0.8200 test accuracy.
4-
# Final Train: 0.9373, Highest Val: 0.9230, Final Test: 0.8200.
5-
# Training longer should produces better results.
1+
# Peak GPU memory usage is around 1.57 G
2+
# | RevGNN Models | Test Acc | Val Acc |
3+
# |-------------------------|-----------------|-----------------|
4+
# | 112 layers 160 channels | 0.8307 ± 0.0030 | 0.9290 ± 0.0007 |
5+
# | 7 layers 160 channels | 0.8276 ± 0.0027 | 0.9272 ± 0.0006 |
66

77
import os.path as osp
88

@@ -93,7 +93,7 @@ def forward(self, x, edge_index):
9393

9494
train_loader = RandomNodeSampler(data, num_parts=10, shuffle=True,
9595
num_workers=5)
96-
# Increase the num_parts of the test loader if you cannot have fix
96+
# Increase the num_parts of the test loader if you cannot fit
9797
# the full batch graph into your GPU:
9898
test_loader = RandomNodeSampler(data, num_parts=1, num_workers=5)
9999

@@ -180,8 +180,18 @@ def test(epoch):
180180
return train_acc, valid_acc, test_acc
181181

182182

183-
for epoch in range(1, 501):
183+
best_val = 0.0
184+
final_train = 0.0
185+
final_test = 0.0
186+
for epoch in range(1, 1001):
184187
loss = train(epoch)
185188
train_acc, val_acc, test_acc = test(epoch)
189+
if val_acc > best_val:
190+
best_val = val_acc
191+
final_train = train_acc
192+
final_test = test_acc
186193
print(f'Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
187194
f'Test: {test_acc:.4f}')
195+
196+
print(f'Final Train: {final_train:.4f}, Best Val: {best_val:.4f}, '
197+
f'Final Test: {final_test:.4f}')

0 commit comments

Comments
 (0)