Skip to content

Commit df61109

Browse files
lightaimerusty1s
andauthored
MultiAggregation and aggregation_resolver (#4749)
* Add MulAggregation and MultiAggregation * Fix import issue * Support torch_geometric.nn.aggr package, note: jit errors to fix * Add tests for MulAggregation, MultiAggregation, aggregation_resolver and message_passing interface * Formatting * Fix __repr for gen aggrs * Move resolver * Fix test for MulAggregation * Add test for new mp interface * Add test for MultiAggregation * Minor fix * Add warming for MulAggregation with 'ptr' * Resolve aggr to Aggregation module, remove aggrs logic * changelog * Fix mul aggregation * update * update * update * update * reset Co-authored-by: rusty1s <[email protected]>
1 parent 893aca5 commit df61109

File tree

8 files changed

+144
-15
lines changed

8 files changed

+144
-15
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
1111
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
1212
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
13-
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762))
13+
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749))
1414
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
1515
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
1616
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))

test/nn/aggr/test_basic.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
MaxAggregation,
66
MeanAggregation,
77
MinAggregation,
8+
MulAggregation,
89
PowerMeanAggregation,
910
SoftmaxAggregation,
1011
StdAggregation,
@@ -29,7 +30,7 @@ def test_validate():
2930

3031
@pytest.mark.parametrize('Aggregation', [
3132
MeanAggregation, SumAggregation, MaxAggregation, MinAggregation,
32-
VarAggregation, StdAggregation
33+
MulAggregation, VarAggregation, StdAggregation
3334
])
3435
def test_basic_aggregation(Aggregation):
3536
x = torch.randn(6, 16)
@@ -41,7 +42,12 @@ def test_basic_aggregation(Aggregation):
4142

4243
out = aggr(x, index)
4344
assert out.size() == (3, x.size(1))
44-
assert torch.allclose(out, aggr(x, ptr=ptr))
45+
46+
if isinstance(aggr, MulAggregation):
47+
with pytest.raises(NotImplementedError, match="requires 'index'"):
48+
aggr(x, ptr=ptr)
49+
else:
50+
assert torch.allclose(out, aggr(x, ptr=ptr))
4551

4652

4753
@pytest.mark.parametrize('Aggregation',
@@ -53,7 +59,7 @@ def test_gen_aggregation(Aggregation, learn):
5359
ptr = torch.tensor([0, 2, 5, 6])
5460

5561
aggr = Aggregation(learn=learn)
56-
assert str(aggr) == f'{Aggregation.__name__}()'
62+
assert str(aggr) == f'{Aggregation.__name__}(learn={learn})'
5763

5864
out = aggr(x, index)
5965
assert out.size() == (3, x.size(1))

test/nn/aggr/test_multi.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
3+
from torch_geometric.nn import MultiAggregation
4+
5+
6+
def test_multi_aggr():
7+
x = torch.randn(6, 16)
8+
index = torch.tensor([0, 0, 1, 1, 1, 2])
9+
ptr = torch.tensor([0, 2, 5, 6])
10+
11+
aggrs = ['mean', 'sum', 'max']
12+
aggr = MultiAggregation(aggrs)
13+
assert str(aggr) == ('MultiAggregation([\n'
14+
' MeanAggregation(),\n'
15+
' SumAggregation(),\n'
16+
' MaxAggregation()\n'
17+
'])')
18+
19+
out = aggr(x, index)
20+
assert out.size() == (3, len(aggrs) * x.size(1))
21+
assert torch.allclose(out, aggr(x, ptr=ptr))

test/nn/test_resolver.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import pytest
12
import torch
23

3-
from torch_geometric.nn.resolver import activation_resolver
4+
import torch_geometric
5+
from torch_geometric.nn.resolver import (
6+
activation_resolver,
7+
aggregation_resolver,
8+
)
49

510

611
def test_activation_resolver():
@@ -11,3 +16,20 @@ def test_activation_resolver():
1116
assert isinstance(activation_resolver('elu'), torch.nn.ELU)
1217
assert isinstance(activation_resolver('relu'), torch.nn.ReLU)
1318
assert isinstance(activation_resolver('prelu'), torch.nn.PReLU)
19+
20+
21+
@pytest.mark.parametrize('aggr_tuple', [
22+
(torch_geometric.nn.aggr.MeanAggregation, 'mean'),
23+
(torch_geometric.nn.aggr.SumAggregation, 'sum'),
24+
(torch_geometric.nn.aggr.MaxAggregation, 'max'),
25+
(torch_geometric.nn.aggr.MinAggregation, 'min'),
26+
(torch_geometric.nn.aggr.MulAggregation, 'mul'),
27+
(torch_geometric.nn.aggr.VarAggregation, 'var'),
28+
(torch_geometric.nn.aggr.StdAggregation, 'std'),
29+
(torch_geometric.nn.aggr.SoftmaxAggregation, 'softmax'),
30+
(torch_geometric.nn.aggr.PowerMeanAggregation, 'powermean'),
31+
])
32+
def test_aggregation_resolver(aggr_tuple):
33+
aggr_module, aggr_repr = aggr_tuple
34+
assert isinstance(aggregation_resolver(aggr_module()), aggr_module)
35+
assert isinstance(aggregation_resolver(aggr_repr), aggr_module)

torch_geometric/nn/aggr/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from .base import Aggregation
2+
from .multi import MultiAggregation
23
from .basic import (
34
MeanAggregation,
45
SumAggregation,
6+
AddAggregation,
57
MaxAggregation,
68
MinAggregation,
9+
MulAggregation,
710
VarAggregation,
811
StdAggregation,
912
SoftmaxAggregation,
@@ -14,10 +17,13 @@
1417

1518
__all__ = classes = [
1619
'Aggregation',
20+
'MultiAggregation',
1721
'MeanAggregation',
1822
'SumAggregation',
23+
'AddAggregation',
1924
'MaxAggregation',
2025
'MinAggregation',
26+
'MulAggregation',
2127
'VarAggregation',
2228
'StdAggregation',
2329
'SoftmaxAggregation',

torch_geometric/nn/aggr/basic.py

+21
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
2222
return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')
2323

2424

25+
AddAggregation = SumAggregation # Alias
26+
27+
2528
class MaxAggregation(Aggregation):
2629
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
2730
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
@@ -36,6 +39,15 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
3639
return self.reduce(x, index, ptr, dim_size, dim, reduce='min')
3740

3841

42+
class MulAggregation(Aggregation):
43+
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
44+
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
45+
dim: int = -2) -> Tensor:
46+
# TODO Currently, `mul` reduction can only operate on `index`:
47+
self.assert_index_present(index)
48+
return self.reduce(x, index, None, dim_size, dim, reduce='mul')
49+
50+
3951
class VarAggregation(Aggregation):
4052
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
4153
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
@@ -61,6 +73,7 @@ def __init__(self, t: float = 1.0, learn: bool = False):
6173
super().__init__()
6274
self._init_t = t
6375
self.t = Parameter(torch.Tensor(1)) if learn else t
76+
self.learn = learn
6477
self.reset_parameters()
6578

6679
def reset_parameters(self):
@@ -77,15 +90,20 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
7790
alpha = softmax(alpha, index, ptr, dim_size, dim)
7891
return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum')
7992

93+
def __repr__(self) -> str:
94+
return (f'{self.__class__.__name__}(learn={self.learn})')
95+
8096

8197
class PowerMeanAggregation(Aggregation):
8298
def __init__(self, p: float = 1.0, learn: bool = False):
8399
# TODO Learn distinct `p` per channel.
84100
super().__init__()
85101
self._init_p = p
86102
self.p = Parameter(torch.Tensor(1)) if learn else p
103+
self.learn = learn
87104
self.reset_parameters()
88105

106+
def reset_parameters(self):
89107
if isinstance(self.p, Tensor):
90108
self.p.data.fill_(self._init_p)
91109

@@ -97,3 +115,6 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
97115
if isinstance(self.p, (int, float)) and self.p == 1:
98116
return out
99117
return out.clamp_(min=0, max=100).pow(1. / self.p)
118+
119+
def __repr__(self) -> str:
120+
return (f'{self.__class__.__name__}(learn={self.learn})')

torch_geometric/nn/aggr/multi.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import List, Optional, Union
2+
3+
import torch
4+
from torch import Tensor
5+
6+
from torch_geometric.nn.aggr import Aggregation
7+
from torch_geometric.nn.resolver import aggregation_resolver
8+
9+
10+
class MultiAggregation(Aggregation):
11+
def __init__(self, aggrs: List[Union[Aggregation, str]]):
12+
super().__init__()
13+
14+
if not isinstance(aggrs, (list, tuple)):
15+
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should "
16+
f"be a list or tuple (got '{type(aggrs)}')")
17+
18+
if len(aggrs) == 0:
19+
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should "
20+
f"not be empty")
21+
22+
self.aggrs = [aggregation_resolver(aggr) for aggr in aggrs]
23+
24+
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
25+
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
26+
dim: int = -2) -> Tensor:
27+
outs = []
28+
for aggr in self.aggrs:
29+
outs.append(aggr(x, index, ptr=ptr, dim_size=dim_size, dim=dim))
30+
return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0]
31+
32+
def __repr__(self) -> str:
33+
args = [f' {aggr}' for aggr in self.aggrs]
34+
return '{}([\n{}\n])'.format(self.__class__.__name__, ',\n'.join(args))

torch_geometric/nn/resolver.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
import inspect
2-
from typing import Any, List, Union
2+
from typing import Any, List, Optional, Union
33

4-
import torch
54
from torch import Tensor
65

76

87
def normalize_string(s: str) -> str:
98
return s.lower().replace('-', '').replace('_', '').replace(' ', '')
109

1110

12-
def resolver(classes: List[Any], query: Union[Any, str], *args, **kwargs):
11+
def resolver(classes: List[Any], query: Union[Any, str],
12+
base_cls: Optional[Any], *args, **kwargs):
13+
1314
if query is None or not isinstance(query, str):
1415
return query
1516

16-
query = normalize_string(query)
17+
query_repr = normalize_string(query)
18+
base_cls_repr = normalize_string(base_cls.__name__) if base_cls else ''
1719
for cls in classes:
18-
if query == normalize_string(cls.__name__):
20+
cls_repr = normalize_string(cls.__name__)
21+
if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]:
1922
if inspect.isclass(cls):
2023
return cls(*args, **kwargs)
2124
else:
2225
return cls
2326

24-
return ValueError(
25-
f"Could not resolve '{query}' among the choices "
26-
f"{set(normalize_string(cls.__name__) for cls in classes)}")
27+
return ValueError(f"Could not resolve '{query}' among the choices "
28+
f"{set(cls.__name__ for cls in classes)}")
2729

2830

2931
# Activation Resolver #########################################################
@@ -34,11 +36,28 @@ def swish(x: Tensor) -> Tensor:
3436

3537

3638
def activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs):
39+
import torch
40+
base_cls = torch.nn.Module
41+
3742
acts = [
3843
act for act in vars(torch.nn.modules.activation).values()
39-
if isinstance(act, type) and issubclass(act, torch.nn.Module)
44+
if isinstance(act, type) and issubclass(act, base_cls)
4045
]
4146
acts += [
4247
swish,
4348
]
44-
return resolver(acts, query, *args, **kwargs)
49+
return resolver(acts, query, base_cls, *args, **kwargs)
50+
51+
52+
# Aggregation Resolver ########################################################
53+
54+
55+
def aggregation_resolver(query: Union[Any, str], *args, **kwargs):
56+
import torch_geometric.nn.aggr as aggrs
57+
base_cls = aggrs.Aggregation
58+
59+
aggrs = [
60+
aggr for aggr in vars(aggrs).values()
61+
if isinstance(aggr, type) and issubclass(aggr, base_cls)
62+
]
63+
return resolver(aggrs, query, base_cls, *args, **kwargs)

0 commit comments

Comments
 (0)