Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RandomLinkSplit expects reverse edge types for heterogeneous undirected graphs #4674

Closed
saiden89 opened this issue May 18, 2022 · 6 comments · Fixed by #4757
Closed

RandomLinkSplit expects reverse edge types for heterogeneous undirected graphs #4674

saiden89 opened this issue May 18, 2022 · 6 comments · Fixed by #4757
Labels

Comments

@saiden89
Copy link
Contributor

🐛 Describe the bug

When splitting heterogeneous undirected graphs RandomLinkSplit expects rev_edge_types to be passed, which in this scenario is ambiguous as no reverse edges technically exists. If None is passed then an assert isinstance(rev_edge_types, list) is triggered. If, on the other hand, the very same edge_types are passed, then a different AttributeError is raised.

To reproduce:

import torch
from torch_geometric.data import HeteroData
from torch_geometric.transforms import RandomLinkSplit


edge_index = torch.tensor([[0, 1, 2, 1, 2, 3], [1, 2, 3, 0, 1, 2]])
x = torch.rand(4, 2)

hetero_data = HeteroData()
hetero_data['n'] = x
hetero_data['n', 'link', 'n'].edge_index = edge_index
data = HeteroData(x=x, edge_index=edge_index)
assert data.is_undirected()
RandomLinkSplit(is_undirected=True, edge_types=hetero_data.edge_types)(hetero_data)

Traceback

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
issue.ipynb Cell 1in <cell line: 14>()
     12 data = HeteroData(x=x, edge_index=edge_index)
     13 assert data.is_undirected()
---> 14 RandomLinkSplit(is_undirected=True, edge_types=hetero_data.edge_types)(hetero_data)

File torch_geometric/transforms/random_link_split.py:107, in RandomLinkSplit.__init__(self, num_val, num_test, is_undirected, key, split_labels, add_negative_train_samples, neg_sampling_ratio, disjoint_train_ratio, edge_types, rev_edge_types)
    104 self.rev_edge_types = rev_edge_types
    106 if isinstance(edge_types, list):
--> 107     assert isinstance(rev_edge_types, list)
    108     assert len(edge_types) == len(rev_edge_types)

AssertionError:

Explicitly passing rev_edge_types:

RandomLinkSplit(is_undirected=True, edge_types=hetero_data.edge_types, rev_edge_types=hetero_data.edge_types)(hetero_data)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
issue.ipynb Cell 1' in <cell line: 14>()
     12 data = HeteroData(x=x, edge_index=edge_index_1)
     13 assert data.is_undirected()
---> 14 RandomLinkSplit(is_undirected=True, edge_types=hetero_data.edge_types, rev_edge_types=hetero_data.edge_types)(hetero_data)

File torch_geometric/transforms/random_link_split.py:195, in RandomLinkSplit.__call__(self, data)
    191 num_neg_test = int(num_test * self.neg_sampling_ratio)
    193 num_neg = num_neg_train + num_neg_val + num_neg_test
--> 195 size = store.size()
    196 if store._key is None or store._key[0] == store._key[-1]:
    197     size = size[0]

File torch_geometric/data/storage.py:365, in EdgeStorage.size(self, dim)
    361 if self._key is None:
    362     raise NameError("Unable to infer 'size' without explicit "
    363                     "'_key' assignment")
--> 365 size = (self._parent()[self._key[0]].num_nodes,
    366         self._parent()[self._key[-1]].num_nodes)
    368 return size if dim is None else size[dim]

AttributeError: 'Tensor' object has no attribute 'num_nodes'

Environment

  • PyG version: 2.0.4+master
  • PyTorch version: 1.12+nightly
  • OS: Linux
  • Python version: 3.10.4
  • CUDA/cuDNN version: 10.2
  • How you installed PyTorch and PyG (conda, pip, source): conda for torch, pip for pyg and dependencies.
@saiden89 saiden89 added the bug label May 18, 2022
@rusty1s
Copy link
Member

rusty1s commented May 18, 2022

It looks like hetero_data is somewhat wrongly created. This fixes the issue:

import torch
from torch_geometric.data import HeteroData, Data
from torch_geometric.transforms import RandomLinkSplit

edge_index = torch.tensor([[0, 1, 2, 1, 2, 3], [1, 2, 3, 0, 1, 2]])
x = torch.rand(4, 2)

hetero_data = HeteroData()
hetero_data['n'].x = x
hetero_data['n', 'link', 'n'].edge_index = edge_index

RandomLinkSplit(
    is_undirected=True,
    edge_types=hetero_data.edge_types,
)(hetero_data)

@saiden89
Copy link
Contributor Author

saiden89 commented May 18, 2022

Yup, my bad for wrongly creating the object.
The proposed solution still triggers the assert, but if the user manually assigns rev_edge_types to be hetero_data.edge_types it passes. I think it is still a bit misleading to force the user to manually specify the reverse edges in an undirected context.

@rusty1s
Copy link
Member

rusty1s commented May 18, 2022

Which PyG version are you operating on? This does not throw an error for me.

@saiden89
Copy link
Contributor Author

saiden89 commented May 18, 2022

I was on 2.0.4 stable, indeed upgrading to master gets rid of the assert.

@Padarn Padarn closed this as completed May 23, 2022
@saiden89
Copy link
Contributor Author

Thank you for the help, but I'm still confused by the underlying reason why if the user sets rev_edge_types=hetero_data.edge_types the splitter returns directed objects (even though is_undirected=True is specified), whereas if rev_edge_types=None the returned splits are undirected.
Could somebody please elaborate on this behavior?

train, val, test = RandomLinkSplit(
    is_undirected=True,
    edge_types=hetero_data.edge_types,
    rev_edge_types=hetero_data.edge_types
)(hetero_data)
print(train['n', 'link', 'n'].is_undirected())
>>> False
train, val, test = RandomLinkSplit(
    is_undirected=True,
    edge_types=hetero_data.edge_types,
    rev_edge_types=None
)(hetero_data)
print(train['n', 'link', 'n'].is_undirected())
>>> True

@rusty1s
Copy link
Member

rusty1s commented Jun 2, 2022

Sorry for the late reply. This is indeed a bug. I fixed it in #4757.

@rusty1s rusty1s linked a pull request Jun 2, 2022 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants