diff --git a/finetune_pp_peft.py b/finetune_pp_peft.py index d530d4a..2788481 100644 --- a/finetune_pp_peft.py +++ b/finetune_pp_peft.py @@ -99,7 +99,7 @@ def main(): device_map[f"model.layers.{layer_i}.post_attention_layernorm.weight"] = device_id device_map[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = device_id - model = transformers.LLaMAForCausalLM.from_pretrained( + model = transformers.LlamaForCausalLM.from_pretrained( args.model_path, load_in_8bit=True, device_map=device_map,