Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PNAConv to rely on new Aggregation #4864

Merged
merged 9 commits into from
Jul 3, 2022

Conversation

Padarn
Copy link
Contributor

@Padarn Padarn commented Jun 26, 2022

Simply remove the unneeded handling of aggregation in PNAConv now that this s handled in MessagePassing.

Noted in #4712

@Padarn Padarn requested review from lightaime and rusty1s June 26, 2022 01:45
@codecov
Copy link

codecov bot commented Jun 26, 2022

Codecov Report

Merging #4864 (2c7a289) into master (e430d94) will decrease coverage by 1.86%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #4864      +/-   ##
==========================================
- Coverage   84.58%   82.72%   -1.87%     
==========================================
  Files         329      330       +1     
  Lines       17847    17844       -3     
==========================================
- Hits        15096    14761     -335     
- Misses       2751     3083     +332     
Impacted Files Coverage Δ
torch_geometric/nn/aggr/__init__.py 100.00% <100.00%> (ø)
torch_geometric/nn/aggr/scaler.py 100.00% <100.00%> (ø)
torch_geometric/nn/conv/pna_conv.py 92.00% <100.00%> (-2.65%) ⬇️
torch_geometric/nn/models/dimenet_utils.py 0.00% <0.00%> (-75.52%) ⬇️
torch_geometric/nn/models/dimenet.py 14.51% <0.00%> (-53.00%) ⬇️
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) ⬇️
torch_geometric/nn/inits.py 67.85% <0.00%> (-7.15%) ⬇️
torch_geometric/nn/resolver.py 86.04% <0.00%> (-6.98%) ⬇️
torch_geometric/io/tu.py 93.90% <0.00%> (-2.44%) ⬇️
torch_geometric/nn/models/mlp.py 98.52% <0.00%> (-1.48%) ⬇️
... and 6 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e430d94...2c7a289. Read the comment docs.

@rusty1s rusty1s changed the title Refactor PNAConv to use rely on new Aggregation Refactor PNAConv to rely on new Aggregation Jun 26, 2022
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks pretty clean now :)

raise ValueError(f'Unknown aggregator "{aggregator}".')
outs.append(out)
out = torch.cat(outs, dim=-1)
out = super().aggregate(inputs, index, dim_size=dim_size)

deg = degree(index, dim_size, dtype=inputs.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move these scalers outside aggregate? This way, we don't have to override aggregate at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good idea :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated to apply the scalers as a hook - WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting solution, although I personally feel we shouldn't add business logic to hooks TBH. If you are not a fan of moving this logic post self.propagate, we could add it as an additional module to nn.aggr, e.g.,

class DegreeScalerAggregation(Aggregation):
    def __init__(self, aggregator, scalers, deg):
        pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, I may be mistaken to move the logic outside aggregate - This should definitely be part of aggregation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I like the degree scaler as an Aggregation. It keeps it within the Aggregation logic and should be eaiser to reuse.

Agree on the hooks not feeling like a great solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with adding the degree scaler as an Aggregation is better. Alternatively, we can add it as a helper method to the Aggregation base class instead of an Aggregation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made changes for this now. I'm leaving some comments in the diff though where I'd like some input

@rusty1s
Copy link
Member

rusty1s commented Jun 29, 2022

@Padarn Let me know if you need any help in adding DegreeScalerAggregation :)

@Padarn
Copy link
Contributor Author

Padarn commented Jun 29, 2022 via email

@Padarn Padarn force-pushed the padarn/pna-migrate branch from 0ec3e95 to d2f91fb Compare June 30, 2022 13:16
@Padarn
Copy link
Contributor Author

Padarn commented Jun 30, 2022

I've added the DegreeScalerAggregation - it's a tiny bit ugly right now (just a lift and shift). Also the only tests are through PNAConv.


out = self.agg(x, index, ptr, dim_size, dim)
deg = degree(index, dtype=out.dtype)
deg = deg.clamp_(1).view(*([-1] + [1] * (len(out.shape) - 1)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was modified - before it was simply

deg.clamp_(1).view(-1, 1, 1)

I don't know if there is a much cleaner way to achieve the same thing here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing. I think this is still not completely right though. How about we take out.size() and write -1 to the dimension given to dim, and use that for reshaping.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm makes sense to use the dimension give to dim, but just to clarify on using out.size(), if for example

out.size() = [4, 4, 96]
deg.size() = [4]
dim = -2

I assume we want to do

deg.view(1, -1, 1)

so that the scaler can broadcast

Does that make sense?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that‘s what I meant. Sorry for the confusion.

@lightaime
Copy link
Contributor

I've added the DegreeScalerAggregation - it's a tiny bit ugly right now (just a lift and shift). Also the only tests are through PNAConv.

@Padarn Thanks for adding the DegreeScalerAggregation. IMO, the degree scaler is more like an optional logic that we can do after reduce instead of an aggregation function. I think it makes more sense to make it a helper function in the base class. What do you think?

@Padarn
Copy link
Contributor Author

Padarn commented Jul 2, 2022 via email

@rusty1s
Copy link
Member

rusty1s commented Jul 2, 2022

I personally like the view of seeing PNA as a composition of aggregators and scalers (that‘s how it is defined in the paper), so I am in favor of keeping it that way.

@Padarn
Copy link
Contributor Author

Padarn commented Jul 3, 2022

That makes sense to me.

Thoughts on merging this now then and following up on the items I commented on in the diff separately?

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM. Feel free to merge once the remaining comments are resolved.


out = self.agg(x, index, ptr, dim_size, dim)
deg = degree(index, dtype=out.dtype)
deg = deg.clamp_(1).view(*([-1] + [1] * (len(out.shape) - 1)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing. I think this is still not completely right though. How about we take out.size() and write -1 to the dimension given to dim, and use that for reshaping.

@Padarn Padarn force-pushed the padarn/pna-migrate branch from 1ec28ad to eacfff1 Compare July 3, 2022 07:30
@Padarn Padarn merged commit fd944f3 into pyg-team:master Jul 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants