Skip to content

Commit 325143b

Browse files
committed
reset
1 parent 1d9996c commit 325143b

File tree

2 files changed

+40
-175
lines changed

2 files changed

+40
-175
lines changed

test/nn/aggr/test_mp_interface.py

-131
This file was deleted.

torch_geometric/nn/conv/message_passing.py

+40-44
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
import torch
1212
from torch import Tensor
1313
from torch.utils.hooks import RemovableHandle
14-
from torch_scatter import gather_csr
14+
from torch_scatter import gather_csr, scatter, segment_csr
1515
from torch_sparse import SparseTensor
1616

17-
from torch_geometric.nn.aggr import Aggregation, MultiAggregation
18-
from torch_geometric.nn.resolver import aggregation_resolver
1917
from torch_geometric.typing import Adj, Size
2018

2119
from .utils.helpers import expand_left
@@ -28,11 +26,7 @@
2826
split_types_repr,
2927
)
3028

31-
BASIC_AGGRS = {
32-
'add', 'sum', 'mean', 'min', 'max', 'mul', 'var', 'std', 'softmax',
33-
'powermean'
34-
}
35-
SPMM_AGGRS = {'add', 'sum', 'mean', 'min', 'max'}
29+
AGGRS = {'add', 'sum', 'mean', 'min', 'max', 'mul'}
3630

3731

3832
class MessagePassing(torch.nn.Module):
@@ -91,36 +85,22 @@ class MessagePassing(torch.nn.Module):
9185
'size_i', 'size_j', 'ptr', 'index', 'dim_size'
9286
}
9387

94-
def __init__(
95-
self,
96-
aggr: Optional[Union[str, List[str], Aggregation]] = "add",
97-
flow: str = "source_to_target",
98-
node_dim: int = -2,
99-
decomposed_layers: int = 1,
100-
):
88+
def __init__(self, aggr: Optional[Union[str, List[str]]] = "add",
89+
flow: str = "source_to_target", node_dim: int = -2,
90+
decomposed_layers: int = 1):
10191

10292
super().__init__()
10393

10494
if aggr is None or isinstance(aggr, str):
105-
assert aggr is None or aggr in BASIC_AGGRS
106-
self.aggr_module = aggregation_resolver(aggr)
95+
assert aggr is None or aggr in AGGRS
10796
self.aggr: Optional[str] = aggr
97+
self.aggrs: List[str] = []
10898
elif isinstance(aggr, (tuple, list)):
109-
assert len(
110-
set(filter(lambda x: isinstance(x, str), aggr))
111-
| BASIC_AGGRS) == len(BASIC_AGGRS)
112-
assert all(
113-
map(lambda x: isinstance(x, Aggregation),
114-
set(aggr).difference(BASIC_AGGRS)))
115-
self.aggr_module: List[Aggregation] = MultiAggregation(
116-
list(map(aggregation_resolver, aggr)))
117-
self.aggr: Optional[str] = str(self.aggr_module)
118-
elif isinstance(aggr, Aggregation):
119-
self.aggr_module: Optional[Aggregation] = aggr
120-
self.aggr: Optional[str] = str(self.aggr_module)
99+
assert len(set(aggr) | AGGRS) == len(AGGRS)
100+
self.aggr: Optional[str] = None
101+
self.aggrs: List[str] = aggr
121102
else:
122-
raise ValueError(f"Only strings, list, tuples and subclasses of"
123-
f"torch_geometric.nn.aggr.Aggregation are valid "
103+
raise ValueError(f"Only strings, list and tuples are valid "
124104
f"aggregation schemes (got '{type(aggr)}')")
125105

126106
self.flow = flow
@@ -308,7 +288,7 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
308288

309289
# Run "fused" message and aggregation (if applicable).
310290
if (isinstance(edge_index, SparseTensor) and self.fuse
311-
and not self.explain and (self.aggr in SPMM_AGGRS)):
291+
and not self.explain and len(self.aggrs) == 0):
312292
coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
313293
size, kwargs)
314294

@@ -328,9 +308,7 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
328308
out = self.update(out, **update_kwargs)
329309

330310
# Otherwise, run both functions in separation.
331-
elif isinstance(
332-
edge_index,
333-
Tensor) or not self.fuse or not (self.aggr in SPMM_AGGRS):
311+
elif isinstance(edge_index, Tensor) or not self.fuse:
334312
if decomposed_layers > 1:
335313
user_args = self.__user_args__
336314
decomp_args = {a[:-2] for a in user_args if a[-2:] == '_j'}
@@ -370,7 +348,14 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
370348
if res is not None:
371349
aggr_kwargs = res[0] if isinstance(res, tuple) else res
372350

373-
out = self.aggregate(out, **aggr_kwargs)
351+
if len(self.aggrs) == 0:
352+
out = self.aggregate(out, **aggr_kwargs)
353+
else:
354+
outs = []
355+
for aggr in self.aggrs:
356+
tmp = self.aggregate(out, aggr=aggr, **aggr_kwargs)
357+
outs.append(tmp)
358+
out = self.combine(outs)
374359

375360
for hook in self._aggregate_forward_hooks.values():
376361
res = hook(self, (aggr_kwargs, ), out)
@@ -482,21 +467,25 @@ def explain_message(self, inputs: Tensor, size_i: int) -> Tensor:
482467

483468
def aggregate(self, inputs: Tensor, index: Tensor,
484469
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
485-
aggr: Optional[Aggregation] = None) -> Tensor:
470+
aggr: Optional[str] = None) -> Tensor:
486471
r"""Aggregates messages from neighbors as
487472
:math:`\square_{j \in \mathcal{N}(i)}`.
488473
489474
Takes in the output of message computation as first argument and any
490475
argument which was initially passed to :meth:`propagate`.
491476
492-
By default, this function will delegate its call to `Aggregation`
493-
modules to reduce the messages as specified in :meth:`__init__` by the
494-
:obj:`aggr` argument.
477+
By default, this function will delegate its call to scatter functions
478+
that support "add", "mean", "min", "max" and "mul" operations as
479+
specified in :meth:`__init__` by the :obj:`aggr` argument.
495480
"""
496-
aggr = self.aggr_module if aggr is None else aggr
481+
aggr = self.aggr if aggr is None else aggr
497482
assert aggr is not None
498-
return aggr(inputs, index=index, ptr=ptr, dim_size=dim_size,
499-
dim=self.node_dim)
483+
if ptr is not None:
484+
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
485+
return segment_csr(inputs, ptr, reduce=aggr)
486+
else:
487+
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
488+
reduce=aggr)
500489

501490
def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
502491
r"""Fuses computations of :func:`message` and :func:`aggregate` into a
@@ -508,6 +497,13 @@ def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
508497
"""
509498
raise NotImplementedError
510499

500+
def combine(self, inputs: List[Tensor]) -> Tensor:
501+
r"""Combines the outputs from multiple aggregations into a single
502+
representation. Will only get called in case :obj:`aggr` holds a list
503+
of aggregation schemes to use."""
504+
assert len(inputs) > 0
505+
return torch.cat(inputs, dim=-1) if len(inputs) > 1 else inputs[0]
506+
511507
def update(self, inputs: Tensor) -> Tensor:
512508
r"""Updates node embeddings in analogy to
513509
:math:`\gamma_{\mathbf{\Theta}}` for each node
@@ -763,7 +759,7 @@ def jittable(self, typing: Optional[str] = None):
763759
prop_types=prop_types,
764760
prop_return_type=prop_return_type,
765761
fuse=self.fuse,
766-
single_aggr=isinstance(self.aggr_module, MultiAggregation),
762+
single_aggr=len(self.aggrs) == 0,
767763
collect_types=collect_types,
768764
user_args=self.__user_args__,
769765
edge_user_args=self.__edge_user_args__,

0 commit comments

Comments
 (0)