Skip to content

Commit cf68f4a

Browse files
ggerganovngxson
authored andcommitted
llama : add reranking support (ggml-org#9510)
* py : add XLMRobertaForSequenceClassification [no ci] * py : fix scalar-tensor conversion [no ci] * py : fix position embeddings chop [no ci] * llama : read new cls tensors [no ci] * llama : add classigication head (wip) [no ci] * llama : add "rank" pooling type ggml-ci * server : add rerank endpoint ggml-ci * llama : aboud ggml_repeat during classification * rerank : cleanup + comments * server : accept /rerank endpoint in addition to /v1/rerank [no ci] * embedding : parse special tokens * jina : support v1 reranker * vocab : minor style ggml-ci * server : initiate tests for later ggml-ci * server : add docs * llama : add comment [no ci] * llama : fix uninitialized tensors * ci : add rerank tests ggml-ci * add reranking test * change test data * Update examples/server/server.cpp Co-authored-by: Xuan Son Nguyen <[email protected]> * add `--reranking` argument * update server docs * llama : fix comment [no ci] ggml-ci --------- Co-authored-by: Xuan Son Nguyen <[email protected]> Co-authored-by: Xuan Son Nguyen <[email protected]>
1 parent acc21d2 commit cf68f4a

18 files changed

+602
-56
lines changed

ci/run.sh

+76-9
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,81 @@ function gg_run_embd_bge_small {
712712
set +e
713713
}
714714

715+
function gg_sum_embd_bge_small {
716+
gg_printf '### %s\n\n' "${ci}"
717+
718+
gg_printf 'BGE Small (BERT):\n'
719+
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
720+
gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
721+
gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
722+
}
723+
724+
# rerank_tiny
725+
726+
function gg_run_rerank_tiny {
727+
cd ${SRC}
728+
729+
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
730+
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json
731+
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json
732+
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json
733+
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin
734+
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json
735+
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt
736+
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json
737+
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
738+
739+
gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json
740+
741+
path_models="../models-mnt/rerank-tiny"
742+
743+
rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release
744+
745+
set -e
746+
747+
(time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
748+
(time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log
749+
750+
python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf
751+
752+
model_f16="${path_models}/ggml-model-f16.gguf"
753+
754+
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s><s>hi\nwhat is panda?</s><s>it's a bear\nwhat is panda?</s><s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
755+
756+
# sample output
757+
# rerank score 0: 0.029
758+
# rerank score 1: 0.029
759+
# rerank score 2: 0.135
760+
761+
# check that the score is in the range [$3, $4]
762+
function check_score {
763+
qnt="$1"
764+
score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1)
765+
766+
if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then
767+
printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4"
768+
return 20
769+
fi
770+
771+
printf ' - %s @ %s OK\n' "$qnt" "$score"
772+
return 0
773+
}
774+
775+
check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
776+
check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log
777+
check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.15" | tee -a $OUT/${ci}-rk-f16.log
778+
779+
set +e
780+
}
781+
782+
function gg_sum_rerank_tiny {
783+
gg_printf '### %s\n\n' "${ci}"
784+
785+
gg_printf 'Rerank Tiny (Jina):\n'
786+
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
787+
gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)"
788+
}
789+
715790
function gg_check_build_requirements {
716791
if ! command -v cmake &> /dev/null; then
717792
gg_printf 'cmake not found, please install'
@@ -726,15 +801,6 @@ function gg_check_build_requirements {
726801
fi
727802
}
728803

729-
function gg_sum_embd_bge_small {
730-
gg_printf '### %s\n\n' "${ci}"
731-
732-
gg_printf 'BGE Small (BERT):\n'
733-
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
734-
gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)"
735-
gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)"
736-
}
737-
738804
## main
739805

740806
export LLAMA_LOG_PREFIX=1
@@ -762,6 +828,7 @@ test $ret -eq 0 && gg_run ctest_release
762828

763829
if [ -z ${GG_BUILD_LOW_PERF} ]; then
764830
test $ret -eq 0 && gg_run embd_bge_small
831+
test $ret -eq 0 && gg_run rerank_tiny
765832

766833
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
767834
test $ret -eq 0 && gg_run test_scripts_debug

common/arg.cpp

+15-3
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
284284
params.kv_overrides.back().key[0] = 0;
285285
}
286286

287+
if (params.reranking && params.embedding) {
288+
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
289+
}
290+
287291
return true;
288292
}
289293

@@ -391,7 +395,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
391395
[](gpt_params & params) {
392396
params.verbose_prompt = true;
393397
}
394-
).set_examples({LLAMA_EXAMPLE_MAIN}));
398+
));
395399
add_opt(llama_arg(
396400
{"--no-display-prompt"},
397401
format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"),
@@ -1093,13 +1097,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
10931097
}
10941098
).set_sparam());
10951099
add_opt(llama_arg(
1096-
{"--pooling"}, "{none,mean,cls,last}",
1100+
{"--pooling"}, "{none,mean,cls,last,rank}",
10971101
"pooling type for embeddings, use model default if unspecified",
10981102
[](gpt_params & params, const std::string & value) {
10991103
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
11001104
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
1101-
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
1105+
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
11021106
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
1107+
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
11031108
else { throw std::invalid_argument("invalid value"); }
11041109
}
11051110
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
@@ -1749,6 +1754,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
17491754
params.embedding = true;
17501755
}
17511756
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
1757+
add_opt(llama_arg(
1758+
{"--reranking", "--rerank"},
1759+
format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
1760+
[](gpt_params & params) {
1761+
params.reranking = true;
1762+
}
1763+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
17521764
add_opt(llama_arg(
17531765
{"--api-key"}, "KEY",
17541766
"API key to use for authentication (default: none)",

common/common.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,11 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
10231023
cparams.flash_attn = params.flash_attn;
10241024
cparams.no_perf = params.no_perf;
10251025

1026+
if (params.reranking) {
1027+
cparams.embeddings = true;
1028+
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
1029+
}
1030+
10261031
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
10271032
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
10281033

common/common.h

+1
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ struct gpt_params {
271271
int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
272272
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
273273
std::string embd_sep = "\n"; // separator of embendings
274+
bool reranking = false; // enable reranking support on server
274275

275276
// server params
276277
int32_t port = 8080; // server listens on this network port

convert_hf_to_gguf.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,13 @@ def prepare_tensors(self):
291291
bid = int(part)
292292
break
293293

294-
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
295-
data: np.ndarray # type hint
294+
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
295+
data = data_torch.squeeze().numpy()
296+
297+
# if data ends up empty, it means data_torch was a scalar tensor -> restore
298+
if len(data.shape) == 0:
299+
data = data_torch.numpy()
300+
296301
n_dims = len(data.shape)
297302
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
298303

@@ -592,6 +597,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
592597
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
593598
# ref: https://huggingface.co/databricks/dbrx-base
594599
res = "dbrx"
600+
if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448":
601+
# ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
602+
res = "jina-v1-en"
595603
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
596604
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
597605
res = "jina-v2-en"
@@ -2601,7 +2609,7 @@ def set_gguf_parameters(self):
26012609
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
26022610

26032611

2604-
@Model.register("XLMRobertaModel")
2612+
@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
26052613
class XLMRobertaModel(BertModel):
26062614
model_arch = gguf.MODEL_ARCH.BERT
26072615

@@ -2699,6 +2707,11 @@ def set_vocab(self):
26992707
self.gguf_writer.add_add_eos_token(True)
27002708

27012709
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2710+
# if name starts with "roberta.", remove the prefix
2711+
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
2712+
if name.startswith("roberta."):
2713+
name = name[8:]
2714+
27022715
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
27032716
if name == "embeddings.position_embeddings.weight":
27042717
if self._position_offset is not None:
@@ -3110,6 +3123,14 @@ def set_vocab(self):
31103123
self.gguf_writer.add_add_bos_token(True)
31113124
self.gguf_writer.add_add_eos_token(True)
31123125

3126+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3127+
# if name starts with "bert.", remove the prefix
3128+
# e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
3129+
if name.startswith("bert."):
3130+
name = name[5:]
3131+
3132+
return super().modify_tensors(data_torch, name, bid)
3133+
31133134

31143135
@Model.register("OpenELMForCausalLM")
31153136
class OpenELMModel(Model):

convert_hf_to_gguf_update.py

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class TOKENIZER_TYPE(IntEnum):
8181
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
8282
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
8383
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
84+
{"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", },
8485
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
8586
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
8687
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },

examples/embedding/embedding.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ int main(int argc, char ** argv) {
135135
// tokenize the prompts and trim
136136
std::vector<std::vector<int32_t>> inputs;
137137
for (const auto & prompt : prompts) {
138-
auto inp = ::llama_tokenize(ctx, prompt, true, false);
138+
auto inp = ::llama_tokenize(ctx, prompt, true, true);
139139
if (inp.size() > n_batch) {
140140
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
141141
__func__, (long long int) inp.size(), (long long int) n_batch);
@@ -234,6 +234,11 @@ int main(int argc, char ** argv) {
234234
}
235235
LOG("\n");
236236
}
237+
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
238+
for (int j = 0; j < n_embd_count; j++) {
239+
// NOTE: if you change this log - update the tests in ci/run.sh
240+
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
241+
}
237242
} else {
238243
// print the first part of the embeddings or for a single prompt, the full embedding
239244
for (int j = 0; j < n_prompts; j++) {

examples/server/README.md

+38-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
77
**Features:**
88
* LLM inference of F16 and quantized models on GPU and CPU
99
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
10+
* Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510)
1011
* Parallel decoding with multi-user support
1112
* Continuous batching
1213
* Multimodal (wip)
@@ -23,6 +24,7 @@ The project is under active development, and we are [looking for feedback and co
2324
| -------- | ----------- |
2425
| `-h, --help, --usage` | print usage and exit |
2526
| `--version` | show version and build info |
27+
| `--verbose-prompt` | print a verbose prompt before generation (default: false) |
2628
| `-t, --threads N` | number of threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
2729
| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
2830
| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
@@ -130,14 +132,15 @@ The project is under active development, and we are [looking for feedback and co
130132
| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)<br/>(env: LLAMA_ARG_NO_CONTEXT_SHIFT) |
131133
| `-sp, --special` | special tokens output enabled (default: false) |
132134
| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
133-
| `--pooling {none,mean,cls,last}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
135+
| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified<br/>(env: LLAMA_ARG_POOLING) |
134136
| `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)<br/>(env: LLAMA_ARG_CONT_BATCHING) |
135137
| `-nocb, --no-cont-batching` | disable continuous batching<br/>(env: LLAMA_ARG_NO_CONT_BATCHING) |
136138
| `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
137139
| `--host HOST` | ip address to listen (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
138140
| `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
139141
| `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
140142
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
143+
| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
141144
| `--api-key KEY` | API key to use for authentication (default: none)<br/>(env: LLAMA_API_KEY) |
142145
| `--api-key-file FNAME` | path to file containing API keys (default: none) |
143146
| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
@@ -152,6 +155,7 @@ The project is under active development, and we are [looking for feedback and co
152155
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
153156
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
154157

158+
155159
Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
156160

157161
Example usage of docker compose with environment variables:
@@ -478,6 +482,39 @@ The same as [the embedding example](../embedding) does.
478482

479483
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
480484

485+
### POST `/reranking`: Rerank documents according to a given query
486+
487+
Similar to https://jina.ai/reranker/ but might change in the future.
488+
Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options.
489+
490+
*Options:*
491+
492+
`query`: The query against which the documents will be ranked.
493+
494+
`documents`: An array strings representing the documents to be ranked.
495+
496+
*Aliases:*
497+
- `/rerank`
498+
- `/v1/rerank`
499+
- `/v1/reranking`
500+
501+
*Examples:*
502+
503+
```shell
504+
curl http://127.0.0.1:8012/v1/rerank \
505+
-H "Content-Type: application/json" \
506+
-d '{
507+
"model": "some-model",
508+
"query": "What is panda?",
509+
"top_n": 3,
510+
"documents": [
511+
"hi",
512+
"it is a bear",
513+
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."
514+
]
515+
}' | jq
516+
```
517+
481518
### POST `/infill`: For code infilling.
482519

483520
Takes a prefix and a suffix and returns the predicted completion as stream.

0 commit comments

Comments
 (0)