@@ -3230,6 +3230,51 @@ static struct ggml_tensor * llm_build_ffn(
3230
3230
return cur;
3231
3231
}
3232
3232
3233
+ // Persimmon: n_rot = n_embd_head/2
3234
+ // Other: n_rot = n_embd_head
3235
+ static void llm_build_k_shift(
3236
+ const llama_context & lctx,
3237
+ struct ggml_context * ctx,
3238
+ struct ggml_cgraph * graph,
3239
+ int64_t n_rot,
3240
+ const llm_build_cb & cb) {
3241
+ const auto & model = lctx.model;
3242
+ const auto & kv_self = lctx.kv_self;
3243
+ const auto & cparams = lctx.cparams;
3244
+
3245
+ const auto & hparams = model.hparams;
3246
+
3247
+ const int64_t n_head = hparams.n_head;
3248
+ const int64_t n_layer = hparams.n_layer;
3249
+ const int64_t n_embd_gqa = hparams.n_embd_gqa();
3250
+ const int64_t n_embd_head = hparams.n_embd_head();
3251
+
3252
+ const int64_t n_ctx = lctx.cparams.n_ctx;
3253
+
3254
+ const float freq_base = cparams.rope_freq_base;
3255
+ const float freq_scale = cparams.rope_freq_scale;
3256
+
3257
+ GGML_ASSERT(n_embd_head % n_rot == 0);
3258
+
3259
+ struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
3260
+ cb(K_shift, "K_shift", -1);
3261
+
3262
+ for (int il = 0; il < n_layer; ++il) {
3263
+ struct ggml_tensor * tmp =
3264
+ // we rotate only the first n_rot dimensions
3265
+ ggml_rope_custom_inplace(ctx,
3266
+ ggml_view_3d(ctx, kv_self.k,
3267
+ n_rot, n_head, n_ctx,
3268
+ ggml_element_size(kv_self.k)*n_embd_gqa,
3269
+ ggml_element_size(kv_self.k)*n_embd_head,
3270
+ ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il)
3271
+ ),
3272
+ K_shift, n_rot, 2, 0, freq_base, freq_scale);
3273
+ cb(tmp, "K_shifted", il);
3274
+ ggml_build_forward_expand(graph, tmp);
3275
+ }
3276
+ }
3277
+
3233
3278
static struct ggml_cgraph * llm_build_llama(
3234
3279
llama_context & lctx,
3235
3280
const llama_batch & batch,
@@ -3308,21 +3353,7 @@ static struct ggml_cgraph * llm_build_llama(
3308
3353
3309
3354
// shift the entire K-cache if needed
3310
3355
if (do_rope_shift) {
3311
- struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3312
- cb(K_shift, "K_shift", -1);
3313
-
3314
- for (int il = 0; il < n_layer; ++il) {
3315
- struct ggml_tensor * tmp =
3316
- ggml_rope_custom_inplace(ctx0,
3317
- ggml_view_3d(ctx0, kv_self.k,
3318
- n_embd_head, n_head_kv, n_ctx,
3319
- ggml_element_size(kv_self.k)*n_embd_head,
3320
- ggml_element_size(kv_self.k)*n_embd_gqa,
3321
- ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3322
- K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
3323
- cb(tmp, "K_shifted", il);
3324
- ggml_build_forward_expand(gf, tmp);
3325
- }
3356
+ llm_build_k_shift(lctx, ctx0, gf, n_embd_head, cb);
3326
3357
}
3327
3358
3328
3359
for (int il = 0; il < n_layer; ++il) {
@@ -3557,21 +3588,7 @@ static struct ggml_cgraph * llm_build_baichaun(
3557
3588
3558
3589
// shift the entire K-cache if needed
3559
3590
if (do_rope_shift) {
3560
- struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3561
- cb(K_shift, "K_shift", -1);
3562
-
3563
- for (int il = 0; il < n_layer; ++il) {
3564
- struct ggml_tensor * tmp =
3565
- ggml_rope_custom_inplace(ctx0,
3566
- ggml_view_3d(ctx0, kv_self.k,
3567
- n_embd_head, n_head_kv, n_ctx,
3568
- ggml_element_size(kv_self.k)*n_embd_head,
3569
- ggml_element_size(kv_self.k)*n_embd_gqa,
3570
- ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3571
- K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
3572
- cb(tmp, "K_shifted", il);
3573
- ggml_build_forward_expand(gf, tmp);
3574
- }
3591
+ llm_build_k_shift(lctx, ctx0, gf, n_embd_head, cb);
3575
3592
}
3576
3593
3577
3594
for (int il = 0; il < n_layer; ++il) {
@@ -3830,21 +3847,7 @@ static struct ggml_cgraph * llm_build_falcon(
3830
3847
3831
3848
// shift the entire K-cache if needed
3832
3849
if (do_rope_shift) {
3833
- struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
3834
- cb(K_shift, "K_shift", -1);
3835
-
3836
- for (int il = 0; il < n_layer; ++il) {
3837
- struct ggml_tensor * tmp =
3838
- ggml_rope_custom_inplace(ctx0,
3839
- ggml_view_3d(ctx0, kv_self.k,
3840
- n_embd_head, n_head_kv, n_ctx,
3841
- ggml_element_size(kv_self.k)*n_embd_head,
3842
- ggml_element_size(kv_self.k)*n_embd_gqa,
3843
- ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
3844
- K_shift, n_embd_head, 2, 0, freq_base, freq_scale);
3845
- cb(tmp, "K_shifted", il);
3846
- ggml_build_forward_expand(gf, tmp);
3847
- }
3850
+ llm_build_k_shift(lctx, ctx0, gf, n_embd_head, cb);
3848
3851
}
3849
3852
3850
3853
for (int il = 0; il < n_layer; ++il) {
@@ -4243,14 +4246,15 @@ static struct ggml_cgraph * llm_build_persimmon(
4243
4246
GGML_ASSERT(!!kv_self.ctx);
4244
4247
4245
4248
const auto & cparams = lctx.cparams;
4249
+
4246
4250
const int64_t n_embd = hparams.n_embd;
4247
4251
const int64_t n_layer = hparams.n_layer;
4248
4252
const int64_t n_ctx = cparams.n_ctx;
4249
4253
const int64_t n_head_kv = hparams.n_head_kv;
4250
4254
const int64_t n_head = hparams.n_head;
4251
4255
const int64_t n_embd_head = hparams.n_embd_head();
4252
4256
const int64_t n_embd_gqa = hparams.n_embd_gqa();
4253
- const size_t n_rot = n_embd_head / 2;
4257
+ const int64_t n_rot = n_embd_head / 2;
4254
4258
4255
4259
const float freq_base = cparams.rope_freq_base;
4256
4260
const float freq_scale = cparams.rope_freq_scale;
@@ -4297,23 +4301,7 @@ static struct ggml_cgraph * llm_build_persimmon(
4297
4301
cb(KQ_mask, "KQ_mask", -1);
4298
4302
4299
4303
if (do_rope_shift) {
4300
- struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
4301
- cb(K_shift, "K_shift", -1);
4302
-
4303
- for (int il = 0; il < n_layer; ++il) {
4304
- struct ggml_tensor * tmp =
4305
- // we rotate only the first n_rot dimensions.
4306
- ggml_rope_custom_inplace(ctx0,
4307
- ggml_view_3d(ctx0, kv_self.k,
4308
- n_rot, n_head, n_ctx,
4309
- ggml_element_size(kv_self.k)*n_embd_gqa,
4310
- ggml_element_size(kv_self.k)*n_embd_head,
4311
- ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il)
4312
- ),
4313
- K_shift, n_rot, 2, 0, freq_base, freq_scale);
4314
- cb(tmp, "K_shifted", il);
4315
- ggml_build_forward_expand(gf, tmp);
4316
- }
4304
+ llm_build_k_shift(lctx, ctx0, gf, n_rot, cb);
4317
4305
}
4318
4306
4319
4307
for (int il = 0; il < n_layer; ++il) {
@@ -5534,7 +5522,7 @@ static struct ggml_cgraph * llama_build_graph(
5534
5522
#ifdef GGML_USE_CUBLAS
5535
5523
const bool do_offload = true;
5536
5524
#else
5537
- const bool do_offload = false;
5525
+ const bool do_offload = true; // TODO: set to false after finishing refactoring
5538
5526
#endif
5539
5527
5540
5528
if (!do_offload) {
0 commit comments