-
Notifications
You must be signed in to change notification settings - Fork 31
Support DTensor params in local_sgd/diloco #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9a89a69
to
6ba0eee
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- it is unfortunate that we have to add special cases to DTensor for this kind of thing. Does it do a full .to_local() call if you don't do this?
Without the Dtensor conditionals, the operations like |
|
||
def _restore_parameters(self) -> None: | ||
with torch.no_grad(): | ||
# TODO: consider running copy on a separate stream | ||
for name, p in self._model.named_parameters(): | ||
p.data.copy_(self.original_parameters[name], non_blocking=False) | ||
if isinstance(p, DTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If p
is a DTensor, does p.copy_()
, instead of p.data.copy_()
, work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC it didn't work and causes a segfault
Tested in torchtitan (pytorch/torchtitan#1122) when the model is sharded with FSDP we need to convert the params to local tensors, update, then back to DTensors.
Ideally I would like to add a test in the integration tests but that would require us to set up the device_mesh / FSDP for a model