From 45d2252153cd72fbaef273705d8b6fc848845f2a Mon Sep 17 00:00:00 2001
From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com>
Date: Mon, 29 Jul 2024 14:35:58 +0200
Subject: [PATCH 1/8] Backported . (any chat) from llama.cpp

---
 llama_cpp/llama_grammar.py | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index 0ac7354bb..01798e74f 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -432,10 +432,12 @@ def end(self) -> "std.map[T, U].iterator[T, U]":
 #     // be an inclusive range ([a-z])
 #     LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
 
-
 #     // modifies a preceding LLAMA_GRETYPE_CHAR or
 #     // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
 #     LLAMA_GRETYPE_CHAR_ALT       = 6,
+
+#     // any character (.)
+#     LLAMA_GRETYPE_CHAR_ANY       = 7,
 # };
 class llama_gretype(Enum):
     """grammar element type"""
@@ -447,6 +449,7 @@ class llama_gretype(Enum):
     LLAMA_GRETYPE_CHAR_NOT = 4  # inverse char(s) ([^a], [^a-b] [^abc])
     LLAMA_GRETYPE_CHAR_RNG_UPPER = 5  # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z])
     LLAMA_GRETYPE_CHAR_ALT = 6  # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
+    LLAMA_GRETYPE_CHAR_ANY = 7  # any character (.)
 
 
 # struct parse_state {
@@ -830,6 +833,10 @@ def parse_sequence(
         #     if (last_sym_start == out_elements.size()) {
         #         throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
         #     }
+        elif pos[0] == '.':
+            last_sym_start = out_elements.size()
+            out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR_ANY, 0))
+            pos = parse_space(pos + 1, is_nested)
         elif pos[0] in ("*", "+", "?"):  # repetition operator
             if last_sym_start == out_elements.size():
                 raise RuntimeError("expecting preceding item to */+/? at " + str(pos))
@@ -1039,6 +1046,7 @@ def is_char_element(elem: LlamaGrammarElement) -> bool:
         llama_gretype.LLAMA_GRETYPE_CHAR_NOT,
         llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
         llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
+        llama_gretype.LLAMA_GRETYPE_CHAR_ANY,
     )
 
 
@@ -1135,6 +1143,8 @@ def print_rule(
                     + str(i)
                 )
             print_grammar_char(file, elem.value)
+        elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ANY:
+            print(".", file=file, end="")
         # if (is_char_element(elem)) {
         #     switch (rule[i + 1].type) {
         #         case LLAMA_GRETYPE_CHAR_ALT:

From 90c2bc4922265c4cac9266b7550ed9e41f1fd477 Mon Sep 17 00:00:00 2001
From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com>
Date: Mon, 29 Jul 2024 15:25:09 +0200
Subject: [PATCH 2/8] unfinished {count,optionalmax)

---
 llama_cpp/llama_grammar.py | 295 ++++++++++++++++++++++++++++---------
 1 file changed, 224 insertions(+), 71 deletions(-)

diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index 01798e74f..39ba8e9dc 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -563,11 +563,34 @@ def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]:
     return value, pos
 
 
+"""#static bool is_digit_char(char c) {
+#    return '0' <= c && c <= '9';
+#}
+def is_digit_char(c: str) -> bool:
+    return "0" <= c <= "9"
+
+
 # bool is_word_char(char c) {
 #     return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
 # }
 def is_word_char(c: str) -> bool:
-    return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")
+    return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") or is_digit_char(c)
+"""
+
+##optimized version
+# Original is_digit_char time: 2.868295 seconds
+# Optimized is_digit_char time: 1.993195 seconds
+# Original is_word_char time: 3.856689 seconds
+# Optimized is_word_char time: 2.052832 seconds
+
+digit_chars = set("0123456789")
+word_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-0123456789")
+
+def is_digit_char(c: str) -> bool:
+    return c in digit_chars
+
+def is_word_char(c: str) -> bool:
+    return c in word_chars
 
 
 # std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
@@ -679,6 +702,25 @@ def parse_name(src: const_char_p) -> const_char_p:
     return pos
 
 
+#static const char * parse_int(const char * src) {
+#    const char * pos = src;
+#    while (is_digit_char(*pos)) {
+#        pos++;
+#    }
+#    if (pos == src) {
+#        throw std::runtime_error(std::string("expecting integer at ") + src);
+#    }
+#    return pos;
+#}
+def parse_int(src: const_char_p) -> const_char_p:
+    pos = const_char_p(src)  # type: const_char_p
+    while is_digit_char(pos[0]):
+        pos += 1
+    if pos == src:
+        raise RuntimeError("expecting integer at " + str(src))
+    return pos
+
+
 # const char * parse_space(const char * src, bool newline_ok) {
 #     const char * pos = src;
 #     while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
@@ -721,8 +763,104 @@ def parse_sequence(
     # const char * pos = src;
     last_sym_start = out_elements.size()  # type: int
     pos = const_char_p(src)  # type: const_char_p
+    
+    
+    # auto handle_repetitions = [&](int min_times, int max_times) {
+
+    #     if (last_sym_start == out_elements.size()) {
+    #         throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
+    #     }
+
+    #     // apply transformation to previous symbol (last_sym_start to end) according to
+    #     // the following rewrite rules:
+    #     // S{m,n} --> S S S (m times) S'(n-m)
+    #     //            S'(x)   ::= S S'(x-1) |
+    #     //            (... n-m definitions of these S' rules ...)
+    #     //            S'(1)   ::= S |
+    #     // S{m,} -->  S S S (m times) S'
+    #     //            S'     ::= S S' |
+    #     // S*     --> S{0,}
+    #     //        --> S'     ::= S S' |
+    #     // S+     --> S{1,}
+    #     //        --> S S'
+    #     //            S'     ::= S S' |
+    #     // S?     --> S{0,1}
+    #     //        --> S'
+    #     //            S'     ::= S |
+
+    #     std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
+    #     if (min_times == 0) {
+    #         out_elements.resize(last_sym_start);
+    #     } else {
+    #         // Repeat the previous elements (min_times - 1) times
+    #         for (int i = 1; i < min_times; i++) {
+    #             out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
+    #         }
+    #     }
+
+    #     uint32_t last_rec_rule_id = 0;
+    #     auto n_opt = max_times < 0 ? 1 : max_times - min_times;
+
+    #     std::vector<llama_grammar_element> rec_rule(previous_elements);
+    #     for (int i = 0; i < n_opt; i++) {
+    #         rec_rule.resize(previous_elements.size());
+    #         uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
+    #         if (i > 0 || max_times < 0) {
+    #             rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
+    #         }
+    #         rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+    #         rec_rule.push_back({LLAMA_GRETYPE_END, 0});
+    #         add_rule(state, rec_rule_id, rec_rule);
+    #         last_rec_rule_id = rec_rule_id;
+    #     }
+    #     if (n_opt > 0) {
+    #         out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+    #     }
+    # };
+    
+    def handle_repetitions(min_times: int, max_times: int) -> None:
+        if last_sym_start == out_elements.size():
+            raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos))
+
+
+        previous_elements = out_elements[last_sym_start:]
+        print("type-1 ", type(out_elements))
+        if min_times == 0:
+            out_elements.resize(last_sym_start)
+        else:
+            # Repeat the previous elements (min_times - 1) times
+            for i in range(1, min_times):
+                out_elements.extend(previous_elements)
+        
+        last_rec_rule_id = 0  # type: int
+        n_opt = 1 if max_times < 0 else max_times - min_times  # type: int
+        rec_rule = previous_elements  # type: List[LlamaGrammarElement]
+        print("type1", type(rec_rule))
+        print('ahhhhhhhhh')
+        for i in range(n_opt):
+            rec_rule = previous_elements
+            rec_rule.resize(len(previous_elements))
+            rec_rule_id = generate_symbol_id(state, rule_name)  # type: int
+            if i > 0 or max_times < 0:
+                rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id))
+            rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
+            rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
+            add_rule(state, rec_rule_id, rec_rule)
+            
+            last_rec_rule_id = rec_rule_id
+        if n_opt > 0:
+            out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id))
+    
+    
+    
+    
+    
+    
+    
+    
     # while (*pos) {
     while pos[0]:
+
         # if (*pos == '"') { // literal string
         #     pos++;
         #     last_sym_start = out_elements.size();
@@ -767,12 +905,12 @@ def parse_sequence(
             while pos[0] != "]":
                 char_pair = parse_char(pos)  # type: Tuple[int, const_char_p]
                 pos = char_pair[1]
-                type = (
+                _type = (
                     llama_gretype.LLAMA_GRETYPE_CHAR_ALT
                     if last_sym_start < out_elements.size()
                     else start_type
                 )  # type: llama_gretype
-                out_elements.push_back(LlamaGrammarElement(type, char_pair[0]))
+                out_elements.push_back(LlamaGrammarElement(_type, char_pair[0]))
                 #     if (pos[0] == '-' && pos[1] != ']') {
                 #         auto endchar_pair = parse_char(pos + 1);
                 #                 pos          = endchar_pair.second;
@@ -829,83 +967,98 @@ def parse_sequence(
             if pos[0] != ")":
                 raise RuntimeError("expecting ')' at " + str(pos))
             pos = parse_space(pos + 1, is_nested)
-        # } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
-        #     if (last_sym_start == out_elements.size()) {
-        #         throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
-        #     }
         elif pos[0] == '.':
             last_sym_start = out_elements.size()
             out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR_ANY, 0))
             pos = parse_space(pos + 1, is_nested)
-        elif pos[0] in ("*", "+", "?"):  # repetition operator
-            if last_sym_start == out_elements.size():
-                raise RuntimeError("expecting preceding item to */+/? at " + str(pos))
-            # // apply transformation to previous symbol (last_sym_start to end) according to
-            # // rewrite rules:
-            # // S* --> S' ::= S S' |
-            # // S+ --> S' ::= S S' | S
-            # // S? --> S' ::= S |
-            # uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
-            # std::vector<llama_grammar_element> sub_rule;
-            # // add preceding symbol to generated rule
-            # sub_rule.insert(
-            #     sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
-            sub_rule_id = generate_symbol_id(state, rule_name)  # type: int
-            sub_rule = std.vector[
-                LlamaGrammarElement
-            ]()  # type: std.vector[LlamaGrammarElement]
-            sub_rule.insert(
-                sub_rule.end(),
-                out_elements.begin() + last_sym_start,
-                out_elements.end(),
-            )
-            # if (*pos == '*' || *pos == '+') {
-            #     // cause generated rule to recurse
-            #     sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
-            # }
-            # // mark start of alternate def
-            # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
-            if pos[0] in ("*", "+"):
-                sub_rule.push_back(
-                    LlamaGrammarElement(
-                        llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id
-                    )
-                )
-            sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
-            # if (*pos == '+') {
-            #     // add preceding symbol as alternate only for '+' (otherwise empty)
-            #     sub_rule.insert(
-            #         sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
-            # }
-            # sub_rule.push_back({LLAMA_GRETYPE_END, 0});
-            # add_rule(state, sub_rule_id, sub_rule);
-            # // in original rule, replace previous symbol with reference to generated rule
-            # out_elements.resize(last_sym_start);
-            # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
-            # pos = parse_space(pos + 1, is_nested);
-            if pos[0] == "+":
-                # add preceding symbol as alternate only for '+' (otherwise empty)
-                sub_rule.insert(
-                    sub_rule.end(),
-                    out_elements.begin() + last_sym_start,
-                    out_elements.end(),
-                )
-            sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
-            add_rule(state, sub_rule_id, sub_rule)
-            # in original rule, replace previous symbol with reference to generated rule
-            out_elements.resize(last_sym_start)
-            out_elements.push_back(
-                LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
-            )
-            pos = parse_space(pos + 1, is_nested)
+           
+        # } else if (*pos == '*') {
+        #         pos = parse_space(pos + 1, is_nested);
+        #         handle_repetitions(0, -1);
+        # } else if (*pos == '+') {
+        #     pos = parse_space(pos + 1, is_nested);
+        #     handle_repetitions(1, -1);
+        # } else if (*pos == '?') {
+        #     pos = parse_space(pos + 1, is_nested);
+        #     handle_repetitions(0, 1);
+        # } else if (*pos == '{') {
+        #     pos = parse_space(pos + 1, is_nested);
+
+        #     if (!is_digit_char(*pos)) {
+        #         throw std::runtime_error(std::string("expecting an int at ") + pos);
+        #     }
+        #     const char * int_end = parse_int(pos);
+        #     int min_times = std::stoul(std::string(pos, int_end - pos));
+        #     pos = parse_space(int_end, is_nested);
+
+        #     int max_times = -1;
+
+        #     if (*pos == '}') {
+        #         max_times = min_times;
+        #         pos = parse_space(pos + 1, is_nested);
+        #     } else if (*pos == ',') {
+        #         pos = parse_space(pos + 1, is_nested);
+
+        #         if (is_digit_char(*pos)) {
+        #             const char * int_end = parse_int(pos);
+        #             max_times = std::stoul(std::string(pos, int_end - pos));
+        #             pos = parse_space(int_end, is_nested);
+        #         }
+
+        #         if (*pos != '}') {
+        #             throw std::runtime_error(std::string("expecting '}' at ") + pos);
+        #         }
+        #         pos = parse_space(pos + 1, is_nested);
+        #     } else {
+        #         throw std::runtime_error(std::string("expecting ',' at ") + pos);
+        #     }
+        #     handle_repetitions(min_times, max_times);
         # } else {
         #     break;
         # }
+        elif pos[0] == "*":
+            pos = parse_space(pos + 1, is_nested)
+            handle_repetitions(0, -1)
+        elif pos[0] == "+":
+            pos = parse_space(pos + 1, is_nested)
+            handle_repetitions(1, -1)
+        elif pos[0] == "?":
+            pos = parse_space(pos + 1, is_nested)
+            handle_repetitions(0, 1)
+        elif pos[0] == "{":
+            pos = parse_space(pos + 1, is_nested)
+            
+            if not is_digit_char(pos[0]):
+                raise RuntimeError("expecting an int at " + str(pos))
+            
+
+            
+            int_end = parse_int(pos)
+            min_times = int(str(pos)[:int_end - pos])
+            pos = parse_space(int_end, is_nested)
+            max_times = -1
+            if pos[0] == "}":
+                max_times = min_times
+                pos = parse_space(pos + 1, is_nested)
+            elif pos[0] == ",":
+                pos = parse_space(pos + 1, is_nested)
+                if is_digit_char(pos[0]):
+                    int_end = parse_int(pos)
+                    max_times = int(str(pos)[:int_end - pos])
+                    pos = parse_space(int_end, is_nested)
+                if pos[0] != "}":
+                    raise RuntimeError("expecting '}' at " + str(pos))
+                pos = parse_space(pos + 1, is_nested)
+            else:
+                raise RuntimeError("expecting ',' at " + str(pos))
+            handle_repetitions(min_times, max_times)
+            
         else:
             break
-    #     }
-    #     return pos;
-    # }
+        
+        
+        
+
     return pos
 
 

From 5c050e82b2f8b9ce0f9606a7e70f0944578a6bcd Mon Sep 17 00:00:00 2001
From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com>
Date: Mon, 29 Jul 2024 15:48:19 +0200
Subject: [PATCH 3/8] implemented slice function in std:vector

---
 llama_cpp/llama_grammar.py | 30 ++++++++++++++++++++++++------
 1 file changed, 24 insertions(+), 6 deletions(-)

diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index 39ba8e9dc..bda3b162f 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -244,7 +244,8 @@ def __add__(self, value: int) -> "std.vector[T].iterator":
             def __sub__(self, value: int) -> "std.vector[T].iterator":
                 return self.__class__(self._vector, self._index - value)
 
-        def __init__(self):
+        def __init__(self, *args, **kwargs):
+            super().__init__(*args, **kwargs)
             self._version = 0
 
         def modify(self):
@@ -309,7 +310,7 @@ def insert(
             first: "std.vector[T].iterator",
             last: "std.vector[T].iterator",
         ) -> None:
-            self[pos._index : pos._index] = list(
+            self[pos._index:pos._index] = list(
                 islice(first._vector, first._index, last._index)
             )
 
@@ -319,6 +320,24 @@ def begin(self) -> "std.vector[T].iterator":
         def end(self) -> "std.vector[T].iterator":
             return self.iterator(self, self.size())
 
+        def __getitem__(self, index):
+            if isinstance(index, slice):
+                return std.vector(super().__getitem__(index))
+            return super().__getitem__(index)
+
+        def __setitem__(self, index, value):
+            self.modify()
+            if isinstance(index, slice):
+                if isinstance(value, std.vector):
+                    value = list(value)
+                super().__setitem__(index, value)
+            else:
+                super().__setitem__(index, value)
+
+        def __delitem__(self, index):
+            self.modify()
+            super().__delitem__(index)
+
     class map(Generic[T, U], OrderedDict[T, U]):
         """C++ implementation of std::map."""
 
@@ -410,7 +429,6 @@ def begin(self) -> "std.map[T, U].iterator[T, U]":
         def end(self) -> "std.map[T, U].iterator[T, U]":
             return self.iterator(self, Sentinel())
 
-
 # // grammar element type
 # enum llama_gretype {
 #     // end of rule definition
@@ -824,7 +842,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
 
 
         previous_elements = out_elements[last_sym_start:]
-        print("type-1 ", type(out_elements))
+
         if min_times == 0:
             out_elements.resize(last_sym_start)
         else:
@@ -835,8 +853,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
         last_rec_rule_id = 0  # type: int
         n_opt = 1 if max_times < 0 else max_times - min_times  # type: int
         rec_rule = previous_elements  # type: List[LlamaGrammarElement]
-        print("type1", type(rec_rule))
-        print('ahhhhhhhhh')
+
         for i in range(n_opt):
             rec_rule = previous_elements
             rec_rule.resize(len(previous_elements))
@@ -1263,6 +1280,7 @@ def print_rule(
     #                 print_grammar_char(file, elem.value);
     #                 break;
     #         }
+    
     for i, elem in enumerate(rule[:-1]):
         case = elem.type  # type: llama_gretype
         if case is llama_gretype.LLAMA_GRETYPE_END:

From 4c74a82281872fdce86bcc135c80d45e5703a986 Mon Sep 17 00:00:00 2001
From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com>
Date: Mon, 29 Jul 2024 16:31:24 +0200
Subject: [PATCH 4/8] fixed mistake done while reading

---
 llama_cpp/llama_grammar.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index bda3b162f..1fae0e688 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -841,7 +841,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
             raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos))
 
 
-        previous_elements = out_elements[last_sym_start:]
+        previous_elements:std.vector[LlamaGrammarElement] = out_elements[last_sym_start:out_elements.size()]
 
         if min_times == 0:
             out_elements.resize(last_sym_start)
@@ -859,12 +859,12 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
             rec_rule.resize(len(previous_elements))
             rec_rule_id = generate_symbol_id(state, rule_name)  # type: int
             if i > 0 or max_times < 0:
-                rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id))
+                rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id))
             rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
             rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
             add_rule(state, rec_rule_id, rec_rule)
-            
             last_rec_rule_id = rec_rule_id
+
         if n_opt > 0:
             out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id))
     
@@ -1058,6 +1058,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
                 max_times = min_times
                 pos = parse_space(pos + 1, is_nested)
             elif pos[0] == ",":
+            
                 pos = parse_space(pos + 1, is_nested)
                 if is_digit_char(pos[0]):
                     int_end = parse_int(pos)
@@ -1281,6 +1282,7 @@ def print_rule(
     #                 break;
     #         }
     
+
     for i, elem in enumerate(rule[:-1]):
         case = elem.type  # type: llama_gretype
         if case is llama_gretype.LLAMA_GRETYPE_END:

From 1fd884069f8b081f3b4258ce844ee5988755bf18 Mon Sep 17 00:00:00 2001
From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com>
Date: Mon, 29 Jul 2024 16:50:27 +0200
Subject: [PATCH 5/8] ported https://github.com/ggerganov/llama.cpp/pull/7194

---
 llama_cpp/llama_grammar.py | 58 +++++++++++++++++++++++++++++---------
 1 file changed, 44 insertions(+), 14 deletions(-)

diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index 1fae0e688..fded061b0 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -891,6 +891,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
             pos += 1
             last_sym_start = out_elements.size()
             while pos[0] != '"':
+                assert pos[0] is not None, "Unexpected end of input"
                 char_pair = parse_char(pos)  # type: Tuple[int, const_char_p]
                 pos = char_pair[1]
                 out_elements.push_back(
@@ -920,6 +921,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
             #         : start_type;
             #     out_elements.push_back({type, char_pair.first});
             while pos[0] != "]":
+                assert pos[0] is not None, "Unexpected end of input"
                 char_pair = parse_char(pos)  # type: Tuple[int, const_char_p]
                 pos = char_pair[1]
                 _type = (
@@ -935,6 +937,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
                 #     }
                 # }
                 if pos[0] == "-" and pos[1] != "]":
+                    assert pos[1] is not None, "Unexpected end of input"
                     endchar_pair = parse_char(pos + 1)  # type: Tuple[int, const_char_p]
                     pos = endchar_pair[1]
                     out_elements.push_back(
@@ -1159,33 +1162,59 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p:
     elif pos[0]:
         raise RuntimeError("expecting newline or end at " + str(pos))
     return parse_space(pos, True)
+    
+#parse_state parse(const char * src) {
+#    try {
+#        parse_state state;
+#        const char * pos = parse_space(src, true);
+#        while (*pos) {
+#            pos = parse_rule(state, pos);
+#        }
+#        // Validate the state to ensure that all rules are defined
+#        for (const auto & rule : state.rules) {
+#            for (const auto & elem : rule) {
+#                if (elem.type == LLAMA_GRETYPE_RULE_REF) {
+#                    // Ensure that the rule at that location exists
+#                    if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
+#                        // Get the name of the rule that is missing
+#                        for (const auto & kv : state.symbol_ids) {
+#                            if (kv.second == elem.value) {
+#                                throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
+#                            }
+#                        }
+#                    }
+#                }
+#            }
+#        }
+#        return state;
+#    } catch (const std::exception & err) {
+#        fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+#        return parse_state();
+#    }
+#}
 
 
-# parse_state parse(const char * src) {
-#     try {
-#         parse_state state;
-#         const char * pos = parse_space(src, true);
-#         while (*pos) {
-#             pos = parse_rule(state, pos);
-#         }
-#         return state;
-#     } catch (const std::exception & err) {
-#         fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
-#         return parse_state();
-#     }
-# }
 def parse(src: const_char_p) -> parse_state:
     try:
         state = parse_state()  # type: parse_state
         pos = parse_space(src, True)  # type: const_char_p
         while pos[0]:
             pos = parse_rule(state, pos)
+        # Validate the state to ensure that all rules are defined
+        for rule in state.rules:
+            for elem in rule:
+                if elem.type == llama_gretype.LLAMA_GRETYPE_RULE_REF:
+                    # Ensure that the rule at that location exists
+                    if elem.value >= len(state.rules) or not state.rules[elem.value]:
+                        # Get the name of the rule that is missing
+                        for kv in state.symbol_ids:
+                            if kv.second == elem.value:
+                                raise RuntimeError("Undefined rule identifier '" + kv.first + "'")
         return state
     except Exception as err:
         print(f"{parse.__name__}: error parsing grammar: {err}")
         return parse_state()
 
-
 # void print_grammar_char(FILE * file, uint32_t c) {
 #     if (0x20 <= c && c <= 0x7f) {
 #         fprintf(file, "%c", static_cast<char>(c));
@@ -1283,6 +1312,7 @@ def print_rule(
     #         }
     
 
+
     for i, elem in enumerate(rule[:-1]):
         case = elem.type  # type: llama_gretype
         if case is llama_gretype.LLAMA_GRETYPE_END:

From 81cf9092225dd6719682ba7d31084e61d7b66fd9 Mon Sep 17 00:00:00 2001
From: ExtReMLapin <3909752+ExtReMLapin@users.noreply.github.com>
Date: Tue, 30 Jul 2024 07:44:19 +0200
Subject: [PATCH 6/8] multiple fixes, var copy

---
 llama_cpp/llama_grammar.py | 23 ++++++++---------------
 1 file changed, 8 insertions(+), 15 deletions(-)

diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index fded061b0..8fd679b15 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -523,7 +523,6 @@ def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int:
     result = state.symbol_ids.insert(std.string(src, len), next_id)
     return result[0].second  # type: ignore
 
-
 # uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
 #     uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
 #     state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
@@ -841,23 +840,22 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
             raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos))
 
 
-        previous_elements:std.vector[LlamaGrammarElement] = out_elements[last_sym_start:out_elements.size()]
+        previous_elements:std.vector[LlamaGrammarElement] = std.vector(out_elements[last_sym_start:])  # type: std.vector[LlamaGrammarElement]
 
         if min_times == 0:
             out_elements.resize(last_sym_start)
         else:
             # Repeat the previous elements (min_times - 1) times
             for i in range(1, min_times):
-                out_elements.extend(previous_elements)
+                out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end())
         
         last_rec_rule_id = 0  # type: int
         n_opt = 1 if max_times < 0 else max_times - min_times  # type: int
-        rec_rule = previous_elements  # type: List[LlamaGrammarElement]
+        rec_rule = std.vector(previous_elements)  # type: List[LlamaGrammarElement]
 
         for i in range(n_opt):
-            rec_rule = previous_elements
-            rec_rule.resize(len(previous_elements))
-            rec_rule_id = generate_symbol_id(state, rule_name)  # type: int
+            rec_rule.resize(previous_elements.size())
+            rec_rule_id = generate_symbol_id(state, rule_name) # type: int
             if i > 0 or max_times < 0:
                 rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id))
             rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
@@ -868,12 +866,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
         if n_opt > 0:
             out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id))
     
-    
-    
-    
-    
-    
-    
+
     
     # while (*pos) {
     while pos[0]:
@@ -1208,8 +1201,8 @@ def parse(src: const_char_p) -> parse_state:
                     if elem.value >= len(state.rules) or not state.rules[elem.value]:
                         # Get the name of the rule that is missing
                         for kv in state.symbol_ids:
-                            if kv.second == elem.value:
-                                raise RuntimeError("Undefined rule identifier '" + kv.first + "'")
+                            if kv[1] == elem.value:
+                                raise RuntimeError("Undefined rule identifier '" + kv[0] + "'")
         return state
     except Exception as err:
         print(f"{parse.__name__}: error parsing grammar: {err}")

From 6d53877fe81e350768cc41433817e91a1d532261 Mon Sep 17 00:00:00 2001
From: Andrei Betlen <abetlen@gmail.com>
Date: Sun, 4 Aug 2024 17:14:41 -0400
Subject: [PATCH 7/8] Rewrite LlamaGrammar internals in python style

---
 llama_cpp/llama_cpp.py     |    2 +-
 llama_cpp/llama_grammar.py | 1632 ++++++++----------------------------
 2 files changed, 351 insertions(+), 1283 deletions(-)

diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py
index 727195cf5..d598cf9f1 100644
--- a/llama_cpp/llama_cpp.py
+++ b/llama_cpp/llama_cpp.py
@@ -3002,7 +3002,7 @@ def llama_grammar_init(
     n_rules: Union[ctypes.c_size_t, int],
     start_rule_index: Union[ctypes.c_size_t, int],
     /,
-) -> llama_grammar_p:
+) -> Optional[llama_grammar_p]:
     """Initialize a grammar from a set of rules."""
     ...
 
diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index 8fd679b15..4d48e34ab 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -1,644 +1,93 @@
 """Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
 
 # flake8: noqa
-from pathlib import Path
-import sys
-from ctypes import *  # type: ignore
-from enum import Enum
-from itertools import islice, groupby
+from itertools import groupby
 from typing import (
     Any,
-    Callable,
-    Dict,
     Set,
-    Generic,
     List,
     Optional,
-    OrderedDict,
-    TextIO,
     Tuple,
-    TypeVar,
     Union,
-    overload,
 )
 
-import llama_cpp.llama_cpp as llama_cpp
-
-# Type aliases
-llama_grammar_element = llama_cpp.llama_grammar_element
-llama_grammar_element_p = llama_cpp.llama_grammar_element_p
-llama_grammar_p = llama_cpp.llama_grammar_p
-
-# Type variables
-Ptr = TypeVar("Ptr", bound="const_char_p")
-T = TypeVar("T")
-U = TypeVar("U")
-V = TypeVar("V")
-W = TypeVar("W")
-
-
-class Sentinel:
-    """Used to mark the end of a iterator of std::vector & std::map."""
-
-
-class LlamaGrammar:
-    """Keeps reference counts of all the arguments, so that they are not
-    garbage collected by Python."""
-
-    def __del__(self) -> None:
-        """Free the grammar pointer when the object is deleted."""
-        if self.grammar is not None:
-            llama_cpp.llama_grammar_free(self.grammar)
-            self.grammar = None
-
-    def __init__(
-        self,
-        parsed_grammar: "parse_state",
-    ) -> None:
-        """Initialize the grammar pointer from the parsed state."""
-        self._grammar_rules = (
-            parsed_grammar.c_rules()
-        )  # type: std.vector[std.vector[LlamaGrammarElement]]
-        self._n_rules = self._grammar_rules.size()  # type: int
-        self._start_rule_index = parsed_grammar.symbol_ids.at("root")  # type: int
-        self.init()
-
-    @classmethod
-    def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
-        """Convert a GBNF grammar to a Llama grammar."""
-        parsed_grammar = parse(const_char_p(grammar))  # type: parse_state
-        if parsed_grammar.rules.empty():
-            raise ValueError(
-                f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty"
-            )
-        if verbose:
-            print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
-            print_grammar(sys.stderr, parsed_grammar)
-            print(file=sys.stderr)
-        return cls(parsed_grammar)
-
-    @classmethod
-    def from_json_schema(
-        cls,
-        json_schema: str,
-        verbose: bool = True,
-    ) -> "LlamaGrammar":
-        """Convert a JSON schema to a Llama grammar."""
-        return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)
-
-    @classmethod
-    def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
-        try:
-            with open(file) as f:
-                grammar = f.read()
-        except Exception as err:
-            raise Exception(
-                f"{cls.from_file.__name__}: error reading grammar file: {err}"
-            )
-
-        if grammar:
-            return cls.from_string(grammar, verbose=verbose)
-
-        raise ValueError(
-            f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
-        )
-
-    def init(self) -> None:
-        # Step 1: Convert LlamaGrammarElement to llama_grammar_element
-        self._element_lists = [
-            [
-                llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value))
-                for elem in subvector
-            ]
-            for subvector in self._grammar_rules
-        ]  # type: List[List[llama_grammar_element]]
-
-        # Step 2: Convert each list to llama_grammar_element array and get pointer
-        self._element_arrays = [
-            (llama_grammar_element * len(sublist))(*sublist)
-            for sublist in self._element_lists
-        ]  # type: List[Array[llama_grammar_element]]
-
-        # Step 3: Get pointer of each array
-        self._element_array_pointers = [
-            cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays
-        ]  # type: List[llama_grammar_element_p]
-
-        # Step 4: Make array of these pointers and get its pointer
-        self._rules = (llama_grammar_element_p * len(self._element_array_pointers))(
-            *self._element_array_pointers
-        )
-        self.grammar = llama_cpp.llama_grammar_init(
-            self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index)
-        )
-
-    def reset(self) -> None:
-        if self.grammar is not None:
-            llama_cpp.llama_grammar_free(self.grammar)
-        self.init()
-
-
-class LlamaGrammarElement:
-    def __init__(self, type: "llama_gretype", value: int):
-        self.type = type
-        self.value = value  # Unicode code point or rule ID
-
-
-class const_char_p:
-    """C++ implementation of const char *."""
-
-    def __init__(self, value: Union[str, Ptr], move: Optional[int] = None):
-        if isinstance(value, const_char_p):
-            # We're copying an existing const_char_p
-            self.value = value.value
-            self.pos = value.pos + (move or 0)
-            return
-
-        # We're creating a new const_char_p
-        self.value = value
-        self.pos = move or 0
-
-    def __str__(self) -> str:
-        assert self.value is not None, "null pointer"
-        return self.value[self.pos :]
-
-    def __getitem__(self, index: int) -> str:
-        value = str(self)
-        return value[index] if index < len(value) else ""
-
-    @overload
-    def __add__(self: Ptr, other: int) -> Ptr: ...
-
-    @overload
-    def __add__(self: Ptr, other: Ptr) -> int: ...
-
-    def __add__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]:
-        return (
-            self.__class__(self.value, self.pos + other)
-            if isinstance(other, int)
-            else self.pos + other.pos
-        )
-
-    @overload
-    def __sub__(self: Ptr, other: int) -> Ptr: ...
-
-    @overload
-    def __sub__(self: Ptr, other: Ptr) -> int: ...
-
-    def __sub__(self: Ptr, other: Union[int, Ptr]) -> Union[int, Ptr]:
-        return (
-            self.__class__(self.value, self.pos - other)
-            if isinstance(other, int)
-            else self.pos - other.pos
-        )
-
-    def __eq__(self: Ptr, other: Ptr) -> bool:
-        assert self.value == other.value, "comparing pointers from different strings"
-        return self.pos == other.pos
-
-    def __lt__(self: Ptr, other: Ptr) -> bool:
-        assert self.value == other.value, "comparing pointers from different strings"
-        return self.pos < other.pos
-
-    def __gt__(self: Ptr, other: Ptr) -> bool:
-        assert self.value == other.value, "comparing pointers from different strings"
-        return self.pos > other.pos
-
-
-class std:
-    @staticmethod
-    def string(ptr: const_char_p, length: Optional[int] = None) -> str:
-        """C++ implementation of std::string constructor."""
-        value = str(ptr)
-        if length is not None:
-            value = value[:length]
-        return value
-
-    class vector(Generic[T], List[T]):
-        """C++ implementation of std::vector."""
-
-        class iterator:
-            def __init__(self, vector: "std.vector[T]", index: int):
-                self._vector = vector
-                self._index = index
-                self._version = vector._version
-
-            def _check_version(self):
-                if self._version != self._vector._version:
-                    raise RuntimeError("Iterator used after vector was modified.")
-
-            def __iter__(self):
-                return self
-
-            def __next__(self) -> T:
-                self._check_version()
-                if self._index >= self._vector.size():
-                    raise StopIteration
-                value = self._vector[self._index]
-                self._index += 1
-                return value
-
-            def __add__(self, value: int) -> "std.vector[T].iterator":
-                return self.__class__(self._vector, self._index + value)
-
-            def __sub__(self, value: int) -> "std.vector[T].iterator":
-                return self.__class__(self._vector, self._index - value)
-
-        def __init__(self, *args, **kwargs):
-            super().__init__(*args, **kwargs)
-            self._version = 0
-
-        def modify(self):
-            # This is a bit of a hack to make sure iterators are invalidated
-            self._version += 1
-
-        def push_back(self, value: T) -> None:
-            self.modify()
-            self.append(value)
-
-        def pop_back(self) -> None:
-            self.modify()
-            if not self.empty():
-                self.pop()
-
-        def back(self) -> T:
-            return self[-1]
-
-        def size(self) -> int:
-            return len(self)
-
-        def clear(self) -> None:
-            self.modify()
-            super().clear()
-
-        def empty(self) -> bool:
-            return self.size() == 0
-
-        def data(self) -> "std.vector[T]":
-            return self
-
-        def resize(
-            self,
-            new_size: int,
-            fill_value_factory: Optional[Callable[[], T]] = None,
-        ) -> None:
-            if new_size > self.size():
-                if fill_value_factory is None:
-                    raise ValueError("A fill value factory function must be provided.")
-                self.reserve(new_size, fill_value_factory)
-            elif new_size < self.size():
-                self[:] = self[:new_size]
-
-        def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None:
-            if capacity > self.size():
-                fill_value = fill_value_factory()
-                self.extend([fill_value] * (capacity - self.size()))
-
-        def front(self) -> T:
-            if not self.empty():
-                return self[0]
-            else:
-                raise IndexError("Vector is empty.")
-
-        def assign(self, count: int, value: T) -> None:
-            self.clear()
-            self.extend([value] * count)
-
-        def insert(
-            self,
-            pos: "std.vector[T].iterator",
-            first: "std.vector[T].iterator",
-            last: "std.vector[T].iterator",
-        ) -> None:
-            self[pos._index:pos._index] = list(
-                islice(first._vector, first._index, last._index)
-            )
-
-        def begin(self) -> "std.vector[T].iterator":
-            return self.iterator(self, 0)
-
-        def end(self) -> "std.vector[T].iterator":
-            return self.iterator(self, self.size())
-
-        def __getitem__(self, index):
-            if isinstance(index, slice):
-                return std.vector(super().__getitem__(index))
-            return super().__getitem__(index)
-
-        def __setitem__(self, index, value):
-            self.modify()
-            if isinstance(index, slice):
-                if isinstance(value, std.vector):
-                    value = list(value)
-                super().__setitem__(index, value)
-            else:
-                super().__setitem__(index, value)
-
-        def __delitem__(self, index):
-            self.modify()
-            super().__delitem__(index)
-
-    class map(Generic[T, U], OrderedDict[T, U]):
-        """C++ implementation of std::map."""
-
-        class iterator(Generic[V, W]):
-            def __init__(self, _map: "std.map[T, U]", key: Union[T, Sentinel]):
-                self._map = _map
-                self.iter = iter(_map)
-                self.key = key
-                self._advance()
-
-            def _sanitize_key(self) -> T:
-                if isinstance(self.key, Sentinel):
-                    raise StopIteration
-                return self.key
-
-            def _advance(self) -> None:
-                try:
-                    while next(self.iter) != self.key:
-                        pass
-                except StopIteration:
-                    self.key = Sentinel()
-
-            def __next__(self) -> Tuple[T, U]:
-                key = self._sanitize_key()
-                if key in self._map:
-                    value = self._map[key]
-                    self._advance()
-                    return key, value
-                else:
-                    raise StopIteration
-
-            def get(self) -> Tuple[T, U]:
-                key = self._sanitize_key()
-                return key, self._map[key]
+import enum
+import typing
 
-            @property
-            def first(self) -> T:
-                return self._sanitize_key()
+import llama_cpp.llama_cpp as llama_cpp
 
-            @property
-            def second(self) -> U:
-                return self._map[self._sanitize_key()]
+class GrammarElementType(enum.IntEnum):
+    END = llama_cpp.LLAMA_GRETYPE_END
+    ALT = llama_cpp.LLAMA_GRETYPE_ALT
+    RULE_REF = llama_cpp.LLAMA_GRETYPE_RULE_REF
+    CHAR = llama_cpp.LLAMA_GRETYPE_CHAR
+    CHAR_NOT = llama_cpp.LLAMA_GRETYPE_CHAR_NOT
+    CHAR_RNG_UPPER = llama_cpp.LLAMA_GRETYPE_CHAR_RNG_UPPER
+    CHAR_ALT = llama_cpp.LLAMA_GRETYPE_CHAR_ALT
+    CHAR_ANY = llama_cpp.LLAMA_GRETYPE_CHAR_ANY
+
+import dataclasses
+
+@dataclasses.dataclass
+class GrammarElement:
+    type: GrammarElementType
+    value: int
+
+@dataclasses.dataclass
+class ParseState:
+    symbol_ids: typing.Dict[str, int] = dataclasses.field(default_factory=dict)
+    rules: typing.List[typing.List[GrammarElement]] = dataclasses.field(default_factory=list)
+
+
+def decode_utf8(src: str) -> typing.Tuple[int, str]:
+    lookup: list[int] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]
+    first_byte: int = ord(src[0])
+    highbits: int = first_byte >> 4
+    length: int = lookup[highbits]
+    mask: int = (1 << (8 - length)) - 1
+    value: int = first_byte & mask
+    end: int = min(len(src), length)  # Prevent overrun
+
+    pos: int = 1
+    for pos in range(1, end):
+        if not src[pos]:
+            break
+        value = (value << 6) + (ord(src[pos]) & 0x3F)
 
-        def insert(
-            self, key: T, value: U
-        ) -> Tuple["std.map[T, U].iterator[T, U]", bool]:
-            if key in self:
-                return self.iterator(self, key), False
-            else:
-                self[key] = value
-                return self.iterator(self, key), True
+    return value, src[pos:] if pos < len(src) else ""
 
-        def find(self, key: T) -> "std.map[T, U].iterator[T, U]":
-            if key in self:
-                return self.iterator(self, key)
-            else:
-                return self.end()
 
-        def at(self, key: T) -> U:
-            if key in self:
-                return self[key]
-            else:
-                raise KeyError("The provided key is not found in the map.")
-
-        def erase(self, iterator: "std.map[T, U].iterator[T, U]") -> None:
-            key = iterator.first
-            if key in self:
-                del self[key]
-
-        def size(self) -> int:
-            return len(self)
-
-        def empty(self) -> bool:
-            return self.size() == 0
-
-        def lower_bound(self, key: T) -> "std.map[T, U].iterator[T, U]":
-            try:
-                keys = sorted(list(self.keys()))  # type: ignore
-                for k in keys:
-                    if k >= key:
-                        return self.iterator(self, k)
-                raise ValueError("No key found that is not less than the input key")
-            except TypeError:
-                raise TypeError("Keys of type T cannot be sorted.")
-
-        def begin(self) -> "std.map[T, U].iterator[T, U]":
-            return self.iterator(self, next(iter(self)))
-
-        def end(self) -> "std.map[T, U].iterator[T, U]":
-            return self.iterator(self, Sentinel())
-
-# // grammar element type
-# enum llama_gretype {
-#     // end of rule definition
-#     LLAMA_GRETYPE_END            = 0,
-
-#     // start of alternate definition for rule
-#     LLAMA_GRETYPE_ALT            = 1,
-
-#     // non-terminal element: reference to rule
-#     LLAMA_GRETYPE_RULE_REF       = 2,
-
-#     // terminal element: character (code point)
-#     LLAMA_GRETYPE_CHAR           = 3,
-
-#     // inverse char(s) ([^a], [^a-b] [^abc])
-#     LLAMA_GRETYPE_CHAR_NOT       = 4,
-
-#     // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
-#     // be an inclusive range ([a-z])
-#     LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
-
-#     // modifies a preceding LLAMA_GRETYPE_CHAR or
-#     // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
-#     LLAMA_GRETYPE_CHAR_ALT       = 6,
-
-#     // any character (.)
-#     LLAMA_GRETYPE_CHAR_ANY       = 7,
-# };
-class llama_gretype(Enum):
-    """grammar element type"""
-
-    LLAMA_GRETYPE_END = 0  # end of rule definition
-    LLAMA_GRETYPE_ALT = 1  # start of alternate definition for rule
-    LLAMA_GRETYPE_RULE_REF = 2  # non-terminal element: reference to rule
-    LLAMA_GRETYPE_CHAR = 3  # terminal element: character (code point)
-    LLAMA_GRETYPE_CHAR_NOT = 4  # inverse char(s) ([^a], [^a-b] [^abc])
-    LLAMA_GRETYPE_CHAR_RNG_UPPER = 5  # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to be an inclusive range ([a-z])
-    LLAMA_GRETYPE_CHAR_ALT = 6  # modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
-    LLAMA_GRETYPE_CHAR_ANY = 7  # any character (.)
-
-
-# struct parse_state {
-#     std::map<std::string, uint32_t>                 symbol_ids;
-#     std::vector<std::vector<llama_grammar_element>> rules;
-#     std::vector<const llama_grammar_element *> c_rules();
-# };
-class parse_state:
-    def __init__(self):
-        self.symbol_ids: std.map[str, int] = std.map()
-        self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector()
-
-    # std::vector<const llama_grammar_element *> parse_state::c_rules() {
-    #     std::vector<const llama_grammar_element *> ret;
-    #     for (const auto & rule : rules) {
-    #         ret.push_back(rule.data());
-    #     }
-    #     return ret;
-    # }
-    def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]:
-        ret = std.vector()  # type: std.vector[std.vector[LlamaGrammarElement]]
-        for rule in self.rules:
-            ret.push_back(rule.data())
-        return ret
-
-    def __repr__(self) -> str:
-        return (
-            f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})"
-        )
+def get_symbol_id(state: ParseState, name: str) -> int:
+    next_id = len(state.symbol_ids)
+    return state.symbol_ids.setdefault(name, next_id)
 
 
-# struct llama_grammar {
-#     const std::vector<std::vector<llama_grammar_element>>   rules;
-#     std::vector<std::vector<const llama_grammar_element *>> stacks;
-# };
-# class llama_grammar:
-#     def __init__(
-#         self,
-#         rules: std.vector[std.vector[llama_grammar_element]],
-#         stacks: std.vector[std.vector[llama_grammar_element]],
-#     ):
-#         self.rules = rules
-#         self.stacks = stacks
-
-
-# uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
-#     uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
-#     auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
-#     return result.first->second;
-# }
-def get_symbol_id(state: parse_state, src: const_char_p, len: int) -> int:
-    next_id = state.symbol_ids.size()  # type: int
-    result = state.symbol_ids.insert(std.string(src, len), next_id)
-    return result[0].second  # type: ignore
-
-# uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
-#     uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
-#     state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
-#     return next_id;
-# }
-def generate_symbol_id(state: parse_state, base_name: str) -> int:
-    next_id = state.symbol_ids.size()  # type: int
-    state.symbol_ids[base_name + "_" + str(next_id)] = next_id
+def generate_symbol_id(state: ParseState, base_name: str) -> int:
+    next_id = len(state.symbol_ids)
+    state.symbol_ids[f"{base_name}_{next_id}"] = next_id
     return next_id
 
 
-# void add_rule(
-#         parse_state & state,
-#         uint32_t      rule_id,
-#         const std::vector<llama_grammar_element> & rule) {
-#     if (state.rules.size() <= rule_id) {
-#         state.rules.resize(rule_id + 1);
-#     }
-#     state.rules[rule_id] = rule;
-# }
-def add_rule(
-    state: parse_state,
-    rule_id: int,
-    rule: std.vector[LlamaGrammarElement],
-) -> None:
-    if state.rules.size() <= rule_id:
-        state.rules.resize(
-            rule_id + 1,
-            fill_value_factory=std.vector[LlamaGrammarElement],
-        )
+def add_rule(state: ParseState, rule_id: int, rule: typing.List[GrammarElement]) -> None:
+    if len(state.rules) <= rule_id:
+        state.rules.extend([[]] * (rule_id + 1 - len(state.rules)))
     state.rules[rule_id] = rule
 
 
-# std::pair<uint32_t, const char *> decode_utf8(const char * src) {
-#     static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
-#     uint8_t  first_byte = static_cast<uint8_t>(*src);
-#     uint8_t  highbits   = first_byte >> 4;
-#     int      len        = lookup[highbits];
-#     uint8_t  mask       = (1 << (8 - len)) - 1;
-#     uint32_t value      = first_byte & mask;
-#     const char * end    = src + len; // may overrun!
-#     const char * pos    = src + 1;
-#     for ( ; pos < end && *pos; pos++) {
-#         value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
-#     }
-#     return std::make_pair(value, pos);
-# }
-def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]:
-    """Decodes a UTF-8 character from the source string."""
-    # Get the codepoint of the first character
-    value = ord(src[0])
-    # Move the pointer ahead one character
-    pos = src + 1
-
-    return value, pos
-
-
-"""#static bool is_digit_char(char c) {
-#    return '0' <= c && c <= '9';
-#}
 def is_digit_char(c: str) -> bool:
     return "0" <= c <= "9"
 
 
-# bool is_word_char(char c) {
-#     return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
-# }
 def is_word_char(c: str) -> bool:
-    return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") or is_digit_char(c)
-"""
-
-##optimized version
-# Original is_digit_char time: 2.868295 seconds
-# Optimized is_digit_char time: 1.993195 seconds
-# Original is_word_char time: 3.856689 seconds
-# Optimized is_word_char time: 2.052832 seconds
-
-digit_chars = set("0123456789")
-word_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-0123456789")
+    return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or is_digit_char(c)
 
-def is_digit_char(c: str) -> bool:
-    return c in digit_chars
 
-def is_word_char(c: str) -> bool:
-    return c in word_chars
-
-
-# std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
-#     const char * pos   = src;
-#     const char * end   = src + size;
-#     uint32_t     value = 0;
-#     for ( ; pos < end && *pos; pos++) {
-#         value <<= 4;
-#         char c = *pos;
-#         if ('a' <= c && c <= 'f') {
-#             value += c - 'a' + 10;
-#         } else if ('A' <= c && c <= 'F') {
-#             value += c - 'A' + 10;
-#         } else if ('0' <= c && c <= '9') {
-#             value += c - '0';
-#         } else {
-#             break;
-#         }
-#     }
-#     if (pos != end) {
-#         throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
-#     }
-#     return std::make_pair(value, pos);
-# }
-def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]:
-    pos = const_char_p(src)  # type: const_char_p
-    end = src + size  # type: const_char_p
-    value = 0  # type: int
-    while pos < end and pos[0]:
+def parse_hex(src: str, size: int) -> typing.Tuple[int, str]:
+    pos = 0
+    value = 0
+    for _ in range(size):
         value <<= 4
-        c = pos[0]  # type: str
+        c = src[pos]
         if "a" <= c <= "f":
             value += ord(c) - ord("a") + 10
         elif "A" <= c <= "F":
@@ -648,752 +97,370 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]:
         else:
             break
         pos += 1
-    if pos != end:
-        raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src))
-    return (value, pos)
-
-
-# std::pair<uint32_t, const char *> parse_char(const char * src) {
-#     if (*src == '\\') {
-#         switch (src[1]) {
-#             case 'x': return parse_hex(src + 2, 2);
-#             case 'u': return parse_hex(src + 2, 4);
-#             case 'U': return parse_hex(src + 2, 8);
-#             case 't': return std::make_pair('\t', src + 2);
-#             case 'r': return std::make_pair('\r', src + 2);
-#             case 'n': return std::make_pair('\n', src + 2);
-#             case '\\':
-#             case '"':
-#             case '[':
-#             case ']':
-#                 return std::make_pair(src[1], src + 2);
-#             default:
-#                 throw std::runtime_error(std::string("unknown escape at ") + src);
-#         }
-#     } else if (*src) {
-#         return decode_utf8(src);
-#     }
-#     throw std::runtime_error("unexpected end of input");
-# }
-def parse_char(src: const_char_p) -> Tuple[int, const_char_p]:
-    if src[0] == "\\":
-        case = src[1]  # type: str
-        if case == "x":
-            return parse_hex(src + 2, 2)
-        elif case == "u":
-            return parse_hex(src + 2, 4)
-        elif case == "U":
-            return parse_hex(src + 2, 8)
-        elif case == "t":
-            return (ord("\t"), src + 2)  # implicit cast
-        elif case == "r":
-            return (ord("\r"), src + 2)  # implicit cast
-        elif case == "n":
-            return (ord("\n"), src + 2)  # implicit cast
-        elif case in ("\\", '"', "[", "]"):
-            return (ord(case), src + 2)  # implicit cast
-        else:
-            raise RuntimeError("unknown escape at " + str(src))
-    elif src[0]:
-        return decode_utf8(src)
-    else:
-        raise RuntimeError("unexpected end of input")
-
-
-# const char * parse_name(const char * src) {
-#     const char * pos = src;
-#     while (is_word_char(*pos)) {
-#         pos++;
-#     }
-#     if (pos == src) {
-#         throw std::runtime_error(std::string("expecting name at ") + src);
-#     }
-#     return pos;
-# }
-def parse_name(src: const_char_p) -> const_char_p:
-    pos = const_char_p(src)  # type: const_char_p
-    while is_word_char(pos[0]):
+    if pos != size:
+        raise ValueError(f"expecting {size} hex chars at {src}")
+    return value, src[pos:]
+
+
+def parse_space(src: str, newline_ok: bool) -> str:
+    pos = 0
+    while pos < len(src) and (src[pos] in (" ", "\t", "#") or (newline_ok and src[pos] in ("\r", "\n"))):
+        if src[pos] == "#":
+            while pos < len(src) and src[pos] not in ("\r", "\n"):
+                pos += 1
         pos += 1
-    if pos == src:
-        raise RuntimeError("expecting name at " + str(src))
-    return pos
+    return src[pos:]
+
+
+def parse_name(src: str) -> typing.Tuple[str, str]:
+    pos = 0
+    try:
+        while is_word_char(src[pos]):
+            pos += 1
+    except IndexError:
+        return src, ""
+    if pos == 0:
+        raise ValueError(f"expecting name at {src}")
+    return src[:pos], src[pos:]
 
 
-#static const char * parse_int(const char * src) {
-#    const char * pos = src;
-#    while (is_digit_char(*pos)) {
-#        pos++;
-#    }
-#    if (pos == src) {
-#        throw std::runtime_error(std::string("expecting integer at ") + src);
-#    }
-#    return pos;
-#}
-def parse_int(src: const_char_p) -> const_char_p:
-    pos = const_char_p(src)  # type: const_char_p
-    while is_digit_char(pos[0]):
+def parse_int(src: str) -> typing.Tuple[int, str]:
+    pos = 0
+    while is_digit_char(src[pos]):
         pos += 1
-    if pos == src:
-        raise RuntimeError("expecting integer at " + str(src))
-    return pos
+    if pos == 0:
+        raise ValueError(f"expecting integer at {src}")
+    return int(src[:pos]), src[pos:]
 
 
-# const char * parse_space(const char * src, bool newline_ok) {
-#     const char * pos = src;
-#     while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
-#             (newline_ok && (*pos == '\r' || *pos == '\n'))) {
-#         if (*pos == '#') {
-#             while (*pos && *pos != '\r' && *pos != '\n') {
-#                 pos++;
-#             }
-#         } else {
-#             pos++;
-#         }
-#     }
-#     return pos;
-# }
-def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p:
-    pos = const_char_p(src)  # type: const_char_p
-    while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")):
-        if pos[0] == "#":
-            while pos[0] is not None and pos[0] not in ("\r", "\n"):
-                pos += 1
+def parse_char(src: str) -> typing.Tuple[int, str]:
+    if src[0] == "\\":
+        if src[1] == "x":
+            return parse_hex(src[2:], 2)
+        elif src[1] == "u":
+            return parse_hex(src[2:], 4)
+        elif src[1] == "U":
+            return parse_hex(src[2:], 8)
+        elif src[1] == "t":
+            return ord("\t"), src[2:]
+        elif src[1] == "r":
+            return ord("\r"), src[2:]
+        elif src[1] == "n":
+            return ord("\n"), src[2:]
+        elif src[1] in ('\\', '"', '[', ']'):
+            return ord(src[1]), src[2:]
         else:
-            pos += 1
-    return pos
+            raise ValueError(f"unknown escape at {src}")
+    elif src:
+        return decode_utf8(src)
+    raise ValueError("unexpected end of input")
 
 
-# const char * parse_sequence(
-#         parse_state                        & state,
-#         const char                         * src,
-#         const std::string                  & rule_name,
-#         std::vector<llama_grammar_element> & out_elements,
-#         bool                                 is_nested) {
-def parse_sequence(
-    state: parse_state,
-    src: const_char_p,
-    rule_name: str,
-    out_elements: std.vector[LlamaGrammarElement],
-    is_nested: bool,
-) -> const_char_p:
-    # size_t last_sym_start = out_elements.size();
-    # const char * pos = src;
-    last_sym_start = out_elements.size()  # type: int
-    pos = const_char_p(src)  # type: const_char_p
-    
-    
-    # auto handle_repetitions = [&](int min_times, int max_times) {
-
-    #     if (last_sym_start == out_elements.size()) {
-    #         throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
-    #     }
-
-    #     // apply transformation to previous symbol (last_sym_start to end) according to
-    #     // the following rewrite rules:
-    #     // S{m,n} --> S S S (m times) S'(n-m)
-    #     //            S'(x)   ::= S S'(x-1) |
-    #     //            (... n-m definitions of these S' rules ...)
-    #     //            S'(1)   ::= S |
-    #     // S{m,} -->  S S S (m times) S'
-    #     //            S'     ::= S S' |
-    #     // S*     --> S{0,}
-    #     //        --> S'     ::= S S' |
-    #     // S+     --> S{1,}
-    #     //        --> S S'
-    #     //            S'     ::= S S' |
-    #     // S?     --> S{0,1}
-    #     //        --> S'
-    #     //            S'     ::= S |
-
-    #     std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
-    #     if (min_times == 0) {
-    #         out_elements.resize(last_sym_start);
-    #     } else {
-    #         // Repeat the previous elements (min_times - 1) times
-    #         for (int i = 1; i < min_times; i++) {
-    #             out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
-    #         }
-    #     }
-
-    #     uint32_t last_rec_rule_id = 0;
-    #     auto n_opt = max_times < 0 ? 1 : max_times - min_times;
-
-    #     std::vector<llama_grammar_element> rec_rule(previous_elements);
-    #     for (int i = 0; i < n_opt; i++) {
-    #         rec_rule.resize(previous_elements.size());
-    #         uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
-    #         if (i > 0 || max_times < 0) {
-    #             rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
-    #         }
-    #         rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
-    #         rec_rule.push_back({LLAMA_GRETYPE_END, 0});
-    #         add_rule(state, rec_rule_id, rec_rule);
-    #         last_rec_rule_id = rec_rule_id;
-    #     }
-    #     if (n_opt > 0) {
-    #         out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
-    #     }
-    # };
-    
-    def handle_repetitions(min_times: int, max_times: int) -> None:
-        if last_sym_start == out_elements.size():
-            raise RuntimeError("expecting preceding item to */+/?/{ at " + str(pos))
+def parse_sequence(state: ParseState, src: str, rule_name: str, out_elements: typing.List[GrammarElement], is_nested: bool) -> str:
+    last_sym_start = len(out_elements)
+    pos = src
 
+    def handle_repetitions(min_times: int, max_times: int) -> None:
+        nonlocal last_sym_start
+        nonlocal pos
+        nonlocal out_elements
 
-        previous_elements:std.vector[LlamaGrammarElement] = std.vector(out_elements[last_sym_start:])  # type: std.vector[LlamaGrammarElement]
+        if last_sym_start == len(out_elements):
+            raise ValueError(f"expecting preceding item to */+/?/{{ at {pos}")
 
+        previous_elements = out_elements[last_sym_start:]
         if min_times == 0:
-            out_elements.resize(last_sym_start)
+            out_elements = out_elements[:last_sym_start]
         else:
-            # Repeat the previous elements (min_times - 1) times
             for i in range(1, min_times):
-                out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end())
-        
-        last_rec_rule_id = 0  # type: int
-        n_opt = 1 if max_times < 0 else max_times - min_times  # type: int
-        rec_rule = std.vector(previous_elements)  # type: List[LlamaGrammarElement]
+                out_elements.extend(previous_elements)
+        last_rec_rule_id = 0
+        n_opt = 1 if max_times < 0 else max_times - min_times
 
+        rec_rule = list(previous_elements)
         for i in range(n_opt):
-            rec_rule.resize(previous_elements.size())
-            rec_rule_id = generate_symbol_id(state, rule_name) # type: int
+            rec_rule = rec_rule[:len(previous_elements)]
+            rec_rule_id = generate_symbol_id(state, rule_name)
             if i > 0 or max_times < 0:
-                rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id))
-            rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
-            rec_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
+                rec_rule.append(GrammarElement(GrammarElementType.RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id))
+            rec_rule.append(GrammarElement(GrammarElementType.ALT, 0))
+            rec_rule.append(GrammarElement(GrammarElementType.END, 0))
             add_rule(state, rec_rule_id, rec_rule)
             last_rec_rule_id = rec_rule_id
-
         if n_opt > 0:
-            out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, last_rec_rule_id))
-    
-
-    
-    # while (*pos) {
-    while pos[0]:
-
-        # if (*pos == '"') { // literal string
-        #     pos++;
-        #     last_sym_start = out_elements.size();
-        #     while (*pos != '"') {
-        #         auto char_pair = parse_char(pos);
-        #                 pos       = char_pair.second;
-        #         out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
-        #     }
-        #     pos = parse_space(pos + 1, is_nested);
-        if pos[0] == '"':  # literal string
-            pos += 1
-            last_sym_start = out_elements.size()
+            out_elements.append(GrammarElement(GrammarElementType.RULE_REF, last_rec_rule_id))
+
+    while pos:
+        if pos.startswith('"'):
+            pos = pos[1:]
+            last_sym_start = len(out_elements)
             while pos[0] != '"':
-                assert pos[0] is not None, "Unexpected end of input"
-                char_pair = parse_char(pos)  # type: Tuple[int, const_char_p]
-                pos = char_pair[1]
-                out_elements.push_back(
-                    LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0])
-                )
-            pos = parse_space(pos + 1, is_nested)
-        # } else if (*pos == '[') { // char range(s)
-        #     pos++;
-        #     enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
-        elif pos[0] == "[":  # char range(s)
-            pos += 1
-            start_type = llama_gretype.LLAMA_GRETYPE_CHAR  # type: llama_gretype
-            # if (*pos == '^') {
-            #     pos++;
-            #     start_type = LLAMA_GRETYPE_CHAR_NOT;
-            # }
-            # last_sym_start = out_elements.size();
+                if not pos:
+                    raise ValueError("unexpected end of input")
+                char, pos = parse_char(pos)
+                out_elements.append(GrammarElement(GrammarElementType.CHAR, char))
+            pos = parse_space(pos[1:], is_nested)
+        elif pos.startswith("["):
+            pos = pos[1:]
+            start_type = GrammarElementType.CHAR
             if pos[0] == "^":
-                pos += 1
-                start_type = llama_gretype.LLAMA_GRETYPE_CHAR_NOT
-            last_sym_start = out_elements.size()
-            # while (*pos != ']') {
-            #     auto char_pair = parse_char(pos);
-            #             pos       = char_pair.second;
-            #     enum llama_gretype type = last_sym_start < out_elements.size()
-            #         ? LLAMA_GRETYPE_CHAR_ALT
-            #         : start_type;
-            #     out_elements.push_back({type, char_pair.first});
+                start_type = GrammarElementType.CHAR_NOT
+                pos = pos[1:]
+            last_sym_start = len(out_elements)
             while pos[0] != "]":
-                assert pos[0] is not None, "Unexpected end of input"
-                char_pair = parse_char(pos)  # type: Tuple[int, const_char_p]
-                pos = char_pair[1]
-                _type = (
-                    llama_gretype.LLAMA_GRETYPE_CHAR_ALT
-                    if last_sym_start < out_elements.size()
-                    else start_type
-                )  # type: llama_gretype
-                out_elements.push_back(LlamaGrammarElement(_type, char_pair[0]))
-                #     if (pos[0] == '-' && pos[1] != ']') {
-                #         auto endchar_pair = parse_char(pos + 1);
-                #                 pos          = endchar_pair.second;
-                #         out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
-                #     }
-                # }
+                if not pos:
+                    raise ValueError("unexpected end of input")
+                char, pos = parse_char(pos)
+                type = GrammarElementType.CHAR_ALT if last_sym_start < len(out_elements) else start_type
+                out_elements.append(GrammarElement(type, char))
                 if pos[0] == "-" and pos[1] != "]":
-                    assert pos[1] is not None, "Unexpected end of input"
-                    endchar_pair = parse_char(pos + 1)  # type: Tuple[int, const_char_p]
-                    pos = endchar_pair[1]
-                    out_elements.push_back(
-                        LlamaGrammarElement(
-                            llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
-                            endchar_pair[0],
-                        )
-                    )
-            # pos = parse_space(pos + 1, is_nested);
-            pos = parse_space(pos + 1, is_nested)
-        # } else if (is_word_char(*pos)) { // rule reference
-        #     const char * name_end    = parse_name(pos);
-        #     uint32_t     ref_rule_id = get_symbol_id(state, pos, name_end - pos);
-        #     pos = parse_space(name_end, is_nested);
-        #     last_sym_start = out_elements.size();
-        #     out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
-        elif is_word_char(pos[0]):  # rule reference
-            name_end = parse_name(pos)  # type: const_char_p
-            ref_rule_id = get_symbol_id(state, pos, name_end - pos)  # type: int
-            pos = parse_space(name_end, is_nested)
-            last_sym_start = out_elements.size()
-            out_elements.push_back(
-                LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id)
-            )
-        # } else if (*pos == '(') { // grouping
-        #     // parse nested alternates into synthesized rule
-        #     pos = parse_space(pos + 1, true);
-        #     uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
-        #     pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
-        #     last_sym_start = out_elements.size();
-        #     // output reference to synthesized rule
-        #     out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
-        #     if (*pos != ')') {
-        #         throw std::runtime_error(std::string("expecting ')' at ") + pos);
-        #     }
-        #     pos = parse_space(pos + 1, is_nested);
-        elif pos[0] == "(":  # grouping
-            # parse nested alternates into synthesized rule
-            pos = parse_space(pos + 1, True)
-            sub_rule_id = generate_symbol_id(state, rule_name)  # type: int
-            pos = parse_alternates(state, pos, rule_name, sub_rule_id, True)
-            last_sym_start = out_elements.size()
-            # output reference to synthesized rule
-            out_elements.push_back(
-                LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
-            )
+                    if not pos[1]:
+                        raise ValueError("unexpected end of input")
+                    endchar, pos = parse_char(pos[1:])
+                    out_elements.append(GrammarElement(GrammarElementType.CHAR_RNG_UPPER, endchar))
+            pos = parse_space(pos[1:], is_nested)
+        elif is_word_char(pos[0]):
+            name, rest = parse_name(pos)
+            ref_rule_id = get_symbol_id(state, name)
+            pos = parse_space(rest, is_nested)
+            last_sym_start = len(out_elements)
+            out_elements.append(GrammarElement(GrammarElementType.RULE_REF, ref_rule_id))
+        elif pos.startswith("("):
+            pos = parse_space(pos[1:], newline_ok=True)
+            sub_rule_id = generate_symbol_id(state, rule_name)
+            pos = parse_alternates(state, pos, rule_name, sub_rule_id, is_nested=True)
+            last_sym_start = len(out_elements)
+            out_elements.append(GrammarElement(GrammarElementType.RULE_REF, sub_rule_id))
             if pos[0] != ")":
-                raise RuntimeError("expecting ')' at " + str(pos))
-            pos = parse_space(pos + 1, is_nested)
-        elif pos[0] == '.':
-            last_sym_start = out_elements.size()
-            out_elements.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR_ANY, 0))
-            pos = parse_space(pos + 1, is_nested)
-           
-        # } else if (*pos == '*') {
-        #         pos = parse_space(pos + 1, is_nested);
-        #         handle_repetitions(0, -1);
-        # } else if (*pos == '+') {
-        #     pos = parse_space(pos + 1, is_nested);
-        #     handle_repetitions(1, -1);
-        # } else if (*pos == '?') {
-        #     pos = parse_space(pos + 1, is_nested);
-        #     handle_repetitions(0, 1);
-        # } else if (*pos == '{') {
-        #     pos = parse_space(pos + 1, is_nested);
-
-        #     if (!is_digit_char(*pos)) {
-        #         throw std::runtime_error(std::string("expecting an int at ") + pos);
-        #     }
-        #     const char * int_end = parse_int(pos);
-        #     int min_times = std::stoul(std::string(pos, int_end - pos));
-        #     pos = parse_space(int_end, is_nested);
-
-        #     int max_times = -1;
-
-        #     if (*pos == '}') {
-        #         max_times = min_times;
-        #         pos = parse_space(pos + 1, is_nested);
-        #     } else if (*pos == ',') {
-        #         pos = parse_space(pos + 1, is_nested);
-
-        #         if (is_digit_char(*pos)) {
-        #             const char * int_end = parse_int(pos);
-        #             max_times = std::stoul(std::string(pos, int_end - pos));
-        #             pos = parse_space(int_end, is_nested);
-        #         }
-
-        #         if (*pos != '}') {
-        #             throw std::runtime_error(std::string("expecting '}' at ") + pos);
-        #         }
-        #         pos = parse_space(pos + 1, is_nested);
-        #     } else {
-        #         throw std::runtime_error(std::string("expecting ',' at ") + pos);
-        #     }
-        #     handle_repetitions(min_times, max_times);
-        # } else {
-        #     break;
-        # }
-        elif pos[0] == "*":
-            pos = parse_space(pos + 1, is_nested)
+                raise ValueError(f"expecting ')' at {pos}")
+            pos = parse_space(pos[1:], is_nested)
+        elif pos.startswith("."):
+            last_sym_start = len(out_elements)
+            out_elements.append(GrammarElement(GrammarElementType.CHAR_ANY, 0))
+            pos = parse_space(pos[1:], is_nested)
+        elif pos.startswith("*"):
+            pos = parse_space(pos[1:], is_nested)
             handle_repetitions(0, -1)
-        elif pos[0] == "+":
-            pos = parse_space(pos + 1, is_nested)
+        elif pos.startswith("+"):
+            pos = parse_space(pos[1:], is_nested)
             handle_repetitions(1, -1)
-        elif pos[0] == "?":
-            pos = parse_space(pos + 1, is_nested)
+        elif pos.startswith("?"):
+            pos = parse_space(pos[1:], is_nested)
             handle_repetitions(0, 1)
-        elif pos[0] == "{":
-            pos = parse_space(pos + 1, is_nested)
-            
+        elif pos.startswith("{"):
+            pos = parse_space(pos[1:], is_nested)
             if not is_digit_char(pos[0]):
-                raise RuntimeError("expecting an int at " + str(pos))
-            
+                raise ValueError(f"expecting an int at {pos}")
+            min_times, pos = parse_int(pos)
+            pos = parse_space(pos, is_nested)
 
-            
-            int_end = parse_int(pos)
-            min_times = int(str(pos)[:int_end - pos])
-            pos = parse_space(int_end, is_nested)
             max_times = -1
+
             if pos[0] == "}":
                 max_times = min_times
-                pos = parse_space(pos + 1, is_nested)
+                pos = parse_space(pos[1:], is_nested)
             elif pos[0] == ",":
-            
-                pos = parse_space(pos + 1, is_nested)
+                pos = parse_space(pos[1:], is_nested)
                 if is_digit_char(pos[0]):
-                    int_end = parse_int(pos)
-                    max_times = int(str(pos)[:int_end - pos])
-                    pos = parse_space(int_end, is_nested)
+                    max_times, pos = parse_int(pos)
+                    pos = parse_space(pos, is_nested)
                 if pos[0] != "}":
-                    raise RuntimeError("expecting '}' at " + str(pos))
-                pos = parse_space(pos + 1, is_nested)
+                    raise ValueError("expecting '}' at {}".format(pos))
+                pos = parse_space(pos[1:], is_nested)
             else:
-                raise RuntimeError("expecting ',' at " + str(pos))
+                raise ValueError(f"expecting ',' at {pos}")
             handle_repetitions(min_times, max_times)
-            
         else:
             break
-        
-        
-        
-
     return pos
 
 
-# const char * parse_alternates(
-#         parse_state       & state,
-#         const char        * src,
-#         const std::string & rule_name,
-#         uint32_t            rule_id,
-#         bool                is_nested) {
-#     std::vector<llama_grammar_element> rule;
-#     const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
-#     while (*pos == '|') {
-#         rule.push_back({LLAMA_GRETYPE_ALT, 0});
-#         pos = parse_space(pos + 1, true);
-#         pos = parse_sequence(state, pos, rule_name, rule, is_nested);
-#     }
-#     rule.push_back({LLAMA_GRETYPE_END, 0});
-#     add_rule(state, rule_id, rule);
-#     return pos;
-# }
-def parse_alternates(
-    state: parse_state,
-    src: const_char_p,
-    rule_name: str,
-    rule_id: int,
-    is_nested: bool,
-) -> const_char_p:
-    rule = std.vector()  # type: std.vector[LlamaGrammarElement]
-    pos = parse_sequence(state, src, rule_name, rule, is_nested)  # type: const_char_p
-    while pos[0] == "|":
-        rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
-        pos = parse_space(pos + 1, True)
+def parse_alternates(state: ParseState, src: str, rule_name: str, rule_id: int, is_nested: bool) -> str:
+    rule = []
+    pos = parse_sequence(state, src, rule_name, rule, is_nested)
+    while pos.startswith("|"):
+        rule.append(GrammarElement(GrammarElementType.ALT, 0))
+        pos = parse_space(pos[1:], newline_ok=True)
         pos = parse_sequence(state, pos, rule_name, rule, is_nested)
-    rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
+    rule.append(GrammarElement(GrammarElementType.END, 0))
     add_rule(state, rule_id, rule)
     return pos
 
 
-# const char * parse_rule(parse_state & state, const char * src) {
-#     const char * name_end = parse_name(src);
-#     const char * pos      = parse_space(name_end, false);
-#     size_t       name_len = name_end - src;
-#     uint32_t     rule_id  = get_symbol_id(state, src, name_len);
-#     const std::string name(src, name_len);
-
-#     if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
-#         throw std::runtime_error(std::string("expecting ::= at ") + pos);
-#     }
-#     pos = parse_space(pos + 3, true);
-
-#     pos = parse_alternates(state, pos, name, rule_id, false);
-
-
-#     if (*pos == '\r') {
-#         pos += pos[1] == '\n' ? 2 : 1;
-#     } else if (*pos == '\n') {
-#         pos++;
-#     } else if (*pos) {
-#         throw std::runtime_error(std::string("expecting newline or end at ") + pos);
-#     }
-#     return parse_space(pos, true);
-# }
-def parse_rule(state: parse_state, src: const_char_p) -> const_char_p:
-    name_end = parse_name(src)  # type: const_char_p
-    pos = parse_space(name_end, False)  # type: const_char_p
-    name_len = name_end - src  # type: int
-    rule_id = get_symbol_id(state, src, name_len)  # type: int
-    name = std.string(src, name_len)  # type: str
-
-    if not (pos[0] == ":" and pos[1] == ":" and pos[2] == "="):
-        raise RuntimeError("expecting ::= at " + str(pos))
-
-    pos = parse_space(pos + 3, True)  # type: const_char_p
-    pos = parse_alternates(state, pos, name, rule_id, False)  # type: const_char_p
-
-    if pos[0] == "\r":
-        pos += 2 if pos[1] == "\n" else 1
-    elif pos[0] == "\n":
-        pos += 1
-    elif pos[0]:
-        raise RuntimeError("expecting newline or end at " + str(pos))
-    return parse_space(pos, True)
-    
-#parse_state parse(const char * src) {
-#    try {
-#        parse_state state;
-#        const char * pos = parse_space(src, true);
-#        while (*pos) {
-#            pos = parse_rule(state, pos);
-#        }
-#        // Validate the state to ensure that all rules are defined
-#        for (const auto & rule : state.rules) {
-#            for (const auto & elem : rule) {
-#                if (elem.type == LLAMA_GRETYPE_RULE_REF) {
-#                    // Ensure that the rule at that location exists
-#                    if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
-#                        // Get the name of the rule that is missing
-#                        for (const auto & kv : state.symbol_ids) {
-#                            if (kv.second == elem.value) {
-#                                throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
-#                            }
-#                        }
-#                    }
-#                }
-#            }
-#        }
-#        return state;
-#    } catch (const std::exception & err) {
-#        fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
-#        return parse_state();
-#    }
-#}
-
-
-def parse(src: const_char_p) -> parse_state:
-    try:
-        state = parse_state()  # type: parse_state
-        pos = parse_space(src, True)  # type: const_char_p
-        while pos[0]:
-            pos = parse_rule(state, pos)
-        # Validate the state to ensure that all rules are defined
-        for rule in state.rules:
-            for elem in rule:
-                if elem.type == llama_gretype.LLAMA_GRETYPE_RULE_REF:
-                    # Ensure that the rule at that location exists
-                    if elem.value >= len(state.rules) or not state.rules[elem.value]:
-                        # Get the name of the rule that is missing
-                        for kv in state.symbol_ids:
-                            if kv[1] == elem.value:
-                                raise RuntimeError("Undefined rule identifier '" + kv[0] + "'")
-        return state
-    except Exception as err:
-        print(f"{parse.__name__}: error parsing grammar: {err}")
-        return parse_state()
-
-# void print_grammar_char(FILE * file, uint32_t c) {
-#     if (0x20 <= c && c <= 0x7f) {
-#         fprintf(file, "%c", static_cast<char>(c));
-#     } else {
-#         // cop out of encoding UTF-8
-#         fprintf(file, "<U+%04X>", c);
-#     }
-# }
-def print_grammar_char(file: TextIO, c: int) -> None:
-    if 0x20 <= c and c <= 0x7F:
-        file.write(chr(c))
-    else:
-        # cop out of encoding UTF-8
-        file.write(f"<U+{c:04X}>")
-
-
-# bool is_char_element(llama_grammar_element elem) {
-#     switch (elem.type) {
-#         case LLAMA_GRETYPE_CHAR:           return true;
-#         case LLAMA_GRETYPE_CHAR_NOT:       return true;
-#         case LLAMA_GRETYPE_CHAR_ALT:       return true;
-#         case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
-#         default:                           return false;
-#     }
-# }
-def is_char_element(elem: LlamaGrammarElement) -> bool:
+def parse_rule(state: ParseState, src: str) -> str:
+    name, s = parse_name(src)
+    s = parse_space(s, newline_ok=False)
+    rule_id = get_symbol_id(state, name)
+
+    if not s.startswith("::="):
+        raise ValueError(f"expecting ::= at {s}")
+
+    s = s[3:]
+
+    s = parse_space(s, newline_ok=True)
+
+    s = parse_alternates(state, s, name, rule_id, is_nested=False)
+
+    if s.startswith("\r"):
+        s = s[2:] if s[1] == "\n" else s[1:]
+    elif s.startswith("\n"):
+        s = s[1:]
+    elif s:
+        raise ValueError(f"expecting newline or end at {s}")
+    return parse_space(s, newline_ok=True)
+
+
+def parse(gbnf: str) -> ParseState:
+    state = ParseState()
+    s = parse_space(gbnf, newline_ok=True)
+    while s:
+        s = parse_rule(state, s)
+    # validate
+    for rule in state.rules:
+        for elem in rule:
+            if elem.type == GrammarElementType.RULE_REF:
+                if elem.value >= len(state.rules) or not state.rules[elem.value]:
+                    for k, v in state.symbol_ids.items():
+                        if v == elem.value:
+                            raise ValueError(f"Undefined rule identifier '{k}'")
+    return state
+
+
+def is_char_element(elem: GrammarElement) -> bool:
     return elem.type in (
-        llama_gretype.LLAMA_GRETYPE_CHAR,
-        llama_gretype.LLAMA_GRETYPE_CHAR_NOT,
-        llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
-        llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
-        llama_gretype.LLAMA_GRETYPE_CHAR_ANY,
+        GrammarElementType.CHAR, 
+        GrammarElementType.CHAR_NOT,
+        GrammarElementType.CHAR_ALT,
+        GrammarElementType.CHAR_RNG_UPPER,
+        GrammarElementType.CHAR_ANY
     )
 
 
-# void print_rule(
-#         FILE     * file,
-#         uint32_t   rule_id,
-#         const std::vector<llama_grammar_element> & rule,
-#         const std::map<uint32_t, std::string>    & symbol_id_names) {
+def print_grammar_char(file: typing.TextIO, c: int) -> None:
+    if 0x20 <= c <= 0x7f:
+        print(chr(c), end="", file=file)
+    else:
+        print(f"<U+{c:04X}>", end="", file=file)
+
+
 def print_rule(
-    file: TextIO,
+    file: typing.TextIO,
     rule_id: int,
-    rule: std.vector[LlamaGrammarElement],
-    symbol_id_names: std.map[int, str],
+    rule: typing.List[GrammarElement],
+    symbol_id_names: typing.Dict[int, str],
 ) -> None:
-    #     if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
-    #         throw std::runtime_error(
-    #             "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
-    #     }
-    #     fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
-    if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
-        raise RuntimeError(
-            "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
-        )
-    print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
-    #     for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
-    #         llama_grammar_element elem = rule[i];
-    #         switch (elem.type) {
-    #             case LLAMA_GRETYPE_END:
-    #                 throw std::runtime_error(
-    #                     "unexpected end of rule: " + std::to_string(rule_id) + "," +
-    #                     std::to_string(i));
-    #             case LLAMA_GRETYPE_ALT:
-    #                 fprintf(file, "| ");
-    #                 break;
-    #             case LLAMA_GRETYPE_RULE_REF:
-    #                 fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
-    #                 break;
-    #             case LLAMA_GRETYPE_CHAR:
-    #                 fprintf(file, "[");
-    #                 print_grammar_char(file, elem.value);
-    #                 break;
-    #             case LLAMA_GRETYPE_CHAR_NOT:
-    #                 fprintf(file, "[^");
-    #                 print_grammar_char(file, elem.value);
-    #                 break;
-    #             case LLAMA_GRETYPE_CHAR_RNG_UPPER:
-    #                 if (i == 0 || !is_char_element(rule[i - 1])) {
-    #                     throw std::runtime_error(
-    #                         "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
-    #                         std::to_string(rule_id) + "," + std::to_string(i));
-    #                 }
-    #                 fprintf(file, "-");
-    #                 print_grammar_char(file, elem.value);
-    #                 break;
-    #             case LLAMA_GRETYPE_CHAR_ALT:
-    #                 if (i == 0 || !is_char_element(rule[i - 1])) {
-    #                     throw std::runtime_error(
-    #                         "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
-    #                         std::to_string(rule_id) + "," + std::to_string(i));
-    #                 }
-    #                 print_grammar_char(file, elem.value);
-    #                 break;
-    #         }
-    
+    if not rule or rule[-1].type != GrammarElementType.END:
+        raise ValueError(f"malformed rule, does not end with LLAMA_GRETYPE_END: {rule_id}")
 
+    print(f"{symbol_id_names[rule_id]} ::=", end=" ", file=file)
 
     for i, elem in enumerate(rule[:-1]):
-        case = elem.type  # type: llama_gretype
-        if case is llama_gretype.LLAMA_GRETYPE_END:
-            raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
-        elif case is llama_gretype.LLAMA_GRETYPE_ALT:
-            print("| ", file=file, end="")
-        elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
-            print(f"{symbol_id_names.at(elem.value)} ", file=file, end="")
-        elif case is llama_gretype.LLAMA_GRETYPE_CHAR:
-            print("[", file=file, end="")
+        if elem.type == GrammarElementType.END:
+            raise ValueError(f"unexpected end of rule: {rule_id}, {i}")
+        if elem.type == GrammarElementType.ALT:
+            print("| ", end="", file=file)
+        elif elem.type == GrammarElementType.RULE_REF:
+            print(f"{symbol_id_names[elem.value]} ", end="", file=file)
+        elif elem.type == GrammarElementType.CHAR:
+            print("[", end="", file=file)
             print_grammar_char(file, elem.value)
-        elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT:
-            print("[^", file=file, end="")
+        elif elem.type == GrammarElementType.CHAR_NOT:
+            print("[^", end="", file=file)
             print_grammar_char(file, elem.value)
-        elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER:
+        elif elem.type == GrammarElementType.CHAR_RNG_UPPER:
             if i == 0 or not is_char_element(rule[i - 1]):
-                raise RuntimeError(
-                    "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: "
-                    + str(rule_id)
-                    + ","
-                    + str(i)
-                )
-            print("-", file=file, end="")
+                raise ValueError(f"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: {rule_id}, {i}")
+            print(f"-", end="", file=file)
             print_grammar_char(file, elem.value)
-        elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT:
+        elif elem.type == GrammarElementType.CHAR_ALT:
             if i == 0 or not is_char_element(rule[i - 1]):
-                raise RuntimeError(
-                    "LLAMA_GRETYPE_CHAR_ALT without preceding char: "
-                    + str(rule_id)
-                    + ","
-                    + str(i)
-                )
+                raise ValueError(f"LLAMA_GRETYPE_CHAR_ALT without preceding char: {rule_id}, {i}")
             print_grammar_char(file, elem.value)
-        elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ANY:
-            print(".", file=file, end="")
-        # if (is_char_element(elem)) {
-        #     switch (rule[i + 1].type) {
-        #         case LLAMA_GRETYPE_CHAR_ALT:
-        #         case LLAMA_GRETYPE_CHAR_RNG_UPPER:
-        #             break;
-        #         default:
-        #             fprintf(file, "] ");
+        elif elem.type == GrammarElementType.CHAR_ANY:
+            print(".", end="", file=file)
         if is_char_element(elem):
-            if rule[i + 1].type in (
-                llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
-                llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
-            ):
-                pass
-            else:
-                print("] ", file=file, end="")
-    #             }
-    #         }
-    #     }
-    #     fprintf(file, "\n");
-    # }
+            if rule[i + 1].type in (GrammarElementType.CHAR_ALT, GrammarElementType.CHAR_RNG_UPPER, GrammarElementType.CHAR_ANY):
+                continue
+            print("] ", end="", file=file)
     print(file=file)
 
 
-# void print_grammar(FILE * file, const parse_state & state) {
-#     try {
-#         std::map<uint32_t, std::string> symbol_id_names;
-#         for (auto kv : state.symbol_ids) {
-#             symbol_id_names[kv.second] = kv.first;
-#         }
-#         for (size_t i = 0, end = state.rules.size(); i < end; i++) {
-#             // fprintf(file, "%zu: ", i);
-#             // print_rule_binary(file, state.rules[i]);
-#             print_rule(file, i, state.rules[i], symbol_id_names);
-#             // fprintf(file, "\n");
-#         }
-#     } catch (const std::exception & err) {
-#         fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
-#     }
-# }
-def print_grammar(file: TextIO, state: parse_state) -> None:
+def print_grammar(file: typing.TextIO, state: ParseState) -> None:
     try:
-        symbol_id_names = std.map()  # type: std.map[int, str]
-        for kv in state.symbol_ids.items():
-            symbol_id_names[kv[1]] = kv[0]
-
+        symbol_id_names = {v: k for k, v in state.symbol_ids.items()}
         for i, rule in enumerate(state.rules):
             print_rule(file, i, rule, symbol_id_names)
     except Exception as err:
-        print(
-            f"{print_grammar.__name__}: error printing grammar: {err}",
-            file=sys.stderr,
+        print(f"\nerror printing grammar: {err}", file=file)
+        raise err
+
+import ctypes
+
+class LlamaGrammar:
+    def __init__(self, parse_state: ParseState):
+        self.parse_state = parse_state
+
+        self._grammar_rules = parse_state.rules
+        self._n_rules = len(self._grammar_rules)
+        self._start_rule_index = parse_state.symbol_ids["root"]
+
+        self._element_lists = [
+            [
+                llama_cpp.llama_grammar_element(ctypes.c_int(elem.type.value), ctypes.c_uint32(elem.value))
+                for elem in subvector
+            ]
+            for subvector in self._grammar_rules
+        ]
+
+        # Step 2: Convert each list to llama_grammar_element array and get pointer
+        self._element_arrays = [
+            (llama_cpp.llama_grammar_element * len(sublist))(*sublist)
+            for sublist in self._element_lists
+        ]
+
+        # Step 3: Get pointer of each array
+        self._element_array_pointers = [
+            ctypes.cast(subarray, llama_cpp.llama_grammar_element_p) for subarray in self._element_arrays
+        ]
+
+        # Step 4: Make array of these pointers and get its pointer
+        self._rules = (llama_cpp.llama_grammar_element_p * len(self._element_array_pointers))(
+            *self._element_array_pointers
         )
 
+        self.grammar = None
+        self._init_grammar()
+
+
+    def _init_grammar(self):
+        grammar = llama_cpp.llama_grammar_init(
+            self._rules, ctypes.c_size_t(self._n_rules), ctypes.c_size_t(self._start_rule_index)
+        )
+
+        if grammar is None:
+            raise ValueError("Failed to create grammar")
+
+        self.grammar = grammar
+
+    def __del__(self):
+        if self.grammar is not None:
+            llama_cpp.llama_grammar_free(self.grammar)
+            self.grammar = None
+
+    def reset(self):
+        if self.grammar is not None:
+            llama_cpp.llama_grammar_free(self.grammar)
+        self._init_grammar()
+
+    @classmethod
+    def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
+        parsed_grammar = parse(grammar)
+        return cls(parsed_grammar)
+
+    @classmethod
+    def from_json_schema(cls, json_schema: str, verbose: bool = True) -> "LlamaGrammar":
+        return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)
+
 
 """llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
 
@@ -1564,12 +631,13 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
 string ::=
   "\"" (
     [^"\\\x7F\x00-\x1F] |
-    "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
+    "\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
   )* "\"" ws
 
-number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
+number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws
 
-ws ::= ([ \t\n] ws)?
+# Optional space: by convention, applied in this grammar after literal chars when allowed
+ws ::= | " " | "\n" [ \t]{0,20}
 """
 
 LIST_GBNF = r"""

From 7308d5392a4ae7fecc635c5ebf3176149bb8a14d Mon Sep 17 00:00:00 2001
From: Andrei Betlen <abetlen@gmail.com>
Date: Tue, 6 Aug 2024 20:15:32 -0400
Subject: [PATCH 8/8] bugfix

---
 llama_cpp/llama_grammar.py | 578 ++++++++++++++++++++++++++++++++-----
 1 file changed, 507 insertions(+), 71 deletions(-)

diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py
index 4d48e34ab..e60817e1d 100644
--- a/llama_cpp/llama_grammar.py
+++ b/llama_cpp/llama_grammar.py
@@ -1,6 +1,12 @@
 """Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
 
 # flake8: noqa
+import sys
+import ctypes
+import enum
+import typing
+import dataclasses
+
 from itertools import groupby
 from typing import (
     Any,
@@ -11,9 +17,6 @@
     Union,
 )
 
-import enum
-import typing
-
 import llama_cpp.llama_cpp as llama_cpp
 
 class GrammarElementType(enum.IntEnum):
@@ -26,19 +29,33 @@ class GrammarElementType(enum.IntEnum):
     CHAR_ALT = llama_cpp.LLAMA_GRETYPE_CHAR_ALT
     CHAR_ANY = llama_cpp.LLAMA_GRETYPE_CHAR_ANY
 
-import dataclasses
 
 @dataclasses.dataclass
 class GrammarElement:
     type: GrammarElementType
     value: int
 
+
 @dataclasses.dataclass
 class ParseState:
     symbol_ids: typing.Dict[str, int] = dataclasses.field(default_factory=dict)
     rules: typing.List[typing.List[GrammarElement]] = dataclasses.field(default_factory=list)
 
 
+# static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
+#     static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+#     uint8_t  first_byte = static_cast<uint8_t>(*src);
+#     uint8_t  highbits   = first_byte >> 4;
+#     int      len        = lookup[highbits];
+#     uint8_t  mask       = (1 << (8 - len)) - 1;
+#     uint32_t value      = first_byte & mask;
+#     const char * end    = src + len; // may overrun!
+#     const char * pos    = src + 1;
+#     for ( ; pos < end && *pos; pos++) {
+#         value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+#     }
+#     return std::make_pair(value, pos);
+# }
 def decode_utf8(src: str) -> typing.Tuple[int, str]:
     lookup: list[int] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]
     first_byte: int = ord(src[0])
@@ -57,31 +74,78 @@ def decode_utf8(src: str) -> typing.Tuple[int, str]:
     return value, src[pos:] if pos < len(src) else ""
 
 
+# static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
+#     uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+#     auto result = state.symbol_ids.emplace(std::string(src, len), next_id);
+#     return result.first->second;
+# }
 def get_symbol_id(state: ParseState, name: str) -> int:
     next_id = len(state.symbol_ids)
     return state.symbol_ids.setdefault(name, next_id)
 
 
+# static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
+#     uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
+#     state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
+#     return next_id;
+# }
 def generate_symbol_id(state: ParseState, base_name: str) -> int:
     next_id = len(state.symbol_ids)
     state.symbol_ids[f"{base_name}_{next_id}"] = next_id
     return next_id
 
 
+# static void add_rule(
+#         parse_state & state,
+#         uint32_t      rule_id,
+#         const std::vector<llama_grammar_element> & rule) {
+#     if (state.rules.size() <= rule_id) {
+#         state.rules.resize(rule_id + 1);
+#     }
+#     state.rules[rule_id] = rule;
+# }
 def add_rule(state: ParseState, rule_id: int, rule: typing.List[GrammarElement]) -> None:
     if len(state.rules) <= rule_id:
         state.rules.extend([[]] * (rule_id + 1 - len(state.rules)))
     state.rules[rule_id] = rule
 
 
+# static bool is_digit_char(char c) {
+#     return '0' <= c && c <= '9';
+# }
 def is_digit_char(c: str) -> bool:
     return "0" <= c <= "9"
 
 
+# static bool is_word_char(char c) {
+#     return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
+# }
 def is_word_char(c: str) -> bool:
     return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or is_digit_char(c)
 
 
+# static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
+#     const char * pos   = src;
+#     const char * end   = src + size;
+#     uint32_t     value = 0;
+#     for ( ; pos < end && *pos; pos++) {
+#         value <<= 4;
+#         char c = *pos;
+#         if ('a' <= c && c <= 'f') {
+#             value += c - 'a' + 10;
+#         } else if ('A' <= c && c <= 'F') {
+#             value += c - 'A' + 10;
+#         } else if ('0' <= c && c <= '9') {
+#             value += c - '0';
+#         } else {
+#             break;
+#         }
+#     }
+#     if (pos != end) {
+#         throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
+#     }
+#     return std::make_pair(value, pos);
+# }
 def parse_hex(src: str, size: int) -> typing.Tuple[int, str]:
     pos = 0
     value = 0
@@ -102,38 +166,93 @@ def parse_hex(src: str, size: int) -> typing.Tuple[int, str]:
     return value, src[pos:]
 
 
+# static const char * parse_space(const char * src, bool newline_ok) {
+#     const char * pos = src;
+#     while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
+#             (newline_ok && (*pos == '\r' || *pos == '\n'))) {
+#         if (*pos == '#') {
+#             while (*pos && *pos != '\r' && *pos != '\n') {
+#                 pos++;
+#             }
+#         } else {
+#             pos++;
+#         }
+#     }
+#     return pos;
+# }
 def parse_space(src: str, newline_ok: bool) -> str:
-    pos = 0
-    while pos < len(src) and (src[pos] in (" ", "\t", "#") or (newline_ok and src[pos] in ("\r", "\n"))):
-        if src[pos] == "#":
-            while pos < len(src) and src[pos] not in ("\r", "\n"):
-                pos += 1
-        pos += 1
-    return src[pos:]
+    pos = src
+    while pos and (pos[0] in (' ', '\t', '#') or (newline_ok and pos[0] in ('\r', '\n'))):
+        if pos[0] == "#":
+            while pos and pos[0] not in ("\r", "\n"):
+                pos = pos[1:]
+        else:
+            pos = pos[1:]
+    return pos
 
 
+# static const char * parse_name(const char * src) {
+#     const char * pos = src;
+#     while (is_word_char(*pos)) {
+#         pos++;
+#     }
+#     if (pos == src) {
+#         throw std::runtime_error(std::string("expecting name at ") + src);
+#     }
+#     return pos;
+# }
 def parse_name(src: str) -> typing.Tuple[str, str]:
-    pos = 0
-    try:
-        while is_word_char(src[pos]):
-            pos += 1
-    except IndexError:
-        return src, ""
-    if pos == 0:
+    pos = src
+    while pos and is_word_char(pos[0]):
+        pos = pos[1:]
+    if pos == src:
         raise ValueError(f"expecting name at {src}")
-    return src[:pos], src[pos:]
-
-
+    return src[:len(src) - len(pos)], pos
+
+# static const char * parse_int(const char * src) {
+#     const char * pos = src;
+#     while (is_digit_char(*pos)) {
+#         pos++;
+#     }
+#     if (pos == src) {
+#         throw std::runtime_error(std::string("expecting integer at ") + src);
+#     }
+#     return pos;
+# }
 def parse_int(src: str) -> typing.Tuple[int, str]:
-    pos = 0
-    while is_digit_char(src[pos]):
-        pos += 1
-    if pos == 0:
+    pos = src
+    while pos and is_digit_char(pos[0]):
+        pos = pos[1:]
+    if pos == src:
         raise ValueError(f"expecting integer at {src}")
-    return int(src[:pos]), src[pos:]
-
-
+    return int(src[:len(src) - len(pos)]), pos
+
+
+# static std::pair<uint32_t, const char *> parse_char(const char * src) {
+#     if (*src == '\\') {
+#         switch (src[1]) {
+#             case 'x': return parse_hex(src + 2, 2);
+#             case 'u': return parse_hex(src + 2, 4);
+#             case 'U': return parse_hex(src + 2, 8);
+#             case 't': return std::make_pair('\t', src + 2);
+#             case 'r': return std::make_pair('\r', src + 2);
+#             case 'n': return std::make_pair('\n', src + 2);
+#             case '\\':
+#             case '"':
+#             case '[':
+#             case ']':
+#                 return std::make_pair(src[1], src + 2);
+#             default:
+#                 throw std::runtime_error(std::string("unknown escape at ") + src);
+#         }
+#     } else if (*src) {
+#         return decode_utf8(src);
+#     }
+#     throw std::runtime_error("unexpected end of input");
+# }
 def parse_char(src: str) -> typing.Tuple[int, str]:
+    if not src:
+        raise ValueError("unexpected end of input")
     if src[0] == "\\":
         if src[1] == "x":
             return parse_hex(src[2:], 2)
@@ -151,33 +270,202 @@ def parse_char(src: str) -> typing.Tuple[int, str]:
             return ord(src[1]), src[2:]
         else:
             raise ValueError(f"unknown escape at {src}")
-    elif src:
-        return decode_utf8(src)
-    raise ValueError("unexpected end of input")
-
-
+    return decode_utf8(src)
+
+# static const char * parse_sequence(
+#         parse_state                        & state,
+#         const char                         * src,
+#         const std::string                  & rule_name,
+#         std::vector<llama_grammar_element> & out_elements,
+#         bool                                 is_nested) {
+#     size_t last_sym_start = out_elements.size();
+#     const char * pos = src;
+#
+#     auto handle_repetitions = [&](int min_times, int max_times) {
+#
+#         if (last_sym_start == out_elements.size()) {
+#             throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
+#         }
+#
+#         // apply transformation to previous symbol (last_sym_start to end) according to
+#         // the following rewrite rules:
+#         // S{m,n} --> S S S (m times) S'(n-m)
+#         //            S'(x)   ::= S S'(x-1) |
+#         //            (... n-m definitions of these S' rules ...)
+#         //            S'(1)   ::= S |
+#         // S{m,} -->  S S S (m times) S'
+#         //            S'     ::= S S' |
+#         // S*     --> S{0,}
+#         //        --> S'     ::= S S' |
+#         // S+     --> S{1,}
+#         //        --> S S'
+#         //            S'     ::= S S' |
+#         // S?     --> S{0,1}
+#         //        --> S'
+#         //            S'     ::= S |
+#
+#         std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
+#         if (min_times == 0) {
+#             out_elements.resize(last_sym_start);
+#         } else {
+#             // Repeat the previous elements (min_times - 1) times
+#             for (int i = 1; i < min_times; i++) {
+#                 out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
+#             }
+#         }
+#
+#         uint32_t last_rec_rule_id = 0;
+#         auto n_opt = max_times < 0 ? 1 : max_times - min_times;
+#
+#         std::vector<llama_grammar_element> rec_rule(previous_elements);
+#         for (int i = 0; i < n_opt; i++) {
+#             rec_rule.resize(previous_elements.size());
+#             uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
+#             if (i > 0 || max_times < 0) {
+#                 rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
+#             }
+#             rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+#             rec_rule.push_back({LLAMA_GRETYPE_END, 0});
+#             add_rule(state, rec_rule_id, rec_rule);
+#             last_rec_rule_id = rec_rule_id;
+#         }
+#         if (n_opt > 0) {
+#             out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+#         }
+#     };
+#
+#     while (*pos) {
+#         if (*pos == '"') { // literal string
+#             pos++;
+#             last_sym_start = out_elements.size();
+#             while (*pos != '"') {
+#                 if (!*pos) {
+#                     throw std::runtime_error("unexpected end of input");
+#                 }
+#                 auto char_pair = parse_char(pos);
+#                      pos       = char_pair.second;
+#                 out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+#             }
+#             pos = parse_space(pos + 1, is_nested);
+#         } else if (*pos == '[') { // char range(s)
+#             pos++;
+#             enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
+#             if (*pos == '^') {
+#                 pos++;
+#                 start_type = LLAMA_GRETYPE_CHAR_NOT;
+#             }
+#             last_sym_start = out_elements.size();
+#             while (*pos != ']') {
+#                 if (!*pos) {
+#                     throw std::runtime_error("unexpected end of input");
+#                 }
+#                 auto char_pair = parse_char(pos);
+#                      pos       = char_pair.second;
+#                 enum llama_gretype type = last_sym_start < out_elements.size()
+#                     ? LLAMA_GRETYPE_CHAR_ALT
+#                     : start_type;
+#
+#                 out_elements.push_back({type, char_pair.first});
+#                 if (pos[0] == '-' && pos[1] != ']') {
+#                     if (!pos[1]) {
+#                         throw std::runtime_error("unexpected end of input");
+#                     }
+#                     auto endchar_pair = parse_char(pos + 1);
+#                          pos          = endchar_pair.second;
+#                     out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
+#                 }
+#             }
+#             pos = parse_space(pos + 1, is_nested);
+#         } else if (is_word_char(*pos)) { // rule reference
+#             const char * name_end    = parse_name(pos);
+#             uint32_t     ref_rule_id = get_symbol_id(state, pos, name_end - pos);
+#             pos = parse_space(name_end, is_nested);
+#             last_sym_start = out_elements.size();
+#             out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
+#         } else if (*pos == '(') { // grouping
+#             // parse nested alternates into synthesized rule
+#             pos = parse_space(pos + 1, true);
+#             uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+#             pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
+#             last_sym_start = out_elements.size();
+#             // output reference to synthesized rule
+#             out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+#             if (*pos != ')') {
+#                 throw std::runtime_error(std::string("expecting ')' at ") + pos);
+#             }
+#             pos = parse_space(pos + 1, is_nested);
+#         } else if (*pos == '.') { // any char
+#             last_sym_start = out_elements.size();
+#             out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
+#             pos = parse_space(pos + 1, is_nested);
+#         } else if (*pos == '*') {
+#             pos = parse_space(pos + 1, is_nested);
+#             handle_repetitions(0, -1);
+#         } else if (*pos == '+') {
+#             pos = parse_space(pos + 1, is_nested);
+#             handle_repetitions(1, -1);
+#         } else if (*pos == '?') {
+#             pos = parse_space(pos + 1, is_nested);
+#             handle_repetitions(0, 1);
+#         } else if (*pos == '{') {
+#             pos = parse_space(pos + 1, is_nested);
+#
+#             if (!is_digit_char(*pos)) {
+#                 throw std::runtime_error(std::string("expecting an int at ") + pos);
+#             }
+#             const char * int_end = parse_int(pos);
+#             int min_times = std::stoul(std::string(pos, int_end - pos));
+#             pos = parse_space(int_end, is_nested);
+#
+#             int max_times = -1;
+#
+#             if (*pos == '}') {
+#                 max_times = min_times;
+#                 pos = parse_space(pos + 1, is_nested);
+#             } else if (*pos == ',') {
+#                 pos = parse_space(pos + 1, is_nested);
+#
+#                 if (is_digit_char(*pos)) {
+#                     const char * int_end = parse_int(pos);
+#                     max_times = std::stoul(std::string(pos, int_end - pos));
+#                     pos = parse_space(int_end, is_nested);
+#                 }
+#
+#                 if (*pos != '}') {
+#                     throw std::runtime_error(std::string("expecting '}' at ") + pos);
+#                 }
+#                 pos = parse_space(pos + 1, is_nested);
+#             } else {
+#                 throw std::runtime_error(std::string("expecting ',' at ") + pos);
+#             }
+#             handle_repetitions(min_times, max_times);
+#         } else {
+#             break;
+#         }
+#     }
+#     return pos;
+# }
 def parse_sequence(state: ParseState, src: str, rule_name: str, out_elements: typing.List[GrammarElement], is_nested: bool) -> str:
     last_sym_start = len(out_elements)
     pos = src
 
     def handle_repetitions(min_times: int, max_times: int) -> None:
-        nonlocal last_sym_start
-        nonlocal pos
-        nonlocal out_elements
+        nonlocal state, src, rule_name, out_elements, is_nested, last_sym_start, pos
 
         if last_sym_start == len(out_elements):
             raise ValueError(f"expecting preceding item to */+/?/{{ at {pos}")
 
         previous_elements = out_elements[last_sym_start:]
         if min_times == 0:
-            out_elements = out_elements[:last_sym_start]
+            del out_elements[last_sym_start:]
         else:
             for i in range(1, min_times):
                 out_elements.extend(previous_elements)
+
         last_rec_rule_id = 0
         n_opt = 1 if max_times < 0 else max_times - min_times
 
-        rec_rule = list(previous_elements)
+        rec_rule = previous_elements[:]
         for i in range(n_opt):
             rec_rule = rec_rule[:len(previous_elements)]
             rec_rule_id = generate_symbol_id(state, rule_name)
@@ -191,21 +479,21 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
             out_elements.append(GrammarElement(GrammarElementType.RULE_REF, last_rec_rule_id))
 
     while pos:
-        if pos.startswith('"'):
+        if pos[0] == '"':
             pos = pos[1:]
             last_sym_start = len(out_elements)
-            while pos[0] != '"':
+            while not pos.startswith('"'):
                 if not pos:
                     raise ValueError("unexpected end of input")
                 char, pos = parse_char(pos)
                 out_elements.append(GrammarElement(GrammarElementType.CHAR, char))
             pos = parse_space(pos[1:], is_nested)
-        elif pos.startswith("["):
+        elif pos[0] == "[":
             pos = pos[1:]
             start_type = GrammarElementType.CHAR
             if pos[0] == "^":
-                start_type = GrammarElementType.CHAR_NOT
                 pos = pos[1:]
+                start_type = GrammarElementType.CHAR_NOT
             last_sym_start = len(out_elements)
             while pos[0] != "]":
                 if not pos:
@@ -219,7 +507,7 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
                     endchar, pos = parse_char(pos[1:])
                     out_elements.append(GrammarElement(GrammarElementType.CHAR_RNG_UPPER, endchar))
             pos = parse_space(pos[1:], is_nested)
-        elif is_word_char(pos[0]):
+        elif pos and is_word_char(pos[0]):
             name, rest = parse_name(pos)
             ref_rule_id = get_symbol_id(state, name)
             pos = parse_space(rest, is_nested)
@@ -249,7 +537,8 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
             handle_repetitions(0, 1)
         elif pos.startswith("{"):
             pos = parse_space(pos[1:], is_nested)
-            if not is_digit_char(pos[0]):
+
+            if not is_digit_char(pos):
                 raise ValueError(f"expecting an int at {pos}")
             min_times, pos = parse_int(pos)
             pos = parse_space(pos, is_nested)
@@ -261,11 +550,14 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
                 pos = parse_space(pos[1:], is_nested)
             elif pos[0] == ",":
                 pos = parse_space(pos[1:], is_nested)
-                if is_digit_char(pos[0]):
+
+                if is_digit_char(pos):
                     max_times, pos = parse_int(pos)
                     pos = parse_space(pos, is_nested)
+
                 if pos[0] != "}":
                     raise ValueError("expecting '}' at {}".format(pos))
+
                 pos = parse_space(pos[1:], is_nested)
             else:
                 raise ValueError(f"expecting ',' at {pos}")
@@ -275,6 +567,23 @@ def handle_repetitions(min_times: int, max_times: int) -> None:
     return pos
 
 
+# const char * parse_alternates(
+#         parse_state       & state,
+#         const char        * src,
+#         const std::string & rule_name,
+#         uint32_t            rule_id,
+#         bool                is_nested) {
+#     std::vector<llama_grammar_element> rule;
+#     const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
+#     while (*pos == '|') {
+#         rule.push_back({LLAMA_GRETYPE_ALT, 0});
+#         pos = parse_space(pos + 1, true);
+#         pos = parse_sequence(state, pos, rule_name, rule, is_nested);
+#     }
+#     rule.push_back({LLAMA_GRETYPE_END, 0});
+#     add_rule(state, rule_id, rule);
+#     return pos;
+# }
 def parse_alternates(state: ParseState, src: str, rule_name: str, rule_id: int, is_nested: bool) -> str:
     rule = []
     pos = parse_sequence(state, src, rule_name, rule, is_nested)
@@ -287,34 +596,86 @@ def parse_alternates(state: ParseState, src: str, rule_name: str, rule_id: int,
     return pos
 
 
+# static const char * parse_rule(parse_state & state, const char * src) {
+#     const char * name_end = parse_name(src);
+#     const char * pos      = parse_space(name_end, false);
+#     size_t       name_len = name_end - src;
+#     uint32_t     rule_id  = get_symbol_id(state, src, name_len);
+#     const std::string name(src, name_len);
+#
+#     if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+#         throw std::runtime_error(std::string("expecting ::= at ") + pos);
+#     }
+#     pos = parse_space(pos + 3, true);
+#
+#     pos = parse_alternates(state, pos, name, rule_id, false);
+#
+#     if (*pos == '\r') {
+#         pos += pos[1] == '\n' ? 2 : 1;
+#     } else if (*pos == '\n') {
+#         pos++;
+#     } else if (*pos) {
+#         throw std::runtime_error(std::string("expecting newline or end at ") + pos);
+#     }
+#     return parse_space(pos, true);
+# }
 def parse_rule(state: ParseState, src: str) -> str:
-    name, s = parse_name(src)
-    s = parse_space(s, newline_ok=False)
+    pos = src
+    name, pos = parse_name(pos)
+    pos = parse_space(pos, newline_ok=False)
     rule_id = get_symbol_id(state, name)
 
-    if not s.startswith("::="):
-        raise ValueError(f"expecting ::= at {s}")
-
-    s = s[3:]
-
-    s = parse_space(s, newline_ok=True)
-
-    s = parse_alternates(state, s, name, rule_id, is_nested=False)
-
-    if s.startswith("\r"):
-        s = s[2:] if s[1] == "\n" else s[1:]
-    elif s.startswith("\n"):
-        s = s[1:]
-    elif s:
-        raise ValueError(f"expecting newline or end at {s}")
-    return parse_space(s, newline_ok=True)
-
-
-def parse(gbnf: str) -> ParseState:
+    if not pos.startswith("::="):
+        raise ValueError(f"expecting ::= at {pos}")
+
+    pos = parse_space(pos[3:], newline_ok=True)
+
+    pos = parse_alternates(state, pos, name, rule_id, is_nested=False)
+
+    if pos.startswith("\r"):
+        pos = pos[2:] if pos[1] == "\n" else pos[1:]
+    elif pos.startswith("\n"):
+        pos = pos[1:]
+    elif pos:
+        raise ValueError(f"expecting newline or end at {pos}")
+    return parse_space(pos, newline_ok=True)
+
+
+# parse_state parse(const char * src) {
+#     try {
+#         parse_state state;
+#         const char * pos = parse_space(src, true);
+#         while (*pos) {
+#             pos = parse_rule(state, pos);
+#         }
+#         // Validate the state to ensure that all rules are defined
+#         for (const auto & rule : state.rules) {
+#             for (const auto & elem : rule) {
+#                 if (elem.type == LLAMA_GRETYPE_RULE_REF) {
+#                     // Ensure that the rule at that location exists
+#                     if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
+#                         // Get the name of the rule that is missing
+#                         for (const auto & kv : state.symbol_ids) {
+#                             if (kv.second == elem.value) {
+#                                 throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
+#                             }
+#                         }
+#                     }
+#                 }
+#             }
+#         }
+#         return state;
+#     } catch (const std::exception & err) {
+#         fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+#         return parse_state();
+#     }
+# }
+def parse(src: str) -> ParseState:
     state = ParseState()
-    s = parse_space(gbnf, newline_ok=True)
-    while s:
-        s = parse_rule(state, s)
+    pos = src
+    pos = parse_space(pos, newline_ok=True)
+    while pos:
+        pos = parse_rule(state, pos)
     # validate
     for rule in state.rules:
         for elem in rule:
@@ -326,6 +687,16 @@ def parse(gbnf: str) -> ParseState:
     return state
 
 
+# static bool is_char_element(llama_grammar_element elem) {
+#     switch (elem.type) {
+#         case LLAMA_GRETYPE_CHAR:           return true;
+#         case LLAMA_GRETYPE_CHAR_NOT:       return true;
+#         case LLAMA_GRETYPE_CHAR_ALT:       return true;
+#         case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
+#         case LLAMA_GRETYPE_CHAR_ANY:       return true;
+#         default:                           return false;
+#     }
+# }
 def is_char_element(elem: GrammarElement) -> bool:
     return elem.type in (
         GrammarElementType.CHAR, 
@@ -343,6 +714,71 @@ def print_grammar_char(file: typing.TextIO, c: int) -> None:
         print(f"<U+{c:04X}>", end="", file=file)
 
 
+# static void print_rule(
+#         FILE     * file,
+#         uint32_t   rule_id,
+#         const std::vector<llama_grammar_element> & rule,
+#         const std::map<uint32_t, std::string>    & symbol_id_names) {
+#     if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
+#         throw std::runtime_error(
+#             "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
+#     }
+#     fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
+#     for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
+#         llama_grammar_element elem = rule[i];
+#         switch (elem.type) {
+#             case LLAMA_GRETYPE_END:
+#                 throw std::runtime_error(
+#                     "unexpected end of rule: " + std::to_string(rule_id) + "," +
+#                     std::to_string(i));
+#             case LLAMA_GRETYPE_ALT:
+#                 fprintf(file, "| ");
+#                 break;
+#             case LLAMA_GRETYPE_RULE_REF:
+#                 fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
+#                 break;
+#             case LLAMA_GRETYPE_CHAR:
+#                 fprintf(file, "[");
+#                 print_grammar_char(file, elem.value);
+#                 break;
+#             case LLAMA_GRETYPE_CHAR_NOT:
+#                 fprintf(file, "[^");
+#                 print_grammar_char(file, elem.value);
+#                 break;
+#             case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+#                 if (i == 0 || !is_char_element(rule[i - 1])) {
+#                     throw std::runtime_error(
+#                         "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
+#                         std::to_string(rule_id) + "," + std::to_string(i));
+#                 }
+#                 fprintf(file, "-");
+#                 print_grammar_char(file, elem.value);
+#                 break;
+#             case LLAMA_GRETYPE_CHAR_ALT:
+#                 if (i == 0 || !is_char_element(rule[i - 1])) {
+#                     throw std::runtime_error(
+#                         "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
+#                         std::to_string(rule_id) + "," + std::to_string(i));
+#                 }
+#                 print_grammar_char(file, elem.value);
+#                 break;
+#             case LLAMA_GRETYPE_CHAR_ANY:
+#                 fprintf(file, ".");
+#                 break;
+#         }
+#         if (is_char_element(elem)) {
+#             switch (rule[i + 1].type) {
+#                 case LLAMA_GRETYPE_CHAR_ALT:
+#                 case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+#                 case LLAMA_GRETYPE_CHAR_ANY:
+#                     break;
+#                 default:
+#                     fprintf(file, "] ");
+#             }
+#         }
+#     }
+#     fprintf(file, "\n");
+# }
 def print_rule(
     file: typing.TextIO,
     rule_id: int,
@@ -394,7 +830,6 @@ def print_grammar(file: typing.TextIO, state: ParseState) -> None:
         print(f"\nerror printing grammar: {err}", file=file)
         raise err
 
-import ctypes
 
 class LlamaGrammar:
     def __init__(self, parse_state: ParseState):
@@ -406,7 +841,7 @@ def __init__(self, parse_state: ParseState):
 
         self._element_lists = [
             [
-                llama_cpp.llama_grammar_element(ctypes.c_int(elem.type.value), ctypes.c_uint32(elem.value))
+                llama_cpp.llama_grammar_element(ctypes.c_int(elem.type), ctypes.c_uint32(elem.value))
                 for elem in subvector
             ]
             for subvector in self._grammar_rules
@@ -455,6 +890,7 @@ def reset(self):
     @classmethod
     def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
         parsed_grammar = parse(grammar)
+        print_grammar(file=sys.stdout, state=parsed_grammar)
         return cls(parsed_grammar)
 
     @classmethod