Skip to content

Support DTensor params in local_sgd/diloco #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025

Conversation

H-Huang
Copy link
Member

@H-Huang H-Huang commented Apr 18, 2025

Tested in torchtitan (pytorch/torchtitan#1122) when the model is sharded with FSDP we need to convert the params to local tensors, update, then back to DTensors.

Ideally I would like to add a test in the integration tests but that would require us to set up the device_mesh / FSDP for a model

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 18, 2025
@H-Huang H-Huang requested review from d4l3k, fegin and XilunWu April 18, 2025 19:28
@H-Huang H-Huang force-pushed the diloco_titan branch 2 times, most recently from 9a89a69 to 6ba0eee Compare April 18, 2025 19:41
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- it is unfortunate that we have to add special cases to DTensor for this kind of thing. Does it do a full .to_local() call if you don't do this?

@H-Huang
Copy link
Member Author

H-Huang commented Apr 21, 2025

Without the Dtensor conditionals, the operations like .copy_() and pseudogradient calculation fail due to incompatible types. The allreduce also fails because got exception in all reduce -- skipping remaining: found no DeviceMesh from dtensor args for c10d.allreduce_.default!


def _restore_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self.original_parameters[name], non_blocking=False)
if isinstance(p, DTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If p is a DTensor, does p.copy_(), instead of p.data.copy_(), work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC it didn't work and causes a segfault

@H-Huang H-Huang merged commit 360c5c5 into pytorch:main Apr 21, 2025
10 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants