|
2 | 2 |
|
3 | 3 | import pytest
|
4 | 4 | import torch
|
| 5 | +import torch.nn.functional as F |
5 | 6 | from torch import Tensor
|
6 | 7 | from torch.nn import Linear, ReLU, Sequential
|
7 | 8 | from torch_sparse import SparseTensor
|
8 | 9 |
|
9 | 10 | from torch_geometric.nn import BatchNorm, GCNConv, GINEConv, GlobalPooling
|
10 | 11 | from torch_geometric.nn import Linear as LazyLinear
|
11 | 12 | from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero
|
| 13 | +from torch_geometric.utils import dropout_adj |
| 14 | + |
| 15 | +torch.fx.wrap('dropout_adj') |
12 | 16 |
|
13 | 17 |
|
14 | 18 | class Net1(torch.nn.Module):
|
@@ -123,6 +127,17 @@ def forward(self, x: Tensor) -> Tensor:
|
123 | 127 | return self.batch_norm(x)
|
124 | 128 |
|
125 | 129 |
|
| 130 | +class Net10(torch.nn.Module): |
| 131 | + def __init__(self): |
| 132 | + super().__init__() |
| 133 | + self.conv = SAGEConv(16, 32) |
| 134 | + |
| 135 | + def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: |
| 136 | + x = F.dropout(x, p=0.5, training=self.training) |
| 137 | + edge_index, _ = dropout_adj(edge_index, p=0.5, training=self.training) |
| 138 | + return self.conv(x, edge_index) |
| 139 | + |
| 140 | + |
126 | 141 | def test_to_hetero():
|
127 | 142 | x_dict = {
|
128 | 143 | 'paper': torch.randn(100, 16),
|
@@ -213,6 +228,13 @@ def test_to_hetero():
|
213 | 228 | assert out['paper'].size() == (4, 16)
|
214 | 229 | assert out['author'].size() == (8, 16)
|
215 | 230 |
|
| 231 | + model = Net10() |
| 232 | + model = to_hetero(model, metadata, debug=False) |
| 233 | + out = model(x_dict, edge_index_dict) |
| 234 | + assert isinstance(out, dict) and len(out) == 2 |
| 235 | + assert out['paper'].size() == (100, 32) |
| 236 | + assert out['author'].size() == (100, 32) |
| 237 | + |
216 | 238 |
|
217 | 239 | class GCN(torch.nn.Module):
|
218 | 240 | def __init__(self):
|
|
0 commit comments