Skip to content

Commit b4f3cbe

Browse files
ZenoTanrusty1s
andauthored
Hotfix: #4856 follow-up, also fix edge_types (#4857)
* fix * fix test Co-authored-by: rusty1s <[email protected]>
1 parent 78bbfbd commit b4f3cbe

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66
## [2.0.5] - 2022-MM-DD
77
### Added
88
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
9-
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854))
9+
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857))
1010
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
1111
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
1212
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))

test/loader/test_neighbor_loader.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -336,12 +336,16 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
336336
for batch1, batch2 in zip(loader1, loader2):
337337
assert len(batch1) == len(batch2)
338338
assert batch1['paper'].batch_size == batch2['paper'].batch_size
339-
assert torch.allclose(batch1['paper'].x, batch2['paper'].x)
340-
assert torch.allclose(batch1['author'].x, batch2['author'].x)
341-
342-
assert torch.allclose(batch1['paper', 'to', 'paper'].edge_index,
343-
batch2['paper', 'to', 'paper'].edge_index)
344-
assert torch.allclose(batch1['paper', 'to', 'author'].edge_index,
345-
batch2['paper', 'to', 'author'].edge_index)
346-
assert torch.allclose(batch1['author', 'to', 'paper'].edge_index,
347-
batch2['author', 'to', 'paper'].edge_index)
339+
340+
# Mapped indices of neighbors may be differently sorted:
341+
assert torch.allclose(batch1['paper'].x.sort()[0],
342+
batch2['paper'].x.sort()[0])
343+
assert torch.allclose(batch1['author'].x.sort()[0],
344+
batch2['author'].x.sort()[0])
345+
346+
assert (batch1['paper', 'to', 'paper'].edge_index.size() == batch1[
347+
'paper', 'to', 'paper'].edge_index.size())
348+
assert (batch1['paper', 'to', 'author'].edge_index.size() == batch1[
349+
'paper', 'to', 'author'].edge_index.size())
350+
assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[
351+
'author', 'to', 'paper'].edge_index.size())

torch_geometric/loader/neighbor_loader.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def __init__(
110110

111111
self.node_types = list(
112112
set(node_attr.group_name for node_attr in node_attrs))
113-
self.edge_types = [edge_attr.edge_type for edge_attr in edge_attrs]
113+
self.edge_types = list(
114+
set(edge_attr.edge_type for edge_attr in edge_attrs))
114115

115116
# Set other required parameters:
116117
if isinstance(num_neighbors, (list, tuple)):

0 commit comments

Comments
 (0)