Skip to content

Commit 2f2f935

Browse files
authored
Merge branch 'master' into master
2 parents e495331 + 3d6eb74 commit 2f2f935

15 files changed

+696
-162
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66
## [2.0.5] - 2022-MM-DD
77
### Added
88
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
9+
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
10+
- Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
911
- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))
1012
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877))
1113
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
1214
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
13-
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882))
15+
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883))
1416
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
1517
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
1618
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))

test/data/test_data.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def test_data():
1515
x = torch.tensor([[1, 3, 5], [2, 4, 6]], dtype=torch.float).t()
1616
edge_index = torch.tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]])
1717
data = Data(x=x, edge_index=edge_index).to(torch.device('cpu'))
18+
data.validate(raise_on_error=True)
1819

1920
N = data.num_nodes
2021
assert N == 3
@@ -299,7 +300,7 @@ def assert_equal_tensor_tuple(expected, actual):
299300
csc = adj.csc()[-2::-1] # (row, colptr)
300301

301302
# Put:
302-
data.put_edge_index(coo, layout='coo')
303+
data.put_edge_index(coo, layout='coo', size=(3, 3))
303304
data.put_edge_index(csr, layout='csr')
304305
data.put_edge_index(csc, layout='csc')
305306

test/data/test_graph_store.py

+56
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
from torch_geometric.data.graph_store import EdgeLayout
66
from torch_geometric.testing.graph_store import MyGraphStore
7+
from torch_geometric.utils.sort_edge_index import sort_edge_index
8+
9+
10+
def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
11+
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
12+
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
13+
return torch.stack([row, col], dim=0)
714

815

916
def test_graph_store():
@@ -38,3 +45,52 @@ def assert_equal_tensor_tuple(expected, actual):
3845

3946
with pytest.raises(KeyError):
4047
_ = graph_store['edge_2', 'coo']
48+
49+
50+
def test_graph_store_conversion():
51+
graph_store = MyGraphStore()
52+
edge_index = get_edge_index(100, 100, 300)
53+
edge_index = sort_edge_index(edge_index, sort_by_row=False)
54+
adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(100, 100))
55+
56+
coo = (edge_index[0], edge_index[1])
57+
csr = adj.csr()[:2]
58+
csc = adj.csc()[-2::-1]
59+
60+
# Put all edge indices:
61+
graph_store.put_edge_index(edge_index=coo, edge_type=('v', '1', 'v'),
62+
layout='coo', size=(100, 100), is_sorted=True)
63+
64+
graph_store.put_edge_index(edge_index=csr, edge_type=('v', '2', 'v'),
65+
layout='csr', size=(100, 100))
66+
67+
graph_store.put_edge_index(edge_index=csc, edge_type=('v', '3', 'v'),
68+
layout='csc', size=(100, 100))
69+
70+
def assert_edge_index_equal(expected: torch.Tensor, actual: torch.Tensor):
71+
assert torch.equal(sort_edge_index(expected), sort_edge_index(actual))
72+
73+
# Convert to COO:
74+
row_dict, col_dict, perm_dict = graph_store.coo()
75+
assert len(row_dict) == len(col_dict) == len(perm_dict) == 3
76+
for key in row_dict.keys():
77+
actual = torch.stack((row_dict[key], col_dict[key]))
78+
assert_edge_index_equal(actual, edge_index)
79+
assert perm_dict[key] is None
80+
81+
# Convert to CSR:
82+
rowptr_dict, col_dict, perm_dict = graph_store.csr()
83+
assert len(rowptr_dict) == len(col_dict) == len(perm_dict) == 3
84+
for key in rowptr_dict:
85+
assert torch.equal(rowptr_dict[key], csr[0])
86+
assert torch.equal(col_dict[key], csr[1])
87+
if key == ('v', '1', 'v'):
88+
assert perm_dict[key] is not None
89+
90+
# Convert to CSC:
91+
row_dict, colptr_dict, perm_dict = graph_store.csc()
92+
assert len(row_dict) == len(colptr_dict) == len(perm_dict) == 3
93+
for key in row_dict:
94+
assert torch.equal(row_dict[key], csc[0])
95+
assert torch.equal(colptr_dict[key], csc[1])
96+
assert perm_dict[key] is None

test/data/test_hetero_data.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def test_init_hetero_data():
3939
data['paper', 'paper'].edge_index = edge_index_paper_paper
4040
data['paper', 'author'].edge_index = edge_index_paper_author
4141
data['author', 'paper'].edge_index = edge_index_author_paper
42+
data.validate(raise_on_error=True)
4243

4344
assert len(data) == 2
4445
assert data.node_types == ['v1', 'paper', 'author']
@@ -464,9 +465,12 @@ def assert_equal_tensor_tuple(expected, actual):
464465
csc = adj.csc()[-2::-1] # (row, colptr)
465466

466467
# Put:
467-
data.put_edge_index(coo, layout='coo', edge_type=('a', 'to', 'b'))
468-
data.put_edge_index(csr, layout='csr', edge_type=('a', 'to', 'c'))
469-
data.put_edge_index(csc, layout='csc', edge_type=('b', 'to', 'c'))
468+
data.put_edge_index(coo, layout='coo', edge_type=('a', 'to', 'b'),
469+
size=(3, 3))
470+
data.put_edge_index(csr, layout='csr', edge_type=('a', 'to', 'c'),
471+
size=(3, 3))
472+
data.put_edge_index(csc, layout='csc', edge_type=('b', 'to', 'c'),
473+
size=(3, 3))
470474

471475
# Get:
472476
assert_equal_tensor_tuple(

test/data/test_lightning_datamodule.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import torch
55
import torch.nn.functional as F
66

7-
from torch_geometric.data import LightningDataset, LightningNodeData
7+
from torch_geometric.data import (
8+
LightningDataset,
9+
LightningLinkData,
10+
LightningNodeData,
11+
)
812
from torch_geometric.nn import global_mean_pool
913
from torch_geometric.testing import onlyFullTest, withCUDA, withPackage
1014

@@ -264,3 +268,22 @@ def test_lightning_hetero_node_data(get_dataset):
264268
offset += 5 * devices * math.ceil(400 / (devices * 32)) # `train`
265269
offset += 5 * devices * math.ceil(400 / (devices * 32)) # `val`
266270
assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data.
271+
272+
273+
@withCUDA
274+
@onlyFullTest
275+
@withPackage('pytorch_lightning')
276+
def test_lightning_hetero_link_data(get_dataset):
277+
# TODO: Add more datasets.
278+
dataset = get_dataset(name='DBLP')
279+
data = dataset[0]
280+
datamodule = LightningLinkData(data, loader='link_neighbor',
281+
num_neighbors=[5], batch_size=32,
282+
num_workers=3)
283+
input_edges = (('author', 'dummy', 'paper'), data['author',
284+
'paper']['edge_index'])
285+
loader = datamodule.dataloader(input_edges=input_edges, input_labels=None,
286+
shuffle=True)
287+
batch = next(iter(loader))
288+
assert (batch['author', 'dummy',
289+
'paper']['edge_label_index'].shape[1] == 32)

test/loader/test_neighbor_loader.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -297,29 +297,30 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
297297
feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)
298298

299299
# Set up edge indices:
300+
301+
# COO:
300302
edge_index = get_edge_index(100, 100, 500)
301303
data['paper', 'to', 'paper'].edge_index = edge_index
302-
graph_store.put_edge_index(
303-
edge_index=SparseTensor.from_edge_index(edge_index).csr()[:2],
304-
edge_type=('paper', 'to', 'paper'),
305-
layout='csr',
306-
)
304+
coo = (edge_index[0], edge_index[1])
305+
graph_store.put_edge_index(edge_index=coo,
306+
edge_type=('paper', 'to', 'paper'),
307+
layout='coo', size=(100, 100))
307308

309+
# CSR:
308310
edge_index = get_edge_index(100, 200, 1000)
309311
data['paper', 'to', 'author'].edge_index = edge_index
310-
graph_store.put_edge_index(
311-
edge_index=SparseTensor.from_edge_index(edge_index).csr()[:2],
312-
edge_type=('paper', 'to', 'author'),
313-
layout='csr',
314-
)
312+
csr = SparseTensor.from_edge_index(edge_index).csr()[:2]
313+
graph_store.put_edge_index(edge_index=csr,
314+
edge_type=('paper', 'to', 'author'),
315+
layout='csr', size=(100, 200))
315316

317+
# CSC:
316318
edge_index = get_edge_index(200, 100, 1000)
317319
data['author', 'to', 'paper'].edge_index = edge_index
318-
graph_store.put_edge_index(
319-
edge_index=SparseTensor.from_edge_index(edge_index).csr()[:2],
320-
edge_type=('author', 'to', 'paper'),
321-
layout='csr',
322-
)
320+
csc = SparseTensor(row=edge_index[1], col=edge_index[0]).csr()[-2::-1]
321+
graph_store.put_edge_index(edge_index=csc,
322+
edge_type=('author', 'to', 'paper'),
323+
layout='csc', size=(200, 100))
323324

324325
# Construct neighbor loaders:
325326
loader1 = NeighborLoader(data, batch_size=20,

torch_geometric/data/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from .batch import Batch
55
from .dataset import Dataset
66
from .in_memory_dataset import InMemoryDataset
7-
from .lightning_datamodule import LightningDataset, LightningNodeData
7+
from .lightning_datamodule import (
8+
LightningDataset,
9+
LightningLinkData,
10+
LightningNodeData,
11+
)
812
from .makedirs import makedirs
913
from .download import download_url
1014
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
@@ -18,6 +22,7 @@
1822
'InMemoryDataset',
1923
'LightningDataset',
2024
'LightningNodeData',
25+
'LightningLinkData',
2126
'makedirs',
2227
'download_url',
2328
'extract_tar',

0 commit comments

Comments
 (0)