From 45d2252153cd72fbaef273705d8b6fc848845f2a Mon Sep 17 00:00:00 2001 From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:35:58 +0200 Subject: [PATCH 1/8] Backported . (any chat) from llama.cpp --- llama_cpp/llama_grammar.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 0ac7354bb..01798e74f 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -432,10 +432,12 @@ def end(self) -> "std.map[T, U].iterator[T, U]": # // be an inclusive range ([a-z]) # LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - # // modifies a preceding LLAMA_GRETYPE_CHAR or # // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) # LLAMA_GRETYPE_CHAR_ALT = 6, + +# // any character (.) +# LLAMA_GRETYPE_CHAR_ANY = 7, # }; class llama_gretype(Enum): """grammar element type""" @@ -447,6 +449,7 @@ class llama_gretype(Enum): LLAMA_GRETYPE_CHAR_NOT = 4 # inverse char(s) ([^a], [^a-b] [^abc]) LLAMA_GRETYPE_CHAR_RNG_UPPER = 5 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z]) LLAMA_GRETYPE_CHAR_ALT = 6 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ANY = 7 # any character (.) # struct parse_state { @@ -830,6 +833,10 @@ def parse_sequence( # if (last_sym_start == out_elements.size()) { # throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); # } + elif pos[0] == '.': + last_sym_start = out_elements.size() + out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR_ANY, 0)) + pos = parse_space(pos + 1, is_nested) elif pos[0] in ("*", "+", "?"): # repetition operator if last_sym_start == out_elements.size(): raise RuntimeError("expecting preceding item to */+/? at " + str(pos)) @@ -1039,6 +1046,7 @@ def is_char_element(elem: LlamaGrammarElement) -> bool: llama_gretype.LLAMA_GRETYPE_CHAR_NOT, llama_gretype.LLAMA_GRETYPE_CHAR_ALT, llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, + llama_gretype.LLAMA_GRETYPE_CHAR_ANY, ) @@ -1135,6 +1143,8 @@ def print_rule( + str(i) ) print_grammar_char(file, elem.value) + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ANY: + print(".", file=file, end="") # if (is_char_element(elem)) { # switch (rule[i + 1].type) { # case LLAMA_GRETYPE_CHAR_ALT: From 90c2bc4922265c4cac9266b7550ed9e41f1fd477 Mon Sep 17 00:00:00 2001 From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Date: Mon, 29 Jul 2024 15:25:09 +0200 Subject: [PATCH 2/8] unfinished {count,optionalmax) --- llama_cpp/llama_grammar.py | 295 ++++++++++++++++++++++++++++--------- 1 file changed, 224 insertions(+), 71 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 01798e74f..39ba8e9dc 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -563,11 +563,34 @@ def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: return value, pos +"""#static bool is_digit_char(char c) { +# return '0' <= c && c <= '9'; +#} +def is_digit_char(c: str) -> bool: + return "0" <= c <= "9" + + # bool is_word_char(char c) { # return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); # } def is_word_char(c: str) -> bool: - return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") + return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") or is_digit_char(c) +""" + +##optimized version +# Original is_digit_char time: 2.868295 seconds +# Optimized is_digit_char time: 1.993195 seconds +# Original is_word_char time: 3.856689 seconds +# Optimized is_word_char time: 2.052832 seconds + +digit_chars = set("0123456789") +word_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-0123456789") + +def is_digit_char(c: str) -> bool: + return c in digit_chars + +def is_word_char(c: str) -> bool: + return c in word_chars # std::pair<uint32_t, const char *> parse_hex(const char * src, int size) { @@ -679,6 +702,25 @@ def parse_name(src: const_char_p) -> const_char_p: return pos +#static const char * parse_int(const char * src) { +# const char * pos = src; +# while (is_digit_char(*pos)) { +# pos++; +# } +# if (pos == src) { +# throw std::runtime_error(std::string("expecting integer at ") + src); +# } +# return pos; +#} +def parse_int(src: const_char_p) -> const_char_p: + pos = const_char_p(src) # type: const_char_p + while is_digit_char(pos[0]): + pos += 1 + if pos == src: + raise RuntimeError("expecting integer at " + str(src)) + return pos + + # const char * parse_space(const char * src, bool newline_ok) { # const char * pos = src; # while (*pos == ' ' || *pos == '\t' || *pos == '#' || @@ -721,8 +763,104 @@ def parse_sequence( # const char * pos = src; last_sym_start = out_elements.size() # type: int pos = const_char_p(src) # type: const_char_p + + + # auto handle_repetitions = [&](int min_times, int max_times) { + + # if (last_sym_start == out_elements.size()) { + # throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + # } + + # // apply transformation to previous symbol (last_sym_start to end) according to + # // the following rewrite rules: + # // S{m,n} --> S S S (m times) S'(n-m) + # // S'(x) ::= S S'(x-1) | + # // (... n-m definitions of these S' rules ...) + # // S'(1) ::= S | + # // S{m,} --> S S S (m times) S' + # // S' ::= S S' | + # // S* --> S{0,} + # // --> S' ::= S S' | + # // S+ --> S{1,} + # // --> S S' + # // S' ::= S S' | + # // S? --> S{0,1} + # // --> S' + # // S' ::= S | + + # std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); + # if (min_times == 0) { + # out_elements.resize(last_sym_start); + # } else { + # // Repeat the previous elements (min_times - 1) times + # for (int i = 1; i < min_times; i++) { + # out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); + # } + # } + + # uint32_t last_rec_rule_id = 0; + # auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + # std::vector<llama_grammar_element> rec_rule(previous_elements); + # for (int i = 0; i < n_opt; i++) { + # rec_rule.resize(previous_elements.size()); + # uint32_t rec_rule_id = generate_symbol_id(state, rule_name); + # if (i > 0 || max_times < 0) { + # rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + # } + # rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + # rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + # add_rule(state, rec_rule_id, rec_rule); + # last_rec_rule_id = rec_rule_id; + # } + # if (n_opt > 0) { + # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + # } + # }; + + def handle_repetitions(min_times: int, max_times: int) -> None: + if last_sym_start == out_elements.size(): + raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos)) + + + previous_elements = out_elements[last_sym_start:] + print("type-1 ", type(out_elements)) + if min_times == 0: + out_elements.resize(last_sym_start) + else: + # Repeat the previous elements (min_times - 1) times + for i in range(1, min_times): + out_elements.extend(previous_elements) + + last_rec_rule_id = 0 # type: int + n_opt = 1 if max_times < 0 else max_times - min_times # type: int + rec_rule = previous_elements # type: List[LlamaGrammarElement] + print("type1", type(rec_rule)) + print('ahhhhhhhhh') + for i in range(n_opt): + rec_rule = previous_elements + rec_rule.resize(len(previous_elements)) + rec_rule_id = generate_symbol_id(state, rule_name) # type: int + if i > 0 or max_times < 0: + rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id)) + rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) + rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) + add_rule(state, rec_rule_id, rec_rule) + + last_rec_rule_id = rec_rule_id + if n_opt > 0: + out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id)) + + + + + + + + # while (*pos) { while pos[0]: + # if (*pos == '"') { // literal string # pos++; # last_sym_start = out_elements.size(); @@ -767,12 +905,12 @@ def parse_sequence( while pos[0] != "]": char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] - type = ( + _type = ( llama_gretype.LLAMA_GRETYPE_CHAR_ALT if last_sym_start < out_elements.size() else start_type ) # type: llama_gretype - out_elements.push_back(LlamaGrammarElement(type, char_pair[0])) + out_elements.push_back(LlamaGrammarElement(_type, char_pair[0])) # if (pos[0] == '-' && pos[1] != ']') { # auto endchar_pair = parse_char(pos + 1); # pos = endchar_pair.second; @@ -829,83 +967,98 @@ def parse_sequence( if pos[0] != ")": raise RuntimeError("expecting ')' at " + str(pos)) pos = parse_space(pos + 1, is_nested) - # } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator - # if (last_sym_start == out_elements.size()) { - # throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); - # } elif pos[0] == '.': last_sym_start = out_elements.size() out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR_ANY, 0)) pos = parse_space(pos + 1, is_nested) - elif pos[0] in ("*", "+", "?"): # repetition operator - if last_sym_start == out_elements.size(): - raise RuntimeError("expecting preceding item to */+/? at " + str(pos)) - # // apply transformation to previous symbol (last_sym_start to end) according to - # // rewrite rules: - # // S* --> S' ::= S S' | - # // S+ --> S' ::= S S' | S - # // S? --> S' ::= S | - # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - # std::vector<llama_grammar_element> sub_rule; - # // add preceding symbol to generated rule - # sub_rule.insert( - # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - sub_rule_id = generate_symbol_id(state, rule_name) # type: int - sub_rule = std.vector[ - LlamaGrammarElement - ]() # type: std.vector[LlamaGrammarElement] - sub_rule.insert( - sub_rule.end(), - out_elements.begin() + last_sym_start, - out_elements.end(), - ) - # if (*pos == '*' || *pos == '+') { - # // cause generated rule to recurse - # sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - # } - # // mark start of alternate def - # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - if pos[0] in ("*", "+"): - sub_rule.push_back( - LlamaGrammarElement( - llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id - ) - ) - sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) - # if (*pos == '+') { - # // add preceding symbol as alternate only for '+' (otherwise empty) - # sub_rule.insert( - # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - # } - # sub_rule.push_back({LLAMA_GRETYPE_END, 0}); - # add_rule(state, sub_rule_id, sub_rule); - # // in original rule, replace previous symbol with reference to generated rule - # out_elements.resize(last_sym_start); - # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - # pos = parse_space(pos + 1, is_nested); - if pos[0] == "+": - # add preceding symbol as alternate only for '+' (otherwise empty) - sub_rule.insert( - sub_rule.end(), - out_elements.begin() + last_sym_start, - out_elements.end(), - ) - sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) - add_rule(state, sub_rule_id, sub_rule) - # in original rule, replace previous symbol with reference to generated rule - out_elements.resize(last_sym_start) - out_elements.push_back( - LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) - ) - pos = parse_space(pos + 1, is_nested) + + # } else if (*pos == '*') { + # pos = parse_space(pos + 1, is_nested); + # handle_repetitions(0, -1); + # } else if (*pos == '+') { + # pos = parse_space(pos + 1, is_nested); + # handle_repetitions(1, -1); + # } else if (*pos == '?') { + # pos = parse_space(pos + 1, is_nested); + # handle_repetitions(0, 1); + # } else if (*pos == '{') { + # pos = parse_space(pos + 1, is_nested); + + # if (!is_digit_char(*pos)) { + # throw std::runtime_error(std::string("expecting an int at ") + pos); + # } + # const char * int_end = parse_int(pos); + # int min_times = std::stoul(std::string(pos, int_end - pos)); + # pos = parse_space(int_end, is_nested); + + # int max_times = -1; + + # if (*pos == '}') { + # max_times = min_times; + # pos = parse_space(pos + 1, is_nested); + # } else if (*pos == ',') { + # pos = parse_space(pos + 1, is_nested); + + # if (is_digit_char(*pos)) { + # const char * int_end = parse_int(pos); + # max_times = std::stoul(std::string(pos, int_end - pos)); + # pos = parse_space(int_end, is_nested); + # } + + # if (*pos != '}') { + # throw std::runtime_error(std::string("expecting '}' at ") + pos); + # } + # pos = parse_space(pos + 1, is_nested); + # } else { + # throw std::runtime_error(std::string("expecting ',' at ") + pos); + # } + # handle_repetitions(min_times, max_times); # } else { # break; # } + elif pos[0] == "*": + pos = parse_space(pos + 1, is_nested) + handle_repetitions(0, -1) + elif pos[0] == "+": + pos = parse_space(pos + 1, is_nested) + handle_repetitions(1, -1) + elif pos[0] == "?": + pos = parse_space(pos + 1, is_nested) + handle_repetitions(0, 1) + elif pos[0] == "{": + pos = parse_space(pos + 1, is_nested) + + if not is_digit_char(pos[0]): + raise RuntimeError("expecting an int at " + str(pos)) + + + + int_end = parse_int(pos) + min_times = int(str(pos)[:int_end - pos]) + pos = parse_space(int_end, is_nested) + max_times = -1 + if pos[0] == "}": + max_times = min_times + pos = parse_space(pos + 1, is_nested) + elif pos[0] == ",": + pos = parse_space(pos + 1, is_nested) + if is_digit_char(pos[0]): + int_end = parse_int(pos) + max_times = int(str(pos)[:int_end - pos]) + pos = parse_space(int_end, is_nested) + if pos[0] != "}": + raise RuntimeError("expecting '}' at " + str(pos)) + pos = parse_space(pos + 1, is_nested) + else: + raise RuntimeError("expecting ',' at " + str(pos)) + handle_repetitions(min_times, max_times) + else: break - # } - # return pos; - # } + + + + return pos From 5c050e82b2f8b9ce0f9606a7e70f0944578a6bcd Mon Sep 17 00:00:00 2001 From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Date: Mon, 29 Jul 2024 15:48:19 +0200 Subject: [PATCH 3/8] implemented slice function in std:vector --- llama_cpp/llama_grammar.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 39ba8e9dc..bda3b162f 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -244,7 +244,8 @@ def __add__(self, value: int) -> "std.vector[T].iterator": def __sub__(self, value: int) -> "std.vector[T].iterator": return self.__class__(self._vector, self._index - value) - def __init__(self): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self._version = 0 def modify(self): @@ -309,7 +310,7 @@ def insert( first: "std.vector[T].iterator", last: "std.vector[T].iterator", ) -> None: - self[pos._index : pos._index] = list( + self[pos._index:pos._index] = list( islice(first._vector, first._index, last._index) ) @@ -319,6 +320,24 @@ def begin(self) -> "std.vector[T].iterator": def end(self) -> "std.vector[T].iterator": return self.iterator(self, self.size()) + def __getitem__(self, index): + if isinstance(index, slice): + return std.vector(super().__getitem__(index)) + return super().__getitem__(index) + + def __setitem__(self, index, value): + self.modify() + if isinstance(index, slice): + if isinstance(value, std.vector): + value = list(value) + super().__setitem__(index, value) + else: + super().__setitem__(index, value) + + def __delitem__(self, index): + self.modify() + super().__delitem__(index) + class map(Generic[T, U], OrderedDict[T, U]): """C++ implementation of std::map.""" @@ -410,7 +429,6 @@ def begin(self) -> "std.map[T, U].iterator[T, U]": def end(self) -> "std.map[T, U].iterator[T, U]": return self.iterator(self, Sentinel()) - # // grammar element type # enum llama_gretype { # // end of rule definition @@ -824,7 +842,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None: previous_elements = out_elements[last_sym_start:] - print("type-1 ", type(out_elements)) + if min_times == 0: out_elements.resize(last_sym_start) else: @@ -835,8 +853,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None: last_rec_rule_id = 0 # type: int n_opt = 1 if max_times < 0 else max_times - min_times # type: int rec_rule = previous_elements # type: List[LlamaGrammarElement] - print("type1", type(rec_rule)) - print('ahhhhhhhhh') + for i in range(n_opt): rec_rule = previous_elements rec_rule.resize(len(previous_elements)) @@ -1263,6 +1280,7 @@ def print_rule( # print_grammar_char(file, elem.value); # break; # } + for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype if case is llama_gretype.LLAMA_GRETYPE_END: From 4c74a82281872fdce86bcc135c80d45e5703a986 Mon Sep 17 00:00:00 2001 From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:31:24 +0200 Subject: [PATCH 4/8] fixed mistake done while reading --- llama_cpp/llama_grammar.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index bda3b162f..1fae0e688 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -841,7 +841,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None: raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos)) - previous_elements = out_elements[last_sym_start:] + previous_elements:std.vector[LlamaGrammarElement] = out_elements[last_sym_start:out_elements.size()] if min_times == 0: out_elements.resize(last_sym_start) @@ -859,12 +859,12 @@ def handle_repetitions(min_times: int, max_times: int) -> None: rec_rule.resize(len(previous_elements)) rec_rule_id = generate_symbol_id(state, rule_name) # type: int if i > 0 or max_times < 0: - rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id)) + rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id)) rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) add_rule(state, rec_rule_id, rec_rule) - last_rec_rule_id = rec_rule_id + if n_opt > 0: out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id)) @@ -1058,6 +1058,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None: max_times = min_times pos = parse_space(pos + 1, is_nested) elif pos[0] == ",": + pos = parse_space(pos + 1, is_nested) if is_digit_char(pos[0]): int_end = parse_int(pos) @@ -1281,6 +1282,7 @@ def print_rule( # break; # } + for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype if case is llama_gretype.LLAMA_GRETYPE_END: From 1fd884069f8b081f3b4258ce844ee5988755bf18 Mon Sep 17 00:00:00 2001 From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Date: Mon, 29 Jul 2024 16:50:27 +0200 Subject: [PATCH 5/8] ported https://github.com/ggerganov/llama.cpp/pull/7194 --- llama_cpp/llama_grammar.py | 58 +++++++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 1fae0e688..fded061b0 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -891,6 +891,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None: pos += 1 last_sym_start = out_elements.size() while pos[0] != '"': + assert pos[0] is not None, "Unexpected end of input" char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] out_elements.push_back( @@ -920,6 +921,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None: # : start_type; # out_elements.push_back({type, char_pair.first}); while pos[0] != "]": + assert pos[0] is not None, "Unexpected end of input" char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] _type = ( @@ -935,6 +937,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None: # } # } if pos[0] == "-" and pos[1] != "]": + assert pos[1] is not None, "Unexpected end of input" endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p] pos = endchar_pair[1] out_elements.push_back( @@ -1159,33 +1162,59 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: elif pos[0]: raise RuntimeError("expecting newline or end at " + str(pos)) return parse_space(pos, True) + +#parse_state parse(const char * src) { +# try { +# parse_state state; +# const char * pos = parse_space(src, true); +# while (*pos) { +# pos = parse_rule(state, pos); +# } +# // Validate the state to ensure that all rules are defined +# for (const auto & rule : state.rules) { +# for (const auto & elem : rule) { +# if (elem.type == LLAMA_GRETYPE_RULE_REF) { +# // Ensure that the rule at that location exists +# if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { +# // Get the name of the rule that is missing +# for (const auto & kv : state.symbol_ids) { +# if (kv.second == elem.value) { +# throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); +# } +# } +# } +# } +# } +# } +# return state; +# } catch (const std::exception & err) { +# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); +# return parse_state(); +# } +#} -# parse_state parse(const char * src) { -# try { -# parse_state state; -# const char * pos = parse_space(src, true); -# while (*pos) { -# pos = parse_rule(state, pos); -# } -# return state; -# } catch (const std::exception & err) { -# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); -# return parse_state(); -# } -# } def parse(src: const_char_p) -> parse_state: try: state = parse_state() # type: parse_state pos = parse_space(src, True) # type: const_char_p while pos[0]: pos = parse_rule(state, pos) + # Validate the state to ensure that all rules are defined + for rule in state.rules: + for elem in rule: + if elem.type == llama_gretype.LLAMA_GRETYPE_RULE_REF: + # Ensure that the rule at that location exists + if elem.value >= len(state.rules) or not state.rules[elem.value]: + # Get the name of the rule that is missing + for kv in state.symbol_ids: + if kv.second == elem.value: + raise RuntimeError("Undefined rule identifier '" + kv.first + "'") return state except Exception as err: print(f"{parse.__name__}: error parsing grammar: {err}") return parse_state() - # void print_grammar_char(FILE * file, uint32_t c) { # if (0x20 <= c && c <= 0x7f) { # fprintf(file, "%c", static_cast<char>(c)); @@ -1283,6 +1312,7 @@ def print_rule( # } + for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype if case is llama_gretype.LLAMA_GRETYPE_END: From 81cf9092225dd6719682ba7d31084e61d7b66fd9 Mon Sep 17 00:00:00 2001 From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com> Date: Tue, 30 Jul 2024 07:44:19 +0200 Subject: [PATCH 6/8] multiple fixes, var copy --- llama_cpp/llama_grammar.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index fded061b0..8fd679b15 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -523,7 +523,6 @@ def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int: result = state.symbol_ids.insert(std.string(src, len), next_id) return result[0].second # type: ignore - # uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { # uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); # state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; @@ -841,23 +840,22 @@ def handle_repetitions(min_times: int, max_times: int) -> None: raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos)) - previous_elements:std.vector[LlamaGrammarElement] = out_elements[last_sym_start:out_elements.size()] + previous_elements:std.vector[LlamaGrammarElement] = std.vector(out_elements[last_sym_start:]) # type: std.vector[LlamaGrammarElement] if min_times == 0: out_elements.resize(last_sym_start) else: # Repeat the previous elements (min_times - 1) times for i in range(1, min_times): - out_elements.extend(previous_elements) + out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()) last_rec_rule_id = 0 # type: int n_opt = 1 if max_times < 0 else max_times - min_times # type: int - rec_rule = previous_elements # type: List[LlamaGrammarElement] + rec_rule = std.vector(previous_elements) # type: List[LlamaGrammarElement] for i in range(n_opt): - rec_rule = previous_elements - rec_rule.resize(len(previous_elements)) - rec_rule_id = generate_symbol_id(state, rule_name) # type: int + rec_rule.resize(previous_elements.size()) + rec_rule_id = generate_symbol_id(state, rule_name) # type: int if i > 0 or max_times < 0: rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id)) rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) @@ -868,12 +866,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None: if n_opt > 0: out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id)) - - - - - - + # while (*pos) { while pos[0]: @@ -1208,8 +1201,8 @@ def parse(src: const_char_p) -> parse_state: if elem.value >= len(state.rules) or not state.rules[elem.value]: # Get the name of the rule that is missing for kv in state.symbol_ids: - if kv.second == elem.value: - raise RuntimeError("Undefined rule identifier '" + kv.first + "'") + if kv[1] == elem.value: + raise RuntimeError("Undefined rule identifier '" + kv[0] + "'") return state except Exception as err: print(f"{parse.__name__}: error parsing grammar: {err}") From 6d53877fe81e350768cc41433817e91a1d532261 Mon Sep 17 00:00:00 2001 From: Andrei Betlen <abetlen@gmail.com> Date: Sun, 4 Aug 2024 17:14:41 -0400 Subject: [PATCH 7/8] Rewrite LlamaGrammar internals in python style --- llama_cpp/llama_cpp.py | 2 +- llama_cpp/llama_grammar.py | 1632 ++++++++---------------------------- 2 files changed, 351 insertions(+), 1283 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 727195cf5..d598cf9f1 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -3002,7 +3002,7 @@ def llama_grammar_init( n_rules: Union[ctypes.c_size_t, int], start_rule_index: Union[ctypes.c_size_t, int], /, -) -> llama_grammar_p: +) -> Optional[llama_grammar_p]: """Initialize a grammar from a set of rules.""" ... diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 8fd679b15..4d48e34ab 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1,644 +1,93 @@ """Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" # flake8: noqa -from pathlib import Path -import sys -from ctypes import * # type: ignore -from enum import Enum -from itertools import islice, groupby +from itertools import groupby from typing import ( Any, - Callable, - Dict, Set, - Generic, List, Optional, - OrderedDict, - TextIO, Tuple, - TypeVar, Union, - overload, ) -import llama_cpp.llama_cpp as llama_cpp - -# Type aliases -llama_grammar_element = llama_cpp.llama_grammar_element -llama_grammar_element_p = llama_cpp.llama_grammar_element_p -llama_grammar_p = llama_cpp.llama_grammar_p - -# Type variables -Ptr = TypeVar("Ptr", bound="const_char_p") -T = TypeVar("T") -U = TypeVar("U") -V = TypeVar("V") -W = TypeVar("W") - - -class Sentinel: - """Used to mark the end of a iterator of std::vector & std::map.""" - - -class LlamaGrammar: - """Keeps reference counts of all the arguments, so that they are not - garbage collected by Python.""" - - def __del__(self) -> None: - """Free the grammar pointer when the object is deleted.""" - if self.grammar is not None: - llama_cpp.llama_grammar_free(self.grammar) - self.grammar = None - - def __init__( - self, - parsed_grammar: "parse_state", - ) -> None: - """Initialize the grammar pointer from the parsed state.""" - self._grammar_rules = ( - parsed_grammar.c_rules() - ) # type: std.vector[std.vector[LlamaGrammarElement]] - self._n_rules = self._grammar_rules.size() # type: int - self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int - self.init() - - @classmethod - def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": - """Convert a GBNF grammar to a Llama grammar.""" - parsed_grammar = parse(const_char_p(grammar)) # type: parse_state - if parsed_grammar.rules.empty(): - raise ValueError( - f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" - ) - if verbose: - print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) - print_grammar(sys.stderr, parsed_grammar) - print(file=sys.stderr) - return cls(parsed_grammar) - - @classmethod - def from_json_schema( - cls, - json_schema: str, - verbose: bool = True, - ) -> "LlamaGrammar": - """Convert a JSON schema to a Llama grammar.""" - return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose) - - @classmethod - def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": - try: - with open(file) as f: - grammar = f.read() - except Exception as err: - raise Exception( - f"{cls.from_file.__name__}: error reading grammar file: {err}" - ) - - if grammar: - return cls.from_string(grammar, verbose=verbose) - - raise ValueError( - f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" - ) - - def init(self) -> None: - # Step 1: Convert LlamaGrammarElement to llama_grammar_element - self._element_lists = [ - [ - llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value)) - for elem in subvector - ] - for subvector in self._grammar_rules - ] # type: List[List[llama_grammar_element]] - - # Step 2: Convert each list to llama_grammar_element array and get pointer - self._element_arrays = [ - (llama_grammar_element * len(sublist))(*sublist) - for sublist in self._element_lists - ] # type: List[Array[llama_grammar_element]] - - # Step 3: Get pointer of each array - self._element_array_pointers = [ - cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays - ] # type: List[llama_grammar_element_p] - - # Step 4: Make array of these pointers and get its pointer - self._rules = (llama_grammar_element_p * len(self._element_array_pointers))( - *self._element_array_pointers - ) - self.grammar = llama_cpp.llama_grammar_init( - self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index) - ) - - def reset(self) -> None: - if self.grammar is not None: - llama_cpp.llama_grammar_free(self.grammar) - self.init() - - -class LlamaGrammarElement: - def __init__(self, type: "llama_gretype", value: int): - self.type = type - self.value = value # Unicode code point or rule ID - - -class const_char_p: - """C++ implementation of const char *.""" - - def __init__(self, value: Union[str, Ptr], move: Optional[int] = None): - if isinstance(value, const_char_p): - # We're copying an existing const_char_p - self.value = value.value - self.pos = value.pos + (move or 0) - return - - # We're creating a new const_char_p - self.value = value - self.pos = move or 0 - - def __str__(self) -> str: - assert self.value is not None, "null pointer" - return self.value[self.pos :] - - def __getitem__(self, index: int) -> str: - value = str(self) - return value[index] if index < len(value) else "" - - @overload - def __add__(self: Ptr, other: int) -> Ptr: ... - - @overload - def __add__(self: Ptr, other: Ptr) -> int: ... - - def __add__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: - return ( - self.__class__(self.value, self.pos + other) - if isinstance(other, int) - else self.pos + other.pos - ) - - @overload - def __sub__(self: Ptr, other: int) -> Ptr: ... - - @overload - def __sub__(self: Ptr, other: Ptr) -> int: ... - - def __sub__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]: - return ( - self.__class__(self.value, self.pos - other) - if isinstance(other, int) - else self.pos - other.pos - ) - - def __eq__(self: Ptr, other: Ptr) -> bool: - assert self.value == other.value, "comparing pointers from different strings" - return self.pos == other.pos - - def __lt__(self: Ptr, other: Ptr) -> bool: - assert self.value == other.value, "comparing pointers from different strings" - return self.pos < other.pos - - def __gt__(self: Ptr, other: Ptr) -> bool: - assert self.value == other.value, "comparing pointers from different strings" - return self.pos > other.pos - - -class std: - @staticmethod - def string(ptr: const_char_p, length: Optional[int] = None) -> str: - """C++ implementation of std::string constructor.""" - value = str(ptr) - if length is not None: - value = value[:length] - return value - - class vector(Generic[T], List[T]): - """C++ implementation of std::vector.""" - - class iterator: - def __init__(self, vector: "std.vector[T]", index: int): - self._vector = vector - self._index = index - self._version = vector._version - - def _check_version(self): - if self._version != self._vector._version: - raise RuntimeError("Iterator used after vector was modified.") - - def __iter__(self): - return self - - def __next__(self) -> T: - self._check_version() - if self._index >= self._vector.size(): - raise StopIteration - value = self._vector[self._index] - self._index += 1 - return value - - def __add__(self, value: int) -> "std.vector[T].iterator": - return self.__class__(self._vector, self._index + value) - - def __sub__(self, value: int) -> "std.vector[T].iterator": - return self.__class__(self._vector, self._index - value) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._version = 0 - - def modify(self): - # This is a bit of a hack to make sure iterators are invalidated - self._version += 1 - - def push_back(self, value: T) -> None: - self.modify() - self.append(value) - - def pop_back(self) -> None: - self.modify() - if not self.empty(): - self.pop() - - def back(self) -> T: - return self[-1] - - def size(self) -> int: - return len(self) - - def clear(self) -> None: - self.modify() - super().clear() - - def empty(self) -> bool: - return self.size() == 0 - - def data(self) -> "std.vector[T]": - return self - - def resize( - self, - new_size: int, - fill_value_factory: Optional[Callable[[], T]] = None, - ) -> None: - if new_size > self.size(): - if fill_value_factory is None: - raise ValueError("A fill value factory function must be provided.") - self.reserve(new_size, fill_value_factory) - elif new_size < self.size(): - self[:] = self[:new_size] - - def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None: - if capacity > self.size(): - fill_value = fill_value_factory() - self.extend([fill_value] * (capacity - self.size())) - - def front(self) -> T: - if not self.empty(): - return self[0] - else: - raise IndexError("Vector is empty.") - - def assign(self, count: int, value: T) -> None: - self.clear() - self.extend([value] * count) - - def insert( - self, - pos: "std.vector[T].iterator", - first: "std.vector[T].iterator", - last: "std.vector[T].iterator", - ) -> None: - self[pos._index:pos._index] = list( - islice(first._vector, first._index, last._index) - ) - - def begin(self) -> "std.vector[T].iterator": - return self.iterator(self, 0) - - def end(self) -> "std.vector[T].iterator": - return self.iterator(self, self.size()) - - def __getitem__(self, index): - if isinstance(index, slice): - return std.vector(super().__getitem__(index)) - return super().__getitem__(index) - - def __setitem__(self, index, value): - self.modify() - if isinstance(index, slice): - if isinstance(value, std.vector): - value = list(value) - super().__setitem__(index, value) - else: - super().__setitem__(index, value) - - def __delitem__(self, index): - self.modify() - super().__delitem__(index) - - class map(Generic[T, U], OrderedDict[T, U]): - """C++ implementation of std::map.""" - - class iterator(Generic[V, W]): - def __init__(self, _map: "std.map[T, U]", key: Union[T, Sentinel]): - self._map = _map - self.iter = iter(_map) - self.key = key - self._advance() - - def _sanitize_key(self) -> T: - if isinstance(self.key, Sentinel): - raise StopIteration - return self.key - - def _advance(self) -> None: - try: - while next(self.iter) != self.key: - pass - except StopIteration: - self.key = Sentinel() - - def __next__(self) -> Tuple[T, U]: - key = self._sanitize_key() - if key in self._map: - value = self._map[key] - self._advance() - return key, value - else: - raise StopIteration - - def get(self) -> Tuple[T, U]: - key = self._sanitize_key() - return key, self._map[key] +import enum +import typing - @property - def first(self) -> T: - return self._sanitize_key() +import llama_cpp.llama_cpp as llama_cpp - @property - def second(self) -> U: - return self._map[self._sanitize_key()] +class GrammarElementType(enum.IntEnum): + END = llama_cpp.LLAMA_GRETYPE_END + ALT = llama_cpp.LLAMA_GRETYPE_ALT + RULE_REF = llama_cpp.LLAMA_GRETYPE_RULE_REF + CHAR = llama_cpp.LLAMA_GRETYPE_CHAR + CHAR_NOT = llama_cpp.LLAMA_GRETYPE_CHAR_NOT + CHAR_RNG_UPPER = llama_cpp.LLAMA_GRETYPE_CHAR_RNG_UPPER + CHAR_ALT = llama_cpp.LLAMA_GRETYPE_CHAR_ALT + CHAR_ANY = llama_cpp.LLAMA_GRETYPE_CHAR_ANY + +import dataclasses + +@dataclasses.dataclass +class GrammarElement: + type: GrammarElementType + value: int + +@dataclasses.dataclass +class ParseState: + symbol_ids: typing.Dict[str, int] = dataclasses.field(default_factory=dict) + rules: typing.List[typing.List[GrammarElement]] = dataclasses.field(default_factory=list) + + +def decode_utf8(src: str) -> typing.Tuple[int, str]: + lookup: list[int] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4] + first_byte: int = ord(src[0]) + highbits: int = first_byte >> 4 + length: int = lookup[highbits] + mask: int = (1 << (8 - length)) - 1 + value: int = first_byte & mask + end: int = min(len(src), length) # Prevent overrun + + pos: int = 1 + for pos in range(1, end): + if not src[pos]: + break + value = (value << 6) + (ord(src[pos]) & 0x3F) - def insert( - self, key: T, value: U - ) -> Tuple["std.map[T, U].iterator[T, U]", bool]: - if key in self: - return self.iterator(self, key), False - else: - self[key] = value - return self.iterator(self, key), True + return value, src[pos:] if pos < len(src) else "" - def find(self, key: T) -> "std.map[T, U].iterator[T, U]": - if key in self: - return self.iterator(self, key) - else: - return self.end() - def at(self, key: T) -> U: - if key in self: - return self[key] - else: - raise KeyError("The provided key is not found in the map.") - - def erase(self, iterator: "std.map[T, U].iterator[T, U]") -> None: - key = iterator.first - if key in self: - del self[key] - - def size(self) -> int: - return len(self) - - def empty(self) -> bool: - return self.size() == 0 - - def lower_bound(self, key: T) -> "std.map[T, U].iterator[T, U]": - try: - keys = sorted(list(self.keys())) # type: ignore - for k in keys: - if k >= key: - return self.iterator(self, k) - raise ValueError("No key found that is not less than the input key") - except TypeError: - raise TypeError("Keys of type T cannot be sorted.") - - def begin(self) -> "std.map[T, U].iterator[T, U]": - return self.iterator(self, next(iter(self))) - - def end(self) -> "std.map[T, U].iterator[T, U]": - return self.iterator(self, Sentinel()) - -# // grammar element type -# enum llama_gretype { -# // end of rule definition -# LLAMA_GRETYPE_END = 0, - -# // start of alternate definition for rule -# LLAMA_GRETYPE_ALT = 1, - -# // non-terminal element: reference to rule -# LLAMA_GRETYPE_RULE_REF = 2, - -# // terminal element: character (code point) -# LLAMA_GRETYPE_CHAR = 3, - -# // inverse char(s) ([^a], [^a-b] [^abc]) -# LLAMA_GRETYPE_CHAR_NOT = 4, - -# // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to -# // be an inclusive range ([a-z]) -# LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - -# // modifies a preceding LLAMA_GRETYPE_CHAR or -# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) -# LLAMA_GRETYPE_CHAR_ALT = 6, - -# // any character (.) -# LLAMA_GRETYPE_CHAR_ANY = 7, -# }; -class llama_gretype(Enum): - """grammar element type""" - - LLAMA_GRETYPE_END = 0 # end of rule definition - LLAMA_GRETYPE_ALT = 1 # start of alternate definition for rule - LLAMA_GRETYPE_RULE_REF = 2 # non-terminal element: reference to rule - LLAMA_GRETYPE_CHAR = 3 # terminal element: character (code point) - LLAMA_GRETYPE_CHAR_NOT = 4 # inverse char(s) ([^a], [^a-b] [^abc]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 5 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_ALT = 6 # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ANY = 7 # any character (.) - - -# struct parse_state { -# std::map<std::string, uint32_t> symbol_ids; -# std::vector<std::vector<llama_grammar_element>> rules; -# std::vector<const llama_grammar_element *> c_rules(); -# }; -class parse_state: - def __init__(self): - self.symbol_ids: std.map[str, int] = std.map() - self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector() - - # std::vector<const llama_grammar_element *> parse_state::c_rules() { - # std::vector<const llama_grammar_element *> ret; - # for (const auto & rule : rules) { - # ret.push_back(rule.data()); - # } - # return ret; - # } - def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]: - ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]] - for rule in self.rules: - ret.push_back(rule.data()) - return ret - - def __repr__(self) -> str: - return ( - f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" - ) +def get_symbol_id(state: ParseState, name: str) -> int: + next_id = len(state.symbol_ids) + return state.symbol_ids.setdefault(name, next_id) -# struct llama_grammar { -# const std::vector<std::vector<llama_grammar_element>> rules; -# std::vector<std::vector<const llama_grammar_element *>> stacks; -# }; -# class llama_grammar: -# def __init__( -# self, -# rules: std.vector[std.vector[llama_grammar_element]], -# stacks: std.vector[std.vector[llama_grammar_element]], -# ): -# self.rules = rules -# self.stacks = stacks - - -# uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { -# uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); -# auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); -# return result.first->second; -# } -def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int: - next_id = state.symbol_ids.size() # type: int - result = state.symbol_ids.insert(std.string(src, len), next_id) - return result[0].second # type: ignore - -# uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { -# uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); -# state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; -# return next_id; -# } -def generate_symbol_id(state: parse_state, base_name: str) -> int: - next_id = state.symbol_ids.size() # type: int - state.symbol_ids[base_name + "_" + str(next_id)] = next_id +def generate_symbol_id(state: ParseState, base_name: str) -> int: + next_id = len(state.symbol_ids) + state.symbol_ids[f"{base_name}_{next_id}"] = next_id return next_id -# void add_rule( -# parse_state & state, -# uint32_t rule_id, -# const std::vector<llama_grammar_element> & rule) { -# if (state.rules.size() <= rule_id) { -# state.rules.resize(rule_id + 1); -# } -# state.rules[rule_id] = rule; -# } -def add_rule( - state: parse_state, - rule_id: int, - rule: std.vector[LlamaGrammarElement], -) -> None: - if state.rules.size() <= rule_id: - state.rules.resize( - rule_id + 1, - fill_value_factory=std.vector[LlamaGrammarElement], - ) +def add_rule(state: ParseState, rule_id: int, rule: typing.List[GrammarElement]) -> None: + if len(state.rules) <= rule_id: + state.rules.extend([[]] * (rule_id + 1 - len(state.rules))) state.rules[rule_id] = rule -# std::pair<uint32_t, const char *> decode_utf8(const char * src) { -# static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; -# uint8_t first_byte = static_cast<uint8_t>(*src); -# uint8_t highbits = first_byte >> 4; -# int len = lookup[highbits]; -# uint8_t mask = (1 << (8 - len)) - 1; -# uint32_t value = first_byte & mask; -# const char * end = src + len; // may overrun! -# const char * pos = src + 1; -# for ( ; pos < end && *pos; pos++) { -# value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); -# } -# return std::make_pair(value, pos); -# } -def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: - """Decodes a UTF-8 character from the source string.""" - # Get the codepoint of the first character - value = ord(src[0]) - # Move the pointer ahead one character - pos = src + 1 - - return value, pos - - -"""#static bool is_digit_char(char c) { -# return '0' <= c && c <= '9'; -#} def is_digit_char(c: str) -> bool: return "0" <= c <= "9" -# bool is_word_char(char c) { -# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); -# } def is_word_char(c: str) -> bool: - return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") or is_digit_char(c) -""" - -##optimized version -# Original is_digit_char time: 2.868295 seconds -# Optimized is_digit_char time: 1.993195 seconds -# Original is_word_char time: 3.856689 seconds -# Optimized is_word_char time: 2.052832 seconds - -digit_chars = set("0123456789") -word_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-0123456789") + return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or is_digit_char(c) -def is_digit_char(c: str) -> bool: - return c in digit_chars -def is_word_char(c: str) -> bool: - return c in word_chars - - -# std::pair<uint32_t, const char *> parse_hex(const char * src, int size) { -# const char * pos = src; -# const char * end = src + size; -# uint32_t value = 0; -# for ( ; pos < end && *pos; pos++) { -# value <<= 4; -# char c = *pos; -# if ('a' <= c && c <= 'f') { -# value += c - 'a' + 10; -# } else if ('A' <= c && c <= 'F') { -# value += c - 'A' + 10; -# } else if ('0' <= c && c <= '9') { -# value += c - '0'; -# } else { -# break; -# } -# } -# if (pos != end) { -# throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); -# } -# return std::make_pair(value, pos); -# } -def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: - pos = const_char_p(src) # type: const_char_p - end = src + size # type: const_char_p - value = 0 # type: int - while pos < end and pos[0]: +def parse_hex(src: str, size: int) -> typing.Tuple[int, str]: + pos = 0 + value = 0 + for _ in range(size): value <<= 4 - c = pos[0] # type: str + c = src[pos] if "a" <= c <= "f": value += ord(c) - ord("a") + 10 elif "A" <= c <= "F": @@ -648,752 +97,370 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: else: break pos += 1 - if pos != end: - raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src)) - return (value, pos) - - -# std::pair<uint32_t, const char *> parse_char(const char * src) { -# if (*src == '\\') { -# switch (src[1]) { -# case 'x': return parse_hex(src + 2, 2); -# case 'u': return parse_hex(src + 2, 4); -# case 'U': return parse_hex(src + 2, 8); -# case 't': return std::make_pair('\t', src + 2); -# case 'r': return std::make_pair('\r', src + 2); -# case 'n': return std::make_pair('\n', src + 2); -# case '\\': -# case '"': -# case '[': -# case ']': -# return std::make_pair(src[1], src + 2); -# default: -# throw std::runtime_error(std::string("unknown escape at ") + src); -# } -# } else if (*src) { -# return decode_utf8(src); -# } -# throw std::runtime_error("unexpected end of input"); -# } -def parse_char(src: const_char_p) -> Tuple[int, const_char_p]: - if src[0] == "\\": - case = src[1] # type: str - if case == "x": - return parse_hex(src + 2, 2) - elif case == "u": - return parse_hex(src + 2, 4) - elif case == "U": - return parse_hex(src + 2, 8) - elif case == "t": - return (ord("\t"), src + 2) # implicit cast - elif case == "r": - return (ord("\r"), src + 2) # implicit cast - elif case == "n": - return (ord("\n"), src + 2) # implicit cast - elif case in ("\\", '"', "[", "]"): - return (ord(case), src + 2) # implicit cast - else: - raise RuntimeError("unknown escape at " + str(src)) - elif src[0]: - return decode_utf8(src) - else: - raise RuntimeError("unexpected end of input") - - -# const char * parse_name(const char * src) { -# const char * pos = src; -# while (is_word_char(*pos)) { -# pos++; -# } -# if (pos == src) { -# throw std::runtime_error(std::string("expecting name at ") + src); -# } -# return pos; -# } -def parse_name(src: const_char_p) -> const_char_p: - pos = const_char_p(src) # type: const_char_p - while is_word_char(pos[0]): + if pos != size: + raise ValueError(f"expecting {size} hex chars at {src}") + return value, src[pos:] + + +def parse_space(src: str, newline_ok: bool) -> str: + pos = 0 + while pos < len(src) and (src[pos] in (" ", "\t", "#") or (newline_ok and src[pos] in ("\r", "\n"))): + if src[pos] == "#": + while pos < len(src) and src[pos] not in ("\r", "\n"): + pos += 1 pos += 1 - if pos == src: - raise RuntimeError("expecting name at " + str(src)) - return pos + return src[pos:] + + +def parse_name(src: str) -> typing.Tuple[str, str]: + pos = 0 + try: + while is_word_char(src[pos]): + pos += 1 + except IndexError: + return src, "" + if pos == 0: + raise ValueError(f"expecting name at {src}") + return src[:pos], src[pos:] -#static const char * parse_int(const char * src) { -# const char * pos = src; -# while (is_digit_char(*pos)) { -# pos++; -# } -# if (pos == src) { -# throw std::runtime_error(std::string("expecting integer at ") + src); -# } -# return pos; -#} -def parse_int(src: const_char_p) -> const_char_p: - pos = const_char_p(src) # type: const_char_p - while is_digit_char(pos[0]): +def parse_int(src: str) -> typing.Tuple[int, str]: + pos = 0 + while is_digit_char(src[pos]): pos += 1 - if pos == src: - raise RuntimeError("expecting integer at " + str(src)) - return pos + if pos == 0: + raise ValueError(f"expecting integer at {src}") + return int(src[:pos]), src[pos:] -# const char * parse_space(const char * src, bool newline_ok) { -# const char * pos = src; -# while (*pos == ' ' || *pos == '\t' || *pos == '#' || -# (newline_ok && (*pos == '\r' || *pos == '\n'))) { -# if (*pos == '#') { -# while (*pos && *pos != '\r' && *pos != '\n') { -# pos++; -# } -# } else { -# pos++; -# } -# } -# return pos; -# } -def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: - pos = const_char_p(src) # type: const_char_p - while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")): - if pos[0] == "#": - while pos[0] is not None and pos[0] not in ("\r", "\n"): - pos += 1 +def parse_char(src: str) -> typing.Tuple[int, str]: + if src[0] == "\\": + if src[1] == "x": + return parse_hex(src[2:], 2) + elif src[1] == "u": + return parse_hex(src[2:], 4) + elif src[1] == "U": + return parse_hex(src[2:], 8) + elif src[1] == "t": + return ord("\t"), src[2:] + elif src[1] == "r": + return ord("\r"), src[2:] + elif src[1] == "n": + return ord("\n"), src[2:] + elif src[1] in ('\\', '"', '[', ']'): + return ord(src[1]), src[2:] else: - pos += 1 - return pos + raise ValueError(f"unknown escape at {src}") + elif src: + return decode_utf8(src) + raise ValueError("unexpected end of input") -# const char * parse_sequence( -# parse_state & state, -# const char * src, -# const std::string & rule_name, -# std::vector<llama_grammar_element> & out_elements, -# bool is_nested) { -def parse_sequence( - state: parse_state, - src: const_char_p, - rule_name: str, - out_elements: std.vector[LlamaGrammarElement], - is_nested: bool, -) -> const_char_p: - # size_t last_sym_start = out_elements.size(); - # const char * pos = src; - last_sym_start = out_elements.size() # type: int - pos = const_char_p(src) # type: const_char_p - - - # auto handle_repetitions = [&](int min_times, int max_times) { - - # if (last_sym_start == out_elements.size()) { - # throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); - # } - - # // apply transformation to previous symbol (last_sym_start to end) according to - # // the following rewrite rules: - # // S{m,n} --> S S S (m times) S'(n-m) - # // S'(x) ::= S S'(x-1) | - # // (... n-m definitions of these S' rules ...) - # // S'(1) ::= S | - # // S{m,} --> S S S (m times) S' - # // S' ::= S S' | - # // S* --> S{0,} - # // --> S' ::= S S' | - # // S+ --> S{1,} - # // --> S S' - # // S' ::= S S' | - # // S? --> S{0,1} - # // --> S' - # // S' ::= S | - - # std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); - # if (min_times == 0) { - # out_elements.resize(last_sym_start); - # } else { - # // Repeat the previous elements (min_times - 1) times - # for (int i = 1; i < min_times; i++) { - # out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); - # } - # } - - # uint32_t last_rec_rule_id = 0; - # auto n_opt = max_times < 0 ? 1 : max_times - min_times; - - # std::vector<llama_grammar_element> rec_rule(previous_elements); - # for (int i = 0; i < n_opt; i++) { - # rec_rule.resize(previous_elements.size()); - # uint32_t rec_rule_id = generate_symbol_id(state, rule_name); - # if (i > 0 || max_times < 0) { - # rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); - # } - # rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - # rec_rule.push_back({LLAMA_GRETYPE_END, 0}); - # add_rule(state, rec_rule_id, rec_rule); - # last_rec_rule_id = rec_rule_id; - # } - # if (n_opt > 0) { - # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); - # } - # }; - - def handle_repetitions(min_times: int, max_times: int) -> None: - if last_sym_start == out_elements.size(): - raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos)) +def parse_sequence(state: ParseState, src: str, rule_name: str, out_elements: typing.List[GrammarElement], is_nested: bool) -> str: + last_sym_start = len(out_elements) + pos = src + def handle_repetitions(min_times: int, max_times: int) -> None: + nonlocal last_sym_start + nonlocal pos + nonlocal out_elements - previous_elements:std.vector[LlamaGrammarElement] = std.vector(out_elements[last_sym_start:]) # type: std.vector[LlamaGrammarElement] + if last_sym_start == len(out_elements): + raise ValueError(f"expecting preceding item to */+/?/{{ at {pos}") + previous_elements = out_elements[last_sym_start:] if min_times == 0: - out_elements.resize(last_sym_start) + out_elements = out_elements[:last_sym_start] else: - # Repeat the previous elements (min_times - 1) times for i in range(1, min_times): - out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()) - - last_rec_rule_id = 0 # type: int - n_opt = 1 if max_times < 0 else max_times - min_times # type: int - rec_rule = std.vector(previous_elements) # type: List[LlamaGrammarElement] + out_elements.extend(previous_elements) + last_rec_rule_id = 0 + n_opt = 1 if max_times < 0 else max_times - min_times + rec_rule = list(previous_elements) for i in range(n_opt): - rec_rule.resize(previous_elements.size()) - rec_rule_id = generate_symbol_id(state, rule_name) # type: int + rec_rule = rec_rule[:len(previous_elements)] + rec_rule_id = generate_symbol_id(state, rule_name) if i > 0 or max_times < 0: - rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id)) - rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) - rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) + rec_rule.append(GrammarElement(GrammarElementType.RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id)) + rec_rule.append(GrammarElement(GrammarElementType.ALT, 0)) + rec_rule.append(GrammarElement(GrammarElementType.END, 0)) add_rule(state, rec_rule_id, rec_rule) last_rec_rule_id = rec_rule_id - if n_opt > 0: - out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id)) - - - - # while (*pos) { - while pos[0]: - - # if (*pos == '"') { // literal string - # pos++; - # last_sym_start = out_elements.size(); - # while (*pos != '"') { - # auto char_pair = parse_char(pos); - # pos = char_pair.second; - # out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); - # } - # pos = parse_space(pos + 1, is_nested); - if pos[0] == '"': # literal string - pos += 1 - last_sym_start = out_elements.size() + out_elements.append(GrammarElement(GrammarElementType.RULE_REF, last_rec_rule_id)) + + while pos: + if pos.startswith('"'): + pos = pos[1:] + last_sym_start = len(out_elements) while pos[0] != '"': - assert pos[0] is not None, "Unexpected end of input" - char_pair = parse_char(pos) # type: Tuple[int, const_char_p] - pos = char_pair[1] - out_elements.push_back( - LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0]) - ) - pos = parse_space(pos + 1, is_nested) - # } else if (*pos == '[') { // char range(s) - # pos++; - # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - elif pos[0] == "[": # char range(s) - pos += 1 - start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype - # if (*pos == '^') { - # pos++; - # start_type = LLAMA_GRETYPE_CHAR_NOT; - # } - # last_sym_start = out_elements.size(); + if not pos: + raise ValueError("unexpected end of input") + char, pos = parse_char(pos) + out_elements.append(GrammarElement(GrammarElementType.CHAR, char)) + pos = parse_space(pos[1:], is_nested) + elif pos.startswith("["): + pos = pos[1:] + start_type = GrammarElementType.CHAR if pos[0] == "^": - pos += 1 - start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT - last_sym_start = out_elements.size() - # while (*pos != ']') { - # auto char_pair = parse_char(pos); - # pos = char_pair.second; - # enum llama_gretype type = last_sym_start < out_elements.size() - # ? LLAMA_GRETYPE_CHAR_ALT - # : start_type; - # out_elements.push_back({type, char_pair.first}); + start_type = GrammarElementType.CHAR_NOT + pos = pos[1:] + last_sym_start = len(out_elements) while pos[0] != "]": - assert pos[0] is not None, "Unexpected end of input" - char_pair = parse_char(pos) # type: Tuple[int, const_char_p] - pos = char_pair[1] - _type = ( - llama_gretype.LLAMA_GRETYPE_CHAR_ALT - if last_sym_start < out_elements.size() - else start_type - ) # type: llama_gretype - out_elements.push_back(LlamaGrammarElement(_type, char_pair[0])) - # if (pos[0] == '-' && pos[1] != ']') { - # auto endchar_pair = parse_char(pos + 1); - # pos = endchar_pair.second; - # out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); - # } - # } + if not pos: + raise ValueError("unexpected end of input") + char, pos = parse_char(pos) + type = GrammarElementType.CHAR_ALT if last_sym_start < len(out_elements) else start_type + out_elements.append(GrammarElement(type, char)) if pos[0] == "-" and pos[1] != "]": - assert pos[1] is not None, "Unexpected end of input" - endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p] - pos = endchar_pair[1] - out_elements.push_back( - LlamaGrammarElement( - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, - endchar_pair[0], - ) - ) - # pos = parse_space(pos + 1, is_nested); - pos = parse_space(pos + 1, is_nested) - # } else if (is_word_char(*pos)) { // rule reference - # const char * name_end = parse_name(pos); - # uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); - # pos = parse_space(name_end, is_nested); - # last_sym_start = out_elements.size(); - # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - elif is_word_char(pos[0]): # rule reference - name_end = parse_name(pos) # type: const_char_p - ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int - pos = parse_space(name_end, is_nested) - last_sym_start = out_elements.size() - out_elements.push_back( - LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id) - ) - # } else if (*pos == '(') { // grouping - # // parse nested alternates into synthesized rule - # pos = parse_space(pos + 1, true); - # uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - # pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); - # last_sym_start = out_elements.size(); - # // output reference to synthesized rule - # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - # if (*pos != ')') { - # throw std::runtime_error(std::string("expecting ')' at ") + pos); - # } - # pos = parse_space(pos + 1, is_nested); - elif pos[0] == "(": # grouping - # parse nested alternates into synthesized rule - pos = parse_space(pos + 1, True) - sub_rule_id = generate_symbol_id(state, rule_name) # type: int - pos = parse_alternates(state, pos, rule_name, sub_rule_id, True) - last_sym_start = out_elements.size() - # output reference to synthesized rule - out_elements.push_back( - LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) - ) + if not pos[1]: + raise ValueError("unexpected end of input") + endchar, pos = parse_char(pos[1:]) + out_elements.append(GrammarElement(GrammarElementType.CHAR_RNG_UPPER, endchar)) + pos = parse_space(pos[1:], is_nested) + elif is_word_char(pos[0]): + name, rest = parse_name(pos) + ref_rule_id = get_symbol_id(state, name) + pos = parse_space(rest, is_nested) + last_sym_start = len(out_elements) + out_elements.append(GrammarElement(GrammarElementType.RULE_REF, ref_rule_id)) + elif pos.startswith("("): + pos = parse_space(pos[1:], newline_ok=True) + sub_rule_id = generate_symbol_id(state, rule_name) + pos = parse_alternates(state, pos, rule_name, sub_rule_id, is_nested=True) + last_sym_start = len(out_elements) + out_elements.append(GrammarElement(GrammarElementType.RULE_REF, sub_rule_id)) if pos[0] != ")": - raise RuntimeError("expecting ')' at " + str(pos)) - pos = parse_space(pos + 1, is_nested) - elif pos[0] == '.': - last_sym_start = out_elements.size() - out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR_ANY, 0)) - pos = parse_space(pos + 1, is_nested) - - # } else if (*pos == '*') { - # pos = parse_space(pos + 1, is_nested); - # handle_repetitions(0, -1); - # } else if (*pos == '+') { - # pos = parse_space(pos + 1, is_nested); - # handle_repetitions(1, -1); - # } else if (*pos == '?') { - # pos = parse_space(pos + 1, is_nested); - # handle_repetitions(0, 1); - # } else if (*pos == '{') { - # pos = parse_space(pos + 1, is_nested); - - # if (!is_digit_char(*pos)) { - # throw std::runtime_error(std::string("expecting an int at ") + pos); - # } - # const char * int_end = parse_int(pos); - # int min_times = std::stoul(std::string(pos, int_end - pos)); - # pos = parse_space(int_end, is_nested); - - # int max_times = -1; - - # if (*pos == '}') { - # max_times = min_times; - # pos = parse_space(pos + 1, is_nested); - # } else if (*pos == ',') { - # pos = parse_space(pos + 1, is_nested); - - # if (is_digit_char(*pos)) { - # const char * int_end = parse_int(pos); - # max_times = std::stoul(std::string(pos, int_end - pos)); - # pos = parse_space(int_end, is_nested); - # } - - # if (*pos != '}') { - # throw std::runtime_error(std::string("expecting '}' at ") + pos); - # } - # pos = parse_space(pos + 1, is_nested); - # } else { - # throw std::runtime_error(std::string("expecting ',' at ") + pos); - # } - # handle_repetitions(min_times, max_times); - # } else { - # break; - # } - elif pos[0] == "*": - pos = parse_space(pos + 1, is_nested) + raise ValueError(f"expecting ')' at {pos}") + pos = parse_space(pos[1:], is_nested) + elif pos.startswith("."): + last_sym_start = len(out_elements) + out_elements.append(GrammarElement(GrammarElementType.CHAR_ANY, 0)) + pos = parse_space(pos[1:], is_nested) + elif pos.startswith("*"): + pos = parse_space(pos[1:], is_nested) handle_repetitions(0, -1) - elif pos[0] == "+": - pos = parse_space(pos + 1, is_nested) + elif pos.startswith("+"): + pos = parse_space(pos[1:], is_nested) handle_repetitions(1, -1) - elif pos[0] == "?": - pos = parse_space(pos + 1, is_nested) + elif pos.startswith("?"): + pos = parse_space(pos[1:], is_nested) handle_repetitions(0, 1) - elif pos[0] == "{": - pos = parse_space(pos + 1, is_nested) - + elif pos.startswith("{"): + pos = parse_space(pos[1:], is_nested) if not is_digit_char(pos[0]): - raise RuntimeError("expecting an int at " + str(pos)) - + raise ValueError(f"expecting an int at {pos}") + min_times, pos = parse_int(pos) + pos = parse_space(pos, is_nested) - - int_end = parse_int(pos) - min_times = int(str(pos)[:int_end - pos]) - pos = parse_space(int_end, is_nested) max_times = -1 + if pos[0] == "}": max_times = min_times - pos = parse_space(pos + 1, is_nested) + pos = parse_space(pos[1:], is_nested) elif pos[0] == ",": - - pos = parse_space(pos + 1, is_nested) + pos = parse_space(pos[1:], is_nested) if is_digit_char(pos[0]): - int_end = parse_int(pos) - max_times = int(str(pos)[:int_end - pos]) - pos = parse_space(int_end, is_nested) + max_times, pos = parse_int(pos) + pos = parse_space(pos, is_nested) if pos[0] != "}": - raise RuntimeError("expecting '}' at " + str(pos)) - pos = parse_space(pos + 1, is_nested) + raise ValueError("expecting '}' at {}".format(pos)) + pos = parse_space(pos[1:], is_nested) else: - raise RuntimeError("expecting ',' at " + str(pos)) + raise ValueError(f"expecting ',' at {pos}") handle_repetitions(min_times, max_times) - else: break - - - - return pos -# const char * parse_alternates( -# parse_state & state, -# const char * src, -# const std::string & rule_name, -# uint32_t rule_id, -# bool is_nested) { -# std::vector<llama_grammar_element> rule; -# const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); -# while (*pos == '|') { -# rule.push_back({LLAMA_GRETYPE_ALT, 0}); -# pos = parse_space(pos + 1, true); -# pos = parse_sequence(state, pos, rule_name, rule, is_nested); -# } -# rule.push_back({LLAMA_GRETYPE_END, 0}); -# add_rule(state, rule_id, rule); -# return pos; -# } -def parse_alternates( - state: parse_state, - src: const_char_p, - rule_name: str, - rule_id: int, - is_nested: bool, -) -> const_char_p: - rule = std.vector() # type: std.vector[LlamaGrammarElement] - pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p - while pos[0] == "|": - rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) - pos = parse_space(pos + 1, True) +def parse_alternates(state: ParseState, src: str, rule_name: str, rule_id: int, is_nested: bool) -> str: + rule = [] + pos = parse_sequence(state, src, rule_name, rule, is_nested) + while pos.startswith("|"): + rule.append(GrammarElement(GrammarElementType.ALT, 0)) + pos = parse_space(pos[1:], newline_ok=True) pos = parse_sequence(state, pos, rule_name, rule, is_nested) - rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) + rule.append(GrammarElement(GrammarElementType.END, 0)) add_rule(state, rule_id, rule) return pos -# const char * parse_rule(parse_state & state, const char * src) { -# const char * name_end = parse_name(src); -# const char * pos = parse_space(name_end, false); -# size_t name_len = name_end - src; -# uint32_t rule_id = get_symbol_id(state, src, name_len); -# const std::string name(src, name_len); - -# if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { -# throw std::runtime_error(std::string("expecting ::= at ") + pos); -# } -# pos = parse_space(pos + 3, true); - -# pos = parse_alternates(state, pos, name, rule_id, false); - - -# if (*pos == '\r') { -# pos += pos[1] == '\n' ? 2 : 1; -# } else if (*pos == '\n') { -# pos++; -# } else if (*pos) { -# throw std::runtime_error(std::string("expecting newline or end at ") + pos); -# } -# return parse_space(pos, true); -# } -def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: - name_end = parse_name(src) # type: const_char_p - pos = parse_space(name_end, False) # type: const_char_p - name_len = name_end - src # type: int - rule_id = get_symbol_id(state, src, name_len) # type: int - name = std.string(src, name_len) # type: str - - if not (pos[0] == ":" and pos[1] == ":" and pos[2] == "="): - raise RuntimeError("expecting ::= at " + str(pos)) - - pos = parse_space(pos + 3, True) # type: const_char_p - pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p - - if pos[0] == "\r": - pos += 2 if pos[1] == "\n" else 1 - elif pos[0] == "\n": - pos += 1 - elif pos[0]: - raise RuntimeError("expecting newline or end at " + str(pos)) - return parse_space(pos, True) - -#parse_state parse(const char * src) { -# try { -# parse_state state; -# const char * pos = parse_space(src, true); -# while (*pos) { -# pos = parse_rule(state, pos); -# } -# // Validate the state to ensure that all rules are defined -# for (const auto & rule : state.rules) { -# for (const auto & elem : rule) { -# if (elem.type == LLAMA_GRETYPE_RULE_REF) { -# // Ensure that the rule at that location exists -# if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { -# // Get the name of the rule that is missing -# for (const auto & kv : state.symbol_ids) { -# if (kv.second == elem.value) { -# throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); -# } -# } -# } -# } -# } -# } -# return state; -# } catch (const std::exception & err) { -# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); -# return parse_state(); -# } -#} - - -def parse(src: const_char_p) -> parse_state: - try: - state = parse_state() # type: parse_state - pos = parse_space(src, True) # type: const_char_p - while pos[0]: - pos = parse_rule(state, pos) - # Validate the state to ensure that all rules are defined - for rule in state.rules: - for elem in rule: - if elem.type == llama_gretype.LLAMA_GRETYPE_RULE_REF: - # Ensure that the rule at that location exists - if elem.value >= len(state.rules) or not state.rules[elem.value]: - # Get the name of the rule that is missing - for kv in state.symbol_ids: - if kv[1] == elem.value: - raise RuntimeError("Undefined rule identifier '" + kv[0] + "'") - return state - except Exception as err: - print(f"{parse.__name__}: error parsing grammar: {err}") - return parse_state() - -# void print_grammar_char(FILE * file, uint32_t c) { -# if (0x20 <= c && c <= 0x7f) { -# fprintf(file, "%c", static_cast<char>(c)); -# } else { -# // cop out of encoding UTF-8 -# fprintf(file, "<U+%04X>", c); -# } -# } -def print_grammar_char(file: TextIO, c: int) -> None: - if 0x20 <= c and c <= 0x7F: - file.write(chr(c)) - else: - # cop out of encoding UTF-8 - file.write(f"<U+{c:04X}>") - - -# bool is_char_element(llama_grammar_element elem) { -# switch (elem.type) { -# case LLAMA_GRETYPE_CHAR: return true; -# case LLAMA_GRETYPE_CHAR_NOT: return true; -# case LLAMA_GRETYPE_CHAR_ALT: return true; -# case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; -# default: return false; -# } -# } -def is_char_element(elem: LlamaGrammarElement) -> bool: +def parse_rule(state: ParseState, src: str) -> str: + name, s = parse_name(src) + s = parse_space(s, newline_ok=False) + rule_id = get_symbol_id(state, name) + + if not s.startswith("::="): + raise ValueError(f"expecting ::= at {s}") + + s = s[3:] + + s = parse_space(s, newline_ok=True) + + s = parse_alternates(state, s, name, rule_id, is_nested=False) + + if s.startswith("\r"): + s = s[2:] if s[1] == "\n" else s[1:] + elif s.startswith("\n"): + s = s[1:] + elif s: + raise ValueError(f"expecting newline or end at {s}") + return parse_space(s, newline_ok=True) + + +def parse(gbnf: str) -> ParseState: + state = ParseState() + s = parse_space(gbnf, newline_ok=True) + while s: + s = parse_rule(state, s) + # validate + for rule in state.rules: + for elem in rule: + if elem.type == GrammarElementType.RULE_REF: + if elem.value >= len(state.rules) or not state.rules[elem.value]: + for k, v in state.symbol_ids.items(): + if v == elem.value: + raise ValueError(f"Undefined rule identifier '{k}'") + return state + + +def is_char_element(elem: GrammarElement) -> bool: return elem.type in ( - llama_gretype.LLAMA_GRETYPE_CHAR, - llama_gretype.LLAMA_GRETYPE_CHAR_NOT, - llama_gretype.LLAMA_GRETYPE_CHAR_ALT, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, - llama_gretype.LLAMA_GRETYPE_CHAR_ANY, + GrammarElementType.CHAR, + GrammarElementType.CHAR_NOT, + GrammarElementType.CHAR_ALT, + GrammarElementType.CHAR_RNG_UPPER, + GrammarElementType.CHAR_ANY ) -# void print_rule( -# FILE * file, -# uint32_t rule_id, -# const std::vector<llama_grammar_element> & rule, -# const std::map<uint32_t, std::string> & symbol_id_names) { +def print_grammar_char(file: typing.TextIO, c: int) -> None: + if 0x20 <= c <= 0x7f: + print(chr(c), end="", file=file) + else: + print(f"<U+{c:04X}>", end="", file=file) + + def print_rule( - file: TextIO, + file: typing.TextIO, rule_id: int, - rule: std.vector[LlamaGrammarElement], - symbol_id_names: std.map[int, str], + rule: typing.List[GrammarElement], + symbol_id_names: typing.Dict[int, str], ) -> None: - # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { - # throw std::runtime_error( - # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); - # } - # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: - raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) - ) - print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") - # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { - # llama_grammar_element elem = rule[i]; - # switch (elem.type) { - # case LLAMA_GRETYPE_END: - # throw std::runtime_error( - # "unexpected end of rule: " + std::to_string(rule_id) + "," + - # std::to_string(i)); - # case LLAMA_GRETYPE_ALT: - # fprintf(file, "| "); - # break; - # case LLAMA_GRETYPE_RULE_REF: - # fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); - # break; - # case LLAMA_GRETYPE_CHAR: - # fprintf(file, "["); - # print_grammar_char(file, elem.value); - # break; - # case LLAMA_GRETYPE_CHAR_NOT: - # fprintf(file, "[^"); - # print_grammar_char(file, elem.value); - # break; - # case LLAMA_GRETYPE_CHAR_RNG_UPPER: - # if (i == 0 || !is_char_element(rule[i - 1])) { - # throw std::runtime_error( - # "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + - # std::to_string(rule_id) + "," + std::to_string(i)); - # } - # fprintf(file, "-"); - # print_grammar_char(file, elem.value); - # break; - # case LLAMA_GRETYPE_CHAR_ALT: - # if (i == 0 || !is_char_element(rule[i - 1])) { - # throw std::runtime_error( - # "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + - # std::to_string(rule_id) + "," + std::to_string(i)); - # } - # print_grammar_char(file, elem.value); - # break; - # } - + if not rule or rule[-1].type != GrammarElementType.END: + raise ValueError(f"malformed rule, does not end with LLAMA_GRETYPE_END: {rule_id}") + print(f"{symbol_id_names[rule_id]} ::=", end=" ", file=file) for i, elem in enumerate(rule[:-1]): - case = elem.type # type: llama_gretype - if case is llama_gretype.LLAMA_GRETYPE_END: - raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) - elif case is llama_gretype.LLAMA_GRETYPE_ALT: - print("| ", file=file, end="") - elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: - print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") - elif case is llama_gretype.LLAMA_GRETYPE_CHAR: - print("[", file=file, end="") + if elem.type == GrammarElementType.END: + raise ValueError(f"unexpected end of rule: {rule_id}, {i}") + if elem.type == GrammarElementType.ALT: + print("| ", end="", file=file) + elif elem.type == GrammarElementType.RULE_REF: + print(f"{symbol_id_names[elem.value]} ", end="", file=file) + elif elem.type == GrammarElementType.CHAR: + print("[", end="", file=file) print_grammar_char(file, elem.value) - elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT: - print("[^", file=file, end="") + elif elem.type == GrammarElementType.CHAR_NOT: + print("[^", end="", file=file) print_grammar_char(file, elem.value) - elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: + elif elem.type == GrammarElementType.CHAR_RNG_UPPER: if i == 0 or not is_char_element(rule[i - 1]): - raise RuntimeError( - "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " - + str(rule_id) - + "," - + str(i) - ) - print("-", file=file, end="") + raise ValueError(f"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: {rule_id}, {i}") + print(f"-", end="", file=file) print_grammar_char(file, elem.value) - elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT: + elif elem.type == GrammarElementType.CHAR_ALT: if i == 0 or not is_char_element(rule[i - 1]): - raise RuntimeError( - "LLAMA_GRETYPE_CHAR_ALT without preceding char: " - + str(rule_id) - + "," - + str(i) - ) + raise ValueError(f"LLAMA_GRETYPE_CHAR_ALT without preceding char: {rule_id}, {i}") print_grammar_char(file, elem.value) - elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ANY: - print(".", file=file, end="") - # if (is_char_element(elem)) { - # switch (rule[i + 1].type) { - # case LLAMA_GRETYPE_CHAR_ALT: - # case LLAMA_GRETYPE_CHAR_RNG_UPPER: - # break; - # default: - # fprintf(file, "] "); + elif elem.type == GrammarElementType.CHAR_ANY: + print(".", end="", file=file) if is_char_element(elem): - if rule[i + 1].type in ( - llama_gretype.LLAMA_GRETYPE_CHAR_ALT, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, - ): - pass - else: - print("] ", file=file, end="") - # } - # } - # } - # fprintf(file, "\n"); - # } + if rule[i + 1].type in (GrammarElementType.CHAR_ALT, GrammarElementType.CHAR_RNG_UPPER, GrammarElementType.CHAR_ANY): + continue + print("] ", end="", file=file) print(file=file) -# void print_grammar(FILE * file, const parse_state & state) { -# try { -# std::map<uint32_t, std::string> symbol_id_names; -# for (auto kv : state.symbol_ids) { -# symbol_id_names[kv.second] = kv.first; -# } -# for (size_t i = 0, end = state.rules.size(); i < end; i++) { -# // fprintf(file, "%zu: ", i); -# // print_rule_binary(file, state.rules[i]); -# print_rule(file, i, state.rules[i], symbol_id_names); -# // fprintf(file, "\n"); -# } -# } catch (const std::exception & err) { -# fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); -# } -# } -def print_grammar(file: TextIO, state: parse_state) -> None: +def print_grammar(file: typing.TextIO, state: ParseState) -> None: try: - symbol_id_names = std.map() # type: std.map[int, str] - for kv in state.symbol_ids.items(): - symbol_id_names[kv[1]] = kv[0] - + symbol_id_names = {v: k for k, v in state.symbol_ids.items()} for i, rule in enumerate(state.rules): print_rule(file, i, rule, symbol_id_names) except Exception as err: - print( - f"{print_grammar.__name__}: error printing grammar: {err}", - file=sys.stderr, + print(f"\nerror printing grammar: {err}", file=file) + raise err + +import ctypes + +class LlamaGrammar: + def __init__(self, parse_state: ParseState): + self.parse_state = parse_state + + self._grammar_rules = parse_state.rules + self._n_rules = len(self._grammar_rules) + self._start_rule_index = parse_state.symbol_ids["root"] + + self._element_lists = [ + [ + llama_cpp.llama_grammar_element(ctypes.c_int(elem.type.value), ctypes.c_uint32(elem.value)) + for elem in subvector + ] + for subvector in self._grammar_rules + ] + + # Step 2: Convert each list to llama_grammar_element array and get pointer + self._element_arrays = [ + (llama_cpp.llama_grammar_element * len(sublist))(*sublist) + for sublist in self._element_lists + ] + + # Step 3: Get pointer of each array + self._element_array_pointers = [ + ctypes.cast(subarray, llama_cpp.llama_grammar_element_p) for subarray in self._element_arrays + ] + + # Step 4: Make array of these pointers and get its pointer + self._rules = (llama_cpp.llama_grammar_element_p * len(self._element_array_pointers))( + *self._element_array_pointers ) + self.grammar = None + self._init_grammar() + + + def _init_grammar(self): + grammar = llama_cpp.llama_grammar_init( + self._rules, ctypes.c_size_t(self._n_rules), ctypes.c_size_t(self._start_rule_index) + ) + + if grammar is None: + raise ValueError("Failed to create grammar") + + self.grammar = grammar + + def __del__(self): + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self.grammar = None + + def reset(self): + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self._init_grammar() + + @classmethod + def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": + parsed_grammar = parse(grammar) + return cls(parsed_grammar) + + @classmethod + def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar": + return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose) + """llama.cpp gbnf rules from vendor/llama.cpp/grammars""" @@ -1564,12 +631,13 @@ def print_grammar(file: TextIO, state: parse_state) -> None: string ::= "\"" ( [^"\\\x7F\x00-\x1F] | - "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes )* "\"" ws -number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws +number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws -ws ::= ([ \t\n] ws)? +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= | " " | "\n" [ \t]{0,20} """ LIST_GBNF = r""" From 7308d5392a4ae7fecc635c5ebf3176149bb8a14d Mon Sep 17 00:00:00 2001 From: Andrei Betlen <abetlen@gmail.com> Date: Tue, 6 Aug 2024 20:15:32 -0400 Subject: [PATCH 8/8] bugfix --- llama_cpp/llama_grammar.py | 578 ++++++++++++++++++++++++++++++++----- 1 file changed, 507 insertions(+), 71 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 4d48e34ab..e60817e1d 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1,6 +1,12 @@ """Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp.""" # flake8: noqa +import sys +import ctypes +import enum +import typing +import dataclasses + from itertools import groupby from typing import ( Any, @@ -11,9 +17,6 @@ Union, ) -import enum -import typing - import llama_cpp.llama_cpp as llama_cpp class GrammarElementType(enum.IntEnum): @@ -26,19 +29,33 @@ class GrammarElementType(enum.IntEnum): CHAR_ALT = llama_cpp.LLAMA_GRETYPE_CHAR_ALT CHAR_ANY = llama_cpp.LLAMA_GRETYPE_CHAR_ANY -import dataclasses @dataclasses.dataclass class GrammarElement: type: GrammarElementType value: int + @dataclasses.dataclass class ParseState: symbol_ids: typing.Dict[str, int] = dataclasses.field(default_factory=dict) rules: typing.List[typing.List[GrammarElement]] = dataclasses.field(default_factory=list) +# static std::pair<uint32_t, const char *> decode_utf8(const char * src) { +# static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; +# uint8_t first_byte = static_cast<uint8_t>(*src); +# uint8_t highbits = first_byte >> 4; +# int len = lookup[highbits]; +# uint8_t mask = (1 << (8 - len)) - 1; +# uint32_t value = first_byte & mask; +# const char * end = src + len; // may overrun! +# const char * pos = src + 1; +# for ( ; pos < end && *pos; pos++) { +# value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F); +# } +# return std::make_pair(value, pos); +# } def decode_utf8(src: str) -> typing.Tuple[int, str]: lookup: list[int] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4] first_byte: int = ord(src[0]) @@ -57,31 +74,78 @@ def decode_utf8(src: str) -> typing.Tuple[int, str]: return value, src[pos:] if pos < len(src) else "" +# static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { +# uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); +# auto result = state.symbol_ids.emplace(std::string(src, len), next_id); +# return result.first->second; +# } def get_symbol_id(state: ParseState, name: str) -> int: next_id = len(state.symbol_ids) return state.symbol_ids.setdefault(name, next_id) +# static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { +# uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size()); +# state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; +# return next_id; +# } def generate_symbol_id(state: ParseState, base_name: str) -> int: next_id = len(state.symbol_ids) state.symbol_ids[f"{base_name}_{next_id}"] = next_id return next_id +# static void add_rule( +# parse_state & state, +# uint32_t rule_id, +# const std::vector<llama_grammar_element> & rule) { +# if (state.rules.size() <= rule_id) { +# state.rules.resize(rule_id + 1); +# } +# state.rules[rule_id] = rule; +# } def add_rule(state: ParseState, rule_id: int, rule: typing.List[GrammarElement]) -> None: if len(state.rules) <= rule_id: state.rules.extend([[]] * (rule_id + 1 - len(state.rules))) state.rules[rule_id] = rule +# static bool is_digit_char(char c) { +# return '0' <= c && c <= '9'; +# } def is_digit_char(c: str) -> bool: return "0" <= c <= "9" +# static bool is_word_char(char c) { +# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); +# } def is_word_char(c: str) -> bool: return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or is_digit_char(c) +# static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) { +# const char * pos = src; +# const char * end = src + size; +# uint32_t value = 0; +# for ( ; pos < end && *pos; pos++) { +# value <<= 4; +# char c = *pos; +# if ('a' <= c && c <= 'f') { +# value += c - 'a' + 10; +# } else if ('A' <= c && c <= 'F') { +# value += c - 'A' + 10; +# } else if ('0' <= c && c <= '9') { +# value += c - '0'; +# } else { +# break; +# } +# } +# if (pos != end) { +# throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); +# } +# return std::make_pair(value, pos); +# } def parse_hex(src: str, size: int) -> typing.Tuple[int, str]: pos = 0 value = 0 @@ -102,38 +166,93 @@ def parse_hex(src: str, size: int) -> typing.Tuple[int, str]: return value, src[pos:] +# static const char * parse_space(const char * src, bool newline_ok) { +# const char * pos = src; +# while (*pos == ' ' || *pos == '\t' || *pos == '#' || +# (newline_ok && (*pos == '\r' || *pos == '\n'))) { +# if (*pos == '#') { +# while (*pos && *pos != '\r' && *pos != '\n') { +# pos++; +# } +# } else { +# pos++; +# } +# } +# return pos; +# } def parse_space(src: str, newline_ok: bool) -> str: - pos = 0 - while pos < len(src) and (src[pos] in (" ", "\t", "#") or (newline_ok and src[pos] in ("\r", "\n"))): - if src[pos] == "#": - while pos < len(src) and src[pos] not in ("\r", "\n"): - pos += 1 - pos += 1 - return src[pos:] + pos = src + while pos and (pos[0] in (' ', '\t', '#') or (newline_ok and pos[0] in ('\r', '\n'))): + if pos[0] == "#": + while pos and pos[0] not in ("\r", "\n"): + pos = pos[1:] + else: + pos = pos[1:] + return pos +# static const char * parse_name(const char * src) { +# const char * pos = src; +# while (is_word_char(*pos)) { +# pos++; +# } +# if (pos == src) { +# throw std::runtime_error(std::string("expecting name at ") + src); +# } +# return pos; +# } def parse_name(src: str) -> typing.Tuple[str, str]: - pos = 0 - try: - while is_word_char(src[pos]): - pos += 1 - except IndexError: - return src, "" - if pos == 0: + pos = src + while pos and is_word_char(pos[0]): + pos = pos[1:] + if pos == src: raise ValueError(f"expecting name at {src}") - return src[:pos], src[pos:] - - + return src[:len(src) - len(pos)], pos + +# static const char * parse_int(const char * src) { +# const char * pos = src; +# while (is_digit_char(*pos)) { +# pos++; +# } +# if (pos == src) { +# throw std::runtime_error(std::string("expecting integer at ") + src); +# } +# return pos; +# } def parse_int(src: str) -> typing.Tuple[int, str]: - pos = 0 - while is_digit_char(src[pos]): - pos += 1 - if pos == 0: + pos = src + while pos and is_digit_char(pos[0]): + pos = pos[1:] + if pos == src: raise ValueError(f"expecting integer at {src}") - return int(src[:pos]), src[pos:] - - + return int(src[:len(src) - len(pos)]), pos + + +# static std::pair<uint32_t, const char *> parse_char(const char * src) { +# if (*src == '\\') { +# switch (src[1]) { +# case 'x': return parse_hex(src + 2, 2); +# case 'u': return parse_hex(src + 2, 4); +# case 'U': return parse_hex(src + 2, 8); +# case 't': return std::make_pair('\t', src + 2); +# case 'r': return std::make_pair('\r', src + 2); +# case 'n': return std::make_pair('\n', src + 2); +# case '\\': +# case '"': +# case '[': +# case ']': +# return std::make_pair(src[1], src + 2); +# default: +# throw std::runtime_error(std::string("unknown escape at ") + src); +# } +# } else if (*src) { +# return decode_utf8(src); +# } +# throw std::runtime_error("unexpected end of input"); +# } def parse_char(src: str) -> typing.Tuple[int, str]: + if not src: + raise ValueError("unexpected end of input") if src[0] == "\\": if src[1] == "x": return parse_hex(src[2:], 2) @@ -151,33 +270,202 @@ def parse_char(src: str) -> typing.Tuple[int, str]: return ord(src[1]), src[2:] else: raise ValueError(f"unknown escape at {src}") - elif src: - return decode_utf8(src) - raise ValueError("unexpected end of input") - - + return decode_utf8(src) + +# static const char * parse_sequence( +# parse_state & state, +# const char * src, +# const std::string & rule_name, +# std::vector<llama_grammar_element> & out_elements, +# bool is_nested) { +# size_t last_sym_start = out_elements.size(); +# const char * pos = src; +# +# auto handle_repetitions = [&](int min_times, int max_times) { +# +# if (last_sym_start == out_elements.size()) { +# throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); +# } +# +# // apply transformation to previous symbol (last_sym_start to end) according to +# // the following rewrite rules: +# // S{m,n} --> S S S (m times) S'(n-m) +# // S'(x) ::= S S'(x-1) | +# // (... n-m definitions of these S' rules ...) +# // S'(1) ::= S | +# // S{m,} --> S S S (m times) S' +# // S' ::= S S' | +# // S* --> S{0,} +# // --> S' ::= S S' | +# // S+ --> S{1,} +# // --> S S' +# // S' ::= S S' | +# // S? --> S{0,1} +# // --> S' +# // S' ::= S | +# +# std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); +# if (min_times == 0) { +# out_elements.resize(last_sym_start); +# } else { +# // Repeat the previous elements (min_times - 1) times +# for (int i = 1; i < min_times; i++) { +# out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); +# } +# } +# +# uint32_t last_rec_rule_id = 0; +# auto n_opt = max_times < 0 ? 1 : max_times - min_times; +# +# std::vector<llama_grammar_element> rec_rule(previous_elements); +# for (int i = 0; i < n_opt; i++) { +# rec_rule.resize(previous_elements.size()); +# uint32_t rec_rule_id = generate_symbol_id(state, rule_name); +# if (i > 0 || max_times < 0) { +# rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); +# } +# rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); +# rec_rule.push_back({LLAMA_GRETYPE_END, 0}); +# add_rule(state, rec_rule_id, rec_rule); +# last_rec_rule_id = rec_rule_id; +# } +# if (n_opt > 0) { +# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); +# } +# }; +# +# while (*pos) { +# if (*pos == '"') { // literal string +# pos++; +# last_sym_start = out_elements.size(); +# while (*pos != '"') { +# if (!*pos) { +# throw std::runtime_error("unexpected end of input"); +# } +# auto char_pair = parse_char(pos); +# pos = char_pair.second; +# out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); +# } +# pos = parse_space(pos + 1, is_nested); +# } else if (*pos == '[') { // char range(s) +# pos++; +# enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; +# if (*pos == '^') { +# pos++; +# start_type = LLAMA_GRETYPE_CHAR_NOT; +# } +# last_sym_start = out_elements.size(); +# while (*pos != ']') { +# if (!*pos) { +# throw std::runtime_error("unexpected end of input"); +# } +# auto char_pair = parse_char(pos); +# pos = char_pair.second; +# enum llama_gretype type = last_sym_start < out_elements.size() +# ? LLAMA_GRETYPE_CHAR_ALT +# : start_type; +# +# out_elements.push_back({type, char_pair.first}); +# if (pos[0] == '-' && pos[1] != ']') { +# if (!pos[1]) { +# throw std::runtime_error("unexpected end of input"); +# } +# auto endchar_pair = parse_char(pos + 1); +# pos = endchar_pair.second; +# out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); +# } +# } +# pos = parse_space(pos + 1, is_nested); +# } else if (is_word_char(*pos)) { // rule reference +# const char * name_end = parse_name(pos); +# uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); +# pos = parse_space(name_end, is_nested); +# last_sym_start = out_elements.size(); +# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); +# } else if (*pos == '(') { // grouping +# // parse nested alternates into synthesized rule +# pos = parse_space(pos + 1, true); +# uint32_t sub_rule_id = generate_symbol_id(state, rule_name); +# pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); +# last_sym_start = out_elements.size(); +# // output reference to synthesized rule +# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); +# if (*pos != ')') { +# throw std::runtime_error(std::string("expecting ')' at ") + pos); +# } +# pos = parse_space(pos + 1, is_nested); +# } else if (*pos == '.') { // any char +# last_sym_start = out_elements.size(); +# out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); +# pos = parse_space(pos + 1, is_nested); +# } else if (*pos == '*') { +# pos = parse_space(pos + 1, is_nested); +# handle_repetitions(0, -1); +# } else if (*pos == '+') { +# pos = parse_space(pos + 1, is_nested); +# handle_repetitions(1, -1); +# } else if (*pos == '?') { +# pos = parse_space(pos + 1, is_nested); +# handle_repetitions(0, 1); +# } else if (*pos == '{') { +# pos = parse_space(pos + 1, is_nested); +# +# if (!is_digit_char(*pos)) { +# throw std::runtime_error(std::string("expecting an int at ") + pos); +# } +# const char * int_end = parse_int(pos); +# int min_times = std::stoul(std::string(pos, int_end - pos)); +# pos = parse_space(int_end, is_nested); +# +# int max_times = -1; +# +# if (*pos == '}') { +# max_times = min_times; +# pos = parse_space(pos + 1, is_nested); +# } else if (*pos == ',') { +# pos = parse_space(pos + 1, is_nested); +# +# if (is_digit_char(*pos)) { +# const char * int_end = parse_int(pos); +# max_times = std::stoul(std::string(pos, int_end - pos)); +# pos = parse_space(int_end, is_nested); +# } +# +# if (*pos != '}') { +# throw std::runtime_error(std::string("expecting '}' at ") + pos); +# } +# pos = parse_space(pos + 1, is_nested); +# } else { +# throw std::runtime_error(std::string("expecting ',' at ") + pos); +# } +# handle_repetitions(min_times, max_times); +# } else { +# break; +# } +# } +# return pos; +# } def parse_sequence(state: ParseState, src: str, rule_name: str, out_elements: typing.List[GrammarElement], is_nested: bool) -> str: last_sym_start = len(out_elements) pos = src def handle_repetitions(min_times: int, max_times: int) -> None: - nonlocal last_sym_start - nonlocal pos - nonlocal out_elements + nonlocal state, src, rule_name, out_elements, is_nested, last_sym_start, pos if last_sym_start == len(out_elements): raise ValueError(f"expecting preceding item to */+/?/{{ at {pos}") previous_elements = out_elements[last_sym_start:] if min_times == 0: - out_elements = out_elements[:last_sym_start] + del out_elements[last_sym_start:] else: for i in range(1, min_times): out_elements.extend(previous_elements) + last_rec_rule_id = 0 n_opt = 1 if max_times < 0 else max_times - min_times - rec_rule = list(previous_elements) + rec_rule = previous_elements[:] for i in range(n_opt): rec_rule = rec_rule[:len(previous_elements)] rec_rule_id = generate_symbol_id(state, rule_name) @@ -191,21 +479,21 @@ def handle_repetitions(min_times: int, max_times: int) -> None: out_elements.append(GrammarElement(GrammarElementType.RULE_REF, last_rec_rule_id)) while pos: - if pos.startswith('"'): + if pos[0] == '"': pos = pos[1:] last_sym_start = len(out_elements) - while pos[0] != '"': + while not pos.startswith('"'): if not pos: raise ValueError("unexpected end of input") char, pos = parse_char(pos) out_elements.append(GrammarElement(GrammarElementType.CHAR, char)) pos = parse_space(pos[1:], is_nested) - elif pos.startswith("["): + elif pos[0] == "[": pos = pos[1:] start_type = GrammarElementType.CHAR if pos[0] == "^": - start_type = GrammarElementType.CHAR_NOT pos = pos[1:] + start_type = GrammarElementType.CHAR_NOT last_sym_start = len(out_elements) while pos[0] != "]": if not pos: @@ -219,7 +507,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None: endchar, pos = parse_char(pos[1:]) out_elements.append(GrammarElement(GrammarElementType.CHAR_RNG_UPPER, endchar)) pos = parse_space(pos[1:], is_nested) - elif is_word_char(pos[0]): + elif pos and is_word_char(pos[0]): name, rest = parse_name(pos) ref_rule_id = get_symbol_id(state, name) pos = parse_space(rest, is_nested) @@ -249,7 +537,8 @@ def handle_repetitions(min_times: int, max_times: int) -> None: handle_repetitions(0, 1) elif pos.startswith("{"): pos = parse_space(pos[1:], is_nested) - if not is_digit_char(pos[0]): + + if not is_digit_char(pos): raise ValueError(f"expecting an int at {pos}") min_times, pos = parse_int(pos) pos = parse_space(pos, is_nested) @@ -261,11 +550,14 @@ def handle_repetitions(min_times: int, max_times: int) -> None: pos = parse_space(pos[1:], is_nested) elif pos[0] == ",": pos = parse_space(pos[1:], is_nested) - if is_digit_char(pos[0]): + + if is_digit_char(pos): max_times, pos = parse_int(pos) pos = parse_space(pos, is_nested) + if pos[0] != "}": raise ValueError("expecting '}' at {}".format(pos)) + pos = parse_space(pos[1:], is_nested) else: raise ValueError(f"expecting ',' at {pos}") @@ -275,6 +567,23 @@ def handle_repetitions(min_times: int, max_times: int) -> None: return pos +# const char * parse_alternates( +# parse_state & state, +# const char * src, +# const std::string & rule_name, +# uint32_t rule_id, +# bool is_nested) { +# std::vector<llama_grammar_element> rule; +# const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); +# while (*pos == '|') { +# rule.push_back({LLAMA_GRETYPE_ALT, 0}); +# pos = parse_space(pos + 1, true); +# pos = parse_sequence(state, pos, rule_name, rule, is_nested); +# } +# rule.push_back({LLAMA_GRETYPE_END, 0}); +# add_rule(state, rule_id, rule); +# return pos; +# } def parse_alternates(state: ParseState, src: str, rule_name: str, rule_id: int, is_nested: bool) -> str: rule = [] pos = parse_sequence(state, src, rule_name, rule, is_nested) @@ -287,34 +596,86 @@ def parse_alternates(state: ParseState, src: str, rule_name: str, rule_id: int, return pos +# static const char * parse_rule(parse_state & state, const char * src) { +# const char * name_end = parse_name(src); +# const char * pos = parse_space(name_end, false); +# size_t name_len = name_end - src; +# uint32_t rule_id = get_symbol_id(state, src, name_len); +# const std::string name(src, name_len); +# +# if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { +# throw std::runtime_error(std::string("expecting ::= at ") + pos); +# } +# pos = parse_space(pos + 3, true); +# +# pos = parse_alternates(state, pos, name, rule_id, false); +# +# if (*pos == '\r') { +# pos += pos[1] == '\n' ? 2 : 1; +# } else if (*pos == '\n') { +# pos++; +# } else if (*pos) { +# throw std::runtime_error(std::string("expecting newline or end at ") + pos); +# } +# return parse_space(pos, true); +# } def parse_rule(state: ParseState, src: str) -> str: - name, s = parse_name(src) - s = parse_space(s, newline_ok=False) + pos = src + name, pos = parse_name(pos) + pos = parse_space(pos, newline_ok=False) rule_id = get_symbol_id(state, name) - if not s.startswith("::="): - raise ValueError(f"expecting ::= at {s}") - - s = s[3:] - - s = parse_space(s, newline_ok=True) - - s = parse_alternates(state, s, name, rule_id, is_nested=False) - - if s.startswith("\r"): - s = s[2:] if s[1] == "\n" else s[1:] - elif s.startswith("\n"): - s = s[1:] - elif s: - raise ValueError(f"expecting newline or end at {s}") - return parse_space(s, newline_ok=True) - - -def parse(gbnf: str) -> ParseState: + if not pos.startswith("::="): + raise ValueError(f"expecting ::= at {pos}") + + pos = parse_space(pos[3:], newline_ok=True) + + pos = parse_alternates(state, pos, name, rule_id, is_nested=False) + + if pos.startswith("\r"): + pos = pos[2:] if pos[1] == "\n" else pos[1:] + elif pos.startswith("\n"): + pos = pos[1:] + elif pos: + raise ValueError(f"expecting newline or end at {pos}") + return parse_space(pos, newline_ok=True) + + +# parse_state parse(const char * src) { +# try { +# parse_state state; +# const char * pos = parse_space(src, true); +# while (*pos) { +# pos = parse_rule(state, pos); +# } +# // Validate the state to ensure that all rules are defined +# for (const auto & rule : state.rules) { +# for (const auto & elem : rule) { +# if (elem.type == LLAMA_GRETYPE_RULE_REF) { +# // Ensure that the rule at that location exists +# if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { +# // Get the name of the rule that is missing +# for (const auto & kv : state.symbol_ids) { +# if (kv.second == elem.value) { +# throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); +# } +# } +# } +# } +# } +# } +# return state; +# } catch (const std::exception & err) { +# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); +# return parse_state(); +# } +# } +def parse(src: str) -> ParseState: state = ParseState() - s = parse_space(gbnf, newline_ok=True) - while s: - s = parse_rule(state, s) + pos = src + pos = parse_space(pos, newline_ok=True) + while pos: + pos = parse_rule(state, pos) # validate for rule in state.rules: for elem in rule: @@ -326,6 +687,16 @@ def parse(gbnf: str) -> ParseState: return state +# static bool is_char_element(llama_grammar_element elem) { +# switch (elem.type) { +# case LLAMA_GRETYPE_CHAR: return true; +# case LLAMA_GRETYPE_CHAR_NOT: return true; +# case LLAMA_GRETYPE_CHAR_ALT: return true; +# case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; +# case LLAMA_GRETYPE_CHAR_ANY: return true; +# default: return false; +# } +# } def is_char_element(elem: GrammarElement) -> bool: return elem.type in ( GrammarElementType.CHAR, @@ -343,6 +714,71 @@ def print_grammar_char(file: typing.TextIO, c: int) -> None: print(f"<U+{c:04X}>", end="", file=file) +# static void print_rule( +# FILE * file, +# uint32_t rule_id, +# const std::vector<llama_grammar_element> & rule, +# const std::map<uint32_t, std::string> & symbol_id_names) { +# if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { +# throw std::runtime_error( +# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); +# } +# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); +# for (size_t i = 0, end = rule.size() - 1; i < end; i++) { +# llama_grammar_element elem = rule[i]; +# switch (elem.type) { +# case LLAMA_GRETYPE_END: +# throw std::runtime_error( +# "unexpected end of rule: " + std::to_string(rule_id) + "," + +# std::to_string(i)); +# case LLAMA_GRETYPE_ALT: +# fprintf(file, "| "); +# break; +# case LLAMA_GRETYPE_RULE_REF: +# fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); +# break; +# case LLAMA_GRETYPE_CHAR: +# fprintf(file, "["); +# print_grammar_char(file, elem.value); +# break; +# case LLAMA_GRETYPE_CHAR_NOT: +# fprintf(file, "[^"); +# print_grammar_char(file, elem.value); +# break; +# case LLAMA_GRETYPE_CHAR_RNG_UPPER: +# if (i == 0 || !is_char_element(rule[i - 1])) { +# throw std::runtime_error( +# "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + +# std::to_string(rule_id) + "," + std::to_string(i)); +# } +# fprintf(file, "-"); +# print_grammar_char(file, elem.value); +# break; +# case LLAMA_GRETYPE_CHAR_ALT: +# if (i == 0 || !is_char_element(rule[i - 1])) { +# throw std::runtime_error( +# "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + +# std::to_string(rule_id) + "," + std::to_string(i)); +# } +# print_grammar_char(file, elem.value); +# break; +# case LLAMA_GRETYPE_CHAR_ANY: +# fprintf(file, "."); +# break; +# } +# if (is_char_element(elem)) { +# switch (rule[i + 1].type) { +# case LLAMA_GRETYPE_CHAR_ALT: +# case LLAMA_GRETYPE_CHAR_RNG_UPPER: +# case LLAMA_GRETYPE_CHAR_ANY: +# break; +# default: +# fprintf(file, "] "); +# } +# } +# } +# fprintf(file, "\n"); +# } def print_rule( file: typing.TextIO, rule_id: int, @@ -394,7 +830,6 @@ def print_grammar(file: typing.TextIO, state: ParseState) -> None: print(f"\nerror printing grammar: {err}", file=file) raise err -import ctypes class LlamaGrammar: def __init__(self, parse_state: ParseState): @@ -406,7 +841,7 @@ def __init__(self, parse_state: ParseState): self._element_lists = [ [ - llama_cpp.llama_grammar_element(ctypes.c_int(elem.type.value), ctypes.c_uint32(elem.value)) + llama_cpp.llama_grammar_element(ctypes.c_int(elem.type), ctypes.c_uint32(elem.value)) for elem in subvector ] for subvector in self._grammar_rules @@ -455,6 +890,7 @@ def reset(self): @classmethod def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": parsed_grammar = parse(grammar) + print_grammar(file=sys.stdout, state=parsed_grammar) return cls(parsed_grammar) @classmethod