Skip to content

Commit 1be43b6

Browse files
authored
LoRA + Llama4 (#2583)
1 parent 9a88c16 commit 1be43b6

20 files changed

+1347
-68
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
[**Overview**](#overview-) | [**Installation**](#installation-%EF%B8%8F) | [**Get Started**](#get-started-) | [**Documentation**](https://pytorch.org/torchtune/main/index.html) | [**Community**](#community-) | [**Citing torchtune**](#citing-torchtune-) | [**License**](#license)
1111

1212
### 📣 Recent updates 📣
13-
* *April 2025*: **Llama4** is now available in torchtune! Try out our full finetuning configs [here](recipes/configs/llama4) (LoRA coming soon!)
13+
* *April 2025*: **Llama4** is now available in torchtune! Try out our full and LoRA finetuning configs [here](recipes/configs/llama4)
1414
* *February 2025*: Multi-node training is officially [open for business in torchtune](https://pytorch.org/torchtune/main/tutorials/multinode.html)! Full finetune on multiple nodes to take advantage of larger batch sizes and models.
1515
* *December 2024*: torchtune now supports **Llama 3.3 70B**! Try it out by following our installation instructions [here](#installation-%EF%B8%8F), then run any of the configs [here](recipes/configs/llama3_3).
1616
* *November 2024*: torchtune has released [v0.4.0](https://github.com/pytorch/torchtune/releases/tag/v0.4.0) which includes stable support for exciting features like activation offloading and multimodal QLoRA
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Config for multi-device full finetuning in full_finetune_distributed.py
2+
# using a Llama4 17Bx16E MoE model
3+
#
4+
# This config assumes that you've run the following command before launching:
5+
# tune download meta-llama/Llama-4-Scout-17B-16E-Instruct
6+
#
7+
# To launch on 8 devices, run the following command from root:
8+
# tune run --nproc_per_node 8 lora_finetune_distributed --config llama4/scout_17B_16E_lora
9+
#
10+
# You can add specific overrides through the command line. For example, to use a larger bsz:
11+
# tune run --nproc_per_node 8 lora_finetune_distributed --config llama4/scout_17B_16E_lora batch_size=8
12+
#
13+
# This config was only tested on 8xA100 machine.
14+
15+
output_dir: /tmp/torchtune/llama4_17Bx16E/lora
16+
17+
# Modeling Arguments
18+
model:
19+
_component_: torchtune.models.llama4.lora_llama4_scout_17b_16e
20+
decoder_trainable: "lora"
21+
encoder_trainable: "frozen"
22+
fusion_trainable: "lora"
23+
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
24+
apply_lora_to_mlp: True
25+
apply_lora_to_output: False
26+
lora_rank: 16 # higher increases accuracy and memory
27+
lora_alpha: 32 # usually alpha=2*rank
28+
lora_dropout: 0.0
29+
30+
tokenizer:
31+
_component_: torchtune.models.llama4.llama4_transform
32+
path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model
33+
max_seq_len: null
34+
max_num_tiles: 16
35+
36+
checkpointer:
37+
_component_: torchtune.training.FullModelHFCheckpointer
38+
checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct
39+
checkpoint_files:
40+
filename_format: model-{}-of-{}.safetensors
41+
max_filename: "00050"
42+
recipe_checkpoint: null
43+
output_dir: ${output_dir}
44+
model_type: LLAMA4
45+
resume_from_checkpoint: False
46+
47+
# Dataset
48+
dataset:
49+
_component_: torchtune.datasets.alpaca_dataset
50+
packed: False
51+
seed: null
52+
shuffle: True
53+
54+
# Training arguments
55+
epochs: 1
56+
max_steps_per_epoch: null
57+
batch_size: 1
58+
gradient_accumulation_steps: 1 # Use to increase effective batch size
59+
optimizer:
60+
_component_: torch.optim.AdamW
61+
lr: 2e-5
62+
fused: False
63+
optimizer_in_bwd: False
64+
lr_scheduler:
65+
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
66+
num_warmup_steps: 100
67+
loss:
68+
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
69+
clip_grad_norm: null
70+
71+
# cuda, cpu, rocm, xpu...
72+
device: cuda
73+
74+
# Memory management / performance
75+
enable_activation_checkpointing: True
76+
enable_activation_offloading: False
77+
custom_sharded_layers: ['tok_embeddings']
78+
fsdp_cpu_offload: False
79+
compile: False # torch.compile, set to true for perf/memory improvement
80+
81+
# Reduced precision
82+
dtype: bf16
83+
84+
# Log metrics during training
85+
metric_logger:
86+
_component_: torchtune.training.metric_logging.DiskLogger
87+
log_dir: ${output_dir}/logs
88+
log_every_n_steps: 1
89+
log_peak_memory_stats: True
90+
91+
# Useful for understanding how to optimize memory and performance
92+
profiler:
93+
_component_: torchtune.training.setup_torch_profiler
94+
enabled: False

recipes/dev/lora_finetune_distributed_multi_dataset.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,11 @@
2626
from torchtune.data._utils import get_dataloader, get_multi_dataset, load_hf_dataset
2727
from torchtune.datasets._sft import SFTTransform
2828
from torchtune.modules.peft import (
29-
DoRALinear,
29+
AdapterModule,
3030
get_adapter_params,
3131
get_adapter_state_dict,
3232
get_lora_module_names,
3333
get_merged_lora_ckpt,
34-
LoRALinear,
3534
set_trainable_params,
3635
validate_missing_and_unexpected_for_lora,
3736
)
@@ -495,9 +494,7 @@ def _setup_model(
495494
with training.set_default_dtype(self._dtype), self._device:
496495
lora_device = "cpu" if fsdp_cpu_offload else self._device
497496
for m in model.modules():
498-
if (
499-
isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
500-
) and not lora_weights_state_dict:
497+
if (isinstance(m, AdapterModule)) and not lora_weights_state_dict:
501498
# lora may not be covered in state dict
502499
# if finetune for the 1st time
503500
m.to_empty(device=lora_device)

recipes/knowledge_distillation_distributed.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,11 @@
2424
from torchtune.data import padded_collate_packed, padded_collate_sft
2525
from torchtune.datasets import ConcatDataset
2626
from torchtune.modules.peft import (
27-
DoRALinear,
27+
AdapterModule,
2828
get_adapter_params,
2929
get_adapter_state_dict,
3030
get_lora_module_names,
3131
get_merged_lora_ckpt,
32-
LoRALinear,
3332
set_trainable_params,
3433
validate_missing_and_unexpected_for_lora,
3534
)
@@ -478,9 +477,7 @@ def _setup_model(
478477
with training.set_default_dtype(self._dtype), self._device:
479478
lora_device = "cpu" if fsdp_cpu_offload else self._device
480479
for m in model.modules():
481-
if (
482-
isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
483-
) and not lora_weights_state_dict:
480+
if isinstance(m, AdapterModule) and not lora_weights_state_dict:
484481
# lora may not be covered in state dict
485482
# if finetune for the 1st time
486483
m.to_empty(device=lora_device)

recipes/lora_dpo_distributed.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@
2323
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo
2424
from torchtune.datasets import ConcatDataset
2525
from torchtune.modules.peft import (
26+
AdapterModule,
2627
disable_adapter,
27-
DoRALinear,
2828
get_adapter_params,
2929
get_adapter_state_dict,
3030
get_lora_module_names,
3131
get_merged_lora_ckpt,
32-
LoRALinear,
3332
set_trainable_params,
3433
validate_missing_and_unexpected_for_lora,
3534
)
@@ -407,9 +406,7 @@ def _setup_model(
407406
with training.set_default_dtype(self._dtype), self._device:
408407
lora_device = "cpu" if fsdp_cpu_offload else self._device
409408
for m in model.modules():
410-
if (
411-
isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
412-
) and not lora_weights_state_dict:
409+
if (isinstance(m, AdapterModule)) and not lora_weights_state_dict:
413410
# lora may not be covered in state dict
414411
# if finetune for the 1st time
415412
m.to_empty(device=lora_device)

recipes/lora_finetune_distributed.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@
2525
from torchtune.data import padded_collate_packed
2626
from torchtune.datasets import ConcatDataset
2727
from torchtune.modules.peft import (
28-
DoRALinear,
28+
AdapterModule,
2929
get_adapter_params,
3030
get_adapter_state_dict,
3131
get_lora_module_names,
3232
get_merged_lora_ckpt,
33-
LoRALinear,
3433
set_trainable_params,
3534
validate_missing_and_unexpected_for_lora,
3635
)
@@ -519,9 +518,7 @@ def _setup_model(
519518
with training.set_default_dtype(self._dtype), self._device:
520519
lora_device = "cpu" if fsdp_cpu_offload else self._device
521520
for m in model.modules():
522-
if (
523-
isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
524-
) and not lora_weights_state_dict:
521+
if (isinstance(m, AdapterModule)) and not lora_weights_state_dict:
525522
# lora may not be covered in state dict
526523
# if finetune for the 1st time
527524
m.to_empty(device=lora_device)

recipes/qat_lora_finetune_distributed.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from torchtune.data import padded_collate_packed
2626
from torchtune.datasets import ConcatDataset
2727
from torchtune.modules.peft import (
28+
AdapterModule,
2829
DoRALinear,
2930
get_adapter_params,
3031
get_adapter_state_dict,
3132
get_lora_module_names,
3233
get_merged_lora_ckpt,
33-
LoRALinear,
3434
set_trainable_params,
3535
validate_missing_and_unexpected_for_lora,
3636
)
@@ -532,9 +532,7 @@ def _setup_model(
532532
with training.set_default_dtype(self._dtype), self._device:
533533
lora_device = "cpu" if fsdp_cpu_offload else self._device
534534
for m in model.modules():
535-
if (
536-
isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
537-
) and not lora_weights_state_dict:
535+
if (isinstance(m, AdapterModule)) and not lora_weights_state_dict:
538536
# lora may not be covered in state dict
539537
# if finetune for the 1st time
540538
m.to_empty(device=lora_device)

tests/torchtune/models/llama3_2_vision/test_llama_vision_lora.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
from torchtune.models.llama3_2_vision._component_builders import (
1111
lora_llama3_2_vision_decoder,
1212
lora_llama3_2_vision_encoder,
13-
LoRATrainable,
1413
)
1514
from torchtune.modules.model_fusion import DeepFusionModel
16-
from torchtune.modules.peft import get_adapter_params
15+
from torchtune.modules.peft import get_adapter_params, LoRATrainable
1716
from torchtune.training.seed import set_seed
1817

1918
EMBED_DIM = 128

tests/torchtune/modules/moe/test_experts.py

+104-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99
import torch
1010
from tests.test_utils import assert_expected, fixed_init_model
1111
from torch import nn
12-
from torchtune.modules.moe import GroupedExperts
12+
from torchtune.modules.moe import GroupedExperts, LoRAGroupedExperts
13+
from torchtune.modules.peft import LoRALinear
1314
from torchtune.training.seed import set_seed
1415

16+
RANK = 4
17+
ALPHA = 1.0
18+
SEQ_LEN = 32
19+
1520

1621
@pytest.fixture(autouse=True)
1722
def random():
@@ -57,3 +62,101 @@ def test_forward(self, experts, num_tokens_per_expert, dim):
5762

5863
assert out.shape == (16, dim)
5964
assert_expected(out.mean().item(), 120.8260, atol=1e-3, rtol=1e-3)
65+
66+
67+
class TestLoRAGroupedExperts:
68+
@pytest.fixture
69+
def dim(self) -> int:
70+
return 64
71+
72+
@pytest.fixture
73+
def hidden_dim(self) -> int:
74+
return 128
75+
76+
@pytest.fixture
77+
def num_experts(self) -> int:
78+
return 8
79+
80+
@pytest.fixture
81+
def experts_per_token(self) -> int:
82+
return 2
83+
84+
@pytest.fixture
85+
def num_tokens_per_expert(self, num_experts) -> int:
86+
return torch.tensor([1, 2, 1, 4, 3, 1, 2, 2], dtype=torch.int)
87+
88+
@pytest.fixture
89+
def inputs(self, dim, num_experts, experts_per_token) -> torch.Tensor:
90+
inputs = torch.randn(num_experts * experts_per_token, SEQ_LEN, dim)
91+
return inputs
92+
93+
@pytest.fixture
94+
def experts(self, dim, hidden_dim, num_experts) -> nn.Module:
95+
experts = GroupedExperts(
96+
dim=dim,
97+
hidden_dim=hidden_dim,
98+
num_experts=num_experts,
99+
)
100+
fixed_init_model(experts, min_val=-0.1, max_val=0.1)
101+
return experts
102+
103+
@pytest.fixture
104+
def lora_experts(self, dim, hidden_dim, num_experts) -> nn.Module:
105+
experts = LoRAGroupedExperts(
106+
dim=dim,
107+
hidden_dim=hidden_dim,
108+
num_experts=num_experts,
109+
rank=RANK,
110+
alpha=ALPHA,
111+
)
112+
fixed_init_model(experts, min_val=-0.1, max_val=0.1)
113+
return experts
114+
115+
@pytest.fixture
116+
def lora_linear(self, dim, hidden_dim):
117+
def create_lora_linear(dim=dim, hidden_dim=hidden_dim):
118+
lora_linear = LoRALinear(
119+
in_dim=dim,
120+
out_dim=hidden_dim,
121+
rank=RANK,
122+
alpha=ALPHA,
123+
)
124+
fixed_init_model(lora_linear)
125+
return lora_linear
126+
127+
return create_lora_linear
128+
129+
def test_lora_tc_layer_forward(self, lora_linear, lora_experts, inputs):
130+
"""Compare TC forward with LoRALinear as reference"""
131+
lora = lora_linear()
132+
actual = lora_experts._lora_tc_layer_forward(
133+
inputs[0],
134+
lora.weight.T,
135+
lora.lora_a.weight.T,
136+
lora.lora_b.weight.T,
137+
)
138+
expected = lora(inputs[0])
139+
assert_expected(actual, expected, rtol=1e-6, atol=1e-4)
140+
141+
def test_forward_disabled(
142+
self, experts, lora_experts, inputs, num_tokens_per_expert
143+
):
144+
"""Test forward with lora layers disabled and comparing with GroupedExperts"""
145+
lora_experts.disabled = True
146+
actual = lora_experts(inputs, num_tokens_per_expert)
147+
expected = experts(inputs, num_tokens_per_expert)
148+
assert_expected(actual, expected, rtol=1e-6, atol=1e-4)
149+
150+
def test_forward(
151+
self,
152+
lora_experts,
153+
inputs,
154+
num_experts,
155+
experts_per_token,
156+
dim,
157+
num_tokens_per_expert,
158+
) -> None:
159+
expected = torch.tensor(0.441491)
160+
actual = lora_experts(inputs, num_tokens_per_expert)
161+
assert actual.shape == (num_experts * experts_per_token, SEQ_LEN, dim)
162+
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)

torchtune/_recipe_registry.py

+4
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,10 @@ class Recipe:
442442
name="llama3_2_vision/90B_qlora",
443443
file_path="llama3_2_vision/90B_qlora.yaml",
444444
),
445+
Config(
446+
name="llama4/scout_17B_16E_lora",
447+
file_path="llama4/scout_17B_16E_lora.yaml",
448+
),
445449
],
446450
supports_distributed=True,
447451
),

torchtune/models/llama3_2_vision/_component_builders.py

-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from enum import Enum
87
from functools import partial
98
from typing import List, Optional
109

@@ -327,12 +326,6 @@ def llama3_2_vision_projection_head(
327326
# ------------------ LoRA Llama 3.2 Vision ------------------
328327

329328

330-
class LoRATrainable(Enum):
331-
FULL = "full"
332-
LORA = "lora"
333-
FROZEN = "frozen"
334-
335-
336329
def lora_llama3_2_vision_encoder(
337330
encoder_lora: bool,
338331
fusion_lora: bool,

0 commit comments

Comments
 (0)