-
Notifications
You must be signed in to change notification settings - Fork 582
scale_grads with foreach + compile #2624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/IvanKobzarev/2/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2624
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Cancelled JobsAs of commit 9e1ab29 with merge base f3e4747 ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool work! Just a couple comments.
recipes/full_finetune_distributed.py
Outdated
# Manually scale the gradients from unnormalized loss by total # of tokens | ||
training.scale_grads(self._model, self.dp_degree / num_tokens) | ||
def scale_grads_fn(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb q: Why can't we compile the function directly instead of having this extra function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, will this be addressed in this PR?
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not tried compiling object.method before. I can try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip_rope_interval
Not in this PR.
There is a problem in flex_attention mask_mod - "Cuda illegal memory access" - some bad calculation of indices.
This needs a separate fix, will ask Flex attention engineers to take a look at a minimized repro.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not tried compiling object.method before. I can try.
@IvanKobzarev Did this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separately I have a similar question to the one I left on the compiled optimizer step PR -- do we need to do this every step? If not I wonder if we could just do something like self.grad_scaler = get_grad_scale_fn(self._compile)
once during init, returning the maybe-compiled scale function, then call it in a single line here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers @joecummings
Unfortunately for pt2 there is a difference and it fails if I not wrap it in a function.
E.g. I tried
439 self.grad_scaler = training.scale_grads_
440 if self._compile_scale_grads:
441 self.grad_scaler = torch.compile(training.scale_grads_, backend=self._compile_backend)
Agree, that this is bad UX for user from PT2.
I will try to debug this for some time.
@@ -29,7 +30,7 @@ | |||
shard_model, | |||
validate_no_params_on_meta_device, | |||
) | |||
from torchtune.training._grad_scaler import scale_grads | |||
from torchtune.training._grad_scaler import scale_grads, scale_grads_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to be added to all
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None, [ghstack-poisoned]
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## gh/IvanKobzarev/2/base #2624 +/- ##
=========================================================
Coverage ? 65.56%
=========================================================
Files ? 396
Lines ? 23804
Branches ? 0
=========================================================
Hits ? 15607
Misses ? 8197
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None, [ghstack-poisoned]
torchtune/training/_grad_scaler.py
Outdated
) -> dict[tuple[torch.device, torch.dtype], list[Tensor]]: | ||
ret = defaultdict(list) | ||
for i, tensor in enumerate(tensors): | ||
ret[(tensor.device, tensor.dtype)].append(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q: where are we actually using the dtype groupby? Is it implicit in the usage of _foreach_mul_
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically here it is not used. I planned to put this function into torch and that it can be useful for other functionality, but it can be a long process :)
So we do not use it here and can remove dtype groupping :)
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None, [ghstack-poisoned]
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None, [ghstack-poisoned]
Updated diff to compile only during initialization and use Just FYI for testing: compilation at the moment needs workarounds for 2 different problems:
If to remove /tmp/torchinductor_${USER} before every run - then it does not fires (or disable pt2 cache) |
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None, [ghstack-poisoned]
from torch import nn | ||
from torch import nn, Tensor | ||
from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors | ||
from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support | ||
|
||
|
||
def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers Should we add a deprecation warning here then? No need to keep both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talked with Evan, let's add the deprecation label here.
compile_components = cfg.get("compile_components") | ||
if self._compile and compile_components: | ||
self._compile_model = compile_components.get("model", True) | ||
self._compile_loss = compile_components.get("loss", True) | ||
self._compile_optimizer_step = compile_components.get( | ||
"optimizer_step", False | ||
) | ||
self._compile_scale_grads = compile_components.get("scale_grads", False) | ||
|
||
self._grad_scaler = training.scale_grads_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you leave a comment here explaining that we need this indirection for things to work w/ PT2 compile for some reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will add a comment.
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None, [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just deprecate the old scale_grads
from torch import nn | ||
from torch import nn, Tensor | ||
from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors | ||
from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support | ||
|
||
|
||
def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talked with Evan, let's add the deprecation label here.
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None, [ghstack-poisoned]
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None, [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)
Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)
PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,