Skip to content

Commit 208cd52

Browse files
committed
vulkan : implement YaRN RoPE scaling (ggml-org#2268)
The NeoX cur_rot part is different because I'm pretty sure my original implementation was wrong.
1 parent 1829f1d commit 208cd52

File tree

5 files changed

+123
-69
lines changed

5 files changed

+123
-69
lines changed

ggml-vulkan.cpp

+23-13
Original file line numberDiff line numberDiff line change
@@ -1195,8 +1195,8 @@ void ggml_vk_rope(
11951195
const std::shared_ptr<kp::Tensor>& inB,
11961196
const std::shared_ptr<kp::Tensor>& out,
11971197
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
1198-
ggml_type src0t, int32_t n_dims, int32_t mode,
1199-
float freq_base, float freq_scale,
1198+
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
1199+
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
12001200
int32_t ne01, int32_t ne02, int32_t ne03,
12011201
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
12021202
int32_t ne0,
@@ -1224,15 +1224,15 @@ void ggml_vk_rope(
12241224

12251225
struct PushConstants {
12261226
uint32_t inAOff, inBOff, outOff;
1227-
int32_t n_dims, mode;
1228-
float freq_base, freq_scale;
1227+
int32_t n_dims, mode, n_orig_ctx;
1228+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
12291229
uint32_t nb00, nb01, nb02, nb03;
12301230
int32_t ne0;
12311231
uint32_t nb0, nb1, nb2, nb3;
12321232
} pushConsts {
12331233
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
1234-
n_dims, mode,
1235-
freq_base, freq_scale,
1234+
n_dims, mode, n_orig_ctx,
1235+
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
12361236
nb00, nb01, nb02, nb03,
12371237
ne0,
12381238
nb0, nb1, nb2, nb3
@@ -1545,13 +1545,23 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
15451545
GGML_ASSERT(ne10 == ne02);
15461546
GGML_ASSERT(src0t == dstt);
15471547
// const int n_past = ((int32_t *) dst->op_params)[0];
1548-
const int n_dims = ((int32_t *) dst->op_params)[1];
1549-
const int mode = ((int32_t *) dst->op_params)[2];
1550-
float freq_base;
1551-
float freq_scale;
1552-
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1553-
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1554-
ggml_vk_rope(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3);
1548+
const int n_dims = ((int32_t *) dst->op_params)[1];
1549+
const int mode = ((int32_t *) dst->op_params)[2];
1550+
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
1551+
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1552+
1553+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1554+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1555+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1556+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1557+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1558+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1559+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1560+
ggml_vk_rope(
1561+
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
1562+
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
1563+
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
1564+
);
15551565
} break;
15561566
case GGML_OP_DUP:
15571567
case GGML_OP_CPY:

kompute/common.comp

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#define GELU_COEF_A 0.044715
2222
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
23+
#define TWOPI_F 6.283185307179586f
2324

2425
#define QK_K 256
2526

kompute/op_rope_f16.comp

+12-28
Original file line numberDiff line numberDiff line change
@@ -8,50 +8,32 @@
88

99
#version 450
1010

11-
#include "common.comp"
12-
13-
// TODO: use a local size of 32 or more (Metal uses 1024)
14-
layout(local_size_x = 1) in;
11+
#include "rope_common.comp"
1512

1613
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
1714
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
1815
layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
1916

20-
layout (push_constant) uniform parameter {
21-
uint inAOff;
22-
uint inBOff;
23-
uint outOff;
24-
int n_dims;
25-
int mode;
26-
float freq_base;
27-
float freq_scale;
28-
uint nb00;
29-
uint nb01;
30-
uint nb02;
31-
uint nb03;
32-
int ne0;
33-
uint nb0;
34-
uint nb1;
35-
uint nb2;
36-
uint nb3;
37-
} pcs;
38-
3917
void main() {
4018
const uint i3 = gl_WorkGroupID.z;
4119
const uint i2 = gl_WorkGroupID.y;
4220
const uint i1 = gl_WorkGroupID.x;
4321

4422
const bool is_neox = (pcs.mode & 2) != 0;
23+
24+
float corr_dims[2];
25+
rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
26+
4527
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
4628

4729
const int p = inB[pcs.inBOff + i2];
4830

49-
float theta = pcs.freq_scale * float(p);
31+
float theta = float(p);
5032

5133
if (!is_neox) {
5234
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
53-
const float cos_theta = cos(theta);
54-
const float sin_theta = sin(theta);
35+
float cos_theta, sin_theta;
36+
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
5537

5638
theta *= theta_scale;
5739

@@ -68,8 +50,10 @@ void main() {
6850
const float inv_ndims = -1.f/pcs.n_dims;
6951
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
7052
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
71-
const float cos_theta = cos(theta);
72-
const float sin_theta = sin(theta);
53+
const uint cur_rot = ib * pcs.n_dims + ic;
54+
55+
float cos_theta, sin_theta;
56+
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
7357

7458
theta *= theta_scale;
7559

kompute/op_rope_f32.comp

+12-28
Original file line numberDiff line numberDiff line change
@@ -8,50 +8,32 @@
88

99
#version 450
1010

11-
#include "common.comp"
12-
13-
// TODO: use a local size of 32 or more (Metal uses 1024)
14-
layout(local_size_x = 1) in;
11+
#include "rope_common.comp"
1512

1613
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
1714
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
1815
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
1916

20-
layout (push_constant) uniform parameter {
21-
uint inAOff;
22-
uint inBOff;
23-
uint outOff;
24-
int n_dims;
25-
int mode;
26-
float freq_base;
27-
float freq_scale;
28-
uint nb00;
29-
uint nb01;
30-
uint nb02;
31-
uint nb03;
32-
int ne0;
33-
uint nb0;
34-
uint nb1;
35-
uint nb2;
36-
uint nb3;
37-
} pcs;
38-
3917
void main() {
4018
const uint i3 = gl_WorkGroupID.z;
4119
const uint i2 = gl_WorkGroupID.y;
4220
const uint i1 = gl_WorkGroupID.x;
4321

4422
const bool is_neox = (pcs.mode & 2) != 0;
23+
24+
float corr_dims[2];
25+
rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
26+
4527
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
4628

4729
const int p = inB[pcs.inBOff + i2];
4830

49-
float theta = pcs.freq_scale * float(p);
31+
float theta = float(p);
5032

5133
if (!is_neox) {
5234
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
53-
const float cos_theta = cos(theta);
54-
const float sin_theta = sin(theta);
35+
float cos_theta, sin_theta;
36+
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
5537

5638
theta *= theta_scale;
5739

@@ -68,8 +50,10 @@ void main() {
6850
const float inv_ndims = -1.f/pcs.n_dims;
6951
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
7052
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
71-
const float cos_theta = cos(theta);
72-
const float sin_theta = sin(theta);
53+
const uint cur_rot = ib * pcs.n_dims + ic;
54+
55+
float cos_theta, sin_theta;
56+
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
7357

7458
theta *= theta_scale;
7559

kompute/rope_common.comp

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/**
2+
* Copyright (c) 2023 Nomic, Inc. All rights reserved.
3+
*
4+
* This software is licensed under the terms of the Software for Open Models License (SOM),
5+
* version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany
6+
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
7+
*/
8+
9+
#include "common.comp"
10+
11+
// TODO: use a local size of 32 or more (Metal uses 1024)
12+
layout(local_size_x = 1) in;
13+
14+
layout (push_constant) uniform parameter {
15+
uint inAOff;
16+
uint inBOff;
17+
uint outOff;
18+
int n_dims;
19+
int mode;
20+
int n_orig_ctx;
21+
float freq_base;
22+
float freq_scale;
23+
float ext_factor;
24+
float attn_factor;
25+
float beta_fast;
26+
float beta_slow;
27+
uint nb00;
28+
uint nb01;
29+
uint nb02;
30+
uint nb03;
31+
int ne0;
32+
uint nb0;
33+
uint nb1;
34+
uint nb2;
35+
uint nb3;
36+
} pcs;
37+
38+
float rope_yarn_ramp(const float low, const float high, const float i0) {
39+
const float y = (i0 / 2 - low) / max(0.001f, high - low);
40+
return 1.0f - min(1.0f, max(0.0f, y));
41+
}
42+
43+
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
44+
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
45+
void rope_yarn(
46+
float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
47+
out float cos_theta, out float sin_theta
48+
) {
49+
// Get n-d rotational scaling corrected for extrapolation
50+
float theta_interp = freq_scale * theta_extrap;
51+
float theta = theta_interp;
52+
if (ext_factor != 0.0f) {
53+
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
54+
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
55+
56+
// Get n-d magnitude scaling corrected for interpolation
57+
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
58+
}
59+
cos_theta = cos(theta) * mscale;
60+
sin_theta = sin(theta) * mscale;
61+
}
62+
63+
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
64+
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
65+
float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
66+
return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
67+
}
68+
69+
void rope_yarn_corr_dims(
70+
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
71+
) {
72+
// start and end correction dims
73+
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
74+
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
75+
}

0 commit comments

Comments
 (0)