Skip to content

FSDP2 root level parameter management #1091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
dingqingy opened this issue Apr 11, 2025 · 3 comments · May be fixed by #1094
Open

FSDP2 root level parameter management #1091

dingqingy opened this issue Apr 11, 2025 · 3 comments · May be fixed by #1094
Assignees
Labels
module: fsdp question Further information is requested

Comments

@dingqingy
Copy link

Hi,

I am curious about the design decision of managing both token embeddings and the final output layer at the root fsdp level instead of treating them as different layers like other transformer blocks?

This coupled management seems to unshard the final output layer too early and reshard the token embedding too late in forward for example.

Also for the optimization (see here) that disables reshard_after_forward for the last transformer block layer, would it be more appropriate to perform this optimization on the final linear layer instead of the last transformer block?

Thanks!

@awgu
Copy link
Collaborator

awgu commented Apr 11, 2025

#382 is probably closer to ideal wrapping. I agree that separately wrapping embeddings and final output linear is more efficient. cc: @tianyu-l if he wants to change it.

@tianyu-l
Copy link
Contributor

We can do that!
One question for @awgu : if we use reshard_after_forward=False for the [norm, output], do we still need reshard_after_forward=False for the last transformer block?

@tianyu-l tianyu-l added module: fsdp question Further information is requested labels Apr 11, 2025
@awgu
Copy link
Collaborator

awgu commented Apr 11, 2025

@tianyu-l I think we can get rid of the reshard_after_forward=False for last transformer block. I think it increases peak memory slightly, and I saw several places copy it from torchtitan 😓

@tianyu-l tianyu-l self-assigned this Apr 11, 2025
@tianyu-l tianyu-l linked a pull request Apr 11, 2025 that will close this issue
@tianyu-l tianyu-l linked a pull request Apr 11, 2025 that will close this issue
wwwjn added a commit that referenced this issue Apr 13, 2025
As title. Set reshard_after_forward=False for last layer to avoid gather
right after reshard. Similar to llama as discussed in #1091.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fsdp question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants