Skip to content

Commit 6e7cca4

Browse files
jxyggerganov
andauthored
llama : add custom RoPE (#2054)
* Implement customizable RoPE The original RoPE has pre-defined parameters theta_i = 10000^(−2(i−1)/d), for i in [1, 2, ..., d/2] Our customizable RoPE, ggml_rope_custom_inplace, uses theta_i = scale * base^(−2(i−1)/d), for i in [1, 2, ..., d/2] with the default matches the original scale = 1.0 base = 10000 The new command line arguments --rope-freq-base --rope-freq-scale set the two new RoPE parameter. Recent researches show changing these two parameters extends the context limit with minimal loss. 1. Extending Context to 8K kaiokendev https://kaiokendev.github.io/til#extending-context-to-8k 2. Extending Context Window of Large Language Models via Positional Interpolation Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian https://arxiv.org/abs/2306.15595 3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation. https://www.reddit.com/user/bloc97 https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ For the bold, try adding the following command line parameters to your favorite model: -c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5 * ggml-metal: fix custom rope * common: fix argument names in help * llama: increase MEM_REQ_EVAL for MODEL_3B It avoids crashing for quantized weights on CPU. Better ways to calculate the required buffer size would be better. * llama: make MEM_REQ_EVAL depend on n_ctx * server: use proper Content-Type in curl examples Without the header Content-Type: application/json, curl will POST with Content-Type: application/x-www-form-urlencoded Though our simple server doesn't care, the httplib.h used has a limit with CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 With Content-Type: application/json, we can send large json data. * style : minor fixes, mostly indentations * ggml : fix asserts --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent a6803ca commit 6e7cca4

12 files changed

+185
-67
lines changed

examples/common.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
168168
break;
169169
}
170170
params.n_ctx = std::stoi(argv[i]);
171+
} else if (arg == "--rope-freq-base") {
172+
if (++i >= argc) {
173+
invalid_param = true;
174+
break;
175+
}
176+
params.rope_freq_base = std::stof(argv[i]);
177+
} else if (arg == "--rope-freq-scale") {
178+
if (++i >= argc) {
179+
invalid_param = true;
180+
break;
181+
}
182+
params.rope_freq_scale = std::stof(argv[i]);
171183
} else if (arg == "--memory-f32") {
172184
params.memory_f16 = false;
173185
} else if (arg == "--top-p") {
@@ -493,6 +505,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
493505
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
494506
fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
495507
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
508+
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
509+
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
496510
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
497511
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
498512
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -573,6 +587,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
573587
lparams.use_mlock = params.use_mlock;
574588
lparams.logits_all = params.perplexity;
575589
lparams.embedding = params.embedding;
590+
lparams.rope_freq_base = params.rope_freq_base;
591+
lparams.rope_freq_scale = params.rope_freq_scale;
576592

577593
return lparams;
578594
}

examples/common.h

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ struct gpt_params {
3232
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
3333
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
3434
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
35+
float rope_freq_base = 10000.0f; // RoPE base frequency
36+
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
3537

3638
// sampling parameters
3739
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

examples/main/main.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,17 @@ int main(int argc, char ** argv) {
8484
return 0;
8585
}
8686

87+
if (params.rope_freq_base != 10000.0) {
88+
fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
89+
}
90+
91+
if (params.rope_freq_scale != 1.0) {
92+
fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
93+
}
94+
8795
if (params.n_ctx > 2048) {
88-
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
89-
"expect poor results\n", __func__, params.n_ctx);
96+
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
97+
" you are on your own\n", __func__, params.n_ctx);
9098
} else if (params.n_ctx < 8) {
9199
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
92100
params.n_ctx = 8;

examples/server/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Using [curl](https://curl.se/). On Windows `curl.exe` should be available in the
6666
```sh
6767
curl --request POST \
6868
--url http://localhost:8080/completion \
69+
--header "Content-Type: application/json" \
6970
--data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}'
7071
```
7172

examples/server/chat.sh

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ tokenize() {
3232
--silent \
3333
--request POST \
3434
--url "${API_URL}/tokenize" \
35+
--header "Content-Type: application/json" \
3536
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
3637
| jq '.tokens[]'
3738
}
@@ -64,6 +65,7 @@ chat_completion() {
6465
--no-buffer \
6566
--request POST \
6667
--url "${API_URL}/completion" \
68+
--header "Content-Type: application/json" \
6769
--data-raw "${DATA}")
6870

6971
printf "\n"

examples/server/server.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
608608
fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
609609
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
610610
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
611+
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
612+
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
611613
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
612614
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
613615
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -722,6 +724,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
722724
}
723725
params.n_ctx = std::stoi(argv[i]);
724726
}
727+
else if (arg == "--rope-freq-base")
728+
{
729+
if (++i >= argc) {
730+
invalid_param = true;
731+
break;
732+
}
733+
params.rope_freq_base = std::stof(argv[i]);
734+
}
735+
else if (arg == "--rope-freq-scale")
736+
{
737+
if (++i >= argc) {
738+
invalid_param = true;
739+
break;
740+
}
741+
params.rope_freq_scale = std::stof(argv[i]);
742+
}
725743
else if (arg == "--memory-f32" || arg == "--memory_f32")
726744
{
727745
params.memory_f16 = false;

ggml-metal.m

+26-19
Original file line numberDiff line numberDiff line change
@@ -881,28 +881,35 @@ void ggml_metal_graph_compute(
881881

882882
const int n_past = ((int32_t *)(src1->data))[0];
883883

884+
float freq_base;
885+
float freq_scale;
886+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
887+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
888+
884889
[encoder setComputePipelineState:ctx->pipeline_rope];
885890
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
886891
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
887-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
888-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
889-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
890-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
891-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
892-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
893-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
894-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
895-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
896-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
897-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
898-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
899-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
900-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
901-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
902-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
903-
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
904-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
905-
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
892+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
893+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
894+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
895+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
896+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
897+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
898+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
899+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
900+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
901+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
902+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
903+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
904+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
905+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
906+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
907+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
908+
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
909+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
910+
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
911+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
912+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
906913

907914
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
908915
} break;

ggml-metal.metal

+4-2
Original file line numberDiff line numberDiff line change
@@ -656,17 +656,19 @@ kernel void kernel_rope(
656656
constant int & n_past,
657657
constant int & n_dims,
658658
constant int & mode,
659+
constant float & freq_base,
660+
constant float & freq_scale,
659661
uint3 tpig[[thread_position_in_grid]]) {
660662
const int64_t i3 = tpig[2];
661663
const int64_t i2 = tpig[1];
662664
const int64_t i1 = tpig[0];
663665

664666
const bool is_neox = mode & 2;
665-
const float theta_scale = pow(10000.0, -2.0f/n_dims);
667+
const float theta_scale = pow(freq_base, -2.0f/n_dims);
666668

667669
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
668670

669-
float theta = (float)p;
671+
float theta = freq_scale * (float)p;
670672

671673
if (!is_neox) {
672674
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {

ggml.c

+38-12
Original file line numberDiff line numberDiff line change
@@ -6956,6 +6956,8 @@ struct ggml_tensor * ggml_rope_impl(
69566956
int n_past,
69576957
int n_dims,
69586958
int mode,
6959+
float freq_base,
6960+
float freq_scale,
69596961
int n_ctx,
69606962
bool inplace) {
69616963
GGML_ASSERT(n_past >= 0);
@@ -6969,12 +6971,14 @@ struct ggml_tensor * ggml_rope_impl(
69696971

69706972
ggml_scratch_save(ctx);
69716973

6972-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6974+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
69736975

69746976
((int32_t *) b->data)[0] = n_past;
69756977
((int32_t *) b->data)[1] = n_dims;
69766978
((int32_t *) b->data)[2] = mode;
69776979
((int32_t *) b->data)[3] = n_ctx;
6980+
memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float));
6981+
memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
69786982

69796983
ggml_scratch_load(ctx);
69806984

@@ -6993,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
69936997
int n_dims,
69946998
int mode,
69956999
int n_ctx) {
6996-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
7000+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
69977001
}
69987002

69997003
struct ggml_tensor * ggml_rope_inplace(
@@ -7003,7 +7007,19 @@ struct ggml_tensor * ggml_rope_inplace(
70037007
int n_dims,
70047008
int mode,
70057009
int n_ctx) {
7006-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
7010+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
7011+
}
7012+
7013+
struct ggml_tensor * ggml_rope_custom_inplace(
7014+
struct ggml_context * ctx,
7015+
struct ggml_tensor * a,
7016+
int n_past,
7017+
int n_dims,
7018+
int mode,
7019+
float freq_base,
7020+
float freq_scale,
7021+
int n_ctx) {
7022+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
70077023
}
70087024

70097025
// ggml_rope_back
@@ -12074,16 +12090,21 @@ static void ggml_compute_forward_rope_f32(
1207412090
const struct ggml_tensor * src1,
1207512091
struct ggml_tensor * dst) {
1207612092
GGML_ASSERT(src1->type == GGML_TYPE_I32);
12077-
GGML_ASSERT(ggml_nelements(src1) == 4);
12093+
GGML_ASSERT(ggml_nelements(src1) == 6);
1207812094

1207912095
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1208012096
return;
1208112097
}
1208212098

12099+
float freq_base;
12100+
float freq_scale;
12101+
1208312102
const int n_past = ((int32_t *) src1->data)[0];
1208412103
const int n_dims = ((int32_t *) src1->data)[1];
1208512104
const int mode = ((int32_t *) src1->data)[2];
1208612105
const int n_ctx = ((int32_t *) src1->data)[3];
12106+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12107+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1208712108

1208812109
assert(n_past >= 0);
1208912110

@@ -12112,7 +12133,7 @@ static void ggml_compute_forward_rope_f32(
1211212133
// row index used to determine which thread to use
1211312134
int ir = 0;
1211412135

12115-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12136+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1211612137

1211712138
const bool is_neox = mode & 2;
1211812139
const bool is_glm = mode & 4;
@@ -12124,7 +12145,7 @@ static void ggml_compute_forward_rope_f32(
1212412145
if (ir++ < ir0) continue;
1212512146
if (ir > ir1) break;
1212612147

12127-
float theta = (float)p;
12148+
float theta = freq_scale * (float)p;
1212812149

1212912150
if (is_glm) {
1213012151
theta = MIN(p, n_ctx - 2);
@@ -12201,16 +12222,21 @@ static void ggml_compute_forward_rope_f16(
1220112222
const struct ggml_tensor * src1,
1220212223
struct ggml_tensor * dst) {
1220312224
GGML_ASSERT(src1->type == GGML_TYPE_I32);
12204-
GGML_ASSERT(ggml_nelements(src1) == 4);
12225+
GGML_ASSERT(ggml_nelements(src1) == 6);
1220512226

1220612227
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1220712228
return;
1220812229
}
1220912230

12231+
float freq_base;
12232+
float freq_scale;
12233+
1221012234
const int n_past = ((int32_t *) src1->data)[0];
1221112235
const int n_dims = ((int32_t *) src1->data)[1];
1221212236
const int mode = ((int32_t *) src1->data)[2];
1221312237
const int n_ctx = ((int32_t *) src1->data)[3];
12238+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12239+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1221412240

1221512241
assert(n_past >= 0);
1221612242

@@ -12239,7 +12265,7 @@ static void ggml_compute_forward_rope_f16(
1223912265
// row index used to determine which thread to use
1224012266
int ir = 0;
1224112267

12242-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12268+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1224312269

1224412270
const bool is_neox = mode & 2;
1224512271
const bool is_glm = mode & 4;
@@ -12251,7 +12277,7 @@ static void ggml_compute_forward_rope_f16(
1225112277
if (ir++ < ir0) continue;
1225212278
if (ir > ir1) break;
1225312279

12254-
float theta = (float)p;
12280+
float theta = freq_scale * (float)p;
1225512281

1225612282
if (is_glm) {
1225712283
theta = MIN(p, n_ctx - 2);
@@ -12312,7 +12338,7 @@ static void ggml_compute_forward_rope_f16(
1231212338
const float x0 = GGML_FP16_TO_FP32(src[0]);
1231312339
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
1231412340

12315-
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
12341+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
1231612342
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1231712343
}
1231812344
}
@@ -15710,7 +15736,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1571015736
// necessary for llama
1571115737
if (src0->grad) {
1571215738
assert(src1->type == GGML_TYPE_I32);
15713-
assert(ggml_nelements(src1) == 4);
15739+
assert(ggml_nelements(src1) == 6);
1571415740
const int n_past = ((int32_t *) src1->data)[0];
1571515741
const int n_dims = ((int32_t *) src1->data)[1];
1571615742
const int mode = ((int32_t *) src1->data)[2];
@@ -15731,7 +15757,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1573115757
{
1573215758
if (src0->grad) {
1573315759
assert(src1->type == GGML_TYPE_I32);
15734-
assert(ggml_nelements(src1) == 4);
15760+
assert(ggml_nelements(src1) == 3);
1573515761
const int n_past = ((int32_t *) src1->data)[0];
1573615762
const int n_dims = ((int32_t *) src1->data)[1];
1573715763
const int mode = ((int32_t *) src1->data)[2];

ggml.h

+11
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,17 @@ extern "C" {
11211121
int mode,
11221122
int n_ctx);
11231123

1124+
// custom RoPE, in-place, returns view(a)
1125+
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
1126+
struct ggml_context * ctx,
1127+
struct ggml_tensor * a,
1128+
int n_past,
1129+
int n_dims,
1130+
int mode,
1131+
float freq_base,
1132+
float freq_scale,
1133+
int n_ctx);
1134+
11241135
// rotary position embedding backward, i.e compute dx from dy
11251136
// a - dy
11261137
GGML_API struct ggml_tensor * ggml_rope_back(

0 commit comments

Comments
 (0)