Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data.validate() and HeteroData.validate() #4885

Merged
merged 5 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `LinkeNeighborLoader` support to lightning datamodule ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
- Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
Expand Down
1 change: 1 addition & 0 deletions test/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_data():
x = torch.tensor([[1, 3, 5], [2, 4, 6]], dtype=torch.float).t()
edge_index = torch.tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]])
data = Data(x=x, edge_index=edge_index).to(torch.device('cpu'))
data.validate(raise_on_error=True)

N = data.num_nodes
assert N == 3
Expand Down
1 change: 1 addition & 0 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_init_hetero_data():
data['paper', 'paper'].edge_index = edge_index_paper_paper
data['paper', 'author'].edge_index = edge_index_paper_author
data['author', 'paper'].edge_index = edge_index_author_paper
data.validate(raise_on_error=True)

assert len(data) == 2
assert data.node_types == ['v1', 'paper', 'author']
Expand Down
36 changes: 36 additions & 0 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import warnings
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import (
Expand Down Expand Up @@ -514,6 +515,34 @@ def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
else:
return 0

def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the data."""
cls_name = self.__class__.__name__
status = True

num_nodes = self.num_nodes
if num_nodes is None:
status = False
warn_or_raise(f"'num_nodes' is undefined in '{cls_name}'",
raise_on_error)

if 'edge_index' in self and self.edge_index.numel() > 0:
if self.edge_index.min() < 0:
status = False
warn_or_raise(
f"'edge_index' contains negative indices in "
f"'{cls_name}' (found {int(self.edge_index.min())})",
raise_on_error)

if num_nodes is not None and self.edge_index.max() >= num_nodes:
status = False
warn_or_raise(
f"'edge_index' contains larger indices than the number "
f"of nodes ({num_nodes}) in '{cls_name}' "
f"(found {int(self.edge_index.max())})", raise_on_error)

return status

def debug(self):
pass # TODO

Expand Down Expand Up @@ -879,3 +908,10 @@ def size_repr(key: Any, value: Any, indent: int = 0) -> str:
return f'{pad}\033[1m{key}\033[0m={out}'
else:
return f'{pad}{key}={out}'


def warn_or_raise(msg: str, raise_on_error: bool = True):
if raise_on_error:
raise ValueError(msg)
else:
warnings.warn(msg)
55 changes: 54 additions & 1 deletion torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.data.data import BaseData, Data, size_repr
from torch_geometric.data.data import BaseData, Data, size_repr, warn_or_raise
from torch_geometric.data.feature_store import FeatureStore, TensorAttr
from torch_geometric.data.graph_store import (
EDGE_LAYOUT_TO_ATTR_NAME,
Expand Down Expand Up @@ -325,6 +325,59 @@ def is_undirected(self) -> bool:
edge_index, _, _ = to_homogeneous_edge_index(self)
return is_undirected(edge_index, num_nodes=self.num_nodes)

def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the data."""
cls_name = self.__class__.__name__
status = True

for edge_type, store in self._edge_store_dict.items():
src, _, dst = edge_type

num_src_nodes = self[src].num_nodes
num_dst_nodes = self[dst].num_nodes
if num_src_nodes is None:
status = False
warn_or_raise(
f"'num_nodes' is undefined in node type '{src}' of "
f"'{cls_name}'", raise_on_error)

if num_dst_nodes is None:
status = False
warn_or_raise(
f"'num_nodes' is undefined in node type '{dst}' of "
f"'{cls_name}'", raise_on_error)

if 'edge_index' in store and store.edge_index.numel() > 0:
if store.edge_index.min() < 0:
status = False
warn_or_raise(
f"'edge_index' of edge type {edge_type} contains "
f"negative indices in '{cls_name}' "
f"(found {int(store.edge_index.min())})",
raise_on_error)

if (num_src_nodes is not None
and store.edge_index[0].max() >= num_src_nodes):
status = False
warn_or_raise(
f"'edge_index' of edge type {edge_type} contains"
f"larger source indices than the number of nodes"
f"({num_src_nodes}) of this node type in '{cls_name}' "
f"(found {int(store.edge_index[0].max())})",
raise_on_error)

if (num_dst_nodes is not None
and store.edge_index[1].max() >= num_dst_nodes):
status = False
warn_or_raise(
f"'edge_index' of edge type {edge_type} contains"
f"larger destination indices than the number of nodes"
f"({num_dst_nodes}) of this node type in '{cls_name}' "
f"(found {int(store.edge_index[1].max())})",
raise_on_error)

return status

def debug(self):
pass # TODO

Expand Down