Skip to content

Commit 4ea65bd

Browse files
authored
Fix HANConv propagation (#4753)
* fix HanConv * changelog
1 parent 934e880 commit 4ea65bd

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
3030
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
3131
### Changed
32+
- Fixed a bug in `HANConv` in which destination node features rather than source node features were propagated ([#4753](https://github.com/pyg-team/pytorch_geometric/pull/4753))
3233
- Fixed versions of `checkout` and `setup-python` in CI ([#4751](https://github.com/pyg-team/pytorch_geometric/pull/4751))
3334
- Fixed `protobuf` version ([#4719](https://github.com/pyg-team/pytorch_geometric/pull/4719))
3435
- Fixed the ranking protocol bug in the RGCN link prediction example ([#4688](https://github.com/pyg-team/pytorch_geometric/pull/4688))

torch_geometric/nn/conv/han_conv.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,13 @@ def forward(
135135
edge_type = '__'.join(edge_type)
136136
lin_src = self.lin_src[edge_type]
137137
lin_dst = self.lin_dst[edge_type]
138+
x_src = x_node_dict[src_type]
138139
x_dst = x_node_dict[dst_type]
139-
alpha_src = (x_node_dict[src_type] * lin_src).sum(dim=-1)
140+
alpha_src = (x_src * lin_src).sum(dim=-1)
140141
alpha_dst = (x_dst * lin_dst).sum(dim=-1)
141-
alpha = (alpha_src, alpha_dst)
142-
# propagate_type: (x_dst: Tensor, alpha: PairTensor)
143-
out = self.propagate(edge_index, x_dst=x_dst, alpha=alpha,
144-
size=None)
142+
# propagate_type: (x_dst: PairTensor, alpha: PairTensor)
143+
out = self.propagate(edge_index, x=(x_src, x_dst),
144+
alpha=(alpha_src, alpha_dst), size=None)
145145

146146
out = F.relu(out)
147147
out_dict[dst_type].append(out)
@@ -157,15 +157,15 @@ def forward(
157157

158158
return out_dict
159159

160-
def message(self, x_dst_i: Tensor, alpha_i: Tensor, alpha_j: Tensor,
160+
def message(self, x_j: Tensor, alpha_i: Tensor, alpha_j: Tensor,
161161
index: Tensor, ptr: Optional[Tensor],
162162
size_i: Optional[int]) -> Tensor:
163163

164164
alpha = alpha_j + alpha_i
165165
alpha = F.leaky_relu(alpha, self.negative_slope)
166166
alpha = softmax(alpha, index, ptr, size_i)
167167
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
168-
out = x_dst_i * alpha.view(-1, self.heads, 1)
168+
out = x_j * alpha.view(-1, self.heads, 1)
169169
return out.view(-1, self.out_channels)
170170

171171
def __repr__(self) -> str:

0 commit comments

Comments
 (0)