Skip to content

Commit 9e1ab29

Browse files
committed
Update on "scale_grads with foreach + compile"
Scaling gradients with for_each, adding also compilation of it. 9 -> 9.8 (max value of tokens_per_second in the first 10 iterations) Helps with tokens_per_second (for llama4 the parameters are on cpu, not super big wins, but should be better when parameters are on gpu) ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None, [ghstack-poisoned]
2 parents b0a279c + bd7584b commit 9e1ab29

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

recipes/full_finetune_distributed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def setup(self, cfg: DictConfig) -> None:
318318
self._compile_model = compile.get("model", True)
319319
self._compile_loss = compile.get("loss", True)
320320
self._compile_optimizer_step = compile.get("optimizer_step", False)
321-
self._compile_scale_grads = compile_components.get("scale_grads", True)
321+
self._compile_scale_grads = compile.get("scale_grads", True)
322322

323323
# This indirection is needed to apply torch.compile to scale_grads step.
324324
self._grad_scaler = training.scale_grads_

torchtune/training/_grad_scaler.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from torch import nn, Tensor
1212
from torch.nn.utils.clip_grad import _no_grad, _tensor_or_tensors
1313
from torch.utils._foreach_utils import _device_has_foreach_support, _has_foreach_support
14+
from torchtune.utils._logging import deprecated
1415

1516

17+
@deprecated(msg="Please use `scale_grads_` instead.")
1618
def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None:
1719
"""
1820
Utility to scale the gradients of a model.

0 commit comments

Comments
 (0)