diff --git a/CHANGELOG.md b/CHANGELOG.md index 36f855cc0ee4..56fd02033c72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) - Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850)) - Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838)) - Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816)) diff --git a/torch_geometric/nn/dense/diff_pool.py b/torch_geometric/nn/dense/diff_pool.py index 54e746431325..6e6abb92e0cf 100644 --- a/torch_geometric/nn/dense/diff_pool.py +++ b/torch_geometric/nn/dense/diff_pool.py @@ -3,7 +3,7 @@ EPS = 1e-15 -def dense_diff_pool(x, adj, s, mask=None): +def dense_diff_pool(x, adj, s, mask=None, normalize=True): r"""The differentiable pooling operator from the `"Hierarchical Graph Representation Learning with Differentiable Pooling" `_ paper @@ -44,6 +44,9 @@ def dense_diff_pool(x, adj, s, mask=None): mask (BoolTensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) + normalize (bool, optional): If set to :obj:`False`, the link + prediction loss is not divided by :obj:`adj.numel()`. + (default: :obj:`True`) :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`, :class:`Tensor`) @@ -66,7 +69,8 @@ def dense_diff_pool(x, adj, s, mask=None): link_loss = adj - torch.matmul(s, s.transpose(1, 2)) link_loss = torch.norm(link_loss, p=2) - link_loss = link_loss / adj.numel() + if normalize is True: + link_loss = link_loss / adj.numel() ent_loss = (-s * torch.log(s + EPS)).sum(dim=-1).mean()