Skip to content

Commit ab9a324

Browse files
authored
JSON schema conversion: ⚡️ faster repetitions, min/maxLength for strings, cap number length (#6555)
* json: rename python schema converter to make import easier * server: skip null json_schema / grammar fields * json: deps management for primitive rules (+ allow null values) * json: optimize repetitions for minItems/maxItems and regexps: `a{,3}` goes from `"a"? "a"? "a"?` (explosive combos) to `(a (a (a)?)?)?` * grammars: add troubleshooting section to readme * json: cap length of numbers to 15 digits before/after decimal point (avoids infinite gen, e.g. "one third" -> `0.333333333333...`) * json: unify all repetition code (w/ or w/o sep) * json: support string minLength/maxLength * server+json: update server/README w/ result_format * nits * json: fix type error w/ python 3.8 * json: fix server/README (json_schema in /completion vs. result_format in /v1/chat/completions) * json: simplify DOT `{"type": "string", "pattern": "^.$"}` * json: remove recursion in opt_repetitions (avoids Python stack overflow) * json: rm dead code * json: rm useless assert & ggml.h import
1 parent fbbc030 commit ab9a324

10 files changed

+2326
-1907
lines changed

common/json-schema-to-grammar.cpp

+138-95
Large diffs are not rendered by default.

examples/json-schema-to-grammar.py examples/json_schema_to_grammar.py

+139-73
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,94 @@
66
import sys
77
from typing import Any, Dict, List, Set, Tuple, Union
88

9+
def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False):
10+
if not separator_rule:
11+
if min_items == 0 and max_items == 1:
12+
return f'{item_rule}?'
13+
elif min_items == 1 and max_items is None:
14+
return f'{item_rule}+'
15+
16+
result = ''
17+
18+
if min_items > 0:
19+
if item_rule_is_literal and separator_rule is None:
20+
result = '"' + (item_rule[1:-1] * min_items) + '"'
21+
else:
22+
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)
23+
24+
def opt_repetitions(up_to_n, prefix_with_sep=False):
25+
'''
26+
- n=4, no sep: '(a (a (a (a)?)?)?)?'
27+
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
28+
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
29+
'''
30+
31+
content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule
32+
if up_to_n == 0:
33+
return ''
34+
elif up_to_n == 1:
35+
return f'({content})?'
36+
elif separator_rule and not prefix_with_sep:
37+
return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?'
38+
else:
39+
return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n)
40+
41+
if min_items > 0 and max_items != min_items:
42+
result += ' '
43+
44+
if max_items is not None:
45+
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
46+
else:
47+
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
48+
49+
if min_items == 0 and separator_rule:
50+
result = f'({item_rule} {item_operator}*)?'
51+
else:
52+
result += f'{item_operator}*'
53+
54+
return result
55+
56+
57+
class BuiltinRule:
58+
def __init__(self, content: str, deps: list = None):
59+
self.content = content
60+
self.deps = deps or []
61+
62+
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
63+
964
# whitespace is constrained to a single space char to prevent model "running away" in
1065
# whitespace. Also maybe improves generation quality?
1166
SPACE_RULE = '" "?'
1267

1368
PRIMITIVE_RULES = {
14-
'boolean': '("true" | "false") space',
15-
'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
16-
'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space',
17-
'value' : 'object | array | string | number | boolean',
18-
'object' : '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
19-
'array' : '"[" space ( value ("," space value)* )? "]" space',
20-
'uuid' : '"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space',
21-
'string': r''' "\"" (
22-
[^"\\] |
23-
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
24-
)* "\"" space''',
25-
'null': '"null" space',
69+
'boolean' : BuiltinRule('("true" | "false") space', []),
70+
'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []),
71+
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []),
72+
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
73+
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
74+
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
75+
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
76+
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
77+
'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []),
78+
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []),
79+
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
80+
'null' : BuiltinRule('"null" space', []),
2681
}
27-
OBJECT_RULE_NAMES = ['object', 'array', 'string', 'number', 'boolean', 'null', 'value']
2882

2983
# TODO: support "uri", "email" string formats
30-
DATE_RULES = {
31-
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )',
32-
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
33-
'date-time': 'date "T" time',
34-
'date-string': '"\\"" date "\\"" space',
35-
'time-string': '"\\"" time "\\"" space',
36-
'date-time-string': '"\\"" date-time "\\"" space',
84+
STRING_FORMAT_RULES = {
85+
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
86+
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
87+
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
88+
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
89+
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
90+
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
3791
}
3892

39-
RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *DATE_RULES.keys()])
93+
DOTALL = '[\\U00000000-\\U0010FFFF]'
94+
DOT = '[^\\x0A\\x0D]'
95+
96+
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
4097

4198
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
4299
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
@@ -46,16 +103,16 @@
46103
NON_LITERAL_SET = set('|.()[]{}*+?')
47104
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
48105

49-
DATE_PATTERN = '[0-9]{4}-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])'
50-
TIME_PATTERN = '([01][0-9]|2[0-3])(:[0-5][0-9]){2}(\\.[0-9]{1,3})?(Z|[+-](([01][0-9]|2[0-3]):[0-5][0-9]))' # Cap millisecond precision w/ 3 digits
51106

52107
class SchemaConverter:
53108
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
54109
self._prop_order = prop_order
55110
self._allow_fetch = allow_fetch
56111
self._dotall = dotall
57112
self._raw_pattern = raw_pattern
58-
self._rules = {'space': SPACE_RULE}
113+
self._rules = {
114+
'space': SPACE_RULE,
115+
}
59116
self._refs = {}
60117
self._refs_being_resolved = set()
61118

@@ -65,6 +122,29 @@ def _format_literal(self, literal):
65122
)
66123
return f'"{escaped}"'
67124

125+
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
126+
'''
127+
not_literal('a') -> '[^a]'
128+
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
129+
'''
130+
assert len(literal) > 0, 'Empty literal not supported'
131+
def recurse(i: int):
132+
c = literal[i]
133+
if maybe_escaped_underscores and c == '_':
134+
yield f'[^{c}\\\\]'
135+
yield ' | '
136+
yield f'"\\\\"? "{c}"'
137+
else:
138+
yield f'[^{c}]'
139+
if i < len(literal) - 1:
140+
yield ' | '
141+
yield self._format_literal(c)
142+
yield ' ('
143+
yield from recurse(i + 1)
144+
yield ')?'
145+
146+
return ''.join(('(', *recurse(0), ')'))
147+
68148
def _add_rule(self, name, rule):
69149
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
70150
if esc_name not in self._rules or self._rules[esc_name] == rule:
@@ -169,10 +249,10 @@ def transform() -> Tuple[str, bool]:
169249

170250
def get_dot():
171251
if self._dotall:
172-
rule = '[\\U00000000-\\U0010FFFF]'
252+
rule = DOTALL
173253
else:
174254
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
175-
rule = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]'
255+
rule = DOT
176256
return self._add_rule(f'dot', rule)
177257

178258
def join_seq():
@@ -246,26 +326,14 @@ def join_seq():
246326

247327
(sub, sub_is_literal) = seq[-1]
248328

249-
if min_times == 0 and max_times is None:
250-
seq[-1] = (f'{sub}*', False)
251-
elif min_times == 0 and max_times == 1:
252-
seq[-1] = (f'{sub}?', False)
253-
elif min_times == 1 and max_times is None:
254-
seq[-1] = (f'{sub}+', False)
255-
else:
256-
if not sub_is_literal:
257-
id = sub_rule_ids.get(sub)
258-
if id is None:
259-
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
260-
sub_rule_ids[sub] = id
261-
sub = id
262-
263-
seq[-1] = (
264-
' '.join(
265-
([f'"{sub[1:-1] * min_times}"'] if sub_is_literal else [sub] * min_times) +
266-
([f'{sub}?'] * (max_times - min_times) if max_times is not None else [f'{sub}*'])),
267-
False
268-
)
329+
if not sub_is_literal:
330+
id = sub_rule_ids.get(sub)
331+
if id is None:
332+
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
333+
sub_rule_ids[sub] = id
334+
sub = id
335+
336+
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False)
269337
else:
270338
literal = ''
271339
while i < length:
@@ -373,49 +441,47 @@ def add_component(comp_schema, is_required):
373441
' "]" space')
374442
else:
375443
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
376-
list_item_operator = f'( "," space {item_rule_name} )'
377-
successive_items = ""
378444
min_items = schema.get("minItems", 0)
379445
max_items = schema.get("maxItems")
380-
if min_items > 0:
381-
successive_items = list_item_operator * (min_items - 1)
382-
min_items -= 1
383-
if max_items is not None and max_items > min_items:
384-
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
385-
else:
386-
successive_items += list_item_operator + "*"
387-
if min_items == 0:
388-
rule = f'"[" space ( {item_rule_name} {successive_items} )? "]" space'
389-
else:
390-
rule = f'"[" space {item_rule_name} {successive_items} "]" space'
391-
return self._add_rule(rule_name, rule)
446+
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
392447

393448
elif schema_type in (None, 'string') and 'pattern' in schema:
394449
return self._visit_pattern(schema['pattern'], rule_name)
395450

396451
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
397-
return self._add_rule(
452+
return self._add_primitive(
398453
'root' if rule_name == 'root' else schema_format,
399454
PRIMITIVE_RULES['uuid']
400455
)
401456

402-
elif schema_type in (None, 'string') and schema_format in DATE_RULES:
403-
for t, r in DATE_RULES.items():
404-
self._add_rule(t, r)
405-
return schema_format + '-string'
457+
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
458+
prim_name = f'{schema_format}-string'
459+
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
460+
461+
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
462+
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
463+
min_len = schema.get('minLength', 0)
464+
max_len = schema.get('maxLength')
465+
466+
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
406467

407468
elif (schema_type == 'object') or (len(schema) == 0):
408-
for n in OBJECT_RULE_NAMES:
409-
self._add_rule(n, PRIMITIVE_RULES[n])
410-
return self._add_rule(rule_name, 'object')
469+
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
411470

412471
else:
413472
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
414473
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
415-
return self._add_rule(
416-
'root' if rule_name == 'root' else schema_type,
417-
PRIMITIVE_RULES[schema_type]
418-
)
474+
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
475+
476+
def _add_primitive(self, name: str, rule: BuiltinRule):
477+
n = self._add_rule(name, rule.content)
478+
479+
for dep in rule.deps:
480+
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
481+
assert dep_rule, f'Rule {dep} not known'
482+
if dep not in self._rules:
483+
self._add_primitive(dep, dep_rule)
484+
return n
419485

420486
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
421487
prop_order = self._prop_order
@@ -437,7 +503,7 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st
437503
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
438504
prop_kv_rule_names["*"] = self._add_rule(
439505
f'{sub_name}-kv',
440-
self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
506+
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
441507
)
442508
optional_props.append("*")
443509

examples/regex-to-grammar.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"python",
99
os.path.join(
1010
os.path.dirname(os.path.realpath(__file__)),
11-
"json-schema-to-grammar.py"),
11+
"json_schema_to_grammar.py"),
1212
*rest,
1313
"-",
1414
"--raw-pattern",

examples/server/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
1111
* Continuous batching
1212
* Multimodal (wip)
1313
* Monitoring endpoints
14+
* Schema-constrained JSON response format
1415

1516
The project is under active development, and we are [looking for feedback and contributors](https://github.com/ggerganov/llama.cpp/issues/4216).
1617

@@ -250,6 +251,8 @@ node index.js
250251

251252
`grammar`: Set grammar for grammar-based sampling. Default: no grammar
252253

254+
`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema.
255+
253256
`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed.
254257

255258
`ignore_eos`: Ignore end of stream token and continue generating. Default: `false`
@@ -365,6 +368,8 @@ Notice that each `probs` is an array of length `n_probs`.
365368

366369
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). While some OpenAI-specific features such as function calling aren't supported, llama.cpp `/completion`-specific features such as `mirostat` are supported.
367370

371+
The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}`), similar to other OpenAI-inspired API providers.
372+
368373
*Examples:*
369374

370375
You can use either Python `openai` library with appropriate checkpoints:

0 commit comments

Comments
 (0)