Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PNAConv to rely on new Aggregation #4864

Merged
merged 9 commits into from
Jul 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
24 changes: 24 additions & 0 deletions test/nn/aggr/test_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
import torch

from torch_geometric.nn import DegreeScalerAggregation


def test_degree_scaler_aggregation():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])
ptr = torch.tensor([0, 2, 5, 6])
deg = torch.tensor([0, 3, 0, 1, 1, 0])

aggrs = ['mean', 'sum', 'max']
scalers = [
'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear'
]
aggr = DegreeScalerAggregation(aggrs, scalers, deg)
assert str(aggr) == 'DegreeScalerAggregation()'

out = aggr(x, index)
assert out.size() == (3, 240)

with pytest.raises(NotImplementedError):
aggr(x, ptr=ptr)
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from .lstm import LSTMAggregation
from .set2set import Set2Set
from .scaler import DegreeScalerAggregation

__all__ = classes = [
'Aggregation',
Expand All @@ -28,4 +29,5 @@
'PowerMeanAggregation',
'LSTMAggregation',
'Set2Set',
'DegreeScalerAggregation',
]
79 changes: 79 additions & 0 deletions torch_geometric/nn/aggr/scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Any, Dict, List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.aggr import Aggregation, MultiAggregation
from torch_geometric.utils import degree


class DegreeScalerAggregation(Aggregation):
"""
Class that combines together one or more aggregators and then transforms
the result with one or more scalers. The scalers are normalised by the
in-degree of the training set and so must be provided at construction.

Args:
aggrs (list of string or list or Aggregation): The list of
aggregations given as :class:`~torch_geometric.nn.aggr.Aggregation`
(or any string that automatically resolves to it).
scalers (list of str): Set of scaling function identifiers, namely
:obj:`"identity"`, :obj:`"amplification"`,
:obj:`"attenuation"`, :obj:`"linear"` and
:obj:`"inverse_linear"`.
deg (Tensor): Histogram of in-degrees of nodes in the training set,
used by scalers to normalize.
aggr_kwargs (List[Dict[str, Any]], optional): Arguments passed to the
respective aggregation functions in case it gets automatically
resolved. (default: :obj:`None`)
"""
def __init__(self, aggrs: List[Union[Aggregation, str]],
scalers: List[str], deg: Tensor,
aggrs_kwargs: Optional[List[Dict[str, Any]]] = None):

super().__init__()

# TODO: Support non-lists
if not isinstance(aggrs, list):
raise RuntimeError("`aggrs` must be a list of aggregations ")

self.aggr = MultiAggregation(aggrs, aggrs_kwargs)
self.scalers = scalers

deg = deg.to(torch.float)
num_nodes = int(deg.sum())
bin_degrees = torch.arange(deg.numel())
self.avg_deg: Dict[str, float] = {
'lin': float((bin_degrees * deg).sum()) / num_nodes,
'log': float(((bin_degrees + 1).log() * deg).sum()) / num_nodes,
'exp': float((bin_degrees.exp() * deg).sum()) / num_nodes,
}

def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

self.assert_index_present(index)

out = self.aggr(x, index, ptr, dim_size, dim)
deg = degree(index, dtype=out.dtype).clamp_(1)

size = [1] * len(out.size())
size[dim] = -1
deg = deg.view(*size)
outs = []
for scaler in self.scalers:
if scaler == 'identity':
pass
elif scaler == 'amplification':
out = out * (torch.log(deg + 1) / self.avg_deg['log'])
elif scaler == 'attenuation':
out = out * (self.avg_deg['log'] / torch.log(deg + 1))
elif scaler == 'linear':
out = out * (deg / self.avg_deg['lin'])
elif scaler == 'inverse_linear':
out = out * (self.avg_deg['lin'] / deg)
else:
raise ValueError(f'Unknown scaler "{scaler}".')
outs.append(out)
return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0]
54 changes: 4 additions & 50 deletions torch_geometric/nn/conv/pna_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import torch
from torch import Tensor
from torch.nn import ModuleList, ReLU, Sequential
from torch_scatter import scatter

from torch_geometric.nn.aggr import DegreeScalerAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree

from ..inits import reset

Expand Down Expand Up @@ -86,7 +85,9 @@ def __init__(self, in_channels: int, out_channels: int,
pre_layers: int = 1, post_layers: int = 1,
divide_input: bool = False, **kwargs):

kwargs.setdefault('aggr', None)
aggr = DegreeScalerAggregation(aggregators, scalers, deg)
kwargs.setdefault('aggr', aggr)

super().__init__(node_dim=0, **kwargs)

if divide_input:
Expand All @@ -95,8 +96,6 @@ def __init__(self, in_channels: int, out_channels: int,

self.in_channels = in_channels
self.out_channels = out_channels
self.aggregators = aggregators
self.scalers = scalers
self.edge_dim = edge_dim
self.towers = towers
self.divide_input = divide_input
Expand Down Expand Up @@ -178,51 +177,6 @@ def message(self, x_i: Tensor, x_j: Tensor,
hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)]
return torch.stack(hs, dim=1)

def aggregate(self, inputs: Tensor, index: Tensor,
dim_size: Optional[int] = None) -> Tensor:

outs = []
for aggregator in self.aggregators:
if aggregator == 'sum':
out = scatter(inputs, index, 0, None, dim_size, reduce='sum')
elif aggregator == 'mean':
out = scatter(inputs, index, 0, None, dim_size, reduce='mean')
elif aggregator == 'min':
out = scatter(inputs, index, 0, None, dim_size, reduce='min')
elif aggregator == 'max':
out = scatter(inputs, index, 0, None, dim_size, reduce='max')
elif aggregator == 'var' or aggregator == 'std':
mean = scatter(inputs, index, 0, None, dim_size, reduce='mean')
mean_squares = scatter(inputs * inputs, index, 0, None,
dim_size, reduce='mean')
out = mean_squares - mean * mean
if aggregator == 'std':
out = torch.sqrt(torch.relu(out) + 1e-5)
else:
raise ValueError(f'Unknown aggregator "{aggregator}".')
outs.append(out)
out = torch.cat(outs, dim=-1)

deg = degree(index, dim_size, dtype=inputs.dtype)
deg = deg.clamp_(1).view(-1, 1, 1)

outs = []
for scaler in self.scalers:
if scaler == 'identity':
pass
elif scaler == 'amplification':
out = out * (torch.log(deg + 1) / self.avg_deg['log'])
elif scaler == 'attenuation':
out = out * (self.avg_deg['log'] / torch.log(deg + 1))
elif scaler == 'linear':
out = out * (deg / self.avg_deg['lin'])
elif scaler == 'inverse_linear':
out = out * (self.avg_deg['lin'] / deg)
else:
raise ValueError(f'Unknown scaler "{scaler}".')
outs.append(out)
return torch.cat(outs, dim=-1)

def __repr__(self):
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, towers={self.towers}, '
Expand Down