@@ -3093,6 +3093,103 @@ static bool llama_model_load(
3093
3093
3094
3094
using llm_build_cb = std::function<void (struct ggml_tensor * cur, const char * name, int nl)>;
3095
3095
3096
+ enum llm_rope_type {
3097
+ LLM_ROPE,
3098
+ LLM_ROPE_NEOX,
3099
+ LLM_ROPE_GLM,
3100
+ };
3101
+
3102
+ // Persimmon: n_rot = n_embd_head/2
3103
+ // Other: n_rot = n_embd_head
3104
+ static void llm_build_k_shift (
3105
+ const llama_context & lctx,
3106
+ struct ggml_context * ctx,
3107
+ struct ggml_cgraph * graph,
3108
+ int64_t n_rot,
3109
+ llm_rope_type type,
3110
+ const llm_build_cb & cb) {
3111
+ const auto & model = lctx.model ;
3112
+ const auto & kv_self = lctx.kv_self ;
3113
+ const auto & cparams = lctx.cparams ;
3114
+
3115
+ const auto & hparams = model.hparams ;
3116
+
3117
+ const int64_t n_layer = hparams.n_layer ;
3118
+ const int64_t n_head_kv = hparams.n_head_kv ;
3119
+ const int64_t n_embd_gqa = hparams.n_embd_gqa ();
3120
+ const int64_t n_embd_head = hparams.n_embd_head ();
3121
+
3122
+ const int64_t n_ctx = lctx.cparams .n_ctx ;
3123
+
3124
+ const float freq_base = cparams.rope_freq_base ;
3125
+ const float freq_scale = cparams.rope_freq_scale ;
3126
+
3127
+ GGML_ASSERT (n_embd_head % n_rot == 0 );
3128
+
3129
+ struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, n_ctx);
3130
+ cb (K_shift, " K_shift" , -1 );
3131
+
3132
+ int rope_type = 0 ;
3133
+
3134
+ switch (type) {
3135
+ case LLM_ROPE: rope_type = 0 ; break ;
3136
+ case LLM_ROPE_NEOX: rope_type = 2 ; break ;
3137
+ case LLM_ROPE_GLM: rope_type = 4 ; break ;
3138
+ };
3139
+
3140
+ for (int il = 0 ; il < n_layer; ++il) {
3141
+ struct ggml_tensor * tmp =
3142
+ // we rotate only the first n_rot dimensions
3143
+ ggml_rope_custom_inplace (ctx,
3144
+ ggml_view_3d (ctx, kv_self.k ,
3145
+ n_rot, n_head_kv, n_ctx,
3146
+ ggml_element_size (kv_self.k )*n_embd_head,
3147
+ ggml_element_size (kv_self.k )*n_embd_gqa,
3148
+ ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
3149
+ K_shift, n_rot, rope_type, 0 , freq_base, freq_scale);
3150
+ cb (tmp, " K_shifted" , il);
3151
+ ggml_build_forward_expand (graph, tmp);
3152
+ }
3153
+ }
3154
+
3155
+ static void llm_build_kv_store (
3156
+ const llama_context & lctx,
3157
+ struct ggml_context * ctx,
3158
+ struct ggml_cgraph * graph,
3159
+ struct ggml_tensor * k_cur,
3160
+ struct ggml_tensor * v_cur,
3161
+ int32_t n_tokens,
3162
+ int32_t kv_head,
3163
+ const llm_build_cb & cb,
3164
+ int64_t il) {
3165
+ const auto & model = lctx.model ;
3166
+ const auto & kv_self = lctx.kv_self ;
3167
+ const auto & cparams = lctx.cparams ;
3168
+
3169
+ const auto & hparams = model.hparams ;
3170
+
3171
+ const int64_t n_ctx = cparams.n_ctx ;
3172
+ const int64_t n_embd_gqa = hparams.n_embd_gqa ();
3173
+
3174
+ // compute the transposed [n_tokens, n_embd] V matrix
3175
+ struct ggml_tensor * v_cur_t = ggml_transpose (ctx, ggml_reshape_2d (ctx, v_cur, n_embd_gqa, n_tokens));
3176
+ // struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
3177
+ cb (v_cur_t , " v_cur_t" , il);
3178
+
3179
+ struct ggml_tensor * k_cache_view = ggml_view_1d (ctx, kv_self.k , n_tokens*n_embd_gqa,
3180
+ (ggml_element_size (kv_self.k )*n_embd_gqa)*(il*n_ctx + kv_head));
3181
+ cb (k_cache_view, " k_cache_view" , il);
3182
+
3183
+ struct ggml_tensor * v_cache_view = ggml_view_2d (ctx, kv_self.v , n_tokens, n_embd_gqa,
3184
+ ( n_ctx)*ggml_element_size (kv_self.v ),
3185
+ (il*n_ctx)*ggml_element_size (kv_self.v )*n_embd_gqa + kv_head*ggml_element_size (kv_self.v ));
3186
+ cb (v_cache_view, " v_cache_view" , il);
3187
+
3188
+ // important: storing RoPE-ed version of K in the KV cache!
3189
+ ggml_build_forward_expand (graph, ggml_cpy (ctx, k_cur, k_cache_view));
3190
+ ggml_build_forward_expand (graph, ggml_cpy (ctx, v_cur_t , v_cache_view));
3191
+ }
3192
+
3096
3193
enum llm_norm_type {
3097
3194
LLM_NORM,
3098
3195
LLM_NORM_RMS,
@@ -3232,101 +3329,94 @@ static struct ggml_tensor * llm_build_ffn(
3232
3329
return cur;
3233
3330
}
3234
3331
3235
- enum llm_rope_type {
3236
- LLM_ROPE,
3237
- LLM_ROPE_NEOX,
3238
- LLM_ROPE_GLM,
3239
- };
3240
-
3241
- // Persimmon: n_rot = n_embd_head/2
3242
- // Other: n_rot = n_embd_head
3243
- static void llm_build_k_shift (
3332
+ // if max_alibi_bias > 0 then apply ALiBi
3333
+ // TODO: rework ALiBi to be applied via ggml_add of a mask
3334
+ static struct ggml_tensor * llm_build_kqv (
3244
3335
const llama_context & lctx,
3245
3336
struct ggml_context * ctx,
3246
- struct ggml_cgraph * graph,
3247
- int64_t n_rot,
3248
- llm_rope_type type,
3249
- const llm_build_cb & cb) {
3337
+ struct ggml_tensor * cur,
3338
+ struct ggml_tensor * wo,
3339
+ struct ggml_tensor * wo_b,
3340
+ struct ggml_tensor * q_cur,
3341
+ struct ggml_tensor * kq_scale,
3342
+ struct ggml_tensor * kq_mask,
3343
+ int32_t n_tokens,
3344
+ int32_t n_kv,
3345
+ float alibi_bias_max,
3346
+ const llm_build_cb & cb,
3347
+ int il) {
3250
3348
const auto & model = lctx.model ;
3251
3349
const auto & kv_self = lctx.kv_self ;
3252
3350
const auto & cparams = lctx.cparams ;
3253
3351
3254
3352
const auto & hparams = model.hparams ;
3255
3353
3256
- const int64_t n_layer = hparams.n_layer ;
3354
+ const int64_t n_ctx = cparams.n_ctx ;
3355
+ const int64_t n_embd = hparams.n_embd ;
3356
+ const int64_t n_head = hparams.n_head ;
3257
3357
const int64_t n_head_kv = hparams.n_head_kv ;
3258
- const int64_t n_embd_gqa = hparams.n_embd_gqa ();
3259
3358
const int64_t n_embd_head = hparams.n_embd_head ();
3359
+ const int64_t n_embd_gqa = hparams.n_embd_gqa ();
3260
3360
3261
- const int64_t n_ctx = lctx.cparams .n_ctx ;
3361
+ struct ggml_tensor * q = ggml_permute (ctx, q_cur, 0 , 2 , 1 , 3 );
3362
+ cb (q, " q" , il);
3262
3363
3263
- const float freq_base = cparams.rope_freq_base ;
3264
- const float freq_scale = cparams.rope_freq_scale ;
3364
+ struct ggml_tensor * k =
3365
+ ggml_view_3d (ctx, kv_self.k ,
3366
+ n_embd_head, n_kv, n_head_kv,
3367
+ ggml_element_size (kv_self.k )*n_embd_gqa,
3368
+ ggml_element_size (kv_self.k )*n_embd_head,
3369
+ ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il);
3370
+ cb (k, " k" , il);
3265
3371
3266
- GGML_ASSERT (n_embd_head % n_rot == 0 );
3372
+ struct ggml_tensor * kq = ggml_mul_mat (ctx, k, q);
3373
+ cb (kq, " kq" , il);
3267
3374
3268
- struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, n_ctx );
3269
- cb (K_shift , " K_shift " , - 1 );
3375
+ kq = ggml_scale (ctx, kq, kq_scale );
3376
+ cb (kq , " kq_scaled " , il );
3270
3377
3271
- int rope_type = 0 ;
3378
+ if (alibi_bias_max > 0 .0f ) {
3379
+ // TODO: n_head or n_head_kv
3380
+ // TODO: K-shift is likely not working
3381
+ // TODO: change to ggml_add
3382
+ kq = ggml_alibi (ctx, kq, /* n_past*/ 0 , n_head, alibi_bias_max);
3383
+ cb (kq, " kq_scaled_alibi" , il);
3384
+ }
3272
3385
3273
- switch (type) {
3274
- case LLM_ROPE: rope_type = 0 ; break ;
3275
- case LLM_ROPE_NEOX: rope_type = 2 ; break ;
3276
- case LLM_ROPE_GLM: rope_type = 4 ; break ;
3277
- };
3386
+ kq = ggml_add (ctx, kq, kq_mask);
3387
+ cb (kq, " kq_masked" , il);
3278
3388
3279
- for (int il = 0 ; il < n_layer; ++il) {
3280
- struct ggml_tensor * tmp =
3281
- // we rotate only the first n_rot dimensions
3282
- ggml_rope_custom_inplace (ctx,
3283
- ggml_view_3d (ctx, kv_self.k ,
3284
- n_rot, n_head_kv, n_ctx,
3285
- ggml_element_size (kv_self.k )*n_embd_head,
3286
- ggml_element_size (kv_self.k )*n_embd_gqa,
3287
- ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
3288
- K_shift, n_rot, rope_type, 0 , freq_base, freq_scale);
3289
- cb (tmp, " K_shifted" , il);
3290
- ggml_build_forward_expand (graph, tmp);
3291
- }
3292
- }
3389
+ kq = ggml_soft_max (ctx, kq);
3390
+ cb (kq, " kq_soft_max" , il);
3293
3391
3294
- static void llm_build_kv_store (
3295
- const llama_context & lctx,
3296
- struct ggml_context * ctx,
3297
- struct ggml_cgraph * graph,
3298
- struct ggml_tensor * k_cur,
3299
- struct ggml_tensor * v_cur,
3300
- int32_t n_tokens,
3301
- int32_t kv_head,
3302
- const llm_build_cb & cb,
3303
- int64_t il) {
3304
- const auto & model = lctx.model ;
3305
- const auto & kv_self = lctx.kv_self ;
3306
- const auto & cparams = lctx.cparams ;
3392
+ // split cached v into n_head heads
3393
+ struct ggml_tensor * v =
3394
+ ggml_view_3d (ctx, kv_self.v ,
3395
+ n_kv, n_embd_head, n_head_kv,
3396
+ ggml_element_size (kv_self.v )*n_ctx,
3397
+ ggml_element_size (kv_self.v )*n_ctx*n_embd_head,
3398
+ ggml_element_size (kv_self.v )*n_ctx*n_embd_gqa*il);
3399
+ cb (v, " v" , il);
3307
3400
3308
- const auto & hparams = model.hparams ;
3401
+ struct ggml_tensor * kqv = ggml_mul_mat (ctx, v, kq);
3402
+ cb (kqv, " kqv" , il);
3309
3403
3310
- const int64_t n_ctx = cparams. n_ctx ;
3311
- const int64_t n_embd_gqa = hparams. n_embd_gqa ( );
3404
+ struct ggml_tensor * kqv_merged = ggml_permute (ctx, kqv, 0 , 2 , 1 , 3 ) ;
3405
+ cb (kqv_merged, " kqv_merged " , il );
3312
3406
3313
- // compute the transposed [n_tokens, n_embd] V matrix
3314
- struct ggml_tensor * v_cur_t = ggml_transpose (ctx, ggml_reshape_2d (ctx, v_cur, n_embd_gqa, n_tokens));
3315
- // struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
3316
- cb (v_cur_t , " v_cur_t" , il);
3407
+ cur = ggml_cont_2d (ctx, kqv_merged, n_embd, n_tokens);
3408
+ cb (cur, " kqv_merged_cont" , il);
3317
3409
3318
- struct ggml_tensor * k_cache_view = ggml_view_1d (ctx, kv_self.k , n_tokens*n_embd_gqa,
3319
- (ggml_element_size (kv_self.k )*n_embd_gqa)*(il*n_ctx + kv_head));
3320
- cb (k_cache_view, " k_cache_view" , il);
3410
+ cur = ggml_mul_mat (ctx, wo, cur);
3411
+ if (wo_b) {
3412
+ cb (cur, " kqv_wo" , il);
3413
+ }
3321
3414
3322
- struct ggml_tensor * v_cache_view = ggml_view_2d (ctx, kv_self.v , n_tokens, n_embd_gqa,
3323
- ( n_ctx)*ggml_element_size (kv_self.v ),
3324
- (il*n_ctx)*ggml_element_size (kv_self.v )*n_embd_gqa + kv_head*ggml_element_size (kv_self.v ));
3325
- cb (v_cache_view, " v_cache_view" , il);
3415
+ if (wo_b) {
3416
+ cur = ggml_add (ctx, cur, wo_b);
3417
+ }
3326
3418
3327
- // important: storing RoPE-ed version of K in the KV cache!
3328
- ggml_build_forward_expand (graph, ggml_cpy (ctx, k_cur, k_cache_view));
3329
- ggml_build_forward_expand (graph, ggml_cpy (ctx, v_cur_t , v_cache_view));
3419
+ return cur;
3330
3420
}
3331
3421
3332
3422
static struct ggml_cgraph * llm_build_llama (
@@ -3348,7 +3438,6 @@ static struct ggml_cgraph * llm_build_llama(
3348
3438
const int64_t n_head = hparams.n_head ;
3349
3439
const int64_t n_head_kv = hparams.n_head_kv ;
3350
3440
const int64_t n_embd_head = hparams.n_embd_head ();
3351
- const int64_t n_embd_gqa = hparams.n_embd_gqa ();
3352
3441
3353
3442
GGML_ASSERT (n_embd_head == hparams.n_rot );
3354
3443
@@ -3440,67 +3529,10 @@ static struct ggml_cgraph * llm_build_llama(
3440
3529
3441
3530
llm_build_kv_store (lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
3442
3531
3443
- struct ggml_tensor * Q = ggml_permute (ctx0, Qcur, 0 , 2 , 1 , 3 );
3444
- cb (Q, " Q" , il);
3445
-
3446
- struct ggml_tensor * K =
3447
- ggml_view_3d (ctx0, kv_self.k ,
3448
- n_embd_head, n_kv, n_head_kv,
3449
- ggml_element_size (kv_self.k )*n_embd_gqa,
3450
- ggml_element_size (kv_self.k )*n_embd_head,
3451
- ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il);
3452
- cb (K, " K" , il);
3453
-
3454
- // K * Q
3455
- struct ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
3456
- cb (KQ, " KQ" , il);
3457
-
3458
- // KQ_scaled = KQ / sqrt(n_embd_head)
3459
- // KQ_scaled shape [n_kv, n_tokens, n_head, 1]
3460
- struct ggml_tensor * KQ_scaled = ggml_scale (ctx0, KQ, KQ_scale);
3461
- cb (KQ_scaled, " KQ_scaled" , il);
3462
-
3463
- // KQ_masked = mask_past(KQ_scaled)
3464
- struct ggml_tensor * KQ_masked = ggml_add (ctx0, KQ_scaled, KQ_mask);
3465
- cb (KQ_masked, " KQ_masked" , il);
3466
-
3467
- // KQ = soft_max(KQ_masked)
3468
- struct ggml_tensor * KQ_soft_max = ggml_soft_max (ctx0, KQ_masked);
3469
- cb (KQ_soft_max, " KQ_soft_max" , il);
3470
-
3471
- // split cached V into n_head heads
3472
- struct ggml_tensor * V =
3473
- ggml_view_3d (ctx0, kv_self.v ,
3474
- n_kv, n_embd_head, n_head_kv,
3475
- ggml_element_size (kv_self.v )*n_ctx,
3476
- ggml_element_size (kv_self.v )*n_ctx*n_embd_head,
3477
- ggml_element_size (kv_self.v )*n_ctx*n_embd_gqa*il);
3478
- cb (V, " V" , il);
3479
-
3480
- #if 1
3481
- struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ_soft_max);
3482
- cb (KQV, " KQV" , il);
3483
- #else
3484
- // make V contiguous in memory to speed up the matmul, however we waste time on the copy
3485
- // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
3486
- // is there a better way?
3487
- struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head));
3488
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
3489
- #endif
3490
-
3491
- // KQV_merged = KQV.permute(0, 2, 1, 3)
3492
- struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
3493
- cb (KQV_merged, " KQV_merged" , il);
3494
-
3495
- // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
3496
- cur = ggml_cont_2d (ctx0, KQV_merged, n_embd, n_tokens);
3497
- cb (cur, " KQV_merged_contiguous" , il);
3498
-
3499
- // projection (no bias)
3500
- cur = ggml_mul_mat (ctx0,
3501
- model.layers [il].wo ,
3502
- cur);
3503
- cb (cur, " result_wo" , il);
3532
+ cur = llm_build_kqv (lctx, ctx0, cur,
3533
+ model.layers [il].wo , NULL ,
3534
+ Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1 .0f , cb, il);
3535
+ cb (cur, " kqv_out" , il);
3504
3536
}
3505
3537
3506
3538
struct ggml_tensor * inpFF = ggml_add (ctx0, cur, inpSA);
@@ -5164,22 +5196,21 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
5164
5196
{ " krotated" , OFFLOAD_FUNC_KQ },
5165
5197
{ " qrotated" , OFFLOAD_FUNC_KQ },
5166
5198
5167
- { " Q" , OFFLOAD_FUNC_KQ },
5168
- { " K" , OFFLOAD_FUNC_KQ },
5169
- { " KQ" , OFFLOAD_FUNC_KQ },
5170
- { " KQ_scaled" , OFFLOAD_FUNC_KQ },
5171
- { " KQ_scaled_alibi" , OFFLOAD_FUNC_KQ },
5172
- { " KQ_masked" , OFFLOAD_FUNC_KQ },
5173
- { " KQ_soft_max" , OFFLOAD_FUNC_V },
5174
- { " V" , OFFLOAD_FUNC_V },
5175
- { " KQV" , OFFLOAD_FUNC_V },
5176
- { " KQV_merged" , OFFLOAD_FUNC_V },
5177
- { " KQV_merged_contiguous" , OFFLOAD_FUNC_V },
5178
-
5179
- { " result_wo" , OFFLOAD_FUNC },
5180
- { " result_wo_b" , OFFLOAD_FUNC },
5181
- { " inpL_+_result_wo" , OFFLOAD_FUNC },
5199
+ { " q" , OFFLOAD_FUNC_KQ },
5200
+ { " k" , OFFLOAD_FUNC_KQ },
5201
+ { " kq" , OFFLOAD_FUNC_KQ },
5202
+ { " kq_scaled" , OFFLOAD_FUNC_KQ },
5203
+ { " kq_scaled_alibi" , OFFLOAD_FUNC_KQ },
5204
+ { " kq_masked" , OFFLOAD_FUNC_KQ },
5205
+ { " kq_soft_max" , OFFLOAD_FUNC_V },
5206
+ { " v" , OFFLOAD_FUNC_V },
5207
+ { " kqv" , OFFLOAD_FUNC_V },
5208
+ { " kqv_merged" , OFFLOAD_FUNC_V },
5209
+ { " kqv_merged_cont" , OFFLOAD_FUNC_V },
5210
+ { " kqv_wo" , OFFLOAD_FUNC_V },
5211
+ { " kqv_out" , OFFLOAD_FUNC_V },
5182
5212
5213
+ { " inpL_+_result_wo" , OFFLOAD_FUNC },
5183
5214
{ " inpFF" , OFFLOAD_FUNC },
5184
5215
5185
5216
{ " ffn_norm" , OFFLOAD_FUNC },
0 commit comments