Skip to content

Commit 1d7cc4e

Browse files
ochafikMagnusS0
authored andcommitted
json: fix additionalProperties, allow space after enum/const (ggml-org#7840)
* json: default additionalProperty to true * json: don't force additional props after normal properties! * json: allow space after enum/const * json: update pydantic example to set additionalProperties: false * json: prevent additional props to redefine a typed prop * port not_strings to python, add trailing space * fix not_strings & port to js+py * Update json-schema-to-grammar.cpp * fix _not_strings for substring overlaps * json: fix additionalProperties default, uncomment tests * json: add integ. test case for additionalProperties * json: nit: simplify condition * reformat grammar integ tests w/ R"""()""" strings where there's escapes * update # tokens in server test: consts can now have trailing space
1 parent 753b1b5 commit 1d7cc4e

7 files changed

+497
-245
lines changed

common/json-schema-to-grammar.cpp

+86-13
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,75 @@ class SchemaConverter {
614614
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
615615
}
616616

617+
/*
618+
Returns a rule that matches a JSON string that is none of the provided strings
619+
620+
not_strings({"a"})
621+
-> ["] ( [a] char+ | [^"a] char* )? ["] space
622+
not_strings({"and", "also"})
623+
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
624+
*/
625+
std::string _not_strings(const std::vector<std::string> & strings) {
626+
627+
struct TrieNode {
628+
std::map<char, TrieNode> children;
629+
bool is_end_of_string;
630+
631+
TrieNode() : is_end_of_string(false) {}
632+
633+
void insert(const std::string & string) {
634+
auto node = this;
635+
for (char c : string) {
636+
node = &node->children[c];
637+
}
638+
node->is_end_of_string = true;
639+
}
640+
};
641+
642+
TrieNode trie;
643+
for (const auto & s : strings) {
644+
trie.insert(s);
645+
}
646+
647+
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
648+
std::ostringstream out;
649+
out << "[\"] ( ";
650+
std::function<void(const TrieNode &)> visit = [&](const TrieNode & node) {
651+
std::ostringstream rejects;
652+
auto first = true;
653+
for (const auto & kv : node.children) {
654+
rejects << kv.first;
655+
if (first) {
656+
first = false;
657+
} else {
658+
out << " | ";
659+
}
660+
out << "[" << kv.first << "]";
661+
if (!kv.second.children.empty()) {
662+
out << " (";
663+
visit(kv.second);
664+
out << ")";
665+
} else if (kv.second.is_end_of_string) {
666+
out << " " << char_rule << "+";
667+
}
668+
}
669+
if (!node.children.empty()) {
670+
if (!first) {
671+
out << " | ";
672+
}
673+
out << "[^\"" << rejects.str() << "] " << char_rule << "*";
674+
}
675+
};
676+
visit(trie);
677+
678+
out << " )";
679+
if (!trie.is_end_of_string) {
680+
out << "?";
681+
}
682+
out << " [\"] space";
683+
return out.str();
684+
}
685+
617686
std::string _resolve_ref(const std::string & ref) {
618687
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
619688
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
@@ -634,6 +703,7 @@ class SchemaConverter {
634703
std::vector<std::string> required_props;
635704
std::vector<std::string> optional_props;
636705
std::unordered_map<std::string, std::string> prop_kv_rule_names;
706+
std::vector<std::string> prop_names;
637707
for (const auto & kv : properties) {
638708
const auto &prop_name = kv.first;
639709
const auto &prop_schema = kv.second;
@@ -648,11 +718,18 @@ class SchemaConverter {
648718
} else {
649719
optional_props.push_back(prop_name);
650720
}
721+
prop_names.push_back(prop_name);
651722
}
652-
if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) {
723+
if (!(additional_properties.is_boolean() && !additional_properties.get<bool>())) {
653724
std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
654-
std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value");
655-
std::string kv_rule = _add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
725+
std::string value_rule =
726+
additional_properties.is_object() ? visit(additional_properties, sub_name + "-value")
727+
: _add_primitive("value", PRIMITIVE_RULES.at("value"));
728+
729+
auto key_rule =
730+
prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string"))
731+
: _add_rule(sub_name + "-k", _not_strings(prop_names));
732+
std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule);
656733
prop_kv_rule_names["*"] = kv_rule;
657734
optional_props.push_back("*");
658735
}
@@ -678,15 +755,11 @@ class SchemaConverter {
678755
}
679756
std::string k = ks[0];
680757
std::string kv_rule_name = prop_kv_rule_names[k];
681-
if (k == "*") {
682-
res = _add_rule(
683-
name + (name.empty() ? "" : "-") + "additional-kvs",
684-
kv_rule_name + " ( \",\" space " + kv_rule_name + " )*"
685-
);
686-
} else if (first_is_optional) {
687-
res = "( \",\" space " + kv_rule_name + " )?";
758+
std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
759+
if (first_is_optional) {
760+
res = comma_ref + (k == "*" ? "*" : "?");
688761
} else {
689-
res = kv_rule_name;
762+
res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : "");
690763
}
691764
if (ks.size() > 1) {
692765
res += " " + _add_rule(
@@ -824,13 +897,13 @@ class SchemaConverter {
824897
}
825898
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
826899
} else if (schema.contains("const")) {
827-
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
900+
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
828901
} else if (schema.contains("enum")) {
829902
std::vector<std::string> enum_values;
830903
for (const auto & v : schema["enum"]) {
831904
enum_values.push_back(_generate_constant_rule(v));
832905
}
833-
return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | "));
906+
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
834907
} else if ((schema_type.is_null() || schema_type == "object")
835908
&& (schema.contains("properties") ||
836909
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {

examples/json-schema-pydantic-example.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#! pip install pydantic
44
#! python json-schema-pydantic-example.py
55

6-
from pydantic import BaseModel, TypeAdapter
6+
from pydantic import BaseModel, Extra, TypeAdapter
77
from annotated_types import MinLen
88
from typing import Annotated, List, Optional
99
import json, requests
@@ -50,12 +50,16 @@ def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1
5050
if __name__ == '__main__':
5151

5252
class QAPair(BaseModel):
53+
class Config:
54+
extra = 'forbid' # triggers additionalProperties: false in the JSON schema
5355
question: str
5456
concise_answer: str
5557
justification: str
5658
stars: Annotated[int, Field(ge=1, le=5)]
5759

5860
class PyramidalSummary(BaseModel):
61+
class Config:
62+
extra = 'forbid' # triggers additionalProperties: false in the JSON schema
5963
title: str
6064
summary: str
6165
question_answers: Annotated[List[QAPair], MinLen(2)]

examples/json_schema_to_grammar.py

+60-16
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import json
55
import re
66
import sys
7-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
8-
7+
from typing import Any, List, Optional, Set, Tuple, Union
98

109
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
1110

@@ -276,6 +275,51 @@ def recurse(i: int):
276275

277276
return ''.join(('(', *recurse(0), ')'))
278277

278+
def _not_strings(self, strings):
279+
class TrieNode:
280+
def __init__(self):
281+
self.children = {}
282+
self.is_end_of_string = False
283+
284+
def insert(self, string):
285+
node = self
286+
for c in string:
287+
node = node.children.setdefault(c, TrieNode())
288+
node.is_end_of_string = True
289+
290+
trie = TrieNode()
291+
for s in strings:
292+
trie.insert(s)
293+
294+
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
295+
out = ['["] ( ']
296+
297+
def visit(node):
298+
rejects = []
299+
first = True
300+
for c in sorted(node.children.keys()):
301+
child = node.children[c]
302+
rejects.append(c)
303+
if first:
304+
first = False
305+
else:
306+
out.append(' | ')
307+
out.append(f'[{c}]')
308+
if child.children:
309+
out.append(f' (')
310+
visit(child)
311+
out.append(')')
312+
elif child.is_end_of_string:
313+
out.append(f' {char_rule}+')
314+
if node.children:
315+
if not first:
316+
out.append(' | ')
317+
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
318+
visit(trie)
319+
320+
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
321+
return ''.join(out)
322+
279323
def _add_rule(self, name, rule):
280324
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
281325
if esc_name not in self._rules or self._rules[esc_name] == rule:
@@ -524,10 +568,10 @@ def visit(self, schema, name):
524568
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
525569

526570
elif 'const' in schema:
527-
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
571+
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
528572

529573
elif 'enum' in schema:
530-
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
574+
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
531575
return self._add_rule(rule_name, rule)
532576

533577
elif schema_type in (None, 'object') and \
@@ -632,7 +676,7 @@ def _add_primitive(self, name: str, rule: BuiltinRule):
632676
self._add_primitive(dep, dep_rule)
633677
return n
634678

635-
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
679+
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
636680
prop_order = self._prop_order
637681
# sort by position in prop_order (if specified) then by original order
638682
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
@@ -647,12 +691,16 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st
647691
required_props = [k for k in sorted_props if k in required]
648692
optional_props = [k for k in sorted_props if k not in required]
649693

650-
if additional_properties == True or isinstance(additional_properties, dict):
694+
if additional_properties != False:
651695
sub_name = f'{name}{"-" if name else ""}additional'
652-
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
696+
value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
697+
self._add_primitive('value', PRIMITIVE_RULES['value'])
698+
key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
699+
else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))
700+
653701
prop_kv_rule_names["*"] = self._add_rule(
654702
f'{sub_name}-kv',
655-
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
703+
f'{key_rule} ":" space {value_rule}'
656704
)
657705
optional_props.append("*")
658706

@@ -667,15 +715,11 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st
667715
def get_recursive_refs(ks, first_is_optional):
668716
[k, *rest] = ks
669717
kv_rule_name = prop_kv_rule_names[k]
670-
if k == '*':
671-
res = self._add_rule(
672-
f'{name}{"-" if name else ""}additional-kvs',
673-
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
674-
)
675-
elif first_is_optional:
676-
res = f'( "," space {kv_rule_name} )?'
718+
comma_ref = f'( "," space {kv_rule_name} )'
719+
if first_is_optional:
720+
res = comma_ref + ('*' if k == '*' else '?')
677721
else:
678-
res = kv_rule_name
722+
res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
679723
if len(rest) > 0:
680724
res += ' ' + self._add_rule(
681725
f'{name}{"-" if name else ""}{k}-rest',

0 commit comments

Comments
 (0)