You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am curious about the design decision of managing both token embeddings and the final output layer at the root fsdp level instead of treating them as different layers like other transformer blocks?
This coupled management seems to unshard the final output layer too early and reshard the token embedding too late in forward for example.
Also for the optimization (see here) that disables reshard_after_forward for the last transformer block layer, would it be more appropriate to perform this optimization on the final linear layer instead of the last transformer block?
Thanks!
The text was updated successfully, but these errors were encountered:
#382 is probably closer to ideal wrapping. I agree that separately wrapping embeddings and final output linear is more efficient. cc: @tianyu-l if he wants to change it.
We can do that!
One question for @awgu : if we use reshard_after_forward=False for the [norm, output], do we still need reshard_after_forward=False for the last transformer block?
@tianyu-l I think we can get rid of the reshard_after_forward=False for last transformer block. I think it increases peak memory slightly, and I saw several places copy it from torchtitan 😓
Hi,
I am curious about the design decision of managing both token embeddings and the final output layer at the root fsdp level instead of treating them as different layers like other transformer blocks?
This coupled management seems to unshard the final output layer too early and reshard the token embedding too late in forward for example.
Also for the optimization (see here) that disables
reshard_after_forward
for the last transformer block layer, would it be more appropriate to perform this optimization on the final linear layer instead of the last transformer block?Thanks!
The text was updated successfully, but these errors were encountered: