Skip to content

Commit 75c8db5

Browse files
ggerganovjordankanter
authored andcommitted
llama : fix embeddings (ggml-org#5796)
* llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list
1 parent 9f76b6a commit 75c8db5

File tree

7 files changed

+358
-133
lines changed

7 files changed

+358
-133
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
1010

1111
### Recent API changes
1212

13+
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
1314
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849
1415

1516
### Hot topics

common/common.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1292,7 +1292,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
12921292
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
12931293
cparams.seed = params.seed;
12941294
cparams.logits_all = params.logits_all;
1295-
cparams.embedding = params.embedding;
1295+
cparams.embeddings = params.embedding;
12961296
cparams.rope_scaling_type = params.rope_scaling_type;
12971297
cparams.rope_freq_base = params.rope_freq_base;
12981298
cparams.rope_freq_scale = params.rope_freq_scale;

examples/embedding/embedding.cpp

+21-7
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ static std::vector<std::string> split_lines(const std::string & s) {
1919

2020
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
2121
for (size_t i = 0; i < tokens.size(); i++) {
22-
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
22+
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
2323
}
2424
}
2525

26-
static void normalize(float * vec, float * out, int n) {
26+
static void normalize(const float * vec, float * out, int n) {
2727
float norm = 0;
2828
for (int i = 0; i < n; i++) {
2929
norm += vec[i] * vec[i];
@@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
4545
}
4646

4747
// normalize on copy
48-
for (int k = 0; k < n_seq; k++) {
49-
float * emb = llama_get_embeddings_ith(ctx, k);
50-
float * out = output + k * n_embd;
51-
normalize(emb, out, n_embd);
48+
for (int i = 0; i < batch.n_tokens; i++) {
49+
if (!batch.logits[i]) {
50+
continue;
51+
}
52+
53+
// try to get sequence embeddings - supported only when pooling_type is not NONE
54+
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
55+
if (embd == NULL) {
56+
embd = llama_get_embeddings_ith(ctx, i);
57+
if (embd == NULL) {
58+
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
59+
continue;
60+
}
61+
}
62+
63+
float * out = output + batch.seq_id[i][0] * n_embd;
64+
normalize(embd, out, n_embd);
5265
}
5366
}
5467

@@ -132,7 +145,7 @@ int main(int argc, char ** argv) {
132145

133146
// initialize batch
134147
const int n_prompts = prompts.size();
135-
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
148+
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
136149

137150
// allocate output
138151
const int n_embd = llama_n_embd(model);
@@ -145,6 +158,7 @@ int main(int argc, char ** argv) {
145158
for (int k = 0; k < n_prompts; k++) {
146159
// clamp to n_batch tokens
147160
auto & inp = inputs[k];
161+
148162
const uint64_t n_toks = inp.size();
149163

150164
// encode if at capacity

examples/server-embd.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import asyncio
2+
import requests
3+
import numpy as np
4+
5+
n = 8
6+
7+
result = []
8+
9+
async def requests_post_async(*args, **kwargs):
10+
return await asyncio.to_thread(requests.post, *args, **kwargs)
11+
12+
async def main():
13+
model_url = "http://127.0.0.1:6900"
14+
responses: list[requests.Response] = await asyncio.gather(*[requests_post_async(
15+
url= f"{model_url}/embedding",
16+
json= {"content": str(i)*1024}
17+
) for i in range(n)])
18+
19+
for response in responses:
20+
embedding = response.json()["embedding"]
21+
print(embedding[-8:])
22+
result.append(embedding)
23+
24+
asyncio.run(main())
25+
26+
# compute cosine similarity
27+
28+
for i in range(n-1):
29+
for j in range(i+1, n):
30+
embedding1 = np.array(result[i])
31+
embedding2 = np.array(result[j])
32+
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
33+
print(f"Similarity between {i} and {j}: {similarity:.2f}")
34+

examples/server/server.cpp

+42-11
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,7 @@ struct llama_server_context
12101210
queue_results.send(res);
12111211
}
12121212

1213-
void send_embedding(server_slot &slot)
1213+
void send_embedding(server_slot & slot, const llama_batch & batch)
12141214
{
12151215
task_result res;
12161216
res.id = slot.task_id;
@@ -1219,6 +1219,7 @@ struct llama_server_context
12191219
res.stop = true;
12201220

12211221
const int n_embd = llama_n_embd(model);
1222+
12221223
if (!params.embedding)
12231224
{
12241225
LOG_WARNING("embedding disabled", {{"params.embedding", params.embedding}});
@@ -1229,12 +1230,29 @@ struct llama_server_context
12291230
}
12301231
else
12311232
{
1232-
const float *data = llama_get_embeddings(ctx);
1233-
std::vector<float> embedding(data, data + n_embd);
1234-
res.result_json = json
1235-
{
1236-
{"embedding", embedding},
1237-
};
1233+
for (int i = 0; i < batch.n_tokens; ++i) {
1234+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1235+
continue;
1236+
}
1237+
1238+
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1239+
if (embd == NULL) {
1240+
embd = llama_get_embeddings_ith(ctx, i);
1241+
if (embd == NULL) {
1242+
LOG_ERROR("failed to get embeddings for token", {{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
1243+
res.result_json = json
1244+
{
1245+
{"embedding", std::vector<float>(n_embd, 0.0f)},
1246+
};
1247+
continue;
1248+
}
1249+
}
1250+
1251+
res.result_json = json
1252+
{
1253+
{"embedding", std::vector<float>(embd, embd + n_embd)},
1254+
};
1255+
}
12381256
}
12391257
queue_results.send(res);
12401258
}
@@ -1845,7 +1863,7 @@ struct llama_server_context
18451863
ga_i += ga_w/ga_n;
18461864
}
18471865
}
1848-
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
1866+
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
18491867
slot_npast++;
18501868
}
18511869

@@ -1881,7 +1899,7 @@ struct llama_server_context
18811899

18821900
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
18831901
{
1884-
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
1902+
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
18851903

18861904
for (auto & slot : slots)
18871905
{
@@ -1954,7 +1972,7 @@ struct llama_server_context
19541972
// prompt evaluated for embedding
19551973
if (slot.embedding)
19561974
{
1957-
send_embedding(slot);
1975+
send_embedding(slot, batch_view);
19581976
slot.release();
19591977
slot.i_batch = -1;
19601978
continue;
@@ -2036,6 +2054,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
20362054
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
20372055
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
20382056
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
2057+
printf(" --pooling {none,mean,cls}\n");
2058+
printf(" pooling type for embeddings, use model default if unspecified\n");
20392059
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
20402060
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
20412061
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -2276,6 +2296,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
22762296
}
22772297
params.yarn_beta_slow = std::stof(argv[i]);
22782298
}
2299+
else if (arg == "--pooling")
2300+
{
2301+
if (++i >= argc) {
2302+
invalid_param = true;
2303+
break;
2304+
}
2305+
std::string value(argv[i]);
2306+
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
2307+
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
2308+
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
2309+
else { invalid_param = true; break; }
2310+
}
22792311
else if (arg == "--threads" || arg == "-t")
22802312
{
22812313
if (++i >= argc)
@@ -2330,7 +2362,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
23302362
break;
23312363
}
23322364
params.n_batch = std::stoi(argv[i]);
2333-
params.n_batch = std::min(512, params.n_batch);
23342365
}
23352366
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
23362367
{

0 commit comments

Comments
 (0)