Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LSTMAggregation #4731

Merged
merged 7 commits into from
May 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.0.5] - 2022-MM-DD
### Added
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
7 changes: 5 additions & 2 deletions test/nn/aggr/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ def test_validate():

aggr = MeanAggregation()

with pytest.raises(ValueError, match='invalid dimension'):
with pytest.raises(ValueError, match="either 'index' or 'ptr'"):
aggr(x)

with pytest.raises(ValueError, match="invalid dimension"):
aggr(x, index, dim=-3)

with pytest.raises(ValueError, match='mismatch between'):
with pytest.raises(ValueError, match="mismatch between"):
aggr(x, ptr=ptr, dim_size=2)


Expand Down
20 changes: 20 additions & 0 deletions test/nn/aggr/test_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest
import torch

from torch_geometric.nn import LSTMAggregation


def test_lstm_aggregation():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])

aggr = LSTMAggregation(16, 32)
assert str(aggr) == 'LSTMAggregation(16, 32)'

aggr.reset_parameters()

with pytest.raises(ValueError, match="is not sorted"):
aggr(x, torch.tensor([0, 1, 0, 1, 2, 1]))

out = aggr(x, index)
assert out.size() == (3, 32)
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SoftmaxAggregation,
PowerMeanAggregation,
)
from .lstm import LSTMAggregation

__all__ = classes = [
'Aggregation',
Expand All @@ -20,4 +21,5 @@
'StdAggregation',
'SoftmaxAggregation',
'PowerMeanAggregation',
'LSTMAggregation',
]
21 changes: 14 additions & 7 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

class Aggregation(torch.nn.Module, ABC):
r"""An abstract base class for implementing custom aggregations."""
requires_sorted_index = False

@abstractmethod
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
Expand Down Expand Up @@ -37,6 +39,16 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

if index is None and ptr is None:
raise ValueError(f"Expected that either 'index' or 'ptr' is "
f"passed to '{self.__class__.__name__}'")

if (self.requires_sorted_index and index is not None
and not torch.all(index[:-1] <= index[1:])):
raise ValueError(f"Can not perform aggregation inside "
f"'{self.__class__.__name__}' since the "
f"'index' tensor is not sorted")

if dim >= x.dim() or dim < -x.dim():
raise ValueError(f"Encountered invalid dimension '{dim}' of "
f"source tensor with {x.dim()} dimensions")
Expand All @@ -52,17 +64,12 @@ def reduce(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2, reduce: str = 'add') -> Tensor:

assert index is not None or ptr is not None

if ptr is not None:
ptr = expand_left(ptr, dim, dims=x.dim())
return segment_csr(x, ptr, reduce=reduce)

if index is not None:
return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)

raise ValueError(f"Error in '{self.__class__.__name__}': "
f"One of 'index' or 'ptr' must be defined")
assert index is not None
return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
Expand Down
57 changes: 57 additions & 0 deletions torch_geometric/nn/aggr/lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Optional

from torch import Tensor
from torch.nn import LSTM

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import to_dense_batch


class LSTMAggregation(Aggregation):
r"""Performs LSTM-style aggregation in which the elements to aggregate are
interpreted as a sequence.

.. warn::
:class:`LSTMAggregation` is not permutation-invariant.

.. note::
:class:`LSTMAggregation` requires sorted indices.

Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
**kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`.
"""
requires_sorted_index = True

def __init__(self, in_channels: int, out_channels: int, **kwargs):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs)

def reset_parameters(self):
self.lstm.reset_parameters()

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

if index is None: # TODO
raise NotImplementedError(f"'{self.__class__.__name__}' with "
f"'ptr' not yet supported")

if x.dim() != 2:
raise ValueError(f"'{self.__class__.__name__}' requires "
f"two-dimensional inputs (got '{x.dim()}')")

if dim not in [-2, 0]:
raise ValueError(f"'{self.__class__.__name__}' needs to perform "
f"aggregation in first dimension (got '{dim}')")

x, _ = to_dense_batch(x, index, batch_size=dim_size)
return self.lstm(x)[0][:, -1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I am not familiar with LSTM. Some comments: I don't know whether using the output of the last time step is a good idea or not. 1. many padded zeros and fed into LSTM for nodes with smaller degrees. 2. the forgetting issue for long sequences.

How about something like this to remove the effects of padded zeros and average over all the time steps (not sure how it performs):

x, mask = to_dense_batch(x, index, batch_size=dim_size)
return self.reduce(self.lstm(x)[0][mask], index, ptr, dim_size, dim, reduce='mean')

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a valid point. AFAIK, GraphSAGE (and any other LSTM-style aggregation procedure, e.g., Jumping Knowledge), just read out the embeddings after the last element of the sequence has been processed, so I opted to go for this solution. It might be very interesting to explore which reduction performs better here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it is worth trying it out later.


def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels})')