diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ffe23f23878..5f80882306c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755)) - Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926)) - Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723)) -- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872)) +- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872)) - Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800)) - Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487)) - Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730)) diff --git a/test/nn/aggr/test_scaler.py b/test/nn/aggr/test_scaler.py new file mode 100644 index 000000000000..0d205d8b6dbb --- /dev/null +++ b/test/nn/aggr/test_scaler.py @@ -0,0 +1,24 @@ +import pytest +import torch + +from torch_geometric.nn import DegreeScalerAggregation + + +def test_degree_scaler_aggregation(): + x = torch.randn(6, 16) + index = torch.tensor([0, 0, 1, 1, 1, 2]) + ptr = torch.tensor([0, 2, 5, 6]) + deg = torch.tensor([0, 3, 0, 1, 1, 0]) + + aggrs = ['mean', 'sum', 'max'] + scalers = [ + 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear' + ] + aggr = DegreeScalerAggregation(aggrs, scalers, deg) + assert str(aggr) == 'DegreeScalerAggregation()' + + out = aggr(x, index) + assert out.size() == (3, 240) + + with pytest.raises(NotImplementedError): + aggr(x, ptr=ptr) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index 2204e0b8d9db..dbc6c42ec449 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -13,6 +13,7 @@ ) from .lstm import LSTMAggregation from .set2set import Set2Set +from .scaler import DegreeScalerAggregation __all__ = classes = [ 'Aggregation', @@ -28,4 +29,5 @@ 'PowerMeanAggregation', 'LSTMAggregation', 'Set2Set', + 'DegreeScalerAggregation', ] diff --git a/torch_geometric/nn/aggr/scaler.py b/torch_geometric/nn/aggr/scaler.py new file mode 100644 index 000000000000..ce703a12e454 --- /dev/null +++ b/torch_geometric/nn/aggr/scaler.py @@ -0,0 +1,79 @@ +from typing import Any, Dict, List, Optional, Union + +import torch +from torch import Tensor + +from torch_geometric.nn.aggr import Aggregation, MultiAggregation +from torch_geometric.utils import degree + + +class DegreeScalerAggregation(Aggregation): + """ + Class that combines together one or more aggregators and then transforms + the result with one or more scalers. The scalers are normalised by the + in-degree of the training set and so must be provided at construction. + + Args: + aggrs (list of string or list or Aggregation): The list of + aggregations given as :class:`~torch_geometric.nn.aggr.Aggregation` + (or any string that automatically resolves to it). + scalers (list of str): Set of scaling function identifiers, namely + :obj:`"identity"`, :obj:`"amplification"`, + :obj:`"attenuation"`, :obj:`"linear"` and + :obj:`"inverse_linear"`. + deg (Tensor): Histogram of in-degrees of nodes in the training set, + used by scalers to normalize. + aggr_kwargs (List[Dict[str, Any]], optional): Arguments passed to the + respective aggregation functions in case it gets automatically + resolved. (default: :obj:`None`) + """ + def __init__(self, aggrs: List[Union[Aggregation, str]], + scalers: List[str], deg: Tensor, + aggrs_kwargs: Optional[List[Dict[str, Any]]] = None): + + super().__init__() + + # TODO: Support non-lists + if not isinstance(aggrs, list): + raise RuntimeError("`aggrs` must be a list of aggregations ") + + self.aggr = MultiAggregation(aggrs, aggrs_kwargs) + self.scalers = scalers + + deg = deg.to(torch.float) + num_nodes = int(deg.sum()) + bin_degrees = torch.arange(deg.numel()) + self.avg_deg: Dict[str, float] = { + 'lin': float((bin_degrees * deg).sum()) / num_nodes, + 'log': float(((bin_degrees + 1).log() * deg).sum()) / num_nodes, + 'exp': float((bin_degrees.exp() * deg).sum()) / num_nodes, + } + + def forward(self, x: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + + self.assert_index_present(index) + + out = self.aggr(x, index, ptr, dim_size, dim) + deg = degree(index, dtype=out.dtype).clamp_(1) + + size = [1] * len(out.size()) + size[dim] = -1 + deg = deg.view(*size) + outs = [] + for scaler in self.scalers: + if scaler == 'identity': + pass + elif scaler == 'amplification': + out = out * (torch.log(deg + 1) / self.avg_deg['log']) + elif scaler == 'attenuation': + out = out * (self.avg_deg['log'] / torch.log(deg + 1)) + elif scaler == 'linear': + out = out * (deg / self.avg_deg['lin']) + elif scaler == 'inverse_linear': + out = out * (self.avg_deg['lin'] / deg) + else: + raise ValueError(f'Unknown scaler "{scaler}".') + outs.append(out) + return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0] diff --git a/torch_geometric/nn/conv/pna_conv.py b/torch_geometric/nn/conv/pna_conv.py index 3eee2f43adf8..0371f16e0a60 100644 --- a/torch_geometric/nn/conv/pna_conv.py +++ b/torch_geometric/nn/conv/pna_conv.py @@ -3,12 +3,11 @@ import torch from torch import Tensor from torch.nn import ModuleList, ReLU, Sequential -from torch_scatter import scatter +from torch_geometric.nn.aggr import DegreeScalerAggregation from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptTensor -from torch_geometric.utils import degree from ..inits import reset @@ -86,7 +85,9 @@ def __init__(self, in_channels: int, out_channels: int, pre_layers: int = 1, post_layers: int = 1, divide_input: bool = False, **kwargs): - kwargs.setdefault('aggr', None) + aggr = DegreeScalerAggregation(aggregators, scalers, deg) + kwargs.setdefault('aggr', aggr) + super().__init__(node_dim=0, **kwargs) if divide_input: @@ -95,8 +96,6 @@ def __init__(self, in_channels: int, out_channels: int, self.in_channels = in_channels self.out_channels = out_channels - self.aggregators = aggregators - self.scalers = scalers self.edge_dim = edge_dim self.towers = towers self.divide_input = divide_input @@ -178,51 +177,6 @@ def message(self, x_i: Tensor, x_j: Tensor, hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] return torch.stack(hs, dim=1) - def aggregate(self, inputs: Tensor, index: Tensor, - dim_size: Optional[int] = None) -> Tensor: - - outs = [] - for aggregator in self.aggregators: - if aggregator == 'sum': - out = scatter(inputs, index, 0, None, dim_size, reduce='sum') - elif aggregator == 'mean': - out = scatter(inputs, index, 0, None, dim_size, reduce='mean') - elif aggregator == 'min': - out = scatter(inputs, index, 0, None, dim_size, reduce='min') - elif aggregator == 'max': - out = scatter(inputs, index, 0, None, dim_size, reduce='max') - elif aggregator == 'var' or aggregator == 'std': - mean = scatter(inputs, index, 0, None, dim_size, reduce='mean') - mean_squares = scatter(inputs * inputs, index, 0, None, - dim_size, reduce='mean') - out = mean_squares - mean * mean - if aggregator == 'std': - out = torch.sqrt(torch.relu(out) + 1e-5) - else: - raise ValueError(f'Unknown aggregator "{aggregator}".') - outs.append(out) - out = torch.cat(outs, dim=-1) - - deg = degree(index, dim_size, dtype=inputs.dtype) - deg = deg.clamp_(1).view(-1, 1, 1) - - outs = [] - for scaler in self.scalers: - if scaler == 'identity': - pass - elif scaler == 'amplification': - out = out * (torch.log(deg + 1) / self.avg_deg['log']) - elif scaler == 'attenuation': - out = out * (self.avg_deg['log'] / torch.log(deg + 1)) - elif scaler == 'linear': - out = out * (deg / self.avg_deg['lin']) - elif scaler == 'inverse_linear': - out = out * (self.avg_deg['lin'] / deg) - else: - raise ValueError(f'Unknown scaler "{scaler}".') - outs.append(out) - return torch.cat(outs, dim=-1) - def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, towers={self.towers}, '