Skip to content

Commit d82a5c0

Browse files
committed
llama : add llm_build_kqv helper (wip)
1 parent c9121fd commit d82a5c0

File tree

1 file changed

+179
-148
lines changed

1 file changed

+179
-148
lines changed

Diff for: llama.cpp

+179-148
Original file line numberDiff line numberDiff line change
@@ -3093,6 +3093,103 @@ static bool llama_model_load(
30933093

30943094
using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
30953095

3096+
enum llm_rope_type {
3097+
LLM_ROPE,
3098+
LLM_ROPE_NEOX,
3099+
LLM_ROPE_GLM,
3100+
};
3101+
3102+
// Persimmon: n_rot = n_embd_head/2
3103+
// Other: n_rot = n_embd_head
3104+
static void llm_build_k_shift(
3105+
const llama_context & lctx,
3106+
struct ggml_context * ctx,
3107+
struct ggml_cgraph * graph,
3108+
int64_t n_rot,
3109+
llm_rope_type type,
3110+
const llm_build_cb & cb) {
3111+
const auto & model = lctx.model;
3112+
const auto & kv_self = lctx.kv_self;
3113+
const auto & cparams = lctx.cparams;
3114+
3115+
const auto & hparams = model.hparams;
3116+
3117+
const int64_t n_layer = hparams.n_layer;
3118+
const int64_t n_head_kv = hparams.n_head_kv;
3119+
const int64_t n_embd_gqa = hparams.n_embd_gqa();
3120+
const int64_t n_embd_head = hparams.n_embd_head();
3121+
3122+
const int64_t n_ctx = lctx.cparams.n_ctx;
3123+
3124+
const float freq_base = cparams.rope_freq_base;
3125+
const float freq_scale = cparams.rope_freq_scale;
3126+
3127+
GGML_ASSERT(n_embd_head % n_rot == 0);
3128+
3129+
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
3130+
cb(K_shift, "K_shift", -1);
3131+
3132+
int rope_type = 0;
3133+
3134+
switch (type) {
3135+
case LLM_ROPE: rope_type = 0; break;
3136+
case LLM_ROPE_NEOX: rope_type = 2; break;
3137+
case LLM_ROPE_GLM: rope_type = 4; break;
3138+
};
3139+
3140+
for (int il = 0; il < n_layer; ++il) {
3141+
struct ggml_tensor * tmp =
3142+
// we rotate only the first n_rot dimensions
3143+
ggml_rope_custom_inplace(ctx,
3144+
ggml_view_3d(ctx, kv_self.k,
3145+
n_rot, n_head_kv, n_ctx,
3146+
ggml_element_size(kv_self.k)*n_embd_head,
3147+
ggml_element_size(kv_self.k)*n_embd_gqa,
3148+
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3149+
K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
3150+
cb(tmp, "K_shifted", il);
3151+
ggml_build_forward_expand(graph, tmp);
3152+
}
3153+
}
3154+
3155+
static void llm_build_kv_store(
3156+
const llama_context & lctx,
3157+
struct ggml_context * ctx,
3158+
struct ggml_cgraph * graph,
3159+
struct ggml_tensor * k_cur,
3160+
struct ggml_tensor * v_cur,
3161+
int32_t n_tokens,
3162+
int32_t kv_head,
3163+
const llm_build_cb & cb,
3164+
int64_t il) {
3165+
const auto & model = lctx.model;
3166+
const auto & kv_self = lctx.kv_self;
3167+
const auto & cparams = lctx.cparams;
3168+
3169+
const auto & hparams = model.hparams;
3170+
3171+
const int64_t n_ctx = cparams.n_ctx;
3172+
const int64_t n_embd_gqa = hparams.n_embd_gqa();
3173+
3174+
// compute the transposed [n_tokens, n_embd] V matrix
3175+
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens));
3176+
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
3177+
cb(v_cur_t, "v_cur_t", il);
3178+
3179+
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv_self.k, n_tokens*n_embd_gqa,
3180+
(ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
3181+
cb(k_cache_view, "k_cache_view", il);
3182+
3183+
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv_self.v, n_tokens, n_embd_gqa,
3184+
( n_ctx)*ggml_element_size(kv_self.v),
3185+
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
3186+
cb(v_cache_view, "v_cache_view", il);
3187+
3188+
// important: storing RoPE-ed version of K in the KV cache!
3189+
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
3190+
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
3191+
}
3192+
30963193
enum llm_norm_type {
30973194
LLM_NORM,
30983195
LLM_NORM_RMS,
@@ -3232,101 +3329,94 @@ static struct ggml_tensor * llm_build_ffn(
32323329
return cur;
32333330
}
32343331

3235-
enum llm_rope_type {
3236-
LLM_ROPE,
3237-
LLM_ROPE_NEOX,
3238-
LLM_ROPE_GLM,
3239-
};
3240-
3241-
// Persimmon: n_rot = n_embd_head/2
3242-
// Other: n_rot = n_embd_head
3243-
static void llm_build_k_shift(
3332+
// if max_alibi_bias > 0 then apply ALiBi
3333+
// TODO: rework ALiBi to be applied via ggml_add of a mask
3334+
static struct ggml_tensor * llm_build_kqv(
32443335
const llama_context & lctx,
32453336
struct ggml_context * ctx,
3246-
struct ggml_cgraph * graph,
3247-
int64_t n_rot,
3248-
llm_rope_type type,
3249-
const llm_build_cb & cb) {
3337+
struct ggml_tensor * cur,
3338+
struct ggml_tensor * wo,
3339+
struct ggml_tensor * wo_b,
3340+
struct ggml_tensor * q_cur,
3341+
struct ggml_tensor * kq_scale,
3342+
struct ggml_tensor * kq_mask,
3343+
int32_t n_tokens,
3344+
int32_t n_kv,
3345+
float alibi_bias_max,
3346+
const llm_build_cb & cb,
3347+
int il) {
32503348
const auto & model = lctx.model;
32513349
const auto & kv_self = lctx.kv_self;
32523350
const auto & cparams = lctx.cparams;
32533351

32543352
const auto & hparams = model.hparams;
32553353

3256-
const int64_t n_layer = hparams.n_layer;
3354+
const int64_t n_ctx = cparams.n_ctx;
3355+
const int64_t n_embd = hparams.n_embd;
3356+
const int64_t n_head = hparams.n_head;
32573357
const int64_t n_head_kv = hparams.n_head_kv;
3258-
const int64_t n_embd_gqa = hparams.n_embd_gqa();
32593358
const int64_t n_embd_head = hparams.n_embd_head();
3359+
const int64_t n_embd_gqa = hparams.n_embd_gqa();
32603360

3261-
const int64_t n_ctx = lctx.cparams.n_ctx;
3361+
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
3362+
cb(q, "q", il);
32623363

3263-
const float freq_base = cparams.rope_freq_base;
3264-
const float freq_scale = cparams.rope_freq_scale;
3364+
struct ggml_tensor * k =
3365+
ggml_view_3d(ctx, kv_self.k,
3366+
n_embd_head, n_kv, n_head_kv,
3367+
ggml_element_size(kv_self.k)*n_embd_gqa,
3368+
ggml_element_size(kv_self.k)*n_embd_head,
3369+
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
3370+
cb(k, "k", il);
32653371

3266-
GGML_ASSERT(n_embd_head % n_rot == 0);
3372+
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
3373+
cb(kq, "kq", il);
32673374

3268-
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
3269-
cb(K_shift, "K_shift", -1);
3375+
kq = ggml_scale(ctx, kq, kq_scale);
3376+
cb(kq, "kq_scaled", il);
32703377

3271-
int rope_type = 0;
3378+
if (alibi_bias_max > 0.0f) {
3379+
// TODO: n_head or n_head_kv
3380+
// TODO: K-shift is likely not working
3381+
// TODO: change to ggml_add
3382+
kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, alibi_bias_max);
3383+
cb(kq, "kq_scaled_alibi", il);
3384+
}
32723385

3273-
switch (type) {
3274-
case LLM_ROPE: rope_type = 0; break;
3275-
case LLM_ROPE_NEOX: rope_type = 2; break;
3276-
case LLM_ROPE_GLM: rope_type = 4; break;
3277-
};
3386+
kq = ggml_add(ctx, kq, kq_mask);
3387+
cb(kq, "kq_masked", il);
32783388

3279-
for (int il = 0; il < n_layer; ++il) {
3280-
struct ggml_tensor * tmp =
3281-
// we rotate only the first n_rot dimensions
3282-
ggml_rope_custom_inplace(ctx,
3283-
ggml_view_3d(ctx, kv_self.k,
3284-
n_rot, n_head_kv, n_ctx,
3285-
ggml_element_size(kv_self.k)*n_embd_head,
3286-
ggml_element_size(kv_self.k)*n_embd_gqa,
3287-
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3288-
K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
3289-
cb(tmp, "K_shifted", il);
3290-
ggml_build_forward_expand(graph, tmp);
3291-
}
3292-
}
3389+
kq = ggml_soft_max(ctx, kq);
3390+
cb(kq, "kq_soft_max", il);
32933391

3294-
static void llm_build_kv_store(
3295-
const llama_context & lctx,
3296-
struct ggml_context * ctx,
3297-
struct ggml_cgraph * graph,
3298-
struct ggml_tensor * k_cur,
3299-
struct ggml_tensor * v_cur,
3300-
int32_t n_tokens,
3301-
int32_t kv_head,
3302-
const llm_build_cb & cb,
3303-
int64_t il) {
3304-
const auto & model = lctx.model;
3305-
const auto & kv_self = lctx.kv_self;
3306-
const auto & cparams = lctx.cparams;
3392+
// split cached v into n_head heads
3393+
struct ggml_tensor * v =
3394+
ggml_view_3d(ctx, kv_self.v,
3395+
n_kv, n_embd_head, n_head_kv,
3396+
ggml_element_size(kv_self.v)*n_ctx,
3397+
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
3398+
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
3399+
cb(v, "v", il);
33073400

3308-
const auto & hparams = model.hparams;
3401+
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
3402+
cb(kqv, "kqv", il);
33093403

3310-
const int64_t n_ctx = cparams.n_ctx;
3311-
const int64_t n_embd_gqa = hparams.n_embd_gqa();
3404+
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
3405+
cb(kqv_merged, "kqv_merged", il);
33123406

3313-
// compute the transposed [n_tokens, n_embd] V matrix
3314-
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_gqa, n_tokens));
3315-
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
3316-
cb(v_cur_t, "v_cur_t", il);
3407+
cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens);
3408+
cb(cur, "kqv_merged_cont", il);
33173409

3318-
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv_self.k, n_tokens*n_embd_gqa,
3319-
(ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
3320-
cb(k_cache_view, "k_cache_view", il);
3410+
cur = ggml_mul_mat(ctx, wo, cur);
3411+
if (wo_b) {
3412+
cb(cur, "kqv_wo", il);
3413+
}
33213414

3322-
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv_self.v, n_tokens, n_embd_gqa,
3323-
( n_ctx)*ggml_element_size(kv_self.v),
3324-
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
3325-
cb(v_cache_view, "v_cache_view", il);
3415+
if (wo_b) {
3416+
cur = ggml_add(ctx, cur, wo_b);
3417+
}
33263418

3327-
// important: storing RoPE-ed version of K in the KV cache!
3328-
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
3329-
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
3419+
return cur;
33303420
}
33313421

33323422
static struct ggml_cgraph * llm_build_llama(
@@ -3348,7 +3438,6 @@ static struct ggml_cgraph * llm_build_llama(
33483438
const int64_t n_head = hparams.n_head;
33493439
const int64_t n_head_kv = hparams.n_head_kv;
33503440
const int64_t n_embd_head = hparams.n_embd_head();
3351-
const int64_t n_embd_gqa = hparams.n_embd_gqa();
33523441

33533442
GGML_ASSERT(n_embd_head == hparams.n_rot);
33543443

@@ -3440,67 +3529,10 @@ static struct ggml_cgraph * llm_build_llama(
34403529

34413530
llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
34423531

3443-
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
3444-
cb(Q, "Q", il);
3445-
3446-
struct ggml_tensor * K =
3447-
ggml_view_3d(ctx0, kv_self.k,
3448-
n_embd_head, n_kv, n_head_kv,
3449-
ggml_element_size(kv_self.k)*n_embd_gqa,
3450-
ggml_element_size(kv_self.k)*n_embd_head,
3451-
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
3452-
cb(K, "K", il);
3453-
3454-
// K * Q
3455-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
3456-
cb(KQ, "KQ", il);
3457-
3458-
// KQ_scaled = KQ / sqrt(n_embd_head)
3459-
// KQ_scaled shape [n_kv, n_tokens, n_head, 1]
3460-
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
3461-
cb(KQ_scaled, "KQ_scaled", il);
3462-
3463-
// KQ_masked = mask_past(KQ_scaled)
3464-
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
3465-
cb(KQ_masked, "KQ_masked", il);
3466-
3467-
// KQ = soft_max(KQ_masked)
3468-
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
3469-
cb(KQ_soft_max, "KQ_soft_max", il);
3470-
3471-
// split cached V into n_head heads
3472-
struct ggml_tensor * V =
3473-
ggml_view_3d(ctx0, kv_self.v,
3474-
n_kv, n_embd_head, n_head_kv,
3475-
ggml_element_size(kv_self.v)*n_ctx,
3476-
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
3477-
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
3478-
cb(V, "V", il);
3479-
3480-
#if 1
3481-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
3482-
cb(KQV, "KQV", il);
3483-
#else
3484-
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
3485-
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
3486-
// is there a better way?
3487-
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head));
3488-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
3489-
#endif
3490-
3491-
// KQV_merged = KQV.permute(0, 2, 1, 3)
3492-
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
3493-
cb(KQV_merged, "KQV_merged", il);
3494-
3495-
// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
3496-
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
3497-
cb(cur, "KQV_merged_contiguous", il);
3498-
3499-
// projection (no bias)
3500-
cur = ggml_mul_mat(ctx0,
3501-
model.layers[il].wo,
3502-
cur);
3503-
cb(cur, "result_wo", il);
3532+
cur = llm_build_kqv(lctx, ctx0, cur,
3533+
model.layers[il].wo, NULL,
3534+
Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il);
3535+
cb(cur, "kqv_out", il);
35043536
}
35053537

35063538
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
@@ -5164,22 +5196,21 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
51645196
{ "krotated", OFFLOAD_FUNC_KQ },
51655197
{ "qrotated", OFFLOAD_FUNC_KQ },
51665198

5167-
{ "Q", OFFLOAD_FUNC_KQ },
5168-
{ "K", OFFLOAD_FUNC_KQ },
5169-
{ "KQ", OFFLOAD_FUNC_KQ },
5170-
{ "KQ_scaled", OFFLOAD_FUNC_KQ },
5171-
{ "KQ_scaled_alibi", OFFLOAD_FUNC_KQ },
5172-
{ "KQ_masked", OFFLOAD_FUNC_KQ },
5173-
{ "KQ_soft_max", OFFLOAD_FUNC_V },
5174-
{ "V", OFFLOAD_FUNC_V },
5175-
{ "KQV", OFFLOAD_FUNC_V },
5176-
{ "KQV_merged", OFFLOAD_FUNC_V },
5177-
{ "KQV_merged_contiguous", OFFLOAD_FUNC_V },
5178-
5179-
{ "result_wo", OFFLOAD_FUNC },
5180-
{ "result_wo_b", OFFLOAD_FUNC },
5181-
{ "inpL_+_result_wo", OFFLOAD_FUNC },
5199+
{ "q", OFFLOAD_FUNC_KQ },
5200+
{ "k", OFFLOAD_FUNC_KQ },
5201+
{ "kq", OFFLOAD_FUNC_KQ },
5202+
{ "kq_scaled", OFFLOAD_FUNC_KQ },
5203+
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQ },
5204+
{ "kq_masked", OFFLOAD_FUNC_KQ },
5205+
{ "kq_soft_max", OFFLOAD_FUNC_V },
5206+
{ "v", OFFLOAD_FUNC_V },
5207+
{ "kqv", OFFLOAD_FUNC_V },
5208+
{ "kqv_merged", OFFLOAD_FUNC_V },
5209+
{ "kqv_merged_cont", OFFLOAD_FUNC_V },
5210+
{ "kqv_wo", OFFLOAD_FUNC_V },
5211+
{ "kqv_out", OFFLOAD_FUNC_V },
51825212

5213+
{ "inpL_+_result_wo", OFFLOAD_FUNC },
51835214
{ "inpFF", OFFLOAD_FUNC },
51845215

51855216
{ "ffn_norm", OFFLOAD_FUNC },

0 commit comments

Comments
 (0)