-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor PNAConv
to rely on new Aggregation
#4864
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4864 +/- ##
==========================================
- Coverage 84.58% 82.72% -1.87%
==========================================
Files 329 330 +1
Lines 17847 17844 -3
==========================================
- Hits 15096 14761 -335
- Misses 2751 3083 +332
Continue to review full report at Codecov.
|
PNAConv
to use rely on new Aggregation
PNAConv
to rely on new Aggregation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks pretty clean now :)
torch_geometric/nn/conv/pna_conv.py
Outdated
raise ValueError(f'Unknown aggregator "{aggregator}".') | ||
outs.append(out) | ||
out = torch.cat(outs, dim=-1) | ||
out = super().aggregate(inputs, index, dim_size=dim_size) | ||
|
||
deg = degree(index, dim_size, dtype=inputs.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move these scalers outside aggregate
? This way, we don't have to override aggregate
at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh good idea :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated to apply the scalers as a hook - WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting solution, although I personally feel we shouldn't add business logic to hooks TBH. If you are not a fan of moving this logic post self.propagate
, we could add it as an additional module to nn.aggr
, e.g.,
class DegreeScalerAggregation(Aggregation):
def __init__(self, aggregator, scalers, deg):
pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, I may be mistaken to move the logic outside aggregate
- This should definitely be part of aggregation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I like the degree scaler as an Aggregation. It keeps it within the Aggregation logic and should be eaiser to reuse.
Agree on the hooks not feeling like a great solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with adding the degree scaler as an Aggregation
is better. Alternatively, we can add it as a helper method to the Aggregation
base class instead of an Aggregation
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've made changes for this now. I'm leaving some comments in the diff though where I'd like some input
@Padarn Let me know if you need any help in adding |
Hey, I'm happy to do this but not a lot of time during the week. If you
want to get this merged sooner please don't hesitate to take over.
…On Wed, 29 Jun 2022, 1:49 pm Matthias Fey, ***@***.***> wrote:
@Padarn <https://github.com/Padarn> Let me know if you need any help in
adding DegreeScalerAggregation :)
—
Reply to this email directly, view it on GitHub
<#4864 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAGRPN3SAISTXIWBFCT5KHLVRPPVTANCNFSM5Z3EUHQQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
By communicating with Grab Inc and/or its subsidiaries, associate
companies and jointly controlled entities (“Grab Group”), you are deemed to
have consented to the processing of your personal data as set out in the
Privacy Notice which can be viewed at https://grab.com/privacy/
<https://grab.com/privacy/>
This email contains confidential information
and is only for the intended recipient(s). If you are not the intended
recipient(s), please do not disseminate, distribute or copy this email
Please notify Grab Group immediately if you have received this by mistake
and delete this email from your system. Email transmission cannot be
guaranteed to be secure or error-free as any information therein could be
intercepted, corrupted, lost, destroyed, delayed or incomplete, or contain
viruses. Grab Group do not accept liability for any errors or omissions in
the contents of this email arises as a result of email transmission. All
intellectual property rights in this email and attachments therein shall
remain vested in Grab Group, unless otherwise provided by law.
|
0ec3e95
to
d2f91fb
Compare
I've added the |
torch_geometric/nn/aggr/scaler.py
Outdated
|
||
out = self.agg(x, index, ptr, dim_size, dim) | ||
deg = degree(index, dtype=out.dtype) | ||
deg = deg.clamp_(1).view(*([-1] + [1] * (len(out.shape) - 1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was modified - before it was simply
deg.clamp_(1).view(-1, 1, 1)
I don't know if there is a much cleaner way to achieve the same thing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing. I think this is still not completely right though. How about we take out.size()
and write -1 to the dimension given to dim
, and use that for reshaping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm makes sense to use the dimension give to dim
, but just to clarify on using out.size()
, if for example
out.size() = [4, 4, 96]
deg.size() = [4]
dim = -2
I assume we want to do
deg.view(1, -1, 1)
so that the scaler can broadcast
Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that‘s what I meant. Sorry for the confusion.
@Padarn Thanks for adding the |
Actually I do mostly agree about that, but it has two downsides
1. PNAConv needs to override the aggregation function of MessagePassing
2. The degree tensor needs to be provided to the helper (unless we add this
to the constructor of the base class).
I'm still okay to go with that suggestion though, WDYT?
…On Sun, 3 Jul 2022, 4:41 am Guohao Li, ***@***.***> wrote:
I've added the DegreeScalerAggregation - it's a tiny bit ugly right now
(just a lift and shift). Also the only tests are through PNAConv.
@Padarn <https://github.com/Padarn> Thanks for adding the
DegreeScalerAggregation. IMO, the degree scaler is more like an optional
logic that we can do after reduce instead of an aggregation function. I
think it makes more sense to make it a helper function in the base class.
What do you think?
—
Reply to this email directly, view it on GitHub
<#4864 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAGRPN3C3HMD7LKS5EAZ5JTVSCSONANCNFSM5Z3EUHQQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
By communicating with Grab Inc and/or its subsidiaries, associate
companies and jointly controlled entities (“Grab Group”), you are deemed to
have consented to the processing of your personal data as set out in the
Privacy Notice which can be viewed at https://grab.com/privacy/
<https://grab.com/privacy/>
This email contains confidential information
and is only for the intended recipient(s). If you are not the intended
recipient(s), please do not disseminate, distribute or copy this email
Please notify Grab Group immediately if you have received this by mistake
and delete this email from your system. Email transmission cannot be
guaranteed to be secure or error-free as any information therein could be
intercepted, corrupted, lost, destroyed, delayed or incomplete, or contain
viruses. Grab Group do not accept liability for any errors or omissions in
the contents of this email arises as a result of email transmission. All
intellectual property rights in this email and attachments therein shall
remain vested in Grab Group, unless otherwise provided by law.
|
I personally like the view of seeing PNA as a composition of aggregators and scalers (that‘s how it is defined in the paper), so I am in favor of keeping it that way. |
That makes sense to me. Thoughts on merging this now then and following up on the items I commented on in the diff separately? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! LGTM. Feel free to merge once the remaining comments are resolved.
torch_geometric/nn/aggr/scaler.py
Outdated
|
||
out = self.agg(x, index, ptr, dim_size, dim) | ||
deg = degree(index, dtype=out.dtype) | ||
deg = deg.clamp_(1).view(*([-1] + [1] * (len(out.shape) - 1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing. I think this is still not completely right though. How about we take out.size()
and write -1 to the dimension given to dim
, and use that for reshaping.
1ec28ad
to
eacfff1
Compare
Simply remove the unneeded handling of aggregation in
PNAConv
now that this s handled inMessagePassing
.Noted in #4712