-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RandomLinkSplit
expects reverse edge types for heterogeneous undirected graphs
#4674
Comments
It looks like import torch
from torch_geometric.data import HeteroData, Data
from torch_geometric.transforms import RandomLinkSplit
edge_index = torch.tensor([[0, 1, 2, 1, 2, 3], [1, 2, 3, 0, 1, 2]])
x = torch.rand(4, 2)
hetero_data = HeteroData()
hetero_data['n'].x = x
hetero_data['n', 'link', 'n'].edge_index = edge_index
RandomLinkSplit(
is_undirected=True,
edge_types=hetero_data.edge_types,
)(hetero_data) |
Yup, my bad for wrongly creating the object. |
Which PyG version are you operating on? This does not throw an error for me. |
I was on 2.0.4 stable, indeed upgrading to master gets rid of the assert. |
Thank you for the help, but I'm still confused by the underlying reason why if the user sets train, val, test = RandomLinkSplit(
is_undirected=True,
edge_types=hetero_data.edge_types,
rev_edge_types=hetero_data.edge_types
)(hetero_data)
print(train['n', 'link', 'n'].is_undirected())
>>> False train, val, test = RandomLinkSplit(
is_undirected=True,
edge_types=hetero_data.edge_types,
rev_edge_types=None
)(hetero_data)
print(train['n', 'link', 'n'].is_undirected())
>>> True |
Sorry for the late reply. This is indeed a bug. I fixed it in #4757. |
🐛 Describe the bug
When splitting heterogeneous undirected graphs
RandomLinkSplit
expectsrev_edge_types
to be passed, which in this scenario is ambiguous as no reverse edges technically exists. IfNone
is passed then anassert isinstance(rev_edge_types, list)
is triggered. If, on the other hand, the very sameedge_types
are passed, then a differentAttributeError
is raised.To reproduce:
Traceback
Explicitly passing
rev_edge_types
:Environment
conda
,pip
, source): conda for torch, pip for pyg and dependencies.The text was updated successfully, but these errors were encountered: