@@ -2695,8 +2695,8 @@ static void llm_load_tensors(
2695
2695
} break;
2696
2696
case LLM_ARCH_STARCODER:
2697
2697
{
2698
- model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2699
- model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
2698
+ model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2699
+ model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
2700
2700
2701
2701
// output
2702
2702
{
@@ -2747,19 +2747,19 @@ static void llm_load_tensors(
2747
2747
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
2748
2748
2749
2749
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
2750
- layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split );
2750
+ layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend );
2751
2751
2752
2752
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2753
- layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split );
2753
+ layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend );
2754
2754
2755
2755
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
2756
2756
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
2757
2757
2758
2758
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
2759
- layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split );
2759
+ layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend );
2760
2760
2761
2761
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
2762
- layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split );
2762
+ layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend );
2763
2763
2764
2764
if (backend == GGML_BACKEND_GPU) {
2765
2765
vram_weights +=
@@ -4616,6 +4616,8 @@ static struct ggml_cgraph * llm_build_starcoder(
4616
4616
4617
4617
const float norm_eps = hparams.f_norm_eps;
4618
4618
4619
+ const int n_gpu_layers = model.n_gpu_layers;
4620
+
4619
4621
const int32_t n_tokens = batch.n_tokens;
4620
4622
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
4621
4623
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
@@ -4660,6 +4662,27 @@ static struct ggml_cgraph * llm_build_starcoder(
4660
4662
}
4661
4663
}
4662
4664
4665
+ const int i_gpu_start = n_layer - n_gpu_layers;
4666
+ (void) i_gpu_start;
4667
+
4668
+ // offload functions set the tensor output backend to GPU
4669
+ // tensors are GPU-accelerated if any input or the output has been offloaded
4670
+ offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
4671
+ offload_func_t offload_func_kq = llama_nop;
4672
+ offload_func_t offload_func_v = llama_nop;
4673
+
4674
+ #ifdef GGML_USE_CUBLAS
4675
+ if (n_gpu_layers > n_layer) {
4676
+ offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
4677
+ }
4678
+ if (n_gpu_layers > n_layer + 1) {
4679
+ offload_func_v = ggml_cuda_assign_buffers_no_alloc;
4680
+ }
4681
+ if (n_gpu_layers > n_layer + 2) {
4682
+ offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
4683
+ }
4684
+ #endif // GGML_USE_CUBLAS
4685
+
4663
4686
{
4664
4687
// Compute position embeddings.
4665
4688
struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
@@ -4685,6 +4708,7 @@ static struct ggml_cgraph * llm_build_starcoder(
4685
4708
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
4686
4709
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
4687
4710
ggml_set_name(KQ_mask, "KQ_mask");
4711
+ offload_func_kq(KQ_mask);
4688
4712
ggml_allocr_alloc(lctx.alloc, KQ_mask);
4689
4713
if (!ggml_allocr_is_measure(lctx.alloc)) {
4690
4714
float * data = (float *) KQ_mask->data;
@@ -4708,44 +4732,67 @@ static struct ggml_cgraph * llm_build_starcoder(
4708
4732
ggml_set_name(inpL, "inpL");
4709
4733
4710
4734
for (int il = 0; il < n_layer; ++il) {
4735
+ offload_func_t offload_func = llama_nop;
4736
+
4737
+ #ifdef GGML_USE_CUBLAS
4738
+ if (il >= i_gpu_start) {
4739
+ offload_func = ggml_cuda_assign_buffers_no_alloc;
4740
+ }
4741
+ #endif // GGML_USE_CUBLAS
4742
+
4711
4743
{
4712
4744
// Norm
4713
4745
cur = ggml_norm(ctx0, inpL, norm_eps);
4746
+ offload_func(cur);
4747
+
4714
4748
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
4749
+ offload_func(cur);
4715
4750
}
4716
4751
4717
4752
{
4718
4753
// Self Attention
4719
- cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv);
4754
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
4755
+ offload_func_kq(cur);
4720
4756
4721
- struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd);
4722
- struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd);
4723
- struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa));
4757
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
4758
+ offload_func_kq(cur);
4724
4759
4725
- struct ggml_tensor * Qcur = tmpq;
4760
+ struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
4761
+ struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
4762
+ struct ggml_tensor * tmpv = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
4763
+
4764
+ ggml_set_name(tmpq, "tmpq");
4765
+ ggml_set_name(tmpk, "tmpk");
4766
+ ggml_set_name(tmpv, "tmpv");
4767
+
4768
+ offload_func_kq(tmpq);
4769
+ offload_func_kq(tmpk);
4770
+ offload_func_v (tmpv);
4771
+
4772
+ struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
4726
4773
struct ggml_tensor * Kcur = tmpk;
4727
4774
4728
4775
{
4729
- struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens));
4776
+ struct ggml_tensor * Vcur = ggml_transpose(ctx0, tmpv);
4777
+ offload_func_v(Vcur);
4730
4778
ggml_set_name(Vcur, "Vcur");
4731
4779
4732
4780
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
4781
+ offload_func_kq(k);
4733
4782
ggml_set_name(k, "k");
4734
4783
4735
4784
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
4736
4785
( n_ctx)*ggml_element_size(kv_self.v),
4737
4786
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
4787
+ offload_func_v(v);
4788
+ ggml_set_name(v, "v");
4738
4789
4739
4790
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
4740
4791
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
4741
4792
}
4742
4793
4743
- struct ggml_tensor * Q =
4744
- ggml_permute(ctx0,
4745
- ggml_cpy(ctx0,
4746
- Qcur,
4747
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)),
4748
- 0, 2, 1, 3);
4794
+ struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
4795
+ offload_func_kq(Q);
4749
4796
ggml_set_name(Q, "Q");
4750
4797
4751
4798
struct ggml_tensor * K =
@@ -4754,23 +4801,28 @@ static struct ggml_cgraph * llm_build_starcoder(
4754
4801
ggml_element_size(kv_self.k)*n_embd_gqa,
4755
4802
ggml_element_size(kv_self.k)*n_embd_head,
4756
4803
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
4804
+ offload_func_kq(K);
4757
4805
ggml_set_name(K, "K");
4758
4806
4759
4807
// K * Q
4760
4808
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
4809
+ offload_func_kq(KQ);
4761
4810
ggml_set_name(KQ, "KQ");
4762
4811
4763
4812
// KQ_scaled = KQ / sqrt(n_embd_head)
4764
4813
// KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1]
4765
4814
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
4815
+ offload_func_kq(KQ_scaled);
4766
4816
ggml_set_name(KQ_scaled, "KQ_scaled");
4767
4817
4768
4818
// KQ_masked = mask_past(KQ_scaled)
4769
4819
struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask);
4820
+ offload_func_kq(KQ_masked);
4770
4821
ggml_set_name(KQ_masked, "KQ_masked");
4771
4822
4772
4823
// KQ = soft_max(KQ_masked)
4773
4824
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
4825
+ offload_func_v(KQ_soft_max);
4774
4826
ggml_set_name(KQ_soft_max, "KQ_soft_max");
4775
4827
4776
4828
// split cached V into n_head heads
@@ -4783,22 +4835,25 @@ static struct ggml_cgraph * llm_build_starcoder(
4783
4835
ggml_set_name(V, "V");
4784
4836
4785
4837
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
4838
+ offload_func_v(KQV);
4786
4839
ggml_set_name(KQV, "KQV");
4787
4840
4788
- // KQV_merged = KQV.permute(0, 2, 1, 3)
4789
4841
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
4842
+ offload_func_v(KQV_merged);
4790
4843
ggml_set_name(KQV_merged, "KQV_merged");
4791
4844
4792
- // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
4793
4845
cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
4846
+ offload_func_v(cur);
4794
4847
ggml_set_name(cur, "KQV_merged_contiguous");
4795
4848
}
4796
4849
4797
4850
// Projection
4798
4851
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
4852
+ offload_func(cur);
4799
4853
4800
4854
// Add the input
4801
4855
cur = ggml_add(ctx0, cur, inpL);
4856
+ offload_func(cur);
4802
4857
4803
4858
struct ggml_tensor * inpFF = cur;
4804
4859
@@ -4807,27 +4862,36 @@ static struct ggml_cgraph * llm_build_starcoder(
4807
4862
// Norm
4808
4863
{
4809
4864
cur = ggml_norm(ctx0, inpFF, norm_eps);
4865
+ offload_func_nr(cur);
4866
+
4810
4867
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
4868
+ offload_func_nr(cur);
4811
4869
}
4812
4870
4813
4871
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
4872
+ offload_func(cur);
4814
4873
4815
4874
// GELU activation
4816
4875
cur = ggml_gelu(ctx0, cur);
4876
+ offload_func(cur);
4817
4877
4818
4878
// Projection
4819
4879
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
4880
+ offload_func(cur);
4820
4881
}
4821
4882
4822
4883
inpL = ggml_add(ctx0, cur, inpFF);
4884
+
4823
4885
}
4824
4886
4825
4887
// Output Norm
4826
4888
{
4827
4889
cur = ggml_norm(ctx0, inpL, norm_eps);
4890
+ offload_func_nr(cur);
4891
+
4828
4892
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
4893
+ ggml_set_name(cur, "result_norm");
4829
4894
}
4830
- ggml_set_name(cur, "result_norm");
4831
4895
4832
4896
cur = ggml_mul_mat(ctx0, model.output, cur);
4833
4897
ggml_set_name(cur, "result_output");
0 commit comments