Skip to content

Commit 945b4fb

Browse files
ggerganovNexesenex
authored andcommitted
starcoder : add GPU offloading (ggml-org#3827)
* starcoder : do not GPU split 1D bias tensors * starcoder : offload layers to GPU ggml-ci
1 parent 20ef442 commit 945b4fb

File tree

1 file changed

+85
-21
lines changed

1 file changed

+85
-21
lines changed

Diff for: llama.cpp

+85-21
Original file line numberDiff line numberDiff line change
@@ -2720,8 +2720,8 @@ static void llm_load_tensors(
27202720
} break;
27212721
case LLM_ARCH_STARCODER:
27222722
{
2723-
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2724-
model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
2723+
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2724+
model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
27252725

27262726
// output
27272727
{
@@ -2772,19 +2772,19 @@ static void llm_load_tensors(
27722772
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
27732773

27742774
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
2775-
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split);
2775+
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend);
27762776

27772777
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2778-
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);
2778+
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
27792779

27802780
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
27812781
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
27822782

27832783
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
2784-
layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split);
2784+
layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
27852785

27862786
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
2787-
layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split);
2787+
layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
27882788

27892789
if (backend == GGML_BACKEND_GPU) {
27902790
vram_weights +=
@@ -4641,6 +4641,8 @@ static struct ggml_cgraph * llm_build_starcoder(
46414641

46424642
const float norm_eps = hparams.f_norm_eps;
46434643

4644+
const int n_gpu_layers = model.n_gpu_layers;
4645+
46444646
const int32_t n_tokens = batch.n_tokens;
46454647
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
46464648
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
@@ -4685,6 +4687,27 @@ static struct ggml_cgraph * llm_build_starcoder(
46854687
}
46864688
}
46874689

4690+
const int i_gpu_start = n_layer - n_gpu_layers;
4691+
(void) i_gpu_start;
4692+
4693+
// offload functions set the tensor output backend to GPU
4694+
// tensors are GPU-accelerated if any input or the output has been offloaded
4695+
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
4696+
offload_func_t offload_func_kq = llama_nop;
4697+
offload_func_t offload_func_v = llama_nop;
4698+
4699+
#ifdef GGML_USE_CUBLAS
4700+
if (n_gpu_layers > n_layer) {
4701+
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
4702+
}
4703+
if (n_gpu_layers > n_layer + 1) {
4704+
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
4705+
}
4706+
if (n_gpu_layers > n_layer + 2) {
4707+
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
4708+
}
4709+
#endif // GGML_USE_CUBLAS
4710+
46884711
{
46894712
// Compute position embeddings.
46904713
struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
@@ -4710,6 +4733,7 @@ static struct ggml_cgraph * llm_build_starcoder(
47104733
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
47114734
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
47124735
ggml_set_name(KQ_mask, "KQ_mask");
4736+
offload_func_kq(KQ_mask);
47134737
ggml_allocr_alloc(lctx.alloc, KQ_mask);
47144738
if (!ggml_allocr_is_measure(lctx.alloc)) {
47154739
float * data = (float *) KQ_mask->data;
@@ -4733,44 +4757,67 @@ static struct ggml_cgraph * llm_build_starcoder(
47334757
ggml_set_name(inpL, "inpL");
47344758

47354759
for (int il = 0; il < n_layer; ++il) {
4760+
offload_func_t offload_func = llama_nop;
4761+
4762+
#ifdef GGML_USE_CUBLAS
4763+
if (il >= i_gpu_start) {
4764+
offload_func = ggml_cuda_assign_buffers_no_alloc;
4765+
}
4766+
#endif // GGML_USE_CUBLAS
4767+
47364768
{
47374769
// Norm
47384770
cur = ggml_norm(ctx0, inpL, norm_eps);
4771+
offload_func(cur);
4772+
47394773
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
4774+
offload_func(cur);
47404775
}
47414776

47424777
{
47434778
// Self Attention
4744-
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv);
4779+
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
4780+
offload_func_kq(cur);
47454781

4746-
struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd);
4747-
struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd);
4748-
struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa));
4782+
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
4783+
offload_func_kq(cur);
47494784

4750-
struct ggml_tensor * Qcur = tmpq;
4785+
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
4786+
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
4787+
struct ggml_tensor * tmpv = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
4788+
4789+
ggml_set_name(tmpq, "tmpq");
4790+
ggml_set_name(tmpk, "tmpk");
4791+
ggml_set_name(tmpv, "tmpv");
4792+
4793+
offload_func_kq(tmpq);
4794+
offload_func_kq(tmpk);
4795+
offload_func_v (tmpv);
4796+
4797+
struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
47514798
struct ggml_tensor * Kcur = tmpk;
47524799

47534800
{
4754-
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
4801+
struct ggml_tensor * Vcur = ggml_transpose(ctx0, tmpv);
4802+
offload_func_v(Vcur);
47554803
ggml_set_name(Vcur, "Vcur");
47564804

47574805
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
4806+
offload_func_kq(k);
47584807
ggml_set_name(k, "k");
47594808

47604809
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
47614810
( n_ctx)*ggml_element_size(kv_self.v),
47624811
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
4812+
offload_func_v(v);
4813+
ggml_set_name(v, "v");
47634814

47644815
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
47654816
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
47664817
}
47674818

4768-
struct ggml_tensor * Q =
4769-
ggml_permute(ctx0,
4770-
ggml_cpy(ctx0,
4771-
Qcur,
4772-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)),
4773-
0, 2, 1, 3);
4819+
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
4820+
offload_func_kq(Q);
47744821
ggml_set_name(Q, "Q");
47754822

47764823
struct ggml_tensor * K =
@@ -4779,23 +4826,28 @@ static struct ggml_cgraph * llm_build_starcoder(
47794826
ggml_element_size(kv_self.k)*n_embd_gqa,
47804827
ggml_element_size(kv_self.k)*n_embd_head,
47814828
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
4829+
offload_func_kq(K);
47824830
ggml_set_name(K, "K");
47834831

47844832
// K * Q
47854833
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
4834+
offload_func_kq(KQ);
47864835
ggml_set_name(KQ, "KQ");
47874836

47884837
// KQ_scaled = KQ / sqrt(n_embd_head)
47894838
// KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1]
47904839
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
4840+
offload_func_kq(KQ_scaled);
47914841
ggml_set_name(KQ_scaled, "KQ_scaled");
47924842

47934843
// KQ_masked = mask_past(KQ_scaled)
47944844
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
4845+
offload_func_kq(KQ_masked);
47954846
ggml_set_name(KQ_masked, "KQ_masked");
47964847

47974848
// KQ = soft_max(KQ_masked)
47984849
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
4850+
offload_func_v(KQ_soft_max);
47994851
ggml_set_name(KQ_soft_max, "KQ_soft_max");
48004852

48014853
// split cached V into n_head heads
@@ -4808,22 +4860,25 @@ static struct ggml_cgraph * llm_build_starcoder(
48084860
ggml_set_name(V, "V");
48094861

48104862
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
4863+
offload_func_v(KQV);
48114864
ggml_set_name(KQV, "KQV");
48124865

4813-
// KQV_merged = KQV.permute(0, 2, 1, 3)
48144866
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
4867+
offload_func_v(KQV_merged);
48154868
ggml_set_name(KQV_merged, "KQV_merged");
48164869

4817-
// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
48184870
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
4871+
offload_func_v(cur);
48194872
ggml_set_name(cur, "KQV_merged_contiguous");
48204873
}
48214874

48224875
// Projection
48234876
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
4877+
offload_func(cur);
48244878

48254879
// Add the input
48264880
cur = ggml_add(ctx0, cur, inpL);
4881+
offload_func(cur);
48274882

48284883
struct ggml_tensor * inpFF = cur;
48294884

@@ -4832,27 +4887,36 @@ static struct ggml_cgraph * llm_build_starcoder(
48324887
// Norm
48334888
{
48344889
cur = ggml_norm(ctx0, inpFF, norm_eps);
4890+
offload_func_nr(cur);
4891+
48354892
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
4893+
offload_func_nr(cur);
48364894
}
48374895

48384896
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
4897+
offload_func(cur);
48394898

48404899
// GELU activation
48414900
cur = ggml_gelu(ctx0, cur);
4901+
offload_func(cur);
48424902

48434903
// Projection
48444904
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
4905+
offload_func(cur);
48454906
}
48464907

48474908
inpL = ggml_add(ctx0, cur, inpFF);
4909+
48484910
}
48494911

48504912
// Output Norm
48514913
{
48524914
cur = ggml_norm(ctx0, inpL, norm_eps);
4915+
offload_func_nr(cur);
4916+
48534917
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
4918+
ggml_set_name(cur, "result_norm");
48544919
}
4855-
ggml_set_name(cur, "result_norm");
48564920

48574921
cur = ggml_mul_mat(ctx0, model.output, cur);
48584922
ggml_set_name(cur, "result_output");

0 commit comments

Comments
 (0)