Skip to content

Commit 6fd6f5b

Browse files
Padarnrusty1s
andauthored
Fix dimension in edge filter selection (#4629)
* fix dimension in edge filter * update changelog * Update CHANGELOG.md * update Co-authored-by: Matthias Fey <[email protected]>
1 parent bbff5b7 commit 6fd6f5b

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818
- Removed unnecessary colons and fixed typos in the documentation ([#4616](https://github.com/pyg-team/pytorch_geometric/pull/4616))
1919
- The `bias` argument in `TAGConv` is now actually applied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597))
2020
- Fixed subclass behaviour of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586))
21+
- Fixed filtering of attributes for loaders in case `__cat_dim__ != 0` ([#4629](https://github.com/pyg-team/pytorch_geometric/pull/4629))
2122
### Removed

torch_geometric/loader/utils.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def index_select(value: Tensor, index: Tensor, dim: int = 0) -> Tensor:
2121
numel = math.prod(size)
2222
storage = value.storage()._new_shared(numel)
2323
out = value.new(storage).view(size)
24-
return torch.index_select(value, 0, index, out=out)
24+
return torch.index_select(value, dim, index, out=out)
2525

2626

2727
def edge_type_to_str(edge_type: Union[EdgeType, str]) -> str:
@@ -101,7 +101,8 @@ def filter_node_store_(store: NodeStorage, out_store: NodeStorage,
101101

102102
elif store.is_node_attr(key):
103103
index = index.to(value.device)
104-
out_store[key] = index_select(value, index, dim=0)
104+
dim = store._parent().__cat_dim__(key, value, store)
105+
out_store[key] = index_select(value, index, dim=dim)
105106

106107
return store
107108

@@ -132,13 +133,14 @@ def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor,
132133
is_sorted=False, trust_data=True)
133134

134135
elif store.is_edge_attr(key):
136+
dim = store._parent().__cat_dim__(key, value, store)
135137
if perm is None:
136138
index = index.to(value.device)
137-
out_store[key] = index_select(value, index, dim=0)
139+
out_store[key] = index_select(value, index, dim=dim)
138140
else:
139141
perm = perm.to(value.device)
140142
index = index.to(value.device)
141-
out_store[key] = index_select(value, perm[index], dim=0)
143+
out_store[key] = index_select(value, perm[index], dim=dim)
142144

143145
return store
144146

0 commit comments

Comments
 (0)