Skip to content

Commit 926b5dc

Browse files
authored
Add bias to TAGConv (#4597)
* add bias to TAGConv * changelog
1 parent 5ed4b38 commit 926b5dc

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
1010
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
1111
### Changed
12+
- The `bias` argument in `TAGConv` is now actually apllied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597))
1213
- Fixed subclass behaviour of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586))
1314
### Removed

torch_geometric/nn/conv/tag_conv.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch_geometric.nn.conv import MessagePassing
66
from torch_geometric.nn.conv.gcn_conv import gcn_norm
77
from torch_geometric.nn.dense.linear import Linear
8+
from torch_geometric.nn.inits import zeros
89
from torch_geometric.typing import Adj, OptTensor
910

1011

@@ -51,14 +52,21 @@ def __init__(self, in_channels: int, out_channels: int, K: int = 3,
5152
self.K = K
5253
self.normalize = normalize
5354

54-
self.lins = torch.nn.ModuleList(
55-
[Linear(in_channels, out_channels) for _ in range(K + 1)])
55+
self.lins = torch.nn.ModuleList([
56+
Linear(in_channels, out_channels, bias=False) for _ in range(K + 1)
57+
])
58+
59+
if bias:
60+
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
61+
else:
62+
self.register_parameter('bias', None)
5663

5764
self.reset_parameters()
5865

5966
def reset_parameters(self):
6067
for lin in self.lins:
6168
lin.reset_parameters()
69+
zeros(self.bias)
6270

6371
def forward(self, x: Tensor, edge_index: Adj,
6472
edge_weight: OptTensor = None) -> Tensor:
@@ -80,6 +88,10 @@ def forward(self, x: Tensor, edge_index: Adj,
8088
x = self.propagate(edge_index, x=x, edge_weight=edge_weight,
8189
size=None)
8290
out += lin.forward(x)
91+
92+
if self.bias is not None:
93+
out += self.bias
94+
8395
return out
8496

8597
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:

0 commit comments

Comments
 (0)