Skip to content

Commit d72df66

Browse files
authored
Support for dense aggregations in global_*_pool (#4827)
* global dense pool * update
1 parent e3a52f9 commit d72df66

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
89
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
910
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))
1011
- Added a `max_sample` argument to `AddMetaPaths` in order to tackle very dense metapath edges ([#4750](https://github.com/pyg-team/pytorch_geometric/pull/4750))

test/nn/glob/test_glob.py

+5
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,8 @@ def test_permuted_global_pool():
6565
assert out.size() == (2, 4)
6666
assert torch.allclose(out[0], px1.max(dim=0)[0])
6767
assert torch.allclose(out[1], px2.max(dim=0)[0])
68+
69+
70+
def test_dense_global_pool():
71+
x = torch.randn(3, 16, 32)
72+
assert torch.allclose(global_add_pool(x, None), x.sum(dim=1))

torch_geometric/nn/glob/glob.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def global_add_pool(x: Tensor, batch: Optional[Tensor],
2424
Automatically calculated if not given. (default: :obj:`None`)
2525
"""
2626
if batch is None:
27-
return x.sum(dim=0, keepdim=True)
27+
return x.sum(dim=-2, keepdim=x.dim() == 2)
2828
size = int(batch.max().item() + 1) if size is None else size
29-
return scatter(x, batch, dim=0, dim_size=size, reduce='add')
29+
return scatter(x, batch, dim=-2, dim_size=size, reduce='add')
3030

3131

3232
def global_mean_pool(x: Tensor, batch: Optional[Tensor],
@@ -48,9 +48,9 @@ def global_mean_pool(x: Tensor, batch: Optional[Tensor],
4848
Automatically calculated if not given. (default: :obj:`None`)
4949
"""
5050
if batch is None:
51-
return x.mean(dim=0, keepdim=True)
51+
return x.mean(dim=-2, keepdim=x.dim() == 2)
5252
size = int(batch.max().item() + 1) if size is None else size
53-
return scatter(x, batch, dim=0, dim_size=size, reduce='mean')
53+
return scatter(x, batch, dim=-2, dim_size=size, reduce='mean')
5454

5555

5656
def global_max_pool(x: Tensor, batch: Optional[Tensor],
@@ -72,9 +72,9 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor],
7272
Automatically calculated if not given. (default: :obj:`None`)
7373
"""
7474
if batch is None:
75-
return x.max(dim=0, keepdim=True)[0]
75+
return x.max(dim=-2, keepdim=x.dim() == 2)[0]
7676
size = int(batch.max().item() + 1) if size is None else size
77-
return scatter(x, batch, dim=0, dim_size=size, reduce='max')
77+
return scatter(x, batch, dim=-2, dim_size=size, reduce='max')
7878

7979

8080
class GlobalPooling(torch.nn.Module):

0 commit comments

Comments
 (0)