Skip to content

Commit fdee152

Browse files
authored
starcoder : add GPU offloading (#3827)
* starcoder : do not GPU split 1D bias tensors * starcoder : offload layers to GPU ggml-ci
1 parent 41aee4d commit fdee152

File tree

1 file changed

+85
-21
lines changed

1 file changed

+85
-21
lines changed

Diff for: llama.cpp

+85-21
Original file line numberDiff line numberDiff line change
@@ -2695,8 +2695,8 @@ static void llm_load_tensors(
26952695
} break;
26962696
case LLM_ARCH_STARCODER:
26972697
{
2698-
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2699-
model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
2698+
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2699+
model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
27002700

27012701
// output
27022702
{
@@ -2747,19 +2747,19 @@ static void llm_load_tensors(
27472747
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
27482748

27492749
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
2750-
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split);
2750+
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend);
27512751

27522752
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2753-
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);
2753+
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend);
27542754

27552755
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
27562756
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
27572757

27582758
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
2759-
layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split);
2759+
layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend);
27602760

27612761
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
2762-
layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split);
2762+
layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
27632763

27642764
if (backend == GGML_BACKEND_GPU) {
27652765
vram_weights +=
@@ -4616,6 +4616,8 @@ static struct ggml_cgraph * llm_build_starcoder(
46164616

46174617
const float norm_eps = hparams.f_norm_eps;
46184618

4619+
const int n_gpu_layers = model.n_gpu_layers;
4620+
46194621
const int32_t n_tokens = batch.n_tokens;
46204622
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
46214623
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
@@ -4660,6 +4662,27 @@ static struct ggml_cgraph * llm_build_starcoder(
46604662
}
46614663
}
46624664

4665+
const int i_gpu_start = n_layer - n_gpu_layers;
4666+
(void) i_gpu_start;
4667+
4668+
// offload functions set the tensor output backend to GPU
4669+
// tensors are GPU-accelerated if any input or the output has been offloaded
4670+
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
4671+
offload_func_t offload_func_kq = llama_nop;
4672+
offload_func_t offload_func_v = llama_nop;
4673+
4674+
#ifdef GGML_USE_CUBLAS
4675+
if (n_gpu_layers > n_layer) {
4676+
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
4677+
}
4678+
if (n_gpu_layers > n_layer + 1) {
4679+
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
4680+
}
4681+
if (n_gpu_layers > n_layer + 2) {
4682+
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
4683+
}
4684+
#endif // GGML_USE_CUBLAS
4685+
46634686
{
46644687
// Compute position embeddings.
46654688
struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
@@ -4685,6 +4708,7 @@ static struct ggml_cgraph * llm_build_starcoder(
46854708
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
46864709
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
46874710
ggml_set_name(KQ_mask, "KQ_mask");
4711+
offload_func_kq(KQ_mask);
46884712
ggml_allocr_alloc(lctx.alloc, KQ_mask);
46894713
if (!ggml_allocr_is_measure(lctx.alloc)) {
46904714
float * data = (float *) KQ_mask->data;
@@ -4708,44 +4732,67 @@ static struct ggml_cgraph * llm_build_starcoder(
47084732
ggml_set_name(inpL, "inpL");
47094733

47104734
for (int il = 0; il < n_layer; ++il) {
4735+
offload_func_t offload_func = llama_nop;
4736+
4737+
#ifdef GGML_USE_CUBLAS
4738+
if (il >= i_gpu_start) {
4739+
offload_func = ggml_cuda_assign_buffers_no_alloc;
4740+
}
4741+
#endif // GGML_USE_CUBLAS
4742+
47114743
{
47124744
// Norm
47134745
cur = ggml_norm(ctx0, inpL, norm_eps);
4746+
offload_func(cur);
4747+
47144748
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
4749+
offload_func(cur);
47154750
}
47164751

47174752
{
47184753
// Self Attention
4719-
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv);
4754+
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
4755+
offload_func_kq(cur);
47204756

4721-
struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd);
4722-
struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd);
4723-
struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa));
4757+
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
4758+
offload_func_kq(cur);
47244759

4725-
struct ggml_tensor * Qcur = tmpq;
4760+
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
4761+
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
4762+
struct ggml_tensor * tmpv = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
4763+
4764+
ggml_set_name(tmpq, "tmpq");
4765+
ggml_set_name(tmpk, "tmpk");
4766+
ggml_set_name(tmpv, "tmpv");
4767+
4768+
offload_func_kq(tmpq);
4769+
offload_func_kq(tmpk);
4770+
offload_func_v (tmpv);
4771+
4772+
struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
47264773
struct ggml_tensor * Kcur = tmpk;
47274774

47284775
{
4729-
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
4776+
struct ggml_tensor * Vcur = ggml_transpose(ctx0, tmpv);
4777+
offload_func_v(Vcur);
47304778
ggml_set_name(Vcur, "Vcur");
47314779

47324780
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
4781+
offload_func_kq(k);
47334782
ggml_set_name(k, "k");
47344783

47354784
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
47364785
( n_ctx)*ggml_element_size(kv_self.v),
47374786
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
4787+
offload_func_v(v);
4788+
ggml_set_name(v, "v");
47384789

47394790
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
47404791
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
47414792
}
47424793

4743-
struct ggml_tensor * Q =
4744-
ggml_permute(ctx0,
4745-
ggml_cpy(ctx0,
4746-
Qcur,
4747-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)),
4748-
0, 2, 1, 3);
4794+
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
4795+
offload_func_kq(Q);
47494796
ggml_set_name(Q, "Q");
47504797

47514798
struct ggml_tensor * K =
@@ -4754,23 +4801,28 @@ static struct ggml_cgraph * llm_build_starcoder(
47544801
ggml_element_size(kv_self.k)*n_embd_gqa,
47554802
ggml_element_size(kv_self.k)*n_embd_head,
47564803
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
4804+
offload_func_kq(K);
47574805
ggml_set_name(K, "K");
47584806

47594807
// K * Q
47604808
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
4809+
offload_func_kq(KQ);
47614810
ggml_set_name(KQ, "KQ");
47624811

47634812
// KQ_scaled = KQ / sqrt(n_embd_head)
47644813
// KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1]
47654814
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
4815+
offload_func_kq(KQ_scaled);
47664816
ggml_set_name(KQ_scaled, "KQ_scaled");
47674817

47684818
// KQ_masked = mask_past(KQ_scaled)
47694819
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
4820+
offload_func_kq(KQ_masked);
47704821
ggml_set_name(KQ_masked, "KQ_masked");
47714822

47724823
// KQ = soft_max(KQ_masked)
47734824
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
4825+
offload_func_v(KQ_soft_max);
47744826
ggml_set_name(KQ_soft_max, "KQ_soft_max");
47754827

47764828
// split cached V into n_head heads
@@ -4783,22 +4835,25 @@ static struct ggml_cgraph * llm_build_starcoder(
47834835
ggml_set_name(V, "V");
47844836

47854837
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
4838+
offload_func_v(KQV);
47864839
ggml_set_name(KQV, "KQV");
47874840

4788-
// KQV_merged = KQV.permute(0, 2, 1, 3)
47894841
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
4842+
offload_func_v(KQV_merged);
47904843
ggml_set_name(KQV_merged, "KQV_merged");
47914844

4792-
// cur = KQV_merged.contiguous().view(n_embd, n_tokens)
47934845
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
4846+
offload_func_v(cur);
47944847
ggml_set_name(cur, "KQV_merged_contiguous");
47954848
}
47964849

47974850
// Projection
47984851
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
4852+
offload_func(cur);
47994853

48004854
// Add the input
48014855
cur = ggml_add(ctx0, cur, inpL);
4856+
offload_func(cur);
48024857

48034858
struct ggml_tensor * inpFF = cur;
48044859

@@ -4807,27 +4862,36 @@ static struct ggml_cgraph * llm_build_starcoder(
48074862
// Norm
48084863
{
48094864
cur = ggml_norm(ctx0, inpFF, norm_eps);
4865+
offload_func_nr(cur);
4866+
48104867
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
4868+
offload_func_nr(cur);
48114869
}
48124870

48134871
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
4872+
offload_func(cur);
48144873

48154874
// GELU activation
48164875
cur = ggml_gelu(ctx0, cur);
4876+
offload_func(cur);
48174877

48184878
// Projection
48194879
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
4880+
offload_func(cur);
48204881
}
48214882

48224883
inpL = ggml_add(ctx0, cur, inpFF);
4884+
48234885
}
48244886

48254887
// Output Norm
48264888
{
48274889
cur = ggml_norm(ctx0, inpL, norm_eps);
4890+
offload_func_nr(cur);
4891+
48284892
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
4893+
ggml_set_name(cur, "result_norm");
48294894
}
4830-
ggml_set_name(cur, "result_norm");
48314895

48324896
cur = ggml_mul_mat(ctx0, model.output, cur);
48334897
ggml_set_name(cur, "result_output");

0 commit comments

Comments
 (0)