Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to_hetero fails at dropout due to torch.rand not accepting proxies. #4630

Closed
szemyd opened this issue May 12, 2022 · 2 comments · Fixed by #4653
Closed

to_hetero fails at dropout due to torch.rand not accepting proxies. #4630

szemyd opened this issue May 12, 2022 · 2 comments · Fixed by #4653

Comments

@szemyd
Copy link

szemyd commented May 12, 2022

🐛 Describe the bug

I'm trying to pass these layers through to_hertero()

class GNNEncoder(t.nn.Module):
    def __init__(
        self,
        layers: ModuleList,
        p_dropout_edges: Optional[float],
        p_dropout_features: Optional[float],
    ):
        super().__init__()
        self.layers = layers
        self.p_dropout_edges = p_dropout_edges
        self.p_dropout_features = p_dropout_features

    def forward(self, x, edge_index):
        for index, layer in enumerate(self.layers):
            if index == len(self.layers) - 1:
                x = layer(x, edge_index)
            else:
                if self.p_dropout_edges is not None:
                    edge_index, _ = dropout_adj(
                        edge_index,
                        p=self.p_dropout_edges,
                        force_undirected=True,
                        training=self.training,
                    )
                if self.p_dropout_features is not None:
                    x = F.dropout(x, p=self.p_dropout_features, training=self.training)

                x = layer(x, edge_index).relu()

        return x

However the tracing fails at:

dropout_adj(
                        edge_index,
                        p=self.p_dropout_edges,
                        force_undirected=True,
                        training=self.training,
                    )

I'm getting a type error:
rand() received an invalid combination of arguments - got (Proxy, device=Attribute) ...

It fails at mask = torch.rand(row.size(0), device=edge_index.device) >= p

Environment

@rusty1s
Copy link
Member

rusty1s commented May 13, 2022

I am afraid dropout_adj does not work yet in combination with to_hetero. Will try to look into it.

@rusty1s
Copy link
Member

rusty1s commented May 15, 2022

I finally had time to look into this. It looks like to_hetero() already works as expected here. The trick is that one needs to register any functions before-hand that shall not be traced by torch.fx:

torch.fx.wrap('dropout_adj')

See #4653 for an example.

@rusty1s rusty1s closed this as completed May 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants