Skip to content

Commit d2f91fb

Browse files
committed
create DegreeScalerAggregation
1 parent cd18952 commit d2f91fb

File tree

3 files changed

+81
-42
lines changed

3 files changed

+81
-42
lines changed

torch_geometric/nn/aggr/__init__.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,11 @@
1313
)
1414
from .lstm import LSTMAggregation
1515
from .set2set import Set2Set
16+
from .scaler import DegreeScalerAggregation
1617

1718
__all__ = classes = [
18-
'Aggregation',
19-
'MultiAggregation',
20-
'MeanAggregation',
21-
'SumAggregation',
22-
'MaxAggregation',
23-
'MinAggregation',
24-
'MulAggregation',
25-
'VarAggregation',
26-
'StdAggregation',
27-
'SoftmaxAggregation',
28-
'PowerMeanAggregation',
29-
'LSTMAggregation',
30-
'Set2Set',
19+
'Aggregation', 'MultiAggregation', 'MeanAggregation', 'SumAggregation',
20+
'MaxAggregation', 'MinAggregation', 'MulAggregation', 'VarAggregation',
21+
'StdAggregation', 'SoftmaxAggregation', 'PowerMeanAggregation',
22+
'LSTMAggregation', 'Set2Set', 'DegreeScalerAggregation'
3123
]

torch_geometric/nn/aggr/scaler.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
3+
import torch
4+
from torch import Tensor
5+
6+
from torch_geometric.nn.aggr import Aggregation, MultiAggregation
7+
from torch_geometric.utils import degree
8+
9+
10+
class DegreeScalerAggregation(Aggregation):
11+
"""
12+
Class that combines together one or more aggregators and then transforms
13+
the result with one or more scalers. The scalers are normalised by the
14+
in-degree of the training set and so must be provided at construction.
15+
16+
Args:
17+
aggr (string or list or Aggregation, optional): The list of
18+
aggregations given as :class:`~torch_geometric.nn.aggr.Aggregation`
19+
(or any string that automatically resolves to it).
20+
scalers (list of str): Set of scaling function identifiers, namely
21+
:obj:`"identity"`, :obj:`"amplification"`,
22+
:obj:`"attenuation"`, :obj:`"linear"` and
23+
:obj:`"inverse_linear"`.
24+
deg (Tensor): Histogram of in-degrees of nodes in the training set,
25+
used by scalers to normalize.
26+
aggr_kwargs (List[Dict[str, Any]], optional): Arguments passed to the
27+
respective aggregation functions in case it gets automatically
28+
resolved. (default: :obj:`None`)
29+
"""
30+
def __init__(self, aggrs: List[Union[Aggregation, str]],
31+
scalers: List[str], deg: Tensor,
32+
aggrs_kwargs: Optional[List[Dict[str, Any]]] = None):
33+
34+
super().__init__()
35+
36+
self.agg = MultiAggregation(aggrs, aggrs_kwargs)
37+
self.scalersz = scalers
38+
39+
deg = deg.to(torch.float)
40+
num_nodes = int(deg.sum())
41+
bin_degrees = torch.arange(deg.numel())
42+
self.avg_deg: Dict[str, float] = {
43+
'lin': float((bin_degrees * deg).sum()) / num_nodes,
44+
'log': float(((bin_degrees + 1).log() * deg).sum()) / num_nodes,
45+
'exp': float((bin_degrees.exp() * deg).sum()) / num_nodes,
46+
}
47+
48+
def forward(self, x: Tensor, index: Optional[Tensor] = None,
49+
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
50+
dim: int = -2) -> Tensor:
51+
52+
out = self.agg(x, index, ptr, dim_size, dim)
53+
54+
deg = degree(index, dtype=out.dtype)
55+
deg = deg.clamp_(1).view(-1, 1, 1)
56+
57+
outs = []
58+
for scaler in self.scalers:
59+
if scaler == 'identity':
60+
pass
61+
elif scaler == 'amplification':
62+
out = out * (torch.log(deg + 1) / self.avg_deg['log'])
63+
elif scaler == 'attenuation':
64+
out = out * (self.avg_deg['log'] / torch.log(deg + 1))
65+
elif scaler == 'linear':
66+
out = out * (deg / self.avg_deg['lin'])
67+
elif scaler == 'inverse_linear':
68+
out = out * (self.avg_deg['lin'] / deg)
69+
else:
70+
raise ValueError(f'Unknown scaler "{scaler}".')
71+
outs.append(out)
72+
return torch.cat(outs, dim=-1)

torch_geometric/nn/conv/pna_conv.py

+4-29
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from torch import Tensor
55
from torch.nn import ModuleList, ReLU, Sequential
66

7+
from torch_geometric.nn.aggr import DegreeScalerAggregation
78
from torch_geometric.nn.conv import MessagePassing
89
from torch_geometric.nn.dense.linear import Linear
910
from torch_geometric.typing import Adj, OptTensor
10-
from torch_geometric.utils import degree
1111

1212
from ..inits import reset
1313

@@ -85,7 +85,9 @@ def __init__(self, in_channels: int, out_channels: int,
8585
pre_layers: int = 1, post_layers: int = 1,
8686
divide_input: bool = False, **kwargs):
8787

88-
kwargs.setdefault('aggr', aggregators)
88+
aggr = DegreeScalerAggregation(aggregators, scalers, deg)
89+
kwargs.setdefault('aggr', aggr)
90+
8991
super().__init__(node_dim=0, **kwargs)
9092

9193
if divide_input:
@@ -94,7 +96,6 @@ def __init__(self, in_channels: int, out_channels: int,
9496

9597
self.in_channels = in_channels
9698
self.out_channels = out_channels
97-
self.scalers = scalers
9899
self.edge_dim = edge_dim
99100
self.towers = towers
100101
self.divide_input = divide_input
@@ -132,9 +133,6 @@ def __init__(self, in_channels: int, out_channels: int,
132133

133134
self.lin = Linear(out_channels, out_channels)
134135

135-
# scale hook
136-
self.register_aggregate_forward_hook(self.scale_hook)
137-
138136
self.reset_parameters()
139137

140138
def reset_parameters(self):
@@ -179,29 +177,6 @@ def message(self, x_i: Tensor, x_j: Tensor,
179177
hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)]
180178
return torch.stack(hs, dim=1)
181179

182-
@staticmethod
183-
def scale_hook(module, inputs, out: Tensor) -> Tensor:
184-
185-
deg = degree(inputs[0]['index'], dtype=out.dtype)
186-
deg = deg.clamp_(1).view(-1, 1, 1)
187-
188-
outs = []
189-
for scaler in module.scalers:
190-
if scaler == 'identity':
191-
pass
192-
elif scaler == 'amplification':
193-
out = out * (torch.log(deg + 1) / module.avg_deg['log'])
194-
elif scaler == 'attenuation':
195-
out = out * (module.avg_deg['log'] / torch.log(deg + 1))
196-
elif scaler == 'linear':
197-
out = out * (deg / module.avg_deg['lin'])
198-
elif scaler == 'inverse_linear':
199-
out = out * (module.avg_deg['lin'] / deg)
200-
else:
201-
raise ValueError(f'Unknown scaler "{scaler}".')
202-
outs.append(out)
203-
return torch.cat(outs, dim=-1)
204-
205180
def __repr__(self):
206181
return (f'{self.__class__.__name__}({self.in_channels}, '
207182
f'{self.out_channels}, towers={self.towers}, '

0 commit comments

Comments
 (0)