@@ -682,6 +682,114 @@ static bool llama_grammar_match_partial_char(
682
682
return !is_positive_char;
683
683
}
684
684
685
+ // transforms a grammar pushdown stack into N possible stacks, all ending
686
+ // at a character range (terminal element)
687
+ // additionally memorizes the stack to its possible stacks by mapping
688
+ // < llama_grammar_stack, llama_grammar_stacks >
689
+
690
+ struct VectorPointerHash {
691
+ size_t operator ()(const llama_grammar_stack & v) const {
692
+ size_t seed = v.size ();
693
+ for (const auto * ptr : v) {
694
+ seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6 ) + (seed >> 2 );
695
+ }
696
+ return seed;
697
+ }
698
+ };
699
+
700
+ static std::unordered_map<
701
+ llama_grammar_stack,
702
+ llama_grammar_stacks,
703
+ VectorPointerHash>
704
+ llama_grammar_stacks_cache = {};
705
+
706
+ static void llama_grammar_advance_stack_memo (
707
+ const llama_grammar_rules & rules,
708
+ const llama_grammar_stack & stack,
709
+ llama_grammar_stacks & new_stacks);
710
+
711
+ static void llama_grammar_advance_stack_memo_impl (
712
+ const llama_grammar_rules & rules,
713
+ const llama_grammar_stack & stack,
714
+ llama_grammar_stacks & new_stacks) {
715
+ if (stack.empty ()) {
716
+ if (std::find (new_stacks.begin (), new_stacks.end (), stack) == new_stacks.end ()) {
717
+ new_stacks.emplace_back (stack);
718
+ }
719
+ return ;
720
+ }
721
+
722
+ const llama_grammar_element * pos = stack.back ();
723
+
724
+ switch (pos->type ) {
725
+ case LLAMA_GRETYPE_RULE_REF: {
726
+ const size_t rule_id = static_cast <size_t >(pos->value );
727
+ const llama_grammar_element * subpos = rules[rule_id].data ();
728
+ do {
729
+ // init new stack without the top (pos)
730
+ llama_grammar_stack new_stack (stack.begin (), stack.end () - 1 );
731
+ if (!llama_grammar_is_end_of_sequence (pos + 1 )) {
732
+ // if this rule ref is followed by another element, add that to stack
733
+ new_stack.push_back (pos + 1 );
734
+ }
735
+ if (!llama_grammar_is_end_of_sequence (subpos)) {
736
+ // if alternate is nonempty, add to stack
737
+ new_stack.push_back (subpos);
738
+ }
739
+ llama_grammar_advance_stack_memo (rules, new_stack, new_stacks);
740
+ while (!llama_grammar_is_end_of_sequence (subpos)) {
741
+ // scan to end of alternate def
742
+ subpos++;
743
+ }
744
+ if (subpos->type == LLAMA_GRETYPE_ALT) {
745
+ // there's another alternate def of this rule to process
746
+ subpos++;
747
+ } else {
748
+ break ;
749
+ }
750
+ } while (true );
751
+ break ;
752
+ }
753
+ case LLAMA_GRETYPE_CHAR:
754
+ case LLAMA_GRETYPE_CHAR_NOT:
755
+ case LLAMA_GRETYPE_CHAR_ANY:
756
+ if (std::find (new_stacks.begin (), new_stacks.end (), stack) == new_stacks.end ()) {
757
+ // only add the stack if it's not a duplicate of one we already have
758
+ new_stacks.emplace_back (stack);
759
+ }
760
+ break ;
761
+ default :
762
+ // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
763
+ // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
764
+ // those
765
+ GGML_ABORT (" fatal error" );
766
+ }
767
+ }
768
+
769
+ static void llama_grammar_advance_stack_memo (
770
+ const llama_grammar_rules & rules,
771
+ const llama_grammar_stack & stack,
772
+ llama_grammar_stacks & new_stacks) {
773
+
774
+ llama_grammar_stacks advanced_stacks;
775
+ // Look if stack is already in memory
776
+ auto it = llama_grammar_stacks_cache.find (stack);
777
+ if (it != llama_grammar_stacks_cache.end ()) {
778
+ advanced_stacks = it->second ;
779
+ } else {
780
+ // Advance stacks with memorization
781
+ llama_grammar_advance_stack_memo_impl (rules, stack, advanced_stacks);
782
+ llama_grammar_stacks_cache.insert (make_pair (stack, advanced_stacks));
783
+ }
784
+ // Add the advanced stacks to new_stacks avoiding duplicates
785
+ for (const auto & new_stack : advanced_stacks) {
786
+ if (std::find (new_stacks.begin (), new_stacks.end (), new_stack) == new_stacks.end ()) {
787
+ new_stacks.emplace_back (new_stack);
788
+ }
789
+ }
790
+
791
+ }
792
+
685
793
// transforms a grammar pushdown stack into N possible stacks, all ending
686
794
// at a character range (terminal element)
687
795
static void llama_grammar_advance_stack (
@@ -844,7 +952,7 @@ void llama_grammar_accept(
844
952
if (!llama_grammar_is_end_of_sequence (pos)) {
845
953
new_stack.push_back (pos);
846
954
}
847
- llama_grammar_advance_stack (rules, new_stack, stacks_new);
955
+ llama_grammar_advance_stack_memo (rules, new_stack, stacks_new);
848
956
}
849
957
}
850
958
}
@@ -911,6 +1019,8 @@ struct llama_grammar * llama_grammar_init_impl(
911
1019
const llama_grammar_element ** rules,
912
1020
size_t n_rules,
913
1021
size_t start_rule_index) {
1022
+ // Clear stacks cache
1023
+ llama_grammar_stacks_cache.clear ();
914
1024
const llama_grammar_element * pos;
915
1025
916
1026
// copy rule definitions into vectors
@@ -945,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl(
945
1055
// if alternate is nonempty, add to stack
946
1056
stack.push_back (pos);
947
1057
}
948
- llama_grammar_advance_stack (vec_rules, stack, stacks);
1058
+ llama_grammar_advance_stack_memo (vec_rules, stack, stacks);
949
1059
while (!llama_grammar_is_end_of_sequence (pos)) {
950
1060
// scan to end of alternate def
951
1061
pos++;
@@ -965,6 +1075,8 @@ struct llama_grammar * llama_grammar_init_impl(
965
1075
}
966
1076
967
1077
struct llama_grammar * llama_grammar_init_impl (const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
1078
+ // Clear stacks cache
1079
+ llama_grammar_stacks_cache.clear ();
968
1080
llama_grammar_parser parser;
969
1081
970
1082
// if there is a grammar, parse it
@@ -1023,7 +1135,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
1023
1135
// if alternate is nonempty, add to stack
1024
1136
stack.push_back (pos);
1025
1137
}
1026
- llama_grammar_advance_stack (vec_rules, stack, stacks);
1138
+ llama_grammar_advance_stack_memo (vec_rules, stack, stacks);
1027
1139
while (!llama_grammar_is_end_of_sequence (pos)) {
1028
1140
// scan to end of alternate def
1029
1141
pos++;
0 commit comments