diff --git a/CHANGELOG.md b/CHANGELOG.md index 24ae0a8ee31e..7800caaf33cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926)) - Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927)) - Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) - Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885)) diff --git a/test/nn/test_resolver.py b/test/nn/test_resolver.py index 8cb2e9caff4d..88874079cae4 100644 --- a/test/nn/test_resolver.py +++ b/test/nn/test_resolver.py @@ -5,6 +5,7 @@ from torch_geometric.nn.resolver import ( activation_resolver, aggregation_resolver, + normalization_resolver, ) @@ -34,3 +35,21 @@ def test_aggregation_resolver(aggr_tuple): aggr_module, aggr_repr = aggr_tuple assert isinstance(aggregation_resolver(aggr_module()), aggr_module) assert isinstance(aggregation_resolver(aggr_repr), aggr_module) + + +@pytest.mark.parametrize('norm_tuple', [ + (torch_geometric.nn.norm.BatchNorm, 'batch_norm', (16, )), + (torch_geometric.nn.norm.InstanceNorm, 'instance_norm', (16, )), + (torch_geometric.nn.norm.LayerNorm, 'layer_norm', (16, )), + (torch_geometric.nn.norm.GraphNorm, 'graph_norm', (16, )), + (torch_geometric.nn.norm.GraphSizeNorm, 'graphsize_norm', ()), + (torch_geometric.nn.norm.PairNorm, 'pair_norm', ()), + (torch_geometric.nn.norm.MessageNorm, 'message_norm', ()), + (torch_geometric.nn.norm.DiffGroupNorm, 'diffgroup_norm', (16, 4)), +]) +def test_normalization_resolver(norm_tuple): + norm_module, norm_repr, norm_args = norm_tuple + assert isinstance(normalization_resolver(norm_module(*norm_args)), + norm_module) + assert isinstance(normalization_resolver(norm_repr, *norm_args), + norm_module) diff --git a/torch_geometric/nn/resolver.py b/torch_geometric/nn/resolver.py index 8d843da52e85..00a00ed05930 100644 --- a/torch_geometric/nn/resolver.py +++ b/torch_geometric/nn/resolver.py @@ -61,6 +61,22 @@ def activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs): return resolver(acts, act_dict, query, base_cls, *args, **kwargs) +# Normalization Resolver ###################################################### + + +def normalization_resolver(query: Union[Any, str], *args, **kwargs): + import torch + + import torch_geometric.nn.norm as norm + base_cls = torch.nn.Module + norms = [ + norm for norm in vars(norm).values() + if isinstance(norm, type) and issubclass(norm, base_cls) + ] + norm_dict = {} + return resolver(norms, norm_dict, query, base_cls, *args, **kwargs) + + # Aggregation Resolver ########################################################