From 720b6421697e72d685b88a45e951c4dc9c2da803 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 5 May 2022 13:27:38 +0000 Subject: [PATCH 1/2] add bias to TAGConv --- torch_geometric/nn/conv/tag_conv.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/conv/tag_conv.py b/torch_geometric/nn/conv/tag_conv.py index 0bc24f1b9ec1..c0d18e6c4d44 100644 --- a/torch_geometric/nn/conv/tag_conv.py +++ b/torch_geometric/nn/conv/tag_conv.py @@ -5,6 +5,7 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn.dense.linear import Linear +from torch_geometric.nn.inits import zeros from torch_geometric.typing import Adj, OptTensor @@ -51,14 +52,21 @@ def __init__(self, in_channels: int, out_channels: int, K: int = 3, self.K = K self.normalize = normalize - self.lins = torch.nn.ModuleList( - [Linear(in_channels, out_channels) for _ in range(K + 1)]) + self.lins = torch.nn.ModuleList([ + Linear(in_channels, out_channels, bias=False) for _ in range(K + 1) + ]) + + if bias: + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): for lin in self.lins: lin.reset_parameters() + zeros(self.bias) def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: @@ -80,6 +88,10 @@ def forward(self, x: Tensor, edge_index: Adj, x = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) out += lin.forward(x) + + if self.bias is not None: + out += self.bias + return out def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: From 996294694caa68d98a78d2b6d4770c2b884273bb Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 5 May 2022 13:28:48 +0000 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ca1339590768..f87161a30313 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,5 +9,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- The `bias` argument in `TAGConv` is now actually apllied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597)) - Fixed subclass behaviour of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586)) ### Removed