Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 5b6468f

Browse files
committedSep 23, 2024··
llama : aboud ggml_repeat during classification
1 parent 5f95dcc commit 5b6468f

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed
 

‎src/llama.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -10096,9 +10096,6 @@ struct llm_build_context {
1009610096
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
1009710097
cur = ggml_tanh(ctx0, cur);
1009810098
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
10099-
10100-
// broadcast across the embedding size to make it compatible with the llama_get_embeddings API
10101-
cur = ggml_repeat(ctx0, cur, inp);
1010210099
} break;
1010310100
default:
1010410101
{
@@ -16831,7 +16828,6 @@ static int llama_decode_internal(
1683116828
case LLAMA_POOLING_TYPE_MEAN:
1683216829
case LLAMA_POOLING_TYPE_CLS:
1683316830
case LLAMA_POOLING_TYPE_LAST:
16834-
case LLAMA_POOLING_TYPE_RANK:
1683516831
{
1683616832
// extract sequence embeddings (cleared before processing each batch)
1683716833
auto & embd_seq_out = lctx.embd_seq;
@@ -16845,6 +16841,20 @@ static int llama_decode_internal(
1684516841
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1684616842
}
1684716843
} break;
16844+
case LLAMA_POOLING_TYPE_RANK:
16845+
{
16846+
// extract the rank score - a single float per sequence
16847+
auto & embd_seq_out = lctx.embd_seq;
16848+
16849+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
16850+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
16851+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
16852+
continue;
16853+
}
16854+
embd_seq_out[seq_id].resize(1);
16855+
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
16856+
}
16857+
} break;
1684816858
case LLAMA_POOLING_TYPE_UNSPECIFIED:
1684916859
{
1685016860
GGML_ABORT("unknown pooling type");

0 commit comments

Comments
 (0)
Please sign in to comment.