11
11
import torch
12
12
from torch import Tensor
13
13
from torch .utils .hooks import RemovableHandle
14
- from torch_scatter import gather_csr
14
+ from torch_scatter import gather_csr , scatter , segment_csr
15
15
from torch_sparse import SparseTensor
16
16
17
- from torch_geometric .nn .aggr import Aggregation , MultiAggregation
18
- from torch_geometric .nn .resolver import aggregation_resolver
19
17
from torch_geometric .typing import Adj , Size
20
18
21
19
from .utils .helpers import expand_left
28
26
split_types_repr ,
29
27
)
30
28
31
- BASIC_AGGRS = {
32
- 'add' , 'sum' , 'mean' , 'min' , 'max' , 'mul' , 'var' , 'std' , 'softmax' ,
33
- 'powermean'
34
- }
35
- SPMM_AGGRS = {'add' , 'sum' , 'mean' , 'min' , 'max' }
29
+ AGGRS = {'add' , 'sum' , 'mean' , 'min' , 'max' , 'mul' }
36
30
37
31
38
32
class MessagePassing (torch .nn .Module ):
@@ -91,36 +85,22 @@ class MessagePassing(torch.nn.Module):
91
85
'size_i' , 'size_j' , 'ptr' , 'index' , 'dim_size'
92
86
}
93
87
94
- def __init__ (
95
- self ,
96
- aggr : Optional [Union [str , List [str ], Aggregation ]] = "add" ,
97
- flow : str = "source_to_target" ,
98
- node_dim : int = - 2 ,
99
- decomposed_layers : int = 1 ,
100
- ):
88
+ def __init__ (self , aggr : Optional [Union [str , List [str ]]] = "add" ,
89
+ flow : str = "source_to_target" , node_dim : int = - 2 ,
90
+ decomposed_layers : int = 1 ):
101
91
102
92
super ().__init__ ()
103
93
104
94
if aggr is None or isinstance (aggr , str ):
105
- assert aggr is None or aggr in BASIC_AGGRS
106
- self .aggr_module = aggregation_resolver (aggr )
95
+ assert aggr is None or aggr in AGGRS
107
96
self .aggr : Optional [str ] = aggr
97
+ self .aggrs : List [str ] = []
108
98
elif isinstance (aggr , (tuple , list )):
109
- assert len (
110
- set (filter (lambda x : isinstance (x , str ), aggr ))
111
- | BASIC_AGGRS ) == len (BASIC_AGGRS )
112
- assert all (
113
- map (lambda x : isinstance (x , Aggregation ),
114
- set (aggr ).difference (BASIC_AGGRS )))
115
- self .aggr_module : List [Aggregation ] = MultiAggregation (
116
- list (map (aggregation_resolver , aggr )))
117
- self .aggr : Optional [str ] = str (self .aggr_module )
118
- elif isinstance (aggr , Aggregation ):
119
- self .aggr_module : Optional [Aggregation ] = aggr
120
- self .aggr : Optional [str ] = str (self .aggr_module )
99
+ assert len (set (aggr ) | AGGRS ) == len (AGGRS )
100
+ self .aggr : Optional [str ] = None
101
+ self .aggrs : List [str ] = aggr
121
102
else :
122
- raise ValueError (f"Only strings, list, tuples and subclasses of"
123
- f"torch_geometric.nn.aggr.Aggregation are valid "
103
+ raise ValueError (f"Only strings, list and tuples are valid "
124
104
f"aggregation schemes (got '{ type (aggr )} ')" )
125
105
126
106
self .flow = flow
@@ -308,7 +288,7 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
308
288
309
289
# Run "fused" message and aggregation (if applicable).
310
290
if (isinstance (edge_index , SparseTensor ) and self .fuse
311
- and not self .explain and (self .aggr in SPMM_AGGRS ) ):
291
+ and not self .explain and len (self .aggrs ) == 0 ):
312
292
coll_dict = self .__collect__ (self .__fused_user_args__ , edge_index ,
313
293
size , kwargs )
314
294
@@ -328,9 +308,7 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
328
308
out = self .update (out , ** update_kwargs )
329
309
330
310
# Otherwise, run both functions in separation.
331
- elif isinstance (
332
- edge_index ,
333
- Tensor ) or not self .fuse or not (self .aggr in SPMM_AGGRS ):
311
+ elif isinstance (edge_index , Tensor ) or not self .fuse :
334
312
if decomposed_layers > 1 :
335
313
user_args = self .__user_args__
336
314
decomp_args = {a [:- 2 ] for a in user_args if a [- 2 :] == '_j' }
@@ -370,7 +348,14 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
370
348
if res is not None :
371
349
aggr_kwargs = res [0 ] if isinstance (res , tuple ) else res
372
350
373
- out = self .aggregate (out , ** aggr_kwargs )
351
+ if len (self .aggrs ) == 0 :
352
+ out = self .aggregate (out , ** aggr_kwargs )
353
+ else :
354
+ outs = []
355
+ for aggr in self .aggrs :
356
+ tmp = self .aggregate (out , aggr = aggr , ** aggr_kwargs )
357
+ outs .append (tmp )
358
+ out = self .combine (outs )
374
359
375
360
for hook in self ._aggregate_forward_hooks .values ():
376
361
res = hook (self , (aggr_kwargs , ), out )
@@ -482,21 +467,25 @@ def explain_message(self, inputs: Tensor, size_i: int) -> Tensor:
482
467
483
468
def aggregate (self , inputs : Tensor , index : Tensor ,
484
469
ptr : Optional [Tensor ] = None , dim_size : Optional [int ] = None ,
485
- aggr : Optional [Aggregation ] = None ) -> Tensor :
470
+ aggr : Optional [str ] = None ) -> Tensor :
486
471
r"""Aggregates messages from neighbors as
487
472
:math:`\square_{j \in \mathcal{N}(i)}`.
488
473
489
474
Takes in the output of message computation as first argument and any
490
475
argument which was initially passed to :meth:`propagate`.
491
476
492
- By default, this function will delegate its call to `Aggregation`
493
- modules to reduce the messages as specified in :meth:`__init__` by the
494
- :obj:`aggr` argument.
477
+ By default, this function will delegate its call to scatter functions
478
+ that support "add", "mean", "min", "max" and "mul" operations as
479
+ specified in :meth:`__init__` by the :obj:`aggr` argument.
495
480
"""
496
- aggr = self .aggr_module if aggr is None else aggr
481
+ aggr = self .aggr if aggr is None else aggr
497
482
assert aggr is not None
498
- return aggr (inputs , index = index , ptr = ptr , dim_size = dim_size ,
499
- dim = self .node_dim )
483
+ if ptr is not None :
484
+ ptr = expand_left (ptr , dim = self .node_dim , dims = inputs .dim ())
485
+ return segment_csr (inputs , ptr , reduce = aggr )
486
+ else :
487
+ return scatter (inputs , index , dim = self .node_dim , dim_size = dim_size ,
488
+ reduce = aggr )
500
489
501
490
def message_and_aggregate (self , adj_t : SparseTensor ) -> Tensor :
502
491
r"""Fuses computations of :func:`message` and :func:`aggregate` into a
@@ -508,6 +497,13 @@ def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
508
497
"""
509
498
raise NotImplementedError
510
499
500
+ def combine (self , inputs : List [Tensor ]) -> Tensor :
501
+ r"""Combines the outputs from multiple aggregations into a single
502
+ representation. Will only get called in case :obj:`aggr` holds a list
503
+ of aggregation schemes to use."""
504
+ assert len (inputs ) > 0
505
+ return torch .cat (inputs , dim = - 1 ) if len (inputs ) > 1 else inputs [0 ]
506
+
511
507
def update (self , inputs : Tensor ) -> Tensor :
512
508
r"""Updates node embeddings in analogy to
513
509
:math:`\gamma_{\mathbf{\Theta}}` for each node
@@ -763,7 +759,7 @@ def jittable(self, typing: Optional[str] = None):
763
759
prop_types = prop_types ,
764
760
prop_return_type = prop_return_type ,
765
761
fuse = self .fuse ,
766
- single_aggr = isinstance (self .aggr_module , MultiAggregation ) ,
762
+ single_aggr = len (self .aggrs ) == 0 ,
767
763
collect_types = collect_types ,
768
764
user_args = self .__user_args__ ,
769
765
edge_user_args = self .__edge_user_args__ ,
0 commit comments