@@ -12093,57 +12093,54 @@ static void ggml_compute_forward_clamp(
12093
12093
12094
12094
// ggml_compute_forward_rope
12095
12095
12096
- // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
12097
- // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
12098
- #define NTKV2_MAX_POS_EMB 2048
12099
- #define NTKV2_CORRECTION_FACTOR(n_rot) (__builtin_logf(NTKV2_MAX_POS_EMB / ((n_rot) * 2 * (float)M_PI)) / 2)
12100
-
12101
12096
// use -ffast-math so MIN and MAX are optimized to vminss and vmaxss
12102
12097
__attribute__((optimize("-ffast-math"), always_inline))
12103
- static inline float ntkv2_ramp (const float low, const float high, const int i0) {
12098
+ static inline float ggml_rope_ntkv2_ramp (const float low, const float high, const int i0) {
12104
12099
const float y = (i0 / 2 - low) / MIN(0.001f, high - low);
12105
12100
return 1 - MIN(1, MAX(0, y));
12106
12101
}
12107
12102
12108
12103
// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
12109
12104
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
12110
- static float compute_ntkv2 (
12105
+ static float ggml_rope_ntkv2 (
12111
12106
float theta_base,
12107
+ float theta_linear,
12112
12108
float theta_ntk,
12113
- float dims_over_base,
12114
- float freq_scale,
12109
+ const float corr_factors[4],
12115
12110
int64_t i0,
12116
12111
float ntk_factor,
12117
- float extrapolation_factor,
12118
- int n_dims) {
12112
+ float extrapolation_factor) {
12113
+ float ramp_mix;
12114
+ float theta;
12115
+
12116
+ ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
12117
+ theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
12118
+
12119
+ ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor;
12120
+ theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
12121
+ return theta;
12122
+ }
12123
+
12124
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
12125
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
12126
+ static float ggml_rope_ntkv2_corr_factor(const int n_dims, const float n_rot, const float base) {
12127
+ static const float max_pos_emb = 2048;
12128
+ return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
12129
+ }
12130
+
12131
+ void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]) {
12119
12132
// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
12120
12133
// Do not change unless there is a good reason for doing so!
12121
12134
static const float BETA_0 = 1.75f;
12122
12135
static const float BETA_1 = 1.25f;
12123
12136
static const float GAMMA_0 = 16.0f;
12124
12137
static const float GAMMA_1 = 2.0f;
12125
12138
12126
- static const float low_1p = NTKV2_CORRECTION_FACTOR(BETA_0);
12127
- static const float high_1p = NTKV2_CORRECTION_FACTOR(BETA_1);
12128
- static const float low_2p = NTKV2_CORRECTION_FACTOR(GAMMA_0);
12129
- static const float high_2p = NTKV2_CORRECTION_FACTOR(GAMMA_1);
12130
-
12131
12139
// start and end correction factors
12132
- const float low_1 = maxf(0.0f, floorf(low_1p * dims_over_base));
12133
- const float high_1 = minf(n_dims - 1.0f, ceilf(high_1p * dims_over_base));
12134
- const float low_2 = maxf(0.0f, floorf(low_2p * dims_over_base));
12135
- const float high_2 = minf(n_dims - 1.0f, ceilf(high_2p * dims_over_base));
12136
-
12137
- const float theta_linear = freq_scale * theta_base;
12138
- float ramp_mix;
12139
- float theta;
12140
-
12141
- ramp_mix = ntkv2_ramp(low_1, high_1, i0) * ntk_factor;
12142
- theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
12143
-
12144
- ramp_mix = ntkv2_ramp(low_2, high_2, i0) * extrapolation_factor;
12145
- theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
12146
- return theta;
12140
+ factors[0] = maxf(0.0f, floorf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_0, freq_base)));
12141
+ factors[1] = minf(n_dims - 1.0f, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_1, freq_base)));
12142
+ factors[2] = maxf(0.0f, floorf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_0, freq_base)));
12143
+ factors[3] = minf(n_dims - 1.0f, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_1, freq_base)));
12147
12144
}
12148
12145
12149
12146
static void ggml_compute_forward_rope_f32(
@@ -12201,7 +12198,8 @@ static void ggml_compute_forward_rope_f32(
12201
12198
12202
12199
const float theta_scale = powf(freq_base, -2.0f/n_dims);
12203
12200
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
12204
- const float dims_over_base = n_dims / logf(freq_base);
12201
+ float corr_factors[4];
12202
+ ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);
12205
12203
12206
12204
const bool is_neox = mode & 2;
12207
12205
const bool is_glm = mode & 4;
@@ -12243,8 +12241,9 @@ static void ggml_compute_forward_rope_f32(
12243
12241
}
12244
12242
} else if (!is_neox) {
12245
12243
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
12246
- const float theta = compute_ntkv2(theta_base, theta_ntk, dims_over_base,
12247
- freq_scale, i0, ntk_factor, extrapolation_factor, n_dims);
12244
+ const float theta_linear = freq_scale * theta_base;
12245
+ const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors,
12246
+ i0, ntk_factor, extrapolation_factor);
12248
12247
const float cos_theta = cosf(theta);
12249
12248
const float sin_theta = sinf(theta);
12250
12249
@@ -12343,7 +12342,8 @@ static void ggml_compute_forward_rope_f16(
12343
12342
12344
12343
const float theta_scale = powf(freq_base, -2.0f/n_dims);
12345
12344
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
12346
- const float dims_over_base = n_dims / logf(freq_base);
12345
+ float corr_factors[4];
12346
+ ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);
12347
12347
12348
12348
const bool is_neox = mode & 2;
12349
12349
const bool is_glm = mode & 4;
@@ -12385,8 +12385,9 @@ static void ggml_compute_forward_rope_f16(
12385
12385
}
12386
12386
} if (!is_neox) {
12387
12387
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
12388
- const float theta = compute_ntkv2(theta_base, theta_ntk, dims_over_base,
12389
- freq_scale, i0, ntk_factor, extrapolation_factor, n_dims);
12388
+ const float theta_linear = freq_scale * theta_base;
12389
+ const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors,
12390
+ i0, ntk_factor, extrapolation_factor);
12390
12391
const float cos_theta = cosf(theta);
12391
12392
const float sin_theta = sinf(theta);
12392
12393
0 commit comments