Skip to content

Commit 893aca5

Browse files
authored
nn.aggr.Set2Set (#4762)
* update * update * updatE * update * fix test * update * update * add todo * fix test
1 parent 8bd9ae4 commit 893aca5

File tree

11 files changed

+141
-134
lines changed

11 files changed

+141
-134
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
1111
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
1212
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
13-
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731))
13+
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762))
1414
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
1515
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
1616
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ It is commonly applied to graph-level tasks, which require combining node featur
254254
<summary><b>Expand to see all implemented pooling layers...</b></summary>
255255

256256
* **[GlobalAttention](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.GlobalAttention)** from Li *et al.*: [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py)]
257-
* **[Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.Set2Set)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)]
257+
* **[Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.Set2Set)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)]
258258
* **[Sort Pool](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_sort_pool)** from Zhang *et al.*: [An End-to-End Deep Learning Architecture for Graph Classification](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf) (AAAI 2018) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sort_pool.py)]
259259
* **[MinCUT Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.mincut_pool.dense_mincut_pool)** from Bianchi *et al.*: [MinCUT Pooling in Graph Neural Networks](https://arxiv.org/abs/1907.00481) (CoRR 2019) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_mincut_pool.py)]
260260
* **[DMoN Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.dmon_pool.DMoNPooling)** from Tsitsulin *et al.*: [Graph Clustering with Graph Neural Networks](https://arxiv.org/abs/2006.16904) (CoRR 2020) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_dmon_pool.py)]

test/nn/aggr/test_basic.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,10 @@ def test_validate():
2020

2121
aggr = MeanAggregation()
2222

23-
with pytest.raises(ValueError, match="either 'index' or 'ptr'"):
24-
aggr(x)
25-
2623
with pytest.raises(ValueError, match="invalid dimension"):
2724
aggr(x, index, dim=-3)
2825

29-
with pytest.raises(ValueError, match="mismatch between"):
26+
with pytest.raises(ValueError, match="invalid 'dim_size'"):
3027
aggr(x, ptr=ptr, dim_size=2)
3128

3229

test/nn/aggr/test_lstm.py

-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ def test_lstm_aggregation():
1111
aggr = LSTMAggregation(16, 32)
1212
assert str(aggr) == 'LSTMAggregation(16, 32)'
1313

14-
aggr.reset_parameters()
15-
1614
with pytest.raises(ValueError, match="is not sorted"):
1715
aggr(x, torch.tensor([0, 1, 0, 1, 2, 1]))
1816

File renamed without changes.

torch_geometric/nn/aggr/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
PowerMeanAggregation,
1111
)
1212
from .lstm import LSTMAggregation
13+
from .set2set import Set2Set
1314

1415
__all__ = classes = [
1516
'Aggregation',
@@ -22,4 +23,5 @@
2223
'SoftmaxAggregation',
2324
'PowerMeanAggregation',
2425
'LSTMAggregation',
26+
'Set2Set',
2527
]

torch_geometric/nn/aggr/base.py

+59-17
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional
2+
from typing import Optional, Tuple
33

44
import torch
55
from torch import Tensor
66
from torch_scatter import scatter, segment_csr
77

8+
from torch_geometric.utils import to_dense_batch
9+
810

911
class Aggregation(torch.nn.Module, ABC):
1012
r"""An abstract base class for implementing custom aggregations."""
11-
requires_sorted_index = False
12-
1313
@abstractmethod
1414
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
1515
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
@@ -39,26 +39,59 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *,
3939
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
4040
dim: int = -2) -> Tensor:
4141

42+
if dim >= x.dim() or dim < -x.dim():
43+
raise ValueError(f"Encountered invalid dimension '{dim}' of "
44+
f"source tensor with {x.dim()} dimensions")
45+
4246
if index is None and ptr is None:
43-
raise ValueError(f"Expected that either 'index' or 'ptr' is "
44-
f"passed to '{self.__class__.__name__}'")
47+
index = x.new_zeros(x.size(dim), dtype=torch.long)
4548

46-
if (self.requires_sorted_index and index is not None
47-
and not torch.all(index[:-1] <= index[1:])):
49+
if ptr is not None:
50+
if dim_size is None:
51+
dim_size = ptr.numel() - 1
52+
elif dim_size != ptr.numel() - 1:
53+
raise ValueError(f"Encountered invalid 'dim_size' (got "
54+
f"'{dim_size}' but expected "
55+
f"'{ptr.numel() - 1}')")
56+
57+
if index is not None:
58+
if dim_size is None:
59+
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0
60+
elif index.numel() > 0 and dim_size <= int(index.max()):
61+
raise ValueError(f"Encountered invalid 'dim_size' (got "
62+
f"'{dim_size}' but expected "
63+
f">= '{int(index.max()) + 1}')")
64+
65+
return super().__call__(x, index, ptr=ptr, dim_size=dim_size, dim=dim)
66+
67+
def __repr__(self) -> str:
68+
return f'{self.__class__.__name__}()'
69+
70+
# Assertions ##############################################################
71+
72+
def assert_index_present(self, index: Optional[Tensor]):
73+
# TODO Currently, not all aggregators support `ptr`. This assert helps
74+
# to ensure that we require `index` to be passed to the computation:
75+
if index is None:
76+
raise NotImplementedError(f"'{self.__class__.__name__}' requires "
77+
f"'index' to be specified")
78+
79+
def assert_sorted_index(self, index: Optional[Tensor]):
80+
if index is not None and not torch.all(index[:-1] <= index[1:]):
4881
raise ValueError(f"Can not perform aggregation inside "
4982
f"'{self.__class__.__name__}' since the "
5083
f"'index' tensor is not sorted")
5184

52-
if dim >= x.dim() or dim < -x.dim():
53-
raise ValueError(f"Encountered invalid dimension '{dim}' of "
54-
f"source tensor with {x.dim()} dimensions")
85+
def assert_two_dimensional_input(self, x: Tensor, dim: int):
86+
if x.dim() != 2:
87+
raise ValueError(f"'{self.__class__.__name__}' requires "
88+
f"two-dimensional inputs (got '{x.dim()}')")
5589

56-
if (ptr is not None and dim_size is not None
57-
and dim_size != ptr.numel() - 1):
58-
raise ValueError(f"Encountered mismatch between 'dim_size' (got "
59-
f"'{dim_size}') and 'ptr' (got '{ptr.size(0)}')")
90+
if dim not in [-2, 0]:
91+
raise ValueError(f"'{self.__class__.__name__}' needs to perform "
92+
f"aggregation in first dimension (got '{dim}')")
6093

61-
return super().__call__(x, index, ptr=ptr, dim_size=dim_size, dim=dim)
94+
# Helper methods ##########################################################
6295

6396
def reduce(self, x: Tensor, index: Optional[Tensor] = None,
6497
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
@@ -71,8 +104,17 @@ def reduce(self, x: Tensor, index: Optional[Tensor] = None,
71104
assert index is not None
72105
return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)
73106

74-
def __repr__(self) -> str:
75-
return f'{self.__class__.__name__}()'
107+
def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None,
108+
ptr: Optional[Tensor] = None,
109+
dim_size: Optional[int] = None,
110+
dim: int = -2) -> Tuple[Tensor, Tensor]:
111+
112+
# TODO Currently, `to_dense_batch` can only operate on `index`:
113+
self.assert_index_present(index)
114+
self.assert_sorted_index(index)
115+
self.assert_two_dimensional_input(x, dim)
116+
117+
return to_dense_batch(x, index, batch_size=dim_size)
76118

77119

78120
###############################################################################

torch_geometric/nn/aggr/lstm.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from torch.nn import LSTM
55

66
from torch_geometric.nn.aggr import Aggregation
7-
from torch_geometric.utils import to_dense_batch
87

98

109
class LSTMAggregation(Aggregation):
@@ -22,34 +21,20 @@ class LSTMAggregation(Aggregation):
2221
out_channels (int): Size of each output sample.
2322
**kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`.
2423
"""
25-
requires_sorted_index = True
26-
2724
def __init__(self, in_channels: int, out_channels: int, **kwargs):
2825
super().__init__()
2926
self.in_channels = in_channels
3027
self.out_channels = out_channels
3128
self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs)
29+
self.reset_parameters()
3230

3331
def reset_parameters(self):
3432
self.lstm.reset_parameters()
3533

3634
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
3735
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
3836
dim: int = -2) -> Tensor:
39-
40-
if index is None: # TODO
41-
raise NotImplementedError(f"'{self.__class__.__name__}' with "
42-
f"'ptr' not yet supported")
43-
44-
if x.dim() != 2:
45-
raise ValueError(f"'{self.__class__.__name__}' requires "
46-
f"two-dimensional inputs (got '{x.dim()}')")
47-
48-
if dim not in [-2, 0]:
49-
raise ValueError(f"'{self.__class__.__name__}' needs to perform "
50-
f"aggregation in first dimension (got '{dim}')")
51-
52-
x, _ = to_dense_batch(x, index, batch_size=dim_size)
37+
x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim)
5338
return self.lstm(x)[0][:, -1]
5439

5540
def __repr__(self) -> str:

torch_geometric/nn/aggr/set2set.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch import Tensor
5+
6+
from torch_geometric.nn.aggr import Aggregation
7+
from torch_geometric.utils import softmax
8+
9+
10+
class Set2Set(Aggregation):
11+
r"""The Set2Set aggregation operator based on iterative content-based
12+
attention, as described in the `"Order Matters: Sequence to sequence for
13+
Sets" <https://arxiv.org/abs/1511.06391>`_ paper
14+
15+
.. math::
16+
\mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1})
17+
18+
\alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t)
19+
20+
\mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i
21+
22+
\mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t,
23+
24+
where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice
25+
the dimensionality as the input.
26+
27+
Args:
28+
in_channels (int): Size of each input sample.
29+
processing_steps (int): Number of iterations :math:`T`.
30+
**kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`.
31+
"""
32+
def __init__(self, in_channels: int, processing_steps: int, **kwargs):
33+
super().__init__()
34+
self.in_channels = in_channels
35+
self.out_channels = 2 * in_channels
36+
self.processing_steps = processing_steps
37+
self.lstm = torch.nn.LSTM(self.out_channels, in_channels, **kwargs)
38+
self.reset_parameters()
39+
40+
def reset_parameters(self):
41+
self.lstm.reset_parameters()
42+
43+
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
44+
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
45+
dim: int = -2) -> Tensor:
46+
47+
# TODO Currently, `to_dense_batch` can only operate on `index`:
48+
self.assert_index_present(index)
49+
self.assert_two_dimensional_input(x, dim)
50+
51+
h = (x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))),
52+
x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))))
53+
q_star = x.new_zeros(dim_size, self.out_channels)
54+
55+
for _ in range(self.processing_steps):
56+
q, h = self.lstm(q_star.unsqueeze(0), h)
57+
q = q.view(dim_size, self.in_channels)
58+
e = (x * q[index]).sum(dim=-1, keepdim=True)
59+
a = softmax(e, index, ptr, dim_size, dim)
60+
r = self.reduce(a * x, index, ptr, dim_size, dim, reduce='add')
61+
q_star = torch.cat([q, r], dim=-1)
62+
63+
return q_star
64+
65+
def __repr__(self) -> str:
66+
return (f'{self.__class__.__name__}({self.in_channels}, '
67+
f'{self.out_channels})')

torch_geometric/nn/glob/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from .glob import GlobalPooling
33
from .sort import global_sort_pool
44
from .attention import GlobalAttention
5-
from .set2set import Set2Set
65
from .gmt import GraphMultisetTransformer
76

87
__all__ = [
@@ -12,8 +11,15 @@
1211
'GlobalPooling',
1312
'global_sort_pool',
1413
'GlobalAttention',
15-
'Set2Set',
1614
'GraphMultisetTransformer',
1715
]
1816

1917
classes = __all__
18+
19+
from torch_geometric.deprecation import deprecated # noqa
20+
from torch_geometric.nn.aggr import Set2Set # noqa
21+
22+
Set2Set = deprecated(
23+
details="use 'nn.aggr.Set2Set' instead",
24+
func_name='nn.glob.Set2Set',
25+
)(Set2Set)

torch_geometric/nn/glob/set2set.py

-90
This file was deleted.

0 commit comments

Comments
 (0)