Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GlobalPooling and graph-level to_hetero support #4582

Merged
merged 8 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
### Removed
67 changes: 53 additions & 14 deletions test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import Linear, ReLU, Sequential
from torch_sparse import SparseTensor

from torch_geometric.nn import BatchNorm, GCNConv, GINEConv
from torch_geometric.nn import BatchNorm, GCNConv, GINEConv, GlobalPooling
from torch_geometric.nn import Linear as LazyLinear
from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero

Expand Down Expand Up @@ -123,11 +123,10 @@ def forward(self, x: Tensor) -> Tensor:


def test_to_hetero():
metadata = (['paper', 'author'], [('paper', 'cites', 'paper'),
('paper', 'written_by', 'author'),
('author', 'writes', 'paper')])

x_dict = {'paper': torch.randn(100, 16), 'author': torch.randn(100, 16)}
x_dict = {
'paper': torch.randn(100, 16),
'author': torch.randn(100, 16),
}
edge_index_dict = {
('paper', 'cites', 'paper'):
torch.randint(100, (2, 200), dtype=torch.long),
Expand All @@ -142,6 +141,8 @@ def test_to_hetero():
('author', 'writes', 'paper'): torch.randn(200, 8),
}

metadata = list(x_dict.keys()), list(edge_index_dict.keys())

model = Net1()
model = to_hetero(model, metadata, debug=False)
out = model(x_dict, edge_attr_dict)
Expand Down Expand Up @@ -225,13 +226,16 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:


def test_to_hetero_with_gcn():
metadata = (['paper'], [('paper', '0', 'paper'), ('paper', '1', 'paper')])
x_dict = {'paper': torch.randn(100, 16)}
x_dict = {
'paper': torch.randn(100, 16),
}
edge_index_dict = {
('paper', '0', 'paper'): torch.randint(100, (2, 200)),
('paper', '1', 'paper'): torch.randint(100, (2, 200)),
}

metadata = list(x_dict.keys()), list(edge_index_dict.keys())

model = GCN()
model = to_hetero(model, metadata, debug=False)
out = model(x_dict, edge_index_dict)
Expand Down Expand Up @@ -284,10 +288,6 @@ def test_to_hetero_and_rgcn_equal_output():
out1 = conv(x, edge_index, edge_type)

# Run `to_hetero`:
node_types = ['paper', 'author']
edge_types = [('paper', '_', 'paper'), ('paper', '_', 'author'),
('author', '_', 'paper')]

x_dict = {
'paper': x[:6],
'author': x[6:],
Expand All @@ -301,13 +301,14 @@ def test_to_hetero_and_rgcn_equal_output():
edge_index[:, edge_type == 2] - torch.tensor([[6], [0]]),
}

node_types, edge_types = list(x_dict.keys()), list(edge_index_dict.keys())

adj_t_dict = {
key: SparseTensor.from_edge_index(edge_index).t()
for key, edge_index in edge_index_dict.items()
}

metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))
model = to_hetero(RGCN(16, 32), metadata)
model = to_hetero(RGCN(16, 32), (node_types, edge_types))

# Set model weights:
for i, edge_type in enumerate(edge_types):
Expand All @@ -324,3 +325,41 @@ def test_to_hetero_and_rgcn_equal_output():
out3 = model(x_dict, adj_t_dict)
out3 = torch.cat([out3['paper'], out3['author']], dim=0)
assert torch.allclose(out1, out3, atol=1e-6)


class GraphLevelGNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = SAGEConv(16, 32)
self.pool = GlobalPooling(aggr='mean')
self.lin = Linear(32, 64)

def forward(self, x: Tensor, edge_index: Tensor, batch: Tensor) -> Tensor:
x = self.conv(x, edge_index)
x = self.pool(x, batch)
x = self.lin(x)
return x


def test_graph_level_to_hetero():
x_dict = {
'paper': torch.randn(100, 16),
'author': torch.randn(100, 16),
}
edge_index_dict = {
('paper', 'written_by', 'author'):
torch.randint(100, (2, 200), dtype=torch.long),
('author', 'writes', 'paper'):
torch.randint(100, (2, 200), dtype=torch.long),
}
batch_dict = {
'paper': torch.zeros(100, dtype=torch.long),
'author': torch.zeros(100, dtype=torch.long),
}

metadata = list(x_dict.keys()), list(edge_index_dict.keys())

model = GraphLevelGNN()
model = to_hetero(model, metadata, aggr='mean', debug=False)
out = model(x_dict, edge_index_dict, batch_dict)
assert out.size() == (1, 64)
57 changes: 46 additions & 11 deletions torch_geometric/nn/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
from torch.nn import Module, ModuleDict, ModuleList, Sequential

from torch_geometric.nn.conv import MessagePassing

try:
from torch.fx import Graph, GraphModule, Node
except (ImportError, ModuleNotFoundError, AttributeError):
Expand All @@ -32,6 +30,7 @@ class Transformer(object):
+-- call_method()
+-- call_module()
+-- call_message_passing_module()
+-- call_global_pooling_module()
+-- output()
+-- Erase unused nodes in the graph
+-- Iterate over each children module
Expand All @@ -41,8 +40,9 @@ class Transformer(object):
:class:`Transformer` exposes additional functionality:

#. It subdivides :func:`call_module` into nodes that call a regular
:class:`torch.nn.Module` (:func:`call_module`) or a
:class:`MessagePassing` module (:func:`call_message_passing_module`).
:class:`torch.nn.Module` (:func:`call_module`), a
:class:`MessagePassing` module (:func:`call_message_passing_module`),
or a :class:`GlobalPooling` module (:func:`call_global_pooling_module`).

#. It allows to customize or initialize new children modules via
:func:`init_submodule`
Expand Down Expand Up @@ -85,6 +85,9 @@ def get_attr(self, node: Node, target: Any, name: str):
def call_message_passing_module(self, node: Node, target: Any, name: str):
pass

def call_global_pooling_module(self, node: Node, target: Any, name: str):
pass

def call_module(self, node: Node, target: Any, name: str):
pass

Expand Down Expand Up @@ -132,11 +135,15 @@ def transform(self) -> GraphModule:
self._state[node.name] = 'node'
elif is_message_passing_op(self.module, node.op, node.target):
self._state[node.name] = 'node'
elif is_global_pooling_op(self.module, node.op, node.target):
self._state[node.name] = 'graph'
elif node.op in ['call_module', 'call_method', 'call_function']:
if self.has_edge_level_arg(node):
self._state[node.name] = 'edge'
else:
elif self.has_node_level_arg(node):
self._state[node.name] = 'node'
else:
self._state[node.name] = 'graph'

# We iterate over each node and may transform it:
for node in list(self.graph.nodes):
Expand All @@ -145,6 +152,9 @@ def transform(self) -> GraphModule:
op = node.op
if is_message_passing_op(self.module, op, node.target):
op = 'call_message_passing_module'
elif is_global_pooling_op(self.module, op, node.target):
op = 'call_global_pooling_module'

getattr(self, op)(node, node.target, node.name)

# Remove all unused nodes in the computation graph, i.e., all nodes
Expand Down Expand Up @@ -190,13 +200,13 @@ def _init_submodule(self, module: Module, target: str) -> Module:
else:
return self.init_submodule(module, target)

def is_edge_level(self, node: Node) -> bool:
return self._state[node.name] == 'edge'
def _is_level(self, node: Node, name: str) -> bool:
return self._state[node.name] == name

def has_edge_level_arg(self, node: Node) -> bool:
def _has_level_arg(self, node: Node, name: str) -> bool:
def _recurse(value: Any) -> bool:
if isinstance(value, Node):
return self.is_edge_level(value)
return getattr(self, f'is_{name}_level')(value)
elif isinstance(value, dict):
return any([_recurse(v) for v in value.values()])
elif isinstance(value, (list, tuple)):
Expand All @@ -207,6 +217,24 @@ def _recurse(value: Any) -> bool:
return (any([_recurse(value) for value in node.args])
or any([_recurse(value) for value in node.kwargs.values()]))

def is_node_level(self, node: Node) -> bool:
return self._is_level(node, name='node')

def is_edge_level(self, node: Node) -> bool:
return self._is_level(node, name='edge')

def is_graph_level(self, node: Node) -> bool:
return self._is_level(node, name='graph')

def has_node_level_arg(self, node: Node) -> bool:
return self._has_level_arg(node, name='node')

def has_edge_level_arg(self, node: Node) -> bool:
return self._has_level_arg(node, name='edge')

def has_graph_level_arg(self, node: Node) -> bool:
return self._has_level_arg(node, name='graph')

def find_by_name(self, name: str) -> Optional[Node]:
for node in self.graph.nodes:
if node.name == name:
Expand Down Expand Up @@ -249,7 +277,14 @@ def get_submodule(module: Module, target: str) -> Module:


def is_message_passing_op(module: Module, op: str, target: str) -> bool:
from torch_geometric.nn import MessagePassing
if op == 'call_module':
return isinstance(get_submodule(module, target), MessagePassing)
return False


def is_global_pooling_op(module: Module, op: str, target: str) -> bool:
from torch_geometric.nn import GlobalPooling
if op == 'call_module':
if isinstance(get_submodule(module, target), MessagePassing):
return True
return isinstance(get_submodule(module, target), GlobalPooling)
return False
2 changes: 2 additions & 0 deletions torch_geometric/nn/glob/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .glob import global_add_pool, global_mean_pool, global_max_pool
from .glob import GlobalPooling
from .sort import global_sort_pool
from .attention import GlobalAttention
from .set2set import Set2Set
Expand All @@ -8,6 +9,7 @@
'global_add_pool',
'global_mean_pool',
'global_max_pool',
'GlobalPooling',
'global_sort_pool',
'GlobalAttention',
'Set2Set',
Expand Down
43 changes: 42 additions & 1 deletion torch_geometric/nn/glob/glob.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional
from typing import List, Optional, Union

import torch
from torch import Tensor
from torch_scatter import scatter

Expand Down Expand Up @@ -74,3 +75,43 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor],
return x.max(dim=0, keepdim=True)[0]
size = int(batch.max().item() + 1) if size is None else size
return scatter(x, batch, dim=0, dim_size=size, reduce='max')


class GlobalPooling(torch.nn.Module):
r"""A global pooling module that wraps the usage of
:meth:`~torch_geometric.nn.glob.global_add_pool`,
:meth:`~torch_geometric.nn.glob.global_mean_pool` and
:meth:`~torch_geometric.nn.glob.global_max_pool` into a single module.

Args:
aggr (string or List[str]): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
If given as a list, will make use of multiple aggregations in which
different outputs will get concatenated in the last dimension.
"""
def __init__(self, aggr: Union[str, List[str]]):
super().__init__()

self.aggrs = [aggr] if isinstance(aggr, str) else aggr

assert len(self.aggrs) > 0
assert len(set(self.aggrs) | {'sum', 'add', 'mean', 'max'}) == 4

def forward(self, x: Tensor, batch: Optional[Tensor],
size: Optional[int] = None) -> Tensor:
""""""
xs: List[Tensor] = []

for aggr in self.aggrs:
if aggr == 'sum' or aggr == 'add':
xs.append(global_add_pool(x, batch, size))
elif aggr == 'mean':
xs.append(global_mean_pool(x, batch, size))
elif aggr == 'max':
xs.append(global_max_pool(x, batch, size))

return xs[0] if len(xs) == 1 else torch.cat(xs, dim=-1)

def __repr__(self) -> str:
aggr = self.aggrs[0] if len(self.aggrs) == 1 else self.aggrs
return f'{self.__class__.__name__}(aggr={aggr})'
Loading