Skip to content

Commit cb1632b

Browse files
committedOct 11, 2024
llama : adds llama-grammar memorization stacks (ggml-org#4218)
1 parent 7eee341 commit cb1632b

File tree

1 file changed

+115
-3
lines changed

1 file changed

+115
-3
lines changed
 

‎src/llama-grammar.cpp

+115-3
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,114 @@ static bool llama_grammar_match_partial_char(
682682
return !is_positive_char;
683683
}
684684

685+
// transforms a grammar pushdown stack into N possible stacks, all ending
686+
// at a character range (terminal element)
687+
// additionally memorizes the stack to its possible stacks by mapping
688+
// < llama_grammar_stack, llama_grammar_stacks >
689+
690+
struct VectorPointerHash {
691+
size_t operator()(const llama_grammar_stack & v) const {
692+
size_t seed = v.size();
693+
for (const auto* ptr : v) {
694+
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
695+
}
696+
return seed;
697+
}
698+
};
699+
700+
static std::unordered_map<
701+
llama_grammar_stack,
702+
llama_grammar_stacks,
703+
VectorPointerHash>
704+
llama_grammar_stacks_cache = {};
705+
706+
static void llama_grammar_advance_stack_memo(
707+
const llama_grammar_rules & rules,
708+
const llama_grammar_stack & stack,
709+
llama_grammar_stacks & new_stacks);
710+
711+
static void llama_grammar_advance_stack_memo_impl(
712+
const llama_grammar_rules & rules,
713+
const llama_grammar_stack & stack,
714+
llama_grammar_stacks & new_stacks) {
715+
if (stack.empty()) {
716+
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
717+
new_stacks.emplace_back(stack);
718+
}
719+
return;
720+
}
721+
722+
const llama_grammar_element * pos = stack.back();
723+
724+
switch (pos->type) {
725+
case LLAMA_GRETYPE_RULE_REF: {
726+
const size_t rule_id = static_cast<size_t>(pos->value);
727+
const llama_grammar_element * subpos = rules[rule_id].data();
728+
do {
729+
// init new stack without the top (pos)
730+
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
731+
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
732+
// if this rule ref is followed by another element, add that to stack
733+
new_stack.push_back(pos + 1);
734+
}
735+
if (!llama_grammar_is_end_of_sequence(subpos)) {
736+
// if alternate is nonempty, add to stack
737+
new_stack.push_back(subpos);
738+
}
739+
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks);
740+
while (!llama_grammar_is_end_of_sequence(subpos)) {
741+
// scan to end of alternate def
742+
subpos++;
743+
}
744+
if (subpos->type == LLAMA_GRETYPE_ALT) {
745+
// there's another alternate def of this rule to process
746+
subpos++;
747+
} else {
748+
break;
749+
}
750+
} while (true);
751+
break;
752+
}
753+
case LLAMA_GRETYPE_CHAR:
754+
case LLAMA_GRETYPE_CHAR_NOT:
755+
case LLAMA_GRETYPE_CHAR_ANY:
756+
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
757+
// only add the stack if it's not a duplicate of one we already have
758+
new_stacks.emplace_back(stack);
759+
}
760+
break;
761+
default:
762+
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
763+
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
764+
// those
765+
GGML_ABORT("fatal error");
766+
}
767+
}
768+
769+
static void llama_grammar_advance_stack_memo(
770+
const llama_grammar_rules & rules,
771+
const llama_grammar_stack & stack,
772+
llama_grammar_stacks & new_stacks) {
773+
774+
llama_grammar_stacks advanced_stacks;
775+
// Look if stack is already in memory
776+
auto it = llama_grammar_stacks_cache.find(stack);
777+
if (it != llama_grammar_stacks_cache.end()) {
778+
advanced_stacks = it->second;
779+
} else {
780+
// Advance stacks with memorization
781+
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks);
782+
llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks));
783+
}
784+
// Add the advanced stacks to new_stacks avoiding duplicates
785+
for (const auto & new_stack : advanced_stacks) {
786+
if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) {
787+
new_stacks.emplace_back(new_stack);
788+
}
789+
}
790+
791+
}
792+
685793
// transforms a grammar pushdown stack into N possible stacks, all ending
686794
// at a character range (terminal element)
687795
static void llama_grammar_advance_stack(
@@ -844,7 +952,7 @@ void llama_grammar_accept(
844952
if (!llama_grammar_is_end_of_sequence(pos)) {
845953
new_stack.push_back(pos);
846954
}
847-
llama_grammar_advance_stack(rules, new_stack, stacks_new);
955+
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new);
848956
}
849957
}
850958
}
@@ -911,6 +1019,8 @@ struct llama_grammar * llama_grammar_init_impl(
9111019
const llama_grammar_element ** rules,
9121020
size_t n_rules,
9131021
size_t start_rule_index) {
1022+
// Clear stacks cache
1023+
llama_grammar_stacks_cache.clear();
9141024
const llama_grammar_element * pos;
9151025

9161026
// copy rule definitions into vectors
@@ -945,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl(
9451055
// if alternate is nonempty, add to stack
9461056
stack.push_back(pos);
9471057
}
948-
llama_grammar_advance_stack(vec_rules, stack, stacks);
1058+
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
9491059
while (!llama_grammar_is_end_of_sequence(pos)) {
9501060
// scan to end of alternate def
9511061
pos++;
@@ -965,6 +1075,8 @@ struct llama_grammar * llama_grammar_init_impl(
9651075
}
9661076

9671077
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
1078+
// Clear stacks cache
1079+
llama_grammar_stacks_cache.clear();
9681080
llama_grammar_parser parser;
9691081

9701082
// if there is a grammar, parse it
@@ -1023,7 +1135,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
10231135
// if alternate is nonempty, add to stack
10241136
stack.push_back(pos);
10251137
}
1026-
llama_grammar_advance_stack(vec_rules, stack, stacks);
1138+
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
10271139
while (!llama_grammar_is_end_of_sequence(pos)) {
10281140
// scan to end of alternate def
10291141
pos++;

0 commit comments

Comments
 (0)