Skip to content

Commit e8a2ed7

Browse files
cebtenzzrejquesnelle
authored andcommitted
llama : implement YaRN RoPE scaling (ggml-org#2268)
Co-authored-by: cebtenzzre <[email protected]> Co-authored-by: Jeffrey Quesnelle <[email protected]>
1 parent c09822b commit e8a2ed7

15 files changed

+764
-258
lines changed

common/common.cpp

+66-13
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,52 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
219219
break;
220220
}
221221
params.rope_freq_scale = std::stof(argv[i]);
222+
} else if (arg == "--rope-scaling") {
223+
if (++i >= argc) {
224+
invalid_param = true;
225+
break;
226+
}
227+
std::string value(argv[i]);
228+
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
229+
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
230+
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
231+
else { invalid_param = true; break; }
222232
} else if (arg == "--rope-scale") {
223233
if (++i >= argc) {
224234
invalid_param = true;
225235
break;
226236
}
227237
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
238+
} else if (arg == "--yarn-orig-ctx") {
239+
if (++i >= argc) {
240+
invalid_param = true;
241+
break;
242+
}
243+
params.yarn_orig_ctx = std::stoi(argv[i]);
244+
} else if (arg == "--yarn-ext-factor") {
245+
if (++i >= argc) {
246+
invalid_param = true;
247+
break;
248+
}
249+
params.yarn_ext_factor = std::stof(argv[i]);
250+
} else if (arg == "--yarn-attn-factor") {
251+
if (++i >= argc) {
252+
invalid_param = true;
253+
break;
254+
}
255+
params.yarn_attn_factor = std::stof(argv[i]);
256+
} else if (arg == "--yarn-beta-fast") {
257+
if (++i >= argc) {
258+
invalid_param = true;
259+
break;
260+
}
261+
params.yarn_beta_fast = std::stof(argv[i]);
262+
} else if (arg == "--yarn-beta-slow") {
263+
if (++i >= argc) {
264+
invalid_param = true;
265+
break;
266+
}
267+
params.yarn_beta_slow = std::stof(argv[i]);
228268
} else if (arg == "--memory-f32") {
229269
params.memory_f16 = false;
230270
} else if (arg == "--top-p") {
@@ -716,9 +756,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
716756
printf(" --cfg-negative-prompt-file FNAME\n");
717757
printf(" negative prompt file to use for guidance. (default: empty)\n");
718758
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
719-
printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
759+
printf(" --rope-scaling {none,linear,yarn}\n");
760+
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
761+
printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
720762
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
721-
printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
763+
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
764+
printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
765+
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
766+
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
767+
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
768+
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
722769
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
723770
printf(" --no-penalize-nl do not penalize newline token\n");
724771
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -826,17 +873,23 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
826873
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
827874
auto cparams = llama_context_default_params();
828875

829-
cparams.n_ctx = params.n_ctx;
830-
cparams.n_batch = params.n_batch;
831-
cparams.n_threads = params.n_threads;
832-
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
833-
cparams.mul_mat_q = params.mul_mat_q;
834-
cparams.seed = params.seed;
835-
cparams.f16_kv = params.memory_f16;
836-
cparams.logits_all = params.logits_all;
837-
cparams.embedding = params.embedding;
838-
cparams.rope_freq_base = params.rope_freq_base;
839-
cparams.rope_freq_scale = params.rope_freq_scale;
876+
cparams.n_ctx = params.n_ctx;
877+
cparams.n_batch = params.n_batch;
878+
cparams.n_threads = params.n_threads;
879+
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
880+
cparams.mul_mat_q = params.mul_mat_q;
881+
cparams.seed = params.seed;
882+
cparams.f16_kv = params.memory_f16;
883+
cparams.logits_all = params.logits_all;
884+
cparams.embedding = params.embedding;
885+
cparams.rope_scaling_type = params.rope_scaling_type;
886+
cparams.rope_freq_base = params.rope_freq_base;
887+
cparams.rope_freq_scale = params.rope_freq_scale;
888+
cparams.yarn_ext_factor = params.yarn_ext_factor;
889+
cparams.yarn_attn_factor = params.yarn_attn_factor;
890+
cparams.yarn_beta_fast = params.yarn_beta_fast;
891+
cparams.yarn_beta_slow = params.yarn_beta_slow;
892+
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
840893

841894
return cparams;
842895
}

common/common.h

+7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#define LOG_NO_FILE_LINE_FUNCTION
1010
#include "log.h"
1111

12+
#include <cmath>
1213
#include <string>
1314
#include <vector>
1415
#include <random>
@@ -54,6 +55,12 @@ struct gpt_params {
5455
int32_t n_beams = 0; // if non-zero then use beam search of given width.
5556
float rope_freq_base = 0.0f; // RoPE base frequency
5657
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
58+
float yarn_ext_factor = NAN; // YaRN extrapolation mix factor
59+
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
60+
float yarn_beta_fast = 32.0f;// YaRN low correction dim
61+
float yarn_beta_slow = 1.0f; // YaRN high correction dim
62+
int32_t yarn_orig_ctx = 0; // YaRN original context length
63+
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
5764

5865
// // sampling parameters
5966
struct llama_sampling_params sparams;

convert-baichuan-hf-to-gguf.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def parse_args() -> argparse.Namespace:
163163
if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
164164
if "type" in hparams["rope_scaling"]:
165165
if hparams["rope_scaling"]["type"] == "linear":
166-
gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
166+
gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
167+
gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
167168

168169

169170
# TOKENIZATION

convert.py

+48-49
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,11 @@ class Params:
151151
n_head_kv: int
152152
f_norm_eps: float
153153

154+
rope_scaling_type: gguf.RopeScalingType | None = None
154155
f_rope_freq_base: float | None = None
155156
f_rope_scale: float | None = None
157+
n_orig_ctx: int | None = None
158+
rope_finetuned: bool | None = None
156159

157160
ftype: GGMLFileType | None = None
158161

@@ -198,20 +201,20 @@ def guessed(model: LazyModel) -> Params:
198201
def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
199202
config = json.load(open(config_path))
200203

201-
n_vocab = config["vocab_size"]
202-
n_embd = config["hidden_size"]
203-
n_layer = config["num_hidden_layers"]
204-
n_ff = config["intermediate_size"]
205-
n_head = config["num_attention_heads"]
206-
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
207-
f_norm_eps = config["rms_norm_eps"]
208-
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
209-
204+
rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
210205
rope_scaling = config.get("rope_scaling")
211-
if isinstance(rope_scaling, dict) and rope_scaling.get("type") == "linear":
212-
f_rope_scale = config["rope_scaling"].get("factor")
213-
else:
214-
f_rope_scale = None
206+
207+
if rope_scaling is not None and (typ := rope_scaling.get("type")):
208+
rope_factor = rope_scaling.get("factor")
209+
f_rope_scale = rope_factor
210+
if typ == "linear":
211+
rope_scaling_type = gguf.RopeScalingType.LINEAR
212+
elif typ == "yarn":
213+
rope_scaling_type = gguf.RopeScalingType.YARN
214+
n_orig_ctx = rope_scaling['original_max_position_embeddings']
215+
rope_finetuned = rope_scaling['finetuned']
216+
else:
217+
raise NotImplementedError(f'Unknown rope scaling type: {typ}')
215218

216219
if "max_sequence_length" in config:
217220
n_ctx = config["max_sequence_length"]
@@ -222,16 +225,19 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
222225
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
223226

224227
return Params(
225-
n_vocab = n_vocab,
226-
n_embd = n_embd,
227-
n_layer = n_layer,
228-
n_ctx = n_ctx,
229-
n_ff = n_ff,
230-
n_head = n_head,
231-
n_head_kv = n_head_kv,
232-
f_norm_eps = f_norm_eps,
233-
f_rope_freq_base = f_rope_freq_base,
234-
f_rope_scale = f_rope_scale,
228+
n_vocab = config["vocab_size"],
229+
n_embd = config["hidden_size"],
230+
n_layer = config["num_hidden_layers"],
231+
n_ctx = n_ctx,
232+
n_ff = config["intermediate_size"],
233+
n_head = (n_head := config["num_attention_heads"]),
234+
n_head_kv = config.get("num_key_value_heads", n_head),
235+
f_norm_eps = config["rms_norm_eps"],
236+
f_rope_freq_base = config.get("rope_theta"),
237+
rope_scaling_type = rope_scaling_type,
238+
f_rope_scale = f_rope_scale,
239+
n_orig_ctx = n_orig_ctx,
240+
rope_finetuned = rope_finetuned,
235241
)
236242

237243
# LLaMA v2 70B params.json
@@ -240,17 +246,8 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
240246
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
241247
config = json.load(open(config_path))
242248

243-
n_vocab = config["vocab_size"] if "vocab_size" in config else -1
244-
n_embd = config["dim"]
245-
n_layer = config["n_layers"]
246-
n_ff = -1
247-
n_head = config["n_heads"]
248-
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
249-
f_norm_eps = config["norm_eps"]
250-
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
251-
252249
# hack to determine LLaMA v1 vs v2 vs CodeLlama
253-
if f_rope_freq_base == 1000000:
250+
if config.get("rope_theta") == 1000000:
254251
# CodeLlama
255252
n_ctx = 16384
256253
elif config["norm_eps"] == 1e-05:
@@ -260,22 +257,16 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
260257
# LLaMA v1
261258
n_ctx = 2048
262259

263-
if n_vocab == -1:
264-
n_vocab = model["tok_embeddings.weight"].shape[0]
265-
266-
if n_ff == -1:
267-
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
268-
269260
return Params(
270-
n_vocab = n_vocab,
271-
n_embd = n_embd,
272-
n_layer = n_layer,
261+
n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
262+
n_embd = config["dim"],
263+
n_layer = config["n_layers"],
273264
n_ctx = n_ctx,
274-
n_ff = n_ff,
275-
n_head = n_head,
276-
n_head_kv = n_head_kv,
277-
f_norm_eps = f_norm_eps,
278-
f_rope_freq_base = f_rope_freq_base,
265+
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
266+
n_head = (n_head := config["n_heads"]),
267+
n_head_kv = config.get("n_kv_heads", n_head),
268+
f_norm_eps = config["norm_eps"],
269+
f_rope_freq_base = config.get("rope_theta"),
279270
)
280271

281272
@staticmethod
@@ -831,8 +822,16 @@ def add_meta_arch(self, params: Params) -> None:
831822
if params.f_rope_freq_base is not None:
832823
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
833824

834-
if params.f_rope_scale is not None:
835-
self.gguf.add_rope_scale_linear(params.f_rope_scale)
825+
if params.rope_scaling_type:
826+
assert params.f_rope_scale is not None
827+
self.gguf.add_rope_scaling_type(params.rope_scaling_type)
828+
self.gguf.add_rope_scaling_factor(params.f_rope_scale)
829+
830+
if params.n_orig_ctx is not None:
831+
self.gguf.add_rope_scaling_orig_ctx_len(params.n_orig_ctx)
832+
833+
if params.rope_finetuned is not None:
834+
self.gguf.add_rope_scaling_finetuned(params.rope_finetuned)
836835

837836
if params.ftype is not None:
838837
self.gguf.add_file_type(params.ftype)

examples/finetune/finetune.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
642642
const int rope_mode = 0;
643643

644644
return ggml_rope_custom(ctx,
645-
t, KQ_pos, n_rot, rope_mode, n_ctx,
646-
rope_freq_base, rope_freq_scale);
645+
t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
646+
rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f
647+
);
647648
};
648649

649650
set_name(tokens_input, "tokens_input");

examples/server/server.cpp

+55-4
Original file line numberDiff line numberDiff line change
@@ -1755,12 +1755,18 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
17551755
printf("options:\n");
17561756
printf(" -h, --help show this help message and exit\n");
17571757
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
1758-
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
1758+
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
17591759
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
1760-
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
1760+
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
1761+
printf(" --rope-scaling {none,linear,yarn}\n");
1762+
printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
17611763
printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
1762-
printf(" --rope-freq-scale N RoPE frequency scaling factor (default: loaded from model)\n");
1763-
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
1764+
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
1765+
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
1766+
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
1767+
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
1768+
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
1769+
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
17641770
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
17651771
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
17661772
if (llama_mlock_supported())
@@ -1881,6 +1887,19 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
18811887
}
18821888
params.n_ctx = std::stoi(argv[i]);
18831889
}
1890+
else if (arg == "--rope-scaling")
1891+
{
1892+
if (++i >= argc)
1893+
{
1894+
invalid_param = true;
1895+
break;
1896+
}
1897+
std::string value(argv[i]);
1898+
/**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
1899+
else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
1900+
else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
1901+
else { invalid_param = true; break; }
1902+
}
18841903
else if (arg == "--rope-freq-base")
18851904
{
18861905
if (++i >= argc)
@@ -1899,6 +1918,38 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
18991918
}
19001919
params.rope_freq_scale = std::stof(argv[i]);
19011920
}
1921+
else if (arg == "--yarn-ext-factor")
1922+
{
1923+
if (++i >= argc) {
1924+
invalid_param = true;
1925+
break;
1926+
}
1927+
params.yarn_ext_factor = std::stof(argv[i]);
1928+
}
1929+
else if (arg == "--yarn-attn-factor")
1930+
{
1931+
if (++i >= argc) {
1932+
invalid_param = true;
1933+
break;
1934+
}
1935+
params.yarn_attn_factor = std::stof(argv[i]);
1936+
}
1937+
else if (arg == "--yarn-beta-fast")
1938+
{
1939+
if (++i >= argc) {
1940+
invalid_param = true;
1941+
break;
1942+
}
1943+
params.yarn_beta_fast = std::stof(argv[i]);
1944+
}
1945+
else if (arg == "--yarn-beta-slow")
1946+
{
1947+
if (++i >= argc) {
1948+
invalid_param = true;
1949+
break;
1950+
}
1951+
params.yarn_beta_slow = std::stof(argv[i]);
1952+
}
19021953
else if (arg == "--memory-f32" || arg == "--memory_f32")
19031954
{
19041955
params.memory_f16 = false;

examples/train-text-from-scratch/train-text-from-scratch.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,9 @@ static struct ggml_tensor * llama_build_train_graphs(
349349
// not capturing these, to silcence warnings
350350
const int rope_mode = 0;
351351

352-
return ggml_rope_custom(ctx,
353-
t, KQ_pos, n_rot, rope_mode, n_ctx,
354-
rope_freq_base, rope_freq_scale);
352+
return ggml_rope_custom(
353+
ctx, t, KQ_pos, n_rot, rope_mode, n_ctx, 0, rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
354+
);
355355
};
356356

357357
set_name(tokens_input, "tokens_input");

0 commit comments

Comments
 (0)