Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confirm that to_hetero() works with custom functions (dropout_adj) #4653

Merged
merged 5 commits into from
May 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Confirm that `to_hetero()` works with custom functions, *e.g.*, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653))
- Added the `MLP.plain_last=False` option ([4652](https://github.com/pyg-team/pytorch_geometric/pull/4652))
- Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647))
- Added `HeteroData.subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635))
Expand Down
22 changes: 22 additions & 0 deletions test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import pytest
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, ReLU, Sequential
from torch_sparse import SparseTensor

from torch_geometric.nn import BatchNorm, GCNConv, GINEConv, GlobalPooling
from torch_geometric.nn import Linear as LazyLinear
from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero
from torch_geometric.utils import dropout_adj

torch.fx.wrap('dropout_adj')


class Net1(torch.nn.Module):
Expand Down Expand Up @@ -123,6 +127,17 @@ def forward(self, x: Tensor) -> Tensor:
return self.batch_norm(x)


class Net10(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = SAGEConv(16, 32)

def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
x = F.dropout(x, p=0.5, training=self.training)
edge_index, _ = dropout_adj(edge_index, p=0.5, training=self.training)
return self.conv(x, edge_index)


def test_to_hetero():
x_dict = {
'paper': torch.randn(100, 16),
Expand Down Expand Up @@ -213,6 +228,13 @@ def test_to_hetero():
assert out['paper'].size() == (4, 16)
assert out['author'].size() == (8, 16)

model = Net10()
model = to_hetero(model, metadata, debug=False)
out = model(x_dict, edge_index_dict)
assert isinstance(out, dict) and len(out) == 2
assert out['paper'].size() == (100, 32)
assert out['author'].size() == (100, 32)


class GCN(torch.nn.Module):
def __init__(self):
Expand Down