Skip to content

Commit 38728a0

Browse files
committedOct 29, 2023
llama : add llm_build_k_shift helper
ggml-ci
1 parent dbf836b commit 38728a0

File tree

1 file changed

+66
-64
lines changed

1 file changed

+66
-64
lines changed
 

‎llama.cpp

+66-64
Original file line numberDiff line numberDiff line change
@@ -3230,6 +3230,65 @@ static struct ggml_tensor * llm_build_ffn(
32303230
return cur;
32313231
}
32323232

3233+
enum llm_rope_type {
3234+
LLM_ROPE,
3235+
LLM_ROPE_NEOX,
3236+
LLM_ROPE_GLM,
3237+
};
3238+
3239+
// Persimmon: n_rot = n_embd_head/2
3240+
// Other: n_rot = n_embd_head
3241+
static void llm_build_k_shift(
3242+
const llama_context & lctx,
3243+
struct ggml_context * ctx,
3244+
struct ggml_cgraph * graph,
3245+
int64_t n_rot,
3246+
llm_rope_type type,
3247+
const llm_build_cb & cb) {
3248+
const auto & model = lctx.model;
3249+
const auto & kv_self = lctx.kv_self;
3250+
const auto & cparams = lctx.cparams;
3251+
3252+
const auto & hparams = model.hparams;
3253+
3254+
const int64_t n_head = hparams.n_head;
3255+
const int64_t n_layer = hparams.n_layer;
3256+
const int64_t n_embd_gqa = hparams.n_embd_gqa();
3257+
const int64_t n_embd_head = hparams.n_embd_head();
3258+
3259+
const int64_t n_ctx = lctx.cparams.n_ctx;
3260+
3261+
const float freq_base = cparams.rope_freq_base;
3262+
const float freq_scale = cparams.rope_freq_scale;
3263+
3264+
GGML_ASSERT(n_embd_head % n_rot == 0);
3265+
3266+
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
3267+
cb(K_shift, "K_shift", -1);
3268+
3269+
int rope_type = 0;
3270+
3271+
switch (type) {
3272+
case LLM_ROPE: rope_type = 0; break;
3273+
case LLM_ROPE_NEOX: rope_type = 2; break;
3274+
case LLM_ROPE_GLM: rope_type = 4; break;
3275+
};
3276+
3277+
for (int il = 0; il < n_layer; ++il) {
3278+
struct ggml_tensor * tmp =
3279+
// we rotate only the first n_rot dimensions
3280+
ggml_rope_custom_inplace(ctx,
3281+
ggml_view_3d(ctx, kv_self.k,
3282+
n_rot, n_head, n_ctx,
3283+
ggml_element_size(kv_self.k)*n_embd_head,
3284+
ggml_element_size(kv_self.k)*n_embd_gqa,
3285+
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3286+
K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
3287+
cb(tmp, "K_shifted", il);
3288+
ggml_build_forward_expand(graph, tmp);
3289+
}
3290+
}
3291+
32333292
static struct ggml_cgraph * llm_build_llama(
32343293
llama_context & lctx,
32353294
const llama_batch & batch,
@@ -3308,21 +3367,7 @@ static struct ggml_cgraph * llm_build_llama(
33083367

33093368
// shift the entire K-cache if needed
33103369
if (do_rope_shift) {
3311-
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3312-
cb(K_shift, "K_shift", -1);
3313-
3314-
for (int il = 0; il < n_layer; ++il) {
3315-
struct ggml_tensor * tmp =
3316-
ggml_rope_custom_inplace(ctx0,
3317-
ggml_view_3d(ctx0, kv_self.k,
3318-
n_embd_head, n_head_kv, n_ctx,
3319-
ggml_element_size(kv_self.k)*n_embd_head,
3320-
ggml_element_size(kv_self.k)*n_embd_gqa,
3321-
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3322-
K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
3323-
cb(tmp, "K_shifted", il);
3324-
ggml_build_forward_expand(gf, tmp);
3325-
}
3370+
llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
33263371
}
33273372

33283373
for (int il = 0; il < n_layer; ++il) {
@@ -3557,21 +3602,7 @@ static struct ggml_cgraph * llm_build_baichaun(
35573602

35583603
// shift the entire K-cache if needed
35593604
if (do_rope_shift) {
3560-
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3561-
cb(K_shift, "K_shift", -1);
3562-
3563-
for (int il = 0; il < n_layer; ++il) {
3564-
struct ggml_tensor * tmp =
3565-
ggml_rope_custom_inplace(ctx0,
3566-
ggml_view_3d(ctx0, kv_self.k,
3567-
n_embd_head, n_head_kv, n_ctx,
3568-
ggml_element_size(kv_self.k)*n_embd_head,
3569-
ggml_element_size(kv_self.k)*n_embd_gqa,
3570-
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3571-
K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
3572-
cb(tmp, "K_shifted", il);
3573-
ggml_build_forward_expand(gf, tmp);
3574-
}
3605+
llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
35753606
}
35763607

35773608
for (int il = 0; il < n_layer; ++il) {
@@ -3830,21 +3861,7 @@ static struct ggml_cgraph * llm_build_falcon(
38303861

38313862
// shift the entire K-cache if needed
38323863
if (do_rope_shift) {
3833-
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3834-
cb(K_shift, "K_shift", -1);
3835-
3836-
for (int il = 0; il < n_layer; ++il) {
3837-
struct ggml_tensor * tmp =
3838-
ggml_rope_custom_inplace(ctx0,
3839-
ggml_view_3d(ctx0, kv_self.k,
3840-
n_embd_head, n_head_kv, n_ctx,
3841-
ggml_element_size(kv_self.k)*n_embd_head,
3842-
ggml_element_size(kv_self.k)*n_embd_gqa,
3843-
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3844-
K_shift, n_embd_head, 2, 0, freq_base, freq_scale);
3845-
cb(tmp, "K_shifted", il);
3846-
ggml_build_forward_expand(gf, tmp);
3847-
}
3864+
llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE_NEOX, cb);
38483865
}
38493866

38503867
for (int il = 0; il < n_layer; ++il) {
@@ -4243,14 +4260,15 @@ static struct ggml_cgraph * llm_build_persimmon(
42434260
GGML_ASSERT(!!kv_self.ctx);
42444261

42454262
const auto & cparams = lctx.cparams;
4263+
42464264
const int64_t n_embd = hparams.n_embd;
42474265
const int64_t n_layer = hparams.n_layer;
42484266
const int64_t n_ctx = cparams.n_ctx;
42494267
const int64_t n_head_kv = hparams.n_head_kv;
42504268
const int64_t n_head = hparams.n_head;
42514269
const int64_t n_embd_head = hparams.n_embd_head();
42524270
const int64_t n_embd_gqa = hparams.n_embd_gqa();
4253-
const size_t n_rot = n_embd_head / 2;
4271+
const int64_t n_rot = n_embd_head / 2;
42544272

42554273
const float freq_base = cparams.rope_freq_base;
42564274
const float freq_scale = cparams.rope_freq_scale;
@@ -4297,23 +4315,7 @@ static struct ggml_cgraph * llm_build_persimmon(
42974315
cb(KQ_mask, "KQ_mask", -1);
42984316

42994317
if (do_rope_shift) {
4300-
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
4301-
cb(K_shift, "K_shift", -1);
4302-
4303-
for (int il = 0; il < n_layer; ++il) {
4304-
struct ggml_tensor * tmp =
4305-
// we rotate only the first n_rot dimensions.
4306-
ggml_rope_custom_inplace(ctx0,
4307-
ggml_view_3d(ctx0, kv_self.k,
4308-
n_rot, n_head, n_ctx,
4309-
ggml_element_size(kv_self.k)*n_embd_gqa,
4310-
ggml_element_size(kv_self.k)*n_embd_head,
4311-
ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il)
4312-
),
4313-
K_shift, n_rot, 2, 0, freq_base, freq_scale);
4314-
cb(tmp, "K_shifted", il);
4315-
ggml_build_forward_expand(gf, tmp);
4316-
}
4318+
llm_build_k_shift(lctx, ctx0, gf, n_rot, LLM_ROPE_NEOX, cb);
43174319
}
43184320

43194321
for (int il = 0; il < n_layer; ++il) {
@@ -5534,7 +5536,7 @@ static struct ggml_cgraph * llama_build_graph(
55345536
#ifdef GGML_USE_CUBLAS
55355537
const bool do_offload = true;
55365538
#else
5537-
const bool do_offload = false;
5539+
const bool do_offload = true; // TODO: set to false after finishing refactoring
55385540
#endif
55395541

55405542
if (!do_offload) {

0 commit comments

Comments
 (0)
Please sign in to comment.