Skip to content

Commit 5f6e0c0

Browse files
authored
grammar : pre-computed pieces + reserve mem + less string copies (#4330)
* reserve space for codepoints * improvement for the appended 0 * used precomputed token text for grammar sample * reserve canidates_decoded * reserve canidates_grammar * remove candidates_decoded * Revert "remove candidates_decoded" This reverts commit 3773328. * changed decode_utf8 to take src by ref
1 parent 5aa365d commit 5f6e0c0

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

llama.cpp

+7-13
Original file line numberDiff line numberDiff line change
@@ -6851,14 +6851,13 @@ struct llama_grammar_candidate {
68516851
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
68526852
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
68536853
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
6854-
const char * src,
6855-
size_t n_src,
6854+
const std::string & src,
68566855
llama_partial_utf8 partial_start) {
68576856
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
6858-
const char * pos = src;
6857+
const char * pos = src.c_str();
68596858
std::vector<uint32_t> code_points;
68606859
// common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
6861-
code_points.reserve(n_src + 1);
6860+
code_points.reserve(src.size() + 1);
68626861
uint32_t value = partial_start.value;
68636862
int n_remain = partial_start.n_remain;
68646863

@@ -6909,13 +6908,6 @@ static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
69096908
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
69106909
}
69116910

6912-
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
6913-
std::string src,
6914-
llama_partial_utf8 partial_start
6915-
) {
6916-
return decode_utf8(src.c_str(), src.size(), partial_start);
6917-
}
6918-
69196911
// returns true iff pos points to the end of one of the definitions of a rule
69206912
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
69216913
switch (pos->type) {
@@ -7554,11 +7546,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
75547546
const llama_token eos = llama_token_eos(&ctx->model);
75557547

75567548
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
7549+
candidates_decoded.reserve(candidates->size);
75577550
std::vector<llama_grammar_candidate> candidates_grammar;
7551+
candidates_grammar.reserve(candidates->size);
75587552

75597553
for (size_t i = 0; i < candidates->size; ++i) {
75607554
const llama_token id = candidates->data[i].id;
7561-
const std::string piece = llama_token_to_piece(ctx, id);
7555+
const std::string & piece = ctx->model.vocab.id_to_token[id].text;
75627556
if (id == eos) {
75637557
if (!allow_eos) {
75647558
candidates->data[i].logit = -INFINITY;
@@ -7770,7 +7764,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
77707764
GGML_ASSERT(false);
77717765
}
77727766

7773-
const std::string piece = llama_token_to_piece(ctx, token);
7767+
const std::string & piece = ctx->model.vocab.id_to_token[token].text;
77747768

77757769
// Note terminating 0 in decoded string
77767770
const auto decoded = decode_utf8(piece, grammar->partial_utf8);

0 commit comments

Comments
 (0)