Skip to content

Commit f30c571

Browse files
committed
llama: reduce code duplication in NTKv2 RoPE
1 parent 03a715f commit f30c571

File tree

3 files changed

+48
-64
lines changed

3 files changed

+48
-64
lines changed

ggml-cuda.cu

+7-27
Original file line numberDiff line numberDiff line change
@@ -1875,14 +1875,14 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
18751875
cpy_1(cx + x_offset, cdst + dst_offset);
18761876
}
18771877

1878-
static __device__ float ntkv2_ramp(const float low, const float high, const int i0) {
1878+
static __device__ float ggml_rope_ntkv2_ramp(const float low, const float high, const int i0) {
18791879
const float y = (i0 / 2 - low) / min(0.001f, high - low);
18801880
return 1.0f - min(1.0f, max(0.0f, y));
18811881
}
18821882

18831883
// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
18841884
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1885-
static __device__ float compute_ntkv2(
1885+
static __device__ float ggml_rope_ntkv2(
18861886
float theta_base,
18871887
float theta_linear,
18881888
float theta_ntk,
@@ -1893,10 +1893,10 @@ static __device__ float compute_ntkv2(
18931893
float ramp_mix;
18941894
float theta;
18951895

1896-
ramp_mix = ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
1896+
ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
18971897
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
18981898

1899-
ramp_mix = ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor;
1899+
ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor;
19001900
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
19011901
return theta;
19021902
}
@@ -1918,7 +1918,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols,
19181918
const float theta_base = p*powf(theta_scale, col/2);
19191919
const float theta_linear = freq_scale * theta_base;
19201920
const float theta_ntk = p*powf(theta_ntk_scale, col/2);
1921-
const float theta = compute_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor,
1921+
const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors, col, ntk_factor,
19221922
extrapolation_factor);
19231923
const float sin_theta = sinf(theta);
19241924
const float cos_theta = cosf(theta);
@@ -2974,13 +2974,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
29742974
(void) i1;
29752975
}
29762976

2977-
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
2978-
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
2979-
static float ntkv2_correction_factor(const int n_dims, const float n_rot, const float base) {
2980-
static const float max_pos_emb = 2048;
2981-
return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
2982-
}
2983-
29842977
inline void ggml_cuda_op_rope(
29852978
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
29862979
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -3018,21 +3011,8 @@ inline void ggml_cuda_op_rope(
30183011
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
30193012
} else {
30203013
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
3021-
3022-
// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
3023-
// Do not change unless there is a good reason for doing so!
3024-
static const float BETA_0 = 1.75f;
3025-
static const float BETA_1 = 1.25f;
3026-
static const float GAMMA_0 = 16.0f;
3027-
static const float GAMMA_1 = 2.0f;
3028-
3029-
// start and end correction factors
3030-
const float corr_factors[4] = {
3031-
max(0.0f, floorf(ntkv2_correction_factor(n_dims, BETA_0, freq_base))),
3032-
min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, BETA_1, freq_base))),
3033-
max(0.0f, floorf(ntkv2_correction_factor(n_dims, GAMMA_0, freq_base))),
3034-
min(n_dims - 1.0f, ceilf(ntkv2_correction_factor(n_dims, GAMMA_1, freq_base))),
3035-
};
3014+
float corr_factors[4];
3015+
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);
30363016

30373017
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor,
30383018
extrapolation_factor, theta_scale, theta_ntk_scale, p, corr_factors, cudaStream_main);

ggml.c

+38-37
Original file line numberDiff line numberDiff line change
@@ -12093,57 +12093,54 @@ static void ggml_compute_forward_clamp(
1209312093

1209412094
// ggml_compute_forward_rope
1209512095

12096-
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
12097-
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
12098-
#define NTKV2_MAX_POS_EMB 2048
12099-
#define NTKV2_CORRECTION_FACTOR(n_rot) (__builtin_logf(NTKV2_MAX_POS_EMB / ((n_rot) * 2 * (float)M_PI)) / 2)
12100-
1210112096
// use -ffast-math so MIN and MAX are optimized to vminss and vmaxss
1210212097
__attribute__((optimize("-ffast-math"), always_inline))
12103-
static inline float ntkv2_ramp(const float low, const float high, const int i0) {
12098+
static inline float ggml_rope_ntkv2_ramp(const float low, const float high, const int i0) {
1210412099
const float y = (i0 / 2 - low) / MIN(0.001f, high - low);
1210512100
return 1 - MIN(1, MAX(0, y));
1210612101
}
1210712102

1210812103
// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
1210912104
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
12110-
static float compute_ntkv2(
12105+
static float ggml_rope_ntkv2(
1211112106
float theta_base,
12107+
float theta_linear,
1211212108
float theta_ntk,
12113-
float dims_over_base,
12114-
float freq_scale,
12109+
const float corr_factors[4],
1211512110
int64_t i0,
1211612111
float ntk_factor,
12117-
float extrapolation_factor,
12118-
int n_dims) {
12112+
float extrapolation_factor) {
12113+
float ramp_mix;
12114+
float theta;
12115+
12116+
ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
12117+
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
12118+
12119+
ramp_mix = ggml_rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * extrapolation_factor;
12120+
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
12121+
return theta;
12122+
}
12123+
12124+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
12125+
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
12126+
static float ggml_rope_ntkv2_corr_factor(const int n_dims, const float n_rot, const float base) {
12127+
static const float max_pos_emb = 2048;
12128+
return n_dims * logf(max_pos_emb / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
12129+
}
12130+
12131+
void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]) {
1211912132
// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
1212012133
// Do not change unless there is a good reason for doing so!
1212112134
static const float BETA_0 = 1.75f;
1212212135
static const float BETA_1 = 1.25f;
1212312136
static const float GAMMA_0 = 16.0f;
1212412137
static const float GAMMA_1 = 2.0f;
1212512138

12126-
static const float low_1p = NTKV2_CORRECTION_FACTOR(BETA_0);
12127-
static const float high_1p = NTKV2_CORRECTION_FACTOR(BETA_1);
12128-
static const float low_2p = NTKV2_CORRECTION_FACTOR(GAMMA_0);
12129-
static const float high_2p = NTKV2_CORRECTION_FACTOR(GAMMA_1);
12130-
1213112139
// start and end correction factors
12132-
const float low_1 = maxf(0.0f, floorf(low_1p * dims_over_base));
12133-
const float high_1 = minf(n_dims - 1.0f, ceilf(high_1p * dims_over_base));
12134-
const float low_2 = maxf(0.0f, floorf(low_2p * dims_over_base));
12135-
const float high_2 = minf(n_dims - 1.0f, ceilf(high_2p * dims_over_base));
12136-
12137-
const float theta_linear = freq_scale * theta_base;
12138-
float ramp_mix;
12139-
float theta;
12140-
12141-
ramp_mix = ntkv2_ramp(low_1, high_1, i0) * ntk_factor;
12142-
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
12143-
12144-
ramp_mix = ntkv2_ramp(low_2, high_2, i0) * extrapolation_factor;
12145-
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
12146-
return theta;
12140+
factors[0] = maxf(0.0f, floorf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_0, freq_base)));
12141+
factors[1] = minf(n_dims - 1.0f, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, BETA_1, freq_base)));
12142+
factors[2] = maxf(0.0f, floorf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_0, freq_base)));
12143+
factors[3] = minf(n_dims - 1.0f, ceilf(ggml_rope_ntkv2_corr_factor(n_dims, GAMMA_1, freq_base)));
1214712144
}
1214812145

1214912146
static void ggml_compute_forward_rope_f32(
@@ -12201,7 +12198,8 @@ static void ggml_compute_forward_rope_f32(
1220112198

1220212199
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1220312200
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
12204-
const float dims_over_base = n_dims / logf(freq_base);
12201+
float corr_factors[4];
12202+
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);
1220512203

1220612204
const bool is_neox = mode & 2;
1220712205
const bool is_glm = mode & 4;
@@ -12243,8 +12241,9 @@ static void ggml_compute_forward_rope_f32(
1224312241
}
1224412242
} else if (!is_neox) {
1224512243
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
12246-
const float theta = compute_ntkv2(theta_base, theta_ntk, dims_over_base,
12247-
freq_scale, i0, ntk_factor, extrapolation_factor, n_dims);
12244+
const float theta_linear = freq_scale * theta_base;
12245+
const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors,
12246+
i0, ntk_factor, extrapolation_factor);
1224812247
const float cos_theta = cosf(theta);
1224912248
const float sin_theta = sinf(theta);
1225012249

@@ -12343,7 +12342,8 @@ static void ggml_compute_forward_rope_f16(
1234312342

1234412343
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1234512344
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
12346-
const float dims_over_base = n_dims / logf(freq_base);
12345+
float corr_factors[4];
12346+
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);
1234712347

1234812348
const bool is_neox = mode & 2;
1234912349
const bool is_glm = mode & 4;
@@ -12385,8 +12385,9 @@ static void ggml_compute_forward_rope_f16(
1238512385
}
1238612386
} if (!is_neox) {
1238712387
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
12388-
const float theta = compute_ntkv2(theta_base, theta_ntk, dims_over_base,
12389-
freq_scale, i0, ntk_factor, extrapolation_factor, n_dims);
12388+
const float theta_linear = freq_scale * theta_base;
12389+
const float theta = ggml_rope_ntkv2(theta_base, theta_linear, theta_ntk, corr_factors,
12390+
i0, ntk_factor, extrapolation_factor);
1239012391
const float cos_theta = cosf(theta);
1239112392
const float sin_theta = sinf(theta);
1239212393

ggml.h

+3
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,9 @@ extern "C" {
11341134
float extrapolation_factor,
11351135
int n_ctx);
11361136

1137+
// compute correction factors for NTKv2 RoPE scaling
1138+
void ggml_rope_ntkv2_corr_factors(int n_dims, const float freq_base, float factors[4]);
1139+
11371140
// rotary position embedding backward, i.e compute dx from dy
11381141
// a - dy
11391142
GGML_API struct ggml_tensor * ggml_rope_back(

0 commit comments

Comments
 (0)