@@ -3230,6 +3230,65 @@ static struct ggml_tensor * llm_build_ffn(
3230
3230
return cur;
3231
3231
}
3232
3232
3233
+ enum llm_rope_type {
3234
+ LLM_ROPE,
3235
+ LLM_ROPE_NEOX,
3236
+ LLM_ROPE_GLM,
3237
+ };
3238
+
3239
+ // Persimmon: n_rot = n_embd_head/2
3240
+ // Other: n_rot = n_embd_head
3241
+ static void llm_build_k_shift(
3242
+ const llama_context & lctx,
3243
+ struct ggml_context * ctx,
3244
+ struct ggml_cgraph * graph,
3245
+ int64_t n_rot,
3246
+ llm_rope_type type,
3247
+ const llm_build_cb & cb) {
3248
+ const auto & model = lctx.model;
3249
+ const auto & kv_self = lctx.kv_self;
3250
+ const auto & cparams = lctx.cparams;
3251
+
3252
+ const auto & hparams = model.hparams;
3253
+
3254
+ const int64_t n_head = hparams.n_head;
3255
+ const int64_t n_layer = hparams.n_layer;
3256
+ const int64_t n_embd_gqa = hparams.n_embd_gqa();
3257
+ const int64_t n_embd_head = hparams.n_embd_head();
3258
+
3259
+ const int64_t n_ctx = lctx.cparams.n_ctx;
3260
+
3261
+ const float freq_base = cparams.rope_freq_base;
3262
+ const float freq_scale = cparams.rope_freq_scale;
3263
+
3264
+ GGML_ASSERT(n_embd_head % n_rot == 0);
3265
+
3266
+ struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
3267
+ cb(K_shift, "K_shift", -1);
3268
+
3269
+ int rope_type = 0;
3270
+
3271
+ switch (type) {
3272
+ case LLM_ROPE: rope_type = 0; break;
3273
+ case LLM_ROPE_NEOX: rope_type = 2; break;
3274
+ case LLM_ROPE_GLM: rope_type = 4; break;
3275
+ };
3276
+
3277
+ for (int il = 0; il < n_layer; ++il) {
3278
+ struct ggml_tensor * tmp =
3279
+ // we rotate only the first n_rot dimensions
3280
+ ggml_rope_custom_inplace(ctx,
3281
+ ggml_view_3d(ctx, kv_self.k,
3282
+ n_rot, n_head, n_ctx,
3283
+ ggml_element_size(kv_self.k)*n_embd_head,
3284
+ ggml_element_size(kv_self.k)*n_embd_gqa,
3285
+ ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3286
+ K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
3287
+ cb(tmp, "K_shifted", il);
3288
+ ggml_build_forward_expand(graph, tmp);
3289
+ }
3290
+ }
3291
+
3233
3292
static struct ggml_cgraph * llm_build_llama(
3234
3293
llama_context & lctx,
3235
3294
const llama_batch & batch,
@@ -3308,21 +3367,7 @@ static struct ggml_cgraph * llm_build_llama(
3308
3367
3309
3368
// shift the entire K-cache if needed
3310
3369
if (do_rope_shift) {
3311
- struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
3312
- cb (K_shift, " K_shift" , -1 );
3313
-
3314
- for (int il = 0 ; il < n_layer; ++il) {
3315
- struct ggml_tensor * tmp =
3316
- ggml_rope_custom_inplace (ctx0,
3317
- ggml_view_3d (ctx0, kv_self.k ,
3318
- n_embd_head, n_head_kv, n_ctx,
3319
- ggml_element_size (kv_self.k )*n_embd_head,
3320
- ggml_element_size (kv_self.k )*n_embd_gqa,
3321
- ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
3322
- K_shift, n_embd_head, 0 , 0 , freq_base, freq_scale);
3323
- cb (tmp, " K_shifted" , il);
3324
- ggml_build_forward_expand (gf, tmp);
3325
- }
3370
+ llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
3326
3371
}
3327
3372
3328
3373
for (int il = 0; il < n_layer; ++il) {
@@ -3557,21 +3602,7 @@ static struct ggml_cgraph * llm_build_baichaun(
3557
3602
3558
3603
// shift the entire K-cache if needed
3559
3604
if (do_rope_shift) {
3560
- struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
3561
- cb (K_shift, " K_shift" , -1 );
3562
-
3563
- for (int il = 0 ; il < n_layer; ++il) {
3564
- struct ggml_tensor * tmp =
3565
- ggml_rope_custom_inplace (ctx0,
3566
- ggml_view_3d (ctx0, kv_self.k ,
3567
- n_embd_head, n_head_kv, n_ctx,
3568
- ggml_element_size (kv_self.k )*n_embd_head,
3569
- ggml_element_size (kv_self.k )*n_embd_gqa,
3570
- ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
3571
- K_shift, n_embd_head, 0 , 0 , freq_base, freq_scale);
3572
- cb (tmp, " K_shifted" , il);
3573
- ggml_build_forward_expand (gf, tmp);
3574
- }
3605
+ llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
3575
3606
}
3576
3607
3577
3608
for (int il = 0; il < n_layer; ++il) {
@@ -3830,21 +3861,7 @@ static struct ggml_cgraph * llm_build_falcon(
3830
3861
3831
3862
// shift the entire K-cache if needed
3832
3863
if (do_rope_shift) {
3833
- struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
3834
- cb (K_shift, " K_shift" , -1 );
3835
-
3836
- for (int il = 0 ; il < n_layer; ++il) {
3837
- struct ggml_tensor * tmp =
3838
- ggml_rope_custom_inplace (ctx0,
3839
- ggml_view_3d (ctx0, kv_self.k ,
3840
- n_embd_head, n_head_kv, n_ctx,
3841
- ggml_element_size (kv_self.k )*n_embd_head,
3842
- ggml_element_size (kv_self.k )*n_embd_gqa,
3843
- ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
3844
- K_shift, n_embd_head, 2 , 0 , freq_base, freq_scale);
3845
- cb (tmp, " K_shifted" , il);
3846
- ggml_build_forward_expand (gf, tmp);
3847
- }
3864
+ llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE_NEOX, cb);
3848
3865
}
3849
3866
3850
3867
for (int il = 0; il < n_layer; ++il) {
@@ -4243,14 +4260,15 @@ static struct ggml_cgraph * llm_build_persimmon(
4243
4260
GGML_ASSERT(!!kv_self.ctx);
4244
4261
4245
4262
const auto & cparams = lctx.cparams;
4263
+
4246
4264
const int64_t n_embd = hparams.n_embd;
4247
4265
const int64_t n_layer = hparams.n_layer;
4248
4266
const int64_t n_ctx = cparams.n_ctx;
4249
4267
const int64_t n_head_kv = hparams.n_head_kv;
4250
4268
const int64_t n_head = hparams.n_head;
4251
4269
const int64_t n_embd_head = hparams.n_embd_head();
4252
4270
const int64_t n_embd_gqa = hparams.n_embd_gqa();
4253
- const size_t n_rot = n_embd_head / 2 ;
4271
+ const int64_t n_rot = n_embd_head / 2;
4254
4272
4255
4273
const float freq_base = cparams.rope_freq_base;
4256
4274
const float freq_scale = cparams.rope_freq_scale;
@@ -4297,23 +4315,7 @@ static struct ggml_cgraph * llm_build_persimmon(
4297
4315
cb(KQ_mask, "KQ_mask", -1);
4298
4316
4299
4317
if (do_rope_shift) {
4300
- struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
4301
- cb (K_shift, " K_shift" , -1 );
4302
-
4303
- for (int il = 0 ; il < n_layer; ++il) {
4304
- struct ggml_tensor * tmp =
4305
- // we rotate only the first n_rot dimensions.
4306
- ggml_rope_custom_inplace (ctx0,
4307
- ggml_view_3d (ctx0, kv_self.k ,
4308
- n_rot, n_head, n_ctx,
4309
- ggml_element_size (kv_self.k )*n_embd_gqa,
4310
- ggml_element_size (kv_self.k )*n_embd_head,
4311
- ggml_element_size (kv_self.k )*(n_embd_head*n_ctx*il)
4312
- ),
4313
- K_shift, n_rot, 2 , 0 , freq_base, freq_scale);
4314
- cb (tmp, " K_shifted" , il);
4315
- ggml_build_forward_expand (gf, tmp);
4316
- }
4318
+ llm_build_k_shift(lctx, ctx0, gf, n_rot, LLM_ROPE_NEOX, cb);
4317
4319
}
4318
4320
4319
4321
for (int il = 0; il < n_layer; ++il) {
@@ -5534,7 +5536,7 @@ static struct ggml_cgraph * llama_build_graph(
5534
5536
#ifdef GGML_USE_CUBLAS
5535
5537
const bool do_offload = true;
5536
5538
#else
5537
- const bool do_offload = false ;
5539
+ const bool do_offload = true; // TODO: set to false after finishing refactoring
5538
5540
#endif
5539
5541
5540
5542
if (!do_offload) {
0 commit comments