@@ -2720,8 +2720,8 @@ static void llm_load_tensors(
2720
2720
} break;
2721
2721
case LLM_ARCH_STARCODER:
2722
2722
{
2723
- model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2724
- model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
2723
+ model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2724
+ model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
2725
2725
2726
2726
// output
2727
2727
{
@@ -2772,19 +2772,19 @@ static void llm_load_tensors(
2772
2772
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
2773
2773
2774
2774
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
2775
- layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split );
2775
+ layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend );
2776
2776
2777
2777
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2778
- layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split );
2778
+ layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend );
2779
2779
2780
2780
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
2781
2781
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
2782
2782
2783
2783
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
2784
- layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split );
2784
+ layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend );
2785
2785
2786
2786
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
2787
- layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split );
2787
+ layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend );
2788
2788
2789
2789
if (backend == GGML_BACKEND_GPU) {
2790
2790
vram_weights +=
@@ -4641,6 +4641,8 @@ static struct ggml_cgraph * llm_build_starcoder(
4641
4641
4642
4642
const float norm_eps = hparams.f_norm_eps;
4643
4643
4644
+ const int n_gpu_layers = model.n_gpu_layers;
4645
+
4644
4646
const int32_t n_tokens = batch.n_tokens;
4645
4647
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
4646
4648
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
@@ -4685,6 +4687,27 @@ static struct ggml_cgraph * llm_build_starcoder(
4685
4687
}
4686
4688
}
4687
4689
4690
+ const int i_gpu_start = n_layer - n_gpu_layers;
4691
+ (void) i_gpu_start;
4692
+
4693
+ // offload functions set the tensor output backend to GPU
4694
+ // tensors are GPU-accelerated if any input or the output has been offloaded
4695
+ offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
4696
+ offload_func_t offload_func_kq = llama_nop;
4697
+ offload_func_t offload_func_v = llama_nop;
4698
+
4699
+ #ifdef GGML_USE_CUBLAS
4700
+ if (n_gpu_layers > n_layer) {
4701
+ offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
4702
+ }
4703
+ if (n_gpu_layers > n_layer + 1) {
4704
+ offload_func_v = ggml_cuda_assign_buffers_no_alloc;
4705
+ }
4706
+ if (n_gpu_layers > n_layer + 2) {
4707
+ offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
4708
+ }
4709
+ #endif // GGML_USE_CUBLAS
4710
+
4688
4711
{
4689
4712
// Compute position embeddings.
4690
4713
struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
@@ -4710,6 +4733,7 @@ static struct ggml_cgraph * llm_build_starcoder(
4710
4733
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
4711
4734
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
4712
4735
ggml_set_name(KQ_mask, "KQ_mask");
4736
+ offload_func_kq(KQ_mask);
4713
4737
ggml_allocr_alloc(lctx.alloc, KQ_mask);
4714
4738
if (!ggml_allocr_is_measure(lctx.alloc)) {
4715
4739
float * data = (float *) KQ_mask->data;
@@ -4733,44 +4757,67 @@ static struct ggml_cgraph * llm_build_starcoder(
4733
4757
ggml_set_name(inpL, "inpL");
4734
4758
4735
4759
for (int il = 0; il < n_layer; ++il) {
4760
+ offload_func_t offload_func = llama_nop;
4761
+
4762
+ #ifdef GGML_USE_CUBLAS
4763
+ if (il >= i_gpu_start) {
4764
+ offload_func = ggml_cuda_assign_buffers_no_alloc;
4765
+ }
4766
+ #endif // GGML_USE_CUBLAS
4767
+
4736
4768
{
4737
4769
// Norm
4738
4770
cur = ggml_norm(ctx0, inpL, norm_eps);
4771
+ offload_func(cur);
4772
+
4739
4773
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
4774
+ offload_func(cur);
4740
4775
}
4741
4776
4742
4777
{
4743
4778
// Self Attention
4744
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv);
4779
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
4780
+ offload_func_kq(cur);
4745
4781
4746
- struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd);
4747
- struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd);
4748
- struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa));
4782
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
4783
+ offload_func_kq(cur);
4749
4784
4750
- struct ggml_tensor * Qcur = tmpq;
4785
+ struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
4786
+ struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
4787
+ struct ggml_tensor * tmpv = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
4788
+
4789
+ ggml_set_name(tmpq, "tmpq");
4790
+ ggml_set_name(tmpk, "tmpk");
4791
+ ggml_set_name(tmpv, "tmpv");
4792
+
4793
+ offload_func_kq(tmpq);
4794
+ offload_func_kq(tmpk);
4795
+ offload_func_v (tmpv);
4796
+
4797
+ struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
4751
4798
struct ggml_tensor * Kcur = tmpk;
4752
4799
4753
4800
{
4754
- struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
4801
+ struct ggml_tensor * Vcur = ggml_transpose(ctx0, tmpv);
4802
+ offload_func_v(Vcur);
4755
4803
ggml_set_name(Vcur, "Vcur");
4756
4804
4757
4805
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
4806
+ offload_func_kq(k);
4758
4807
ggml_set_name(k, "k");
4759
4808
4760
4809
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
4761
4810
( n_ctx)*ggml_element_size(kv_self.v),
4762
4811
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
4812
+ offload_func_v(v);
4813
+ ggml_set_name(v, "v");
4763
4814
4764
4815
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
4765
4816
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
4766
4817
}
4767
4818
4768
- struct ggml_tensor * Q =
4769
- ggml_permute(ctx0,
4770
- ggml_cpy(ctx0,
4771
- Qcur,
4772
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)),
4773
- 0, 2, 1, 3);
4819
+ struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
4820
+ offload_func_kq(Q);
4774
4821
ggml_set_name(Q, "Q");
4775
4822
4776
4823
struct ggml_tensor * K =
@@ -4779,23 +4826,28 @@ static struct ggml_cgraph * llm_build_starcoder(
4779
4826
ggml_element_size(kv_self.k)*n_embd_gqa,
4780
4827
ggml_element_size(kv_self.k)*n_embd_head,
4781
4828
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
4829
+ offload_func_kq(K);
4782
4830
ggml_set_name(K, "K");
4783
4831
4784
4832
// K * Q
4785
4833
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
4834
+ offload_func_kq(KQ);
4786
4835
ggml_set_name(KQ, "KQ");
4787
4836
4788
4837
// KQ_scaled = KQ / sqrt(n_embd_head)
4789
4838
// KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1]
4790
4839
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
4840
+ offload_func_kq(KQ_scaled);
4791
4841
ggml_set_name(KQ_scaled, "KQ_scaled");
4792
4842
4793
4843
// KQ_masked = mask_past(KQ_scaled)
4794
4844
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
4845
+ offload_func_kq(KQ_masked);
4795
4846
ggml_set_name(KQ_masked, "KQ_masked");
4796
4847
4797
4848
// KQ = soft_max(KQ_masked)
4798
4849
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
4850
+ offload_func_v(KQ_soft_max);
4799
4851
ggml_set_name(KQ_soft_max, "KQ_soft_max");
4800
4852
4801
4853
// split cached V into n_head heads
@@ -4808,22 +4860,25 @@ static struct ggml_cgraph * llm_build_starcoder(
4808
4860
ggml_set_name(V, "V");
4809
4861
4810
4862
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
4863
+ offload_func_v(KQV);
4811
4864
ggml_set_name(KQV, "KQV");
4812
4865
4813
- // KQV_merged = KQV.permute(0, 2, 1, 3)
4814
4866
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
4867
+ offload_func_v(KQV_merged);
4815
4868
ggml_set_name(KQV_merged, "KQV_merged");
4816
4869
4817
- // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
4818
4870
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
4871
+ offload_func_v(cur);
4819
4872
ggml_set_name(cur, "KQV_merged_contiguous");
4820
4873
}
4821
4874
4822
4875
// Projection
4823
4876
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
4877
+ offload_func(cur);
4824
4878
4825
4879
// Add the input
4826
4880
cur = ggml_add(ctx0, cur, inpL);
4881
+ offload_func(cur);
4827
4882
4828
4883
struct ggml_tensor * inpFF = cur;
4829
4884
@@ -4832,27 +4887,36 @@ static struct ggml_cgraph * llm_build_starcoder(
4832
4887
// Norm
4833
4888
{
4834
4889
cur = ggml_norm(ctx0, inpFF, norm_eps);
4890
+ offload_func_nr(cur);
4891
+
4835
4892
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
4893
+ offload_func_nr(cur);
4836
4894
}
4837
4895
4838
4896
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
4897
+ offload_func(cur);
4839
4898
4840
4899
// GELU activation
4841
4900
cur = ggml_gelu(ctx0, cur);
4901
+ offload_func(cur);
4842
4902
4843
4903
// Projection
4844
4904
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
4905
+ offload_func(cur);
4845
4906
}
4846
4907
4847
4908
inpL = ggml_add(ctx0, cur, inpFF);
4909
+
4848
4910
}
4849
4911
4850
4912
// Output Norm
4851
4913
{
4852
4914
cur = ggml_norm(ctx0, inpL, norm_eps);
4915
+ offload_func_nr(cur);
4916
+
4853
4917
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
4918
+ ggml_set_name(cur, "result_norm");
4854
4919
}
4855
- ggml_set_name(cur, "result_norm");
4856
4920
4857
4921
cur = ggml_mul_mat(ctx0, model.output, cur);
4858
4922
ggml_set_name(cur, "result_output");
0 commit comments