From 914f36de904d686d7b0bb0464a7d29dd4c861be0 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 2 May 2022 14:38:55 +0200 Subject: [PATCH 1/8] update --- test/nn/test_to_hetero_transformer.py | 71 +++++++++++++++++++++------ torch_geometric/nn/glob/__init__.py | 2 + torch_geometric/nn/glob/glob.py | 43 +++++++++++++++- 3 files changed, 101 insertions(+), 15 deletions(-) diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index d5458f6bd2ab..d576367d646c 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -5,7 +5,7 @@ from torch.nn import Linear, ReLU, Sequential from torch_sparse import SparseTensor -from torch_geometric.nn import BatchNorm, GCNConv, GINEConv +from torch_geometric.nn import BatchNorm, GCNConv, GINEConv, GlobalPooling from torch_geometric.nn import Linear as LazyLinear from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero @@ -123,11 +123,11 @@ def forward(self, x: Tensor) -> Tensor: def test_to_hetero(): - metadata = (['paper', 'author'], [('paper', 'cites', 'paper'), - ('paper', 'written_by', 'author'), - ('author', 'writes', 'paper')]) - - x_dict = {'paper': torch.randn(100, 16), 'author': torch.randn(100, 16)} + return + x_dict = { + 'paper': torch.randn(100, 16), + 'author': torch.randn(100, 16), + } edge_index_dict = { ('paper', 'cites', 'paper'): torch.randint(100, (2, 200), dtype=torch.long), @@ -142,6 +142,8 @@ def test_to_hetero(): ('author', 'writes', 'paper'): torch.randn(200, 8), } + metadata = list(x_dict.keys()), list(edge_index_dict.keys()) + model = Net1() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_attr_dict) @@ -225,13 +227,17 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: def test_to_hetero_with_gcn(): - metadata = (['paper'], [('paper', '0', 'paper'), ('paper', '1', 'paper')]) - x_dict = {'paper': torch.randn(100, 16)} + return + x_dict = { + 'paper': torch.randn(100, 16), + } edge_index_dict = { ('paper', '0', 'paper'): torch.randint(100, (2, 200)), ('paper', '1', 'paper'): torch.randint(100, (2, 200)), } + metadata = list(x_dict.keys()), list(edge_index_dict.keys()) + model = GCN() model = to_hetero(model, metadata, debug=False) out = model(x_dict, edge_index_dict) @@ -264,6 +270,7 @@ def forward(self, x, edge_index): def test_to_hetero_and_rgcn_equal_output(): + return torch.manual_seed(1234) # Run `RGCN`: @@ -284,10 +291,6 @@ def test_to_hetero_and_rgcn_equal_output(): out1 = conv(x, edge_index, edge_type) # Run `to_hetero`: - node_types = ['paper', 'author'] - edge_types = [('paper', '_', 'paper'), ('paper', '_', 'author'), - ('author', '_', 'paper')] - x_dict = { 'paper': x[:6], 'author': x[6:], @@ -301,13 +304,14 @@ def test_to_hetero_and_rgcn_equal_output(): edge_index[:, edge_type == 2] - torch.tensor([[6], [0]]), } + node_types, edge_types = list(x_dict.keys()), list(edge_index_dict.keys()) + adj_t_dict = { key: SparseTensor.from_edge_index(edge_index).t() for key, edge_index in edge_index_dict.items() } - metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) - model = to_hetero(RGCN(16, 32), metadata) + model = to_hetero(RGCN(16, 32), (node_types, edge_types)) # Set model weights: for i, edge_type in enumerate(edge_types): @@ -324,3 +328,42 @@ def test_to_hetero_and_rgcn_equal_output(): out3 = model(x_dict, adj_t_dict) out3 = torch.cat([out3['paper'], out3['author']], dim=0) assert torch.allclose(out1, out3, atol=1e-6) + + +class GraphLevelGNN(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = SAGEConv(16, 32) + self.pool = GlobalPooling(aggr='mean') + self.lin = Linear(32, 64) + + def forward(self, x: Tensor, edge_index: Tensor, batch: Tensor) -> Tensor: + x = self.conv(x, edge_index) + x = self.pool(x, batch) + x = self.lin(x) + return x + + +def test_graph_level_to_hetero(): + x_dict = { + 'paper': torch.randn(100, 16), + 'author': torch.randn(100, 16), + } + edge_index_dict = { + ('paper', 'written_by', 'author'): + torch.randint(100, (2, 200), dtype=torch.long), + ('author', 'writes', 'paper'): + torch.randint(100, (2, 200), dtype=torch.long), + } + batch_dict = { + 'paper': torch.zeros(100, dtype=torch.long), + 'author': torch.zeros(100, dtype=torch.long), + } + + metadata = list(x_dict.keys()), list(edge_index_dict.keys()) + + model = GraphLevelGNN() + print('-------------------') + model = to_hetero(model, metadata, debug=True) + out = model(x_dict, edge_index_dict, batch_dict) + print(out) diff --git a/torch_geometric/nn/glob/__init__.py b/torch_geometric/nn/glob/__init__.py index 5921d5a86206..8b911ccf859e 100644 --- a/torch_geometric/nn/glob/__init__.py +++ b/torch_geometric/nn/glob/__init__.py @@ -1,4 +1,5 @@ from .glob import global_add_pool, global_mean_pool, global_max_pool +from .glob import GlobalPooling from .sort import global_sort_pool from .attention import GlobalAttention from .set2set import Set2Set @@ -8,6 +9,7 @@ 'global_add_pool', 'global_mean_pool', 'global_max_pool', + 'GlobalPooling', 'global_sort_pool', 'GlobalAttention', 'Set2Set', diff --git a/torch_geometric/nn/glob/glob.py b/torch_geometric/nn/glob/glob.py index 2eb3bfafed66..c4eb0b541297 100644 --- a/torch_geometric/nn/glob/glob.py +++ b/torch_geometric/nn/glob/glob.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import List, Optional, Union +import torch from torch import Tensor from torch_scatter import scatter @@ -74,3 +75,43 @@ def global_max_pool(x: Tensor, batch: Optional[Tensor], return x.max(dim=0, keepdim=True)[0] size = int(batch.max().item() + 1) if size is None else size return scatter(x, batch, dim=0, dim_size=size, reduce='max') + + +class GlobalPooling(torch.nn.Module): + r"""A global pooling module that wraps the usage of + :meth:`~torch_geometric.nn.glob.global_add_pool`, + :meth:`~torch_geometric.nn.glob.global_mean_pool` and + :meth:`~torch_geometric.nn.glob.global_max_pool` into a single module. + + Args: + aggr (string or List[str]): The aggregation scheme to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + If given as a list, will make use of multiple aggregations in which + different outputs will get concatenated in the last dimension. + """ + def __init__(self, aggr: Union[str, List[str]]): + super().__init__() + + self.aggrs = [aggr] if isinstance(aggr, str) else aggr + + assert len(self.aggrs) > 0 + assert len(set(self.aggrs) | {'sum', 'add', 'mean', 'max'}) == 4 + + def forward(self, x: Tensor, batch: Optional[Tensor], + size: Optional[int] = None) -> Tensor: + """""" + xs: List[Tensor] = [] + + for aggr in self.aggrs: + if aggr == 'sum' or aggr == 'add': + xs.append(global_add_pool(x, batch, size)) + elif aggr == 'mean': + xs.append(global_mean_pool(x, batch, size)) + elif aggr == 'max': + xs.append(global_max_pool(x, batch, size)) + + return xs[0] if len(xs) == 1 else torch.cat(xs, dim=-1) + + def __repr__(self) -> str: + aggr = self.aggrs[0] if len(self.aggrs) == 1 else self.aggrs + return f'{self.__class__.__name__}(aggr={aggr})' From 9012fb0925349897d208e3348baebcf88404e86f Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 2 May 2022 14:47:24 +0200 Subject: [PATCH 2/8] update call --- test/nn/test_to_hetero_transformer.py | 4 +--- torch_geometric/nn/fx.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index d576367d646c..44fb5e1bd374 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -363,7 +363,5 @@ def test_graph_level_to_hetero(): metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = GraphLevelGNN() - print('-------------------') model = to_hetero(model, metadata, debug=True) - out = model(x_dict, edge_index_dict, batch_dict) - print(out) + model(x_dict, edge_index_dict, batch_dict) diff --git a/torch_geometric/nn/fx.py b/torch_geometric/nn/fx.py index 94f1afb6463a..056a6a651187 100644 --- a/torch_geometric/nn/fx.py +++ b/torch_geometric/nn/fx.py @@ -4,8 +4,6 @@ import torch from torch.nn import Module, ModuleDict, ModuleList, Sequential -from torch_geometric.nn.conv import MessagePassing - try: from torch.fx import Graph, GraphModule, Node except (ImportError, ModuleNotFoundError, AttributeError): @@ -132,6 +130,8 @@ def transform(self) -> GraphModule: self._state[node.name] = 'node' elif is_message_passing_op(self.module, node.op, node.target): self._state[node.name] = 'node' + elif is_global_pooling_op(self.module, node.op, node.target): + self._state[node.name] = 'node' elif node.op in ['call_module', 'call_method', 'call_function']: if self.has_edge_level_arg(node): self._state[node.name] = 'edge' @@ -145,6 +145,9 @@ def transform(self) -> GraphModule: op = node.op if is_message_passing_op(self.module, op, node.target): op = 'call_message_passing_module' + elif is_global_pooling_op(self.module, op, node.target): + op = 'call_global_pooling_module' + getattr(self, op)(node, node.target, node.name) # Remove all unused nodes in the computation graph, i.e., all nodes @@ -249,7 +252,14 @@ def get_submodule(module: Module, target: str) -> Module: def is_message_passing_op(module: Module, op: str, target: str) -> bool: + from torch_geometric.nn import MessagePassing + if op == 'call_module': + return isinstance(get_submodule(module, target), MessagePassing) + return False + + +def is_global_pooling_op(module: Module, op: str, target: str) -> bool: + from torch_geometric.nn import GlobalPooling if op == 'call_module': - if isinstance(get_submodule(module, target), MessagePassing): - return True + return isinstance(get_submodule(module, target), GlobalPooling) return False From 073ea2bef45fe4dceff7ac27c3231451b51daf89 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 2 May 2022 14:49:11 +0200 Subject: [PATCH 3/8] update --- torch_geometric/nn/fx.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch_geometric/nn/fx.py b/torch_geometric/nn/fx.py index 056a6a651187..e6c408780224 100644 --- a/torch_geometric/nn/fx.py +++ b/torch_geometric/nn/fx.py @@ -30,6 +30,7 @@ class Transformer(object): +-- call_method() +-- call_module() +-- call_message_passing_module() + +-- call_global_pooling_module() +-- output() +-- Erase unused nodes in the graph +-- Iterate over each children module @@ -39,8 +40,9 @@ class Transformer(object): :class:`Transformer` exposes additional functionality: #. It subdivides :func:`call_module` into nodes that call a regular - :class:`torch.nn.Module` (:func:`call_module`) or a - :class:`MessagePassing` module (:func:`call_message_passing_module`). + :class:`torch.nn.Module` (:func:`call_module`), a + :class:`MessagePassing` module (:func:`call_message_passing_module`), + or a :class:`GlobalPooling` module (:func:`call_global_pooling_module`). #. It allows to customize or initialize new children modules via :func:`init_submodule` @@ -83,6 +85,9 @@ def get_attr(self, node: Node, target: Any, name: str): def call_message_passing_module(self, node: Node, target: Any, name: str): pass + def call_global_pooling_module(self, node: Node, target: Any, name: str): + pass + def call_module(self, node: Node, target: Any, name: str): pass @@ -131,7 +136,7 @@ def transform(self) -> GraphModule: elif is_message_passing_op(self.module, node.op, node.target): self._state[node.name] = 'node' elif is_global_pooling_op(self.module, node.op, node.target): - self._state[node.name] = 'node' + self._state[node.name] = 'graph' elif node.op in ['call_module', 'call_method', 'call_function']: if self.has_edge_level_arg(node): self._state[node.name] = 'edge' From 261f93bb3ce13ba63e0c4b8f12158081cad0b9b5 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 2 May 2022 14:56:06 +0200 Subject: [PATCH 4/8] update --- torch_geometric/nn/fx.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/torch_geometric/nn/fx.py b/torch_geometric/nn/fx.py index e6c408780224..5e33f7eae271 100644 --- a/torch_geometric/nn/fx.py +++ b/torch_geometric/nn/fx.py @@ -140,8 +140,12 @@ def transform(self) -> GraphModule: elif node.op in ['call_module', 'call_method', 'call_function']: if self.has_edge_level_arg(node): self._state[node.name] = 'edge' - else: + if self.has_node_level_arg(node): self._state[node.name] = 'node' + else: + self._state[node.name] = 'graph' + + print(self._state) # We iterate over each node and may transform it: for node in list(self.graph.nodes): @@ -198,13 +202,13 @@ def _init_submodule(self, module: Module, target: str) -> Module: else: return self.init_submodule(module, target) - def is_edge_level(self, node: Node) -> bool: - return self._state[node.name] == 'edge' + def _is_level(self, node: None, name: str) -> bool: + return self._state[node.name] == name - def has_edge_level_arg(self, node: Node) -> bool: + def _has_level_arg(self, node: Node, name: str) -> bool: def _recurse(value: Any) -> bool: if isinstance(value, Node): - return self.is_edge_level(value) + return getattr(self, f'is_{name}_level')(value) elif isinstance(value, dict): return any([_recurse(v) for v in value.values()]) elif isinstance(value, (list, tuple)): @@ -215,6 +219,24 @@ def _recurse(value: Any) -> bool: return (any([_recurse(value) for value in node.args]) or any([_recurse(value) for value in node.kwargs.values()])) + def is_node_level(self, node: Node) -> bool: + return self._is_level(node, name='node') + + def is_edge_level(self, node: Node) -> bool: + return self._is_level(node, name='edge') + + def is_graph_level(self, node: Node) -> bool: + return self._is_level(node, name='graph') + + def has_node_level_arg(self, node: Node) -> bool: + return self._has_level_arg(node, name='node') + + def has_edge_level_arg(self, node: Node) -> bool: + return self._has_level_arg(node, name='edge') + + def has_graph_level_arg(self, node: Node) -> bool: + return self._has_level_arg(node, name='graph') + def find_by_name(self, name: str) -> Optional[Node]: for node in self.graph.nodes: if node.name == name: From aac7063ed05ffafb673ad481fae8e34e5448a73f Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 2 May 2022 16:59:42 +0200 Subject: [PATCH 5/8] update --- test/nn/test_to_hetero_transformer.py | 5 +- torch_geometric/nn/fx.py | 2 - torch_geometric/nn/to_hetero_transformer.py | 60 +++++++++++++++++++-- 3 files changed, 60 insertions(+), 7 deletions(-) diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index 44fb5e1bd374..0738becda454 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -363,5 +363,6 @@ def test_graph_level_to_hetero(): metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = GraphLevelGNN() - model = to_hetero(model, metadata, debug=True) - model(x_dict, edge_index_dict, batch_dict) + model = to_hetero(model, metadata, debug=True, aggr='mean') + out = model(x_dict, edge_index_dict, batch_dict) + assert out.size() == (1, 64) diff --git a/torch_geometric/nn/fx.py b/torch_geometric/nn/fx.py index 5e33f7eae271..0a97a5969574 100644 --- a/torch_geometric/nn/fx.py +++ b/torch_geometric/nn/fx.py @@ -145,8 +145,6 @@ def transform(self) -> GraphModule: else: self._state[node.name] = 'graph' - print(self._state) - # We iterate over each node and may transform it: for node in list(self.graph.nodes): # Call the corresponding `Transformer` method for each `node.op`, diff --git a/torch_geometric/nn/to_hetero_transformer.py b/torch_geometric/nn/to_hetero_transformer.py index cedd26f4d8e2..456ad98b23b7 100644 --- a/torch_geometric/nn/to_hetero_transformer.py +++ b/torch_geometric/nn/to_hetero_transformer.py @@ -197,14 +197,14 @@ def call_message_passing_module(self, node: Node, target: Any, name: str): # `keys_per_dst` and append the result to the list. for dst, keys in keys_per_dst.items(): queue = deque([key_name[key] for key in keys]) - i = len(queue) + 1 + i = 1 while len(queue) >= 2: key1, key2 = queue.popleft(), queue.popleft() args = (self.find_by_name(key1), self.find_by_name(key2)) new_name = f'{name}__{dst}' if self.aggr == 'mean' or len(queue) > 0: - new_name += f'{i}' + new_name = f'{new_name}_{i}' out = self.graph.create_node('call_function', target=self.aggrs[self.aggr], @@ -219,9 +219,48 @@ def call_message_passing_module(self, node: Node, target: Any, name: str): 'call_function', target=torch.div, args=(self.find_by_name(key), len(keys_per_dst[dst])), name=f'{name}__{dst}') - self.graph.inserting_after(out) + self.graph.inserting_after(out) + + def call_global_pooling_module(self, node: Node, target: Any, name: str): + # Add calls to node type-wise `GlobalPooling` modules and aggregate + # the outputs to graph type-wise embeddings afterwards. + self.graph.inserting_after(node) + for key in self.metadata[0]: + args, kwargs = self.map_args_kwargs(node, key) + out = self.graph.create_node('call_module', + target=f'{target}.{key2str(key)}', + args=args, kwargs=kwargs, + name=f'{node.name}__{key2str(key)}') + self.graph.inserting_after(out) + + # Perform node-wise aggregation. + queue = deque( + [f'{node.name}__{key2str(key)}' for key in self.metadata[0]]) + i = 1 + while len(queue) >= 2: + key1, key2 = queue.popleft(), queue.popleft() + args = (self.find_by_name(key1), self.find_by_name(key2)) + out = self.graph.create_node('call_function', + target=self.aggrs[self.aggr], + args=args, name=f'{name}_{i}') + self.graph.inserting_after(out) + queue.append(f'{name}_{i}') + i += 1 + + if self.aggr == 'mean': + key = queue.popleft() + out = self.graph.create_node( + 'call_function', target=torch.div, + args=(self.find_by_name(key), len(self.metadata[0])), + name=f'{name}_{i}') + self.graph.inserting_after(out) + + self.replace_all_uses_with(node, out) def call_module(self, node: Node, target: Any, name: str): + if self.is_graph_level(node): + return + # Add calls to node type-wise or edge type-wise modules. self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: @@ -233,6 +272,9 @@ def call_module(self, node: Node, target: Any, name: str): self.graph.inserting_after(out) def call_method(self, node: Node, target: Any, name: str): + if self.is_graph_level(node): + return + # Add calls to node type-wise or edge type-wise methods. self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: @@ -243,6 +285,9 @@ def call_method(self, node: Node, target: Any, name: str): self.graph.inserting_after(out) def call_function(self, node: Node, target: Any, name: str): + if self.is_graph_level(node): + return + # Add calls to node type-wise or edge type-wise functions. self.graph.inserting_after(node) for key in self.metadata[int(self.is_edge_level(node))]: @@ -253,6 +298,9 @@ def call_function(self, node: Node, target: Any, name: str): self.graph.inserting_after(out) def output(self, node: Node, target: Any, name: str): + if self.is_graph_level(node.args[0]): + return + # Replace the output by dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: @@ -281,9 +329,14 @@ def _recurse(value: Any) -> Any: def init_submodule(self, module: Module, target: str) -> Module: # Replicate each module for each node type or edge type. + has_node_level_target = bool( + self.find_by_target(f'{target}.{key2str(self.metadata[0][0])}')) has_edge_level_target = bool( self.find_by_target(f'{target}.{key2str(self.metadata[1][0])}')) + if not has_node_level_target and not has_edge_level_target: + return module + module_dict = torch.nn.ModuleDict() for key in self.metadata[int(has_edge_level_target)]: module_dict[key2str(key)] = copy.deepcopy(module) @@ -296,6 +349,7 @@ def init_submodule(self, module: Module, target: str) -> Module: f"'{target}' will be duplicated, but its parameters " f"cannot be reset. To suppress this warning, add a " f"'reset_parameters()' method to '{target}'") + return module_dict # Helper methods ########################################################## From cac07579221d30896b5aa5763488f51da115ee2b Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 2 May 2022 17:23:48 +0200 Subject: [PATCH 6/8] update --- test/nn/test_to_hetero_transformer.py | 5 +---- torch_geometric/nn/fx.py | 4 ++-- torch_geometric/nn/to_hetero_transformer.py | 12 ++++++------ 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index 0738becda454..95bb0e7ab426 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -123,7 +123,6 @@ def forward(self, x: Tensor) -> Tensor: def test_to_hetero(): - return x_dict = { 'paper': torch.randn(100, 16), 'author': torch.randn(100, 16), @@ -227,7 +226,6 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: def test_to_hetero_with_gcn(): - return x_dict = { 'paper': torch.randn(100, 16), } @@ -270,7 +268,6 @@ def forward(self, x, edge_index): def test_to_hetero_and_rgcn_equal_output(): - return torch.manual_seed(1234) # Run `RGCN`: @@ -363,6 +360,6 @@ def test_graph_level_to_hetero(): metadata = list(x_dict.keys()), list(edge_index_dict.keys()) model = GraphLevelGNN() - model = to_hetero(model, metadata, debug=True, aggr='mean') + model = to_hetero(model, metadata, aggr='mean', debug=False) out = model(x_dict, edge_index_dict, batch_dict) assert out.size() == (1, 64) diff --git a/torch_geometric/nn/fx.py b/torch_geometric/nn/fx.py index 0a97a5969574..10d3663df019 100644 --- a/torch_geometric/nn/fx.py +++ b/torch_geometric/nn/fx.py @@ -140,7 +140,7 @@ def transform(self) -> GraphModule: elif node.op in ['call_module', 'call_method', 'call_function']: if self.has_edge_level_arg(node): self._state[node.name] = 'edge' - if self.has_node_level_arg(node): + elif self.has_node_level_arg(node): self._state[node.name] = 'node' else: self._state[node.name] = 'graph' @@ -200,7 +200,7 @@ def _init_submodule(self, module: Module, target: str) -> Module: else: return self.init_submodule(module, target) - def _is_level(self, node: None, name: str) -> bool: + def _is_level(self, node: Node, name: str) -> bool: return self._state[node.name] == name def _has_level_arg(self, node: Node, name: str) -> bool: diff --git a/torch_geometric/nn/to_hetero_transformer.py b/torch_geometric/nn/to_hetero_transformer.py index 456ad98b23b7..4cde616a60f8 100644 --- a/torch_geometric/nn/to_hetero_transformer.py +++ b/torch_geometric/nn/to_hetero_transformer.py @@ -150,7 +150,6 @@ def __init__( def placeholder(self, node: Node, target: Any, name: str): # Adds a `get` call to the input dictionary for every node-type or # edge-type. - if node.type is not None: Type = EdgeType if self.is_edge_level(node) else NodeType node.type = Dict[Type, node.type] @@ -298,13 +297,12 @@ def call_function(self, node: Node, target: Any, name: str): self.graph.inserting_after(out) def output(self, node: Node, target: Any, name: str): - if self.is_graph_level(node.args[0]): - return - # Replace the output by dictionaries, holding either node type-wise or # edge type-wise data. def _recurse(value: Any) -> Any: if isinstance(value, Node): + if self.is_graph_level(value): + return value return { key: self.find_by_name(f'{value.name}__{key2str(key)}') for key in self.metadata[int(self.is_edge_level(value))] @@ -320,8 +318,10 @@ def _recurse(value: Any) -> Any: if node.type is not None and isinstance(node.args[0], Node): output = node.args[0] - Type = EdgeType if self.is_edge_level(output) else NodeType - node.type = Dict[Type, node.type] + if self.is_node_level(output): + node.type = Dict[NodeType, node.type] + elif self.is_edge_level(output): + node.type = Dict[EdgeType, node.type] else: node.type = None From 7300a4b4ad8d74022eeb619c65cf7ba955c9e851 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 2 May 2022 17:25:46 +0200 Subject: [PATCH 7/8] typo --- torch_geometric/nn/to_hetero_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/to_hetero_transformer.py b/torch_geometric/nn/to_hetero_transformer.py index 4cde616a60f8..710ce8e3e05b 100644 --- a/torch_geometric/nn/to_hetero_transformer.py +++ b/torch_geometric/nn/to_hetero_transformer.py @@ -218,7 +218,7 @@ def call_message_passing_module(self, node: Node, target: Any, name: str): 'call_function', target=torch.div, args=(self.find_by_name(key), len(keys_per_dst[dst])), name=f'{name}__{dst}') - self.graph.inserting_after(out) + self.graph.inserting_after(out) def call_global_pooling_module(self, node: Node, target: Any, name: str): # Add calls to node type-wise `GlobalPooling` modules and aggregate From cd74536b393bc88315206c76248ca4eb47475f81 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 2 May 2022 17:27:19 +0200 Subject: [PATCH 8/8] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f4346bb66ce..c58fb878ea91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) +- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed ### Removed