Skip to content

Commit 85cddb3

Browse files
Hu-ChuxuanChuxuan Hurusty1swsad1
authored
Add a normalize parameter to dense_diff_pool (#4847)
* modified link_loss to make it viable to not normalizing * modified link_loss to make it viable to not normalizing * changelog modified * Update CHANGELOG.md Co-authored-by: Jinu Sunil <[email protected]> * update Co-authored-by: Chuxuan Hu <[email protected]> Co-authored-by: Matthias Fey <[email protected]> Co-authored-by: Jinu Sunil <[email protected]>
1 parent 97c50a0 commit 85cddb3

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
89
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
910
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
1011
- Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))

torch_geometric/nn/dense/diff_pool.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
EPS = 1e-15
44

55

6-
def dense_diff_pool(x, adj, s, mask=None):
6+
def dense_diff_pool(x, adj, s, mask=None, normalize=True):
77
r"""The differentiable pooling operator from the `"Hierarchical Graph
88
Representation Learning with Differentiable Pooling"
99
<https://arxiv.org/abs/1806.08804>`_ paper
@@ -44,6 +44,9 @@ def dense_diff_pool(x, adj, s, mask=None):
4444
mask (BoolTensor, optional): Mask matrix
4545
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
4646
the valid nodes for each graph. (default: :obj:`None`)
47+
normalize (bool, optional): If set to :obj:`False`, the link
48+
prediction loss is not divided by :obj:`adj.numel()`.
49+
(default: :obj:`True`)
4750
4851
:rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
4952
:class:`Tensor`)
@@ -66,7 +69,8 @@ def dense_diff_pool(x, adj, s, mask=None):
6669

6770
link_loss = adj - torch.matmul(s, s.transpose(1, 2))
6871
link_loss = torch.norm(link_loss, p=2)
69-
link_loss = link_loss / adj.numel()
72+
if normalize is True:
73+
link_loss = link_loss / adj.numel()
7074

7175
ent_loss = (-s * torch.log(s + EPS)).sum(dim=-1).mean()
7276

0 commit comments

Comments
 (0)