Skip to content

scale_grads with foreach + compile #2624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: gh/IvanKobzarev/2/base
Choose a base branch
from

Conversation

IvanKobzarev
Copy link

@IvanKobzarev IvanKobzarev commented Apr 22, 2025

Stack from ghstack (oldest at bottom):

Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)

tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml

PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,

Copy link

pytorch-bot bot commented Apr 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2624

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Cancelled Jobs

As of commit 9e1ab29 with merge base f3e4747 (image):

NEW FAILURE - The following job has failed:

  • GPU tests / gpu_test (3.11, stable) (gh)
    tests/recipes/test_full_finetune_distributed.py::TestFullFinetuneDistributedRecipe::test_loss_2d_parallel[llama3/8B_full-llama3-tune-4-1-True-2]

CANCELLED JOBS - The following jobs were cancelled. Please retry:

  • GPU tests / gpu_test (3.10, stable) (gh)
    tests/recipes/test_full_finetune_distributed.py::TestFullFinetuneDistributedRecipe::test_loss_2d_parallel[llama3/8B_full-llama3-tune-4-1-True-2]
  • GPU tests / gpu_test (3.9, stable) (gh)
    tests/recipes/test_full_finetune_distributed.py::TestFullFinetuneDistributedRecipe::test_loss_2d_parallel[llama3/8B_full-llama3-tune-4-1-True-2]

This comment was automatically generated by Dr. CI and updates every 15 minutes.

IvanKobzarev added a commit that referenced this pull request Apr 22, 2025
ghstack-source-id: 9c9ea57
Pull Request resolved: #2624
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2025

Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)


[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Apr 22, 2025
ghstack-source-id: 081a1a9
Pull Request resolved: #2624
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool work! Just a couple comments.

# Manually scale the gradients from unnormalized loss by total # of tokens
training.scale_grads(self._model, self.dp_degree / num_tokens)
def scale_grads_fn():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb q: Why can't we compile the function directly instead of having this extra function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, will this be addressed in this PR?

Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tried compiling object.method before. I can try.

Copy link
Author

@IvanKobzarev IvanKobzarev Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_rope_interval

Not in this PR.
There is a problem in flex_attention mask_mod - "Cuda illegal memory access" - some bad calculation of indices.
This needs a separate fix, will ask Flex attention engineers to take a look at a minimized repro.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tried compiling object.method before. I can try.

@IvanKobzarev Did this work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately I have a similar question to the one I left on the compiled optimizer step PR -- do we need to do this every step? If not I wonder if we could just do something like self.grad_scaler = get_grad_scale_fn(self._compile) once during init, returning the maybe-compiled scale function, then call it in a single line here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers @joecummings
Unfortunately for pt2 there is a difference and it fails if I not wrap it in a function.
E.g. I tried

 439         self.grad_scaler = training.scale_grads_
 440         if self._compile_scale_grads:
 441             self.grad_scaler = torch.compile(training.scale_grads_, backend=self._compile_backend)

Agree, that this is bad UX for user from PT2.

I will try to debug this for some time.

@@ -29,7 +30,7 @@
shard_model,
validate_no_params_on_meta_device,
)
from torchtune.training._grad_scaler import scale_grads
from torchtune.training._grad_scaler import scale_grads, scale_grads_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be added to all


Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```
PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,


[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Apr 22, 2025
ghstack-source-id: c482c46
Pull Request resolved: #2624
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 28.94737% with 27 lines in your changes missing coverage. Please review.

Please upload report for BASE (gh/IvanKobzarev/2/base@1dd7eb2). Learn more about missing BASE report.

Files with missing lines Patch % Lines
torchtune/training/_grad_scaler.py 31.25% 22 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 5 Missing ⚠️
Additional details and impacted files
@@                    Coverage Diff                    @@
##             gh/IvanKobzarev/2/base    #2624   +/-   ##
=========================================================
  Coverage                          ?   65.56%           
=========================================================
  Files                             ?      396           
  Lines                             ?    23804           
  Branches                          ?        0           
=========================================================
  Hits                              ?    15607           
  Misses                            ?     8197           
  Partials                          ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.


Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```
PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,


[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Apr 22, 2025
ghstack-source-id: 5d888be
Pull Request resolved: #2624
) -> dict[tuple[torch.device, torch.dtype], list[Tensor]]:
ret = defaultdict(list)
for i, tensor in enumerate(tensors):
ret[(tensor.device, tensor.dtype)].append(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: where are we actually using the dtype groupby? Is it implicit in the usage of _foreach_mul_?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically here it is not used. I planned to put this function into torch and that it can be useful for other functionality, but it can be a long process :)
So we do not use it here and can remove dtype groupping :)


Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```
PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,


[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Apr 28, 2025
ghstack-source-id: 961b75c
Pull Request resolved: #2624

Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```
PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,


[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Apr 28, 2025
ghstack-source-id: 9db049b
Pull Request resolved: #2624
@IvanKobzarev
Copy link
Author

Updated diff to compile only during initialization and use self.grad_scaler in the loop.

Just FYI for testing: compilation at the moment needs workarounds for 2 different problems:

  1. There is some problem with rng states preservation which can be workarounded
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 668353867ab..493883542f9 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -249,8 +249,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
             prior_dtype = torch.get_default_dtype()
             torch_rng_state = torch.random.get_rng_state()
             cuda_rng_state = None
-            if torch.cuda.is_available():
-                cuda_rng_state = torch.cuda.get_rng_state()
+            # if torch.cuda.is_available():
+            #     cuda_rng_state = torch.cuda.get_rng_state()
             allow_tf32 = torch._C._get_cublas_allow_tf32()
             prior_fwd_from_src = torch.fx.graph_module._forward_from_src
             torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
@@ -281,8 +281,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
                 )
                 if prior_mobile_allocator_state != curr_mobile_allocator_state:
                     torch._C._unset_default_mobile_cpu_allocator()
-                if cuda_rng_state is not None:
-                    torch.cuda.set_rng_state(cuda_rng_state)
+                # if cuda_rng_state is not None:
+                #     torch.cuda.set_rng_state(cuda_rng_state)
                 torch._C._set_cublas_allow_tf32(allow_tf32)
                 torch.fx.graph_module._forward_from_src = prior_fwd_from_src
                 assert guards.check(), (
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index b75b1d6c39f..7ca67523704 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -2110,15 +2110,15 @@ def preserve_rng_state():
     with disable_current_modes(), disable_functorch():
         rng_state = torch.clone(torch.random.get_rng_state())
         skip_frame_if_in_functorch_mode(rng_state)
-        if torch.cuda.is_available():
-            cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
+        # if torch.cuda.is_available():
+        #     cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
     try:
         yield
     finally:
         with torch.utils._python_dispatch._disable_current_modes():
             torch.random.set_rng_state(rng_state)
-            if torch.cuda.is_available():
-                torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
+            # if torch.cuda.is_available():
+            #     torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
 
 
 def is_jit_model(model0):
  1. There is illegal memory access in Chunked flex Attention x Caching.

If to remove /tmp/torchinductor_${USER} before every run - then it does not fires (or disable pt2 cache)


Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```
PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,


[ghstack-poisoned]
IvanKobzarev added a commit that referenced this pull request Apr 28, 2025
ghstack-source-id: 273da3e
Pull Request resolved: #2624
from torch import nn
from torch import nn, Tensor
from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors
from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support


def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers Should we add a deprecation warning here then? No need to keep both.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked with Evan, let's add the deprecation label here.

compile_components = cfg.get("compile_components")
if self._compile and compile_components:
self._compile_model = compile_components.get("model", True)
self._compile_loss = compile_components.get("loss", True)
self._compile_optimizer_step = compile_components.get(
"optimizer_step", False
)
self._compile_scale_grads = compile_components.get("scale_grads", False)

self._grad_scaler = training.scale_grads_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave a comment here explaining that we need this indirection for things to work w/ PT2 compile for some reason?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will add a comment.


Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```
PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,


[ghstack-poisoned]
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just deprecate the old scale_grads

from torch import nn
from torch import nn, Tensor
from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors
from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support


def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked with Evan, let's add the deprecation label here.


Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```
PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,


[ghstack-poisoned]

Scaling gradients with for_each, adding also compilation of it.
9 -> 9.8 (max value of tokens_per_second in the first 10 iterations)

Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu)

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```
PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,


[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants