Skip to content

Commit f2f49bd

Browse files
committed
Update on "compile optimizer"
Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned]
1 parent 7a6a4d2 commit f2f49bd

File tree

2 files changed

+17
-17
lines changed

2 files changed

+17
-17
lines changed

recipes/configs/llama4/scout_17B_16E_full.yaml

+8-6
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,14 @@ device: cuda
6969
enable_activation_checkpointing: True
7070
enable_activation_offloading: False
7171
fsdp_cpu_offload: True
72-
compile: False # torch.compile, set to true for perf/memory improvement
73-
74-
compile_components:
75-
model: True
76-
loss: True
77-
optimizer_step: False
72+
# compile True means use torch.compile for all components
73+
# compile False means no torch.compile
74+
# compile Dictionary with keys: "model", "loss", "optimizer_step"
75+
# enables torch.compile only for specified components.
76+
compile: False
77+
# model: True
78+
# loss: True
79+
# optimizer_step: False
7880

7981
# Reduced precision
8082
dtype: bf16

recipes/full_finetune_distributed.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -306,19 +306,17 @@ def setup(self, cfg: DictConfig) -> None:
306306
# Load the base model
307307
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
308308

309-
self._compile = cfg.get("compile", False)
309+
compile = cfg.get("compile")
310+
compile_bool = bool(compile)
310311
self._compile_backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
311312

312-
self._compile_model = False
313-
self._compile_loss = False
314-
self._compile_optimizer_step = False
315-
compile_components = cfg.get("compile_components")
316-
if self._compile and compile_components:
317-
self._compile_model = compile_components.get("model", True)
318-
self._compile_loss = compile_components.get("loss", True)
319-
self._compile_optimizer_step = compile_components.get(
320-
"optimizer_step", False
321-
)
313+
self._compile_model = compile_bool
314+
self._compile_loss = compile_bool
315+
self._compile_optimizer_step = compile_bool
316+
if isinstance(compile, dict):
317+
self._compile_model = compile.get("model", True)
318+
self._compile_loss = compile.get("loss", True)
319+
self._compile_optimizer_step = compile.get("optimizer_step", False)
322320

323321
self._model = self._setup_model(
324322
cfg_model=cfg.model,

0 commit comments

Comments
 (0)