Skip to content

Commit f045451

Browse files
authored
Confirm that to_hetero() works with custom functions (dropout_adj) (#4653)
* to_hetero_dropout * changelog * update
1 parent ced3886 commit f045451

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Confirm that `to_hetero()` works with custom functions, *e.g.*, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653))
89
- Added the `MLP.plain_last=False` option ([4652](https://github.com/pyg-team/pytorch_geometric/pull/4652))
910
- Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647))
1011
- Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635))

test/nn/test_to_hetero_transformer.py

+22
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33
import pytest
44
import torch
5+
import torch.nn.functional as F
56
from torch import Tensor
67
from torch.nn import Linear, ReLU, Sequential
78
from torch_sparse import SparseTensor
89

910
from torch_geometric.nn import BatchNorm, GCNConv, GINEConv, GlobalPooling
1011
from torch_geometric.nn import Linear as LazyLinear
1112
from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero
13+
from torch_geometric.utils import dropout_adj
14+
15+
torch.fx.wrap('dropout_adj')
1216

1317

1418
class Net1(torch.nn.Module):
@@ -123,6 +127,17 @@ def forward(self, x: Tensor) -> Tensor:
123127
return self.batch_norm(x)
124128

125129

130+
class Net10(torch.nn.Module):
131+
def __init__(self):
132+
super().__init__()
133+
self.conv = SAGEConv(16, 32)
134+
135+
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
136+
x = F.dropout(x, p=0.5, training=self.training)
137+
edge_index, _ = dropout_adj(edge_index, p=0.5, training=self.training)
138+
return self.conv(x, edge_index)
139+
140+
126141
def test_to_hetero():
127142
x_dict = {
128143
'paper': torch.randn(100, 16),
@@ -213,6 +228,13 @@ def test_to_hetero():
213228
assert out['paper'].size() == (4, 16)
214229
assert out['author'].size() == (8, 16)
215230

231+
model = Net10()
232+
model = to_hetero(model, metadata, debug=False)
233+
out = model(x_dict, edge_index_dict)
234+
assert isinstance(out, dict) and len(out) == 2
235+
assert out['paper'].size() == (100, 32)
236+
assert out['author'].size() == (100, 32)
237+
216238

217239
class GCN(torch.nn.Module):
218240
def __init__(self):

0 commit comments

Comments
 (0)