Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

json-schema-to-grammar: fix order of props in C++, support non-string const/enum #6232

Merged
merged 3 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <unordered_set>
#include <vector>

using json = nlohmann::json;
using json = nlohmann::ordered_json;

const std::string SPACE_RULE = "\" \"?";

Expand Down Expand Up @@ -124,7 +124,7 @@ static std::string replacePattern(const std::string & input, const std::regex &
}

static std::string format_literal(const std::string & literal) {
std::string escaped = replacePattern(json(literal).dump(), GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
char c = match.str()[0];
return GRAMMAR_LITERAL_ESCAPES.at(c);
});
Expand All @@ -137,7 +137,7 @@ class SchemaConverter {
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
std::unordered_map<std::string, nlohmann::json> _refs;
std::unordered_map<std::string, json> _refs;
std::unordered_set<std::string> _refs_being_resolved;
std::vector<std::string> _errors;
std::vector<std::string> _warnings;
Expand Down Expand Up @@ -413,7 +413,7 @@ class SchemaConverter {
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
prop_kv_rule_names[prop_name] = _add_rule(
name + (name.empty() ? "" : "-") + prop_name + "-kv",
format_literal(prop_name) + " space \":\" space " + prop_rule_name
format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
);
if (required.find(prop_name) != required.end()) {
required_props.push_back(prop_name);
Expand Down Expand Up @@ -495,7 +495,7 @@ class SchemaConverter {
_rules["space"] = SPACE_RULE;
}

void resolve_refs(nlohmann::json & schema, const std::string & url) {
void resolve_refs(json & schema, const std::string & url) {
/*
* Resolves all $ref fields in the given schema, fetching any remote schemas,
* replacing each $ref with absolute reference URL and populates _refs with the
Expand Down Expand Up @@ -557,11 +557,7 @@ class SchemaConverter {
}

std::string _generate_constant_rule(const json & value) {
if (!value.is_string()) {
_errors.push_back("Only std::string constants are supported, got " + value.dump());
return "";
}
return format_literal(value.get<std::string>());
return format_literal(value.dump());
}

std::string visit(const json & schema, const std::string & name) {
Expand Down
2 changes: 1 addition & 1 deletion common/json-schema-to-grammar.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#pragma once
#include "json.hpp"

std::string json_schema_to_grammar(const nlohmann::json& schema);
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
7 changes: 3 additions & 4 deletions examples/json-schema-to-grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):

def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
)
return f'"{escaped}"'

Expand Down Expand Up @@ -308,8 +308,7 @@ def _resolve_ref(self, ref):
return ref_name

def _generate_constant_rule(self, value):
assert isinstance(value, str), f'Only string constants are supported, got {value}'
return self._format_literal(value)
return self._format_literal(json.dumps(value))

def visit(self, schema, name):
schema_type = schema.get('type')
Expand Down Expand Up @@ -428,7 +427,7 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
prop_kv_rule_names[prop_name] = self._add_rule(
f'{name}{"-" if name else ""}{prop_name}-kv',
fr'{self._format_literal(prop_name)} space ":" space {prop_rule_name}'
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
)
required_props = [k for k in sorted_props if k in required]
optional_props = [k for k in sorted_props if k not in required]
Expand Down
2,791 changes: 1,384 additions & 1,407 deletions examples/server/json-schema-to-grammar.mjs.hpp

Large diffs are not rendered by default.

12 changes: 3 additions & 9 deletions examples/server/public/json-schema-to-grammar.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export class SchemaConverter {
}

_formatLiteral(literal) {
const escaped = JSON.stringify(literal).replace(
const escaped = literal.replace(
GRAMMAR_LITERAL_ESCAPE_RE,
m => GRAMMAR_LITERAL_ESCAPES[m]
);
Expand Down Expand Up @@ -327,10 +327,7 @@ export class SchemaConverter {
}

_generateConstantRule(value) {
if (typeof value !== 'string') {
throw new Error('Only string constants are supported, got ' + JSON.stringify(value));
}
return this._formatLiteral(value);
return this._formatLiteral(JSON.stringify(value));
}

visit(schema, name) {
Expand All @@ -346,9 +343,6 @@ export class SchemaConverter {
} else if (Array.isArray(schemaType)) {
return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({ type: t }))));
} else if ('const' in schema) {
if (typeof schema.const !== 'string') {
throw new Error('Only string constants are supported, got ' + JSON.stringify(schema.const));
}
return this._addRule(ruleName, this._generateConstantRule(schema.const));
} else if ('enum' in schema) {
const rule = schema.enum.map(v => this._generateConstantRule(v)).join(' | ');
Expand Down Expand Up @@ -457,7 +451,7 @@ export class SchemaConverter {
const propRuleName = this.visit(propSchema, `${name ?? ''}${name ? '-' : ''}${propName}`);
propKvRuleNames[propName] = this._addRule(
`${name ?? ''}${name ? '-' : ''}${propName}-kv`,
`${this._formatLiteral(propName)} space ":" space ${propRuleName}`
`${this._formatLiteral(JSON.stringify(propName))} space ":" space ${propRuleName}`
);
}
const requiredProps = sortedProps.filter(k => required.has(k));
Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <signal.h>
#include <memory>

using json = nlohmann::json;
using json = nlohmann::ordered_json;

bool server_verbose = false;
bool server_log_json = true;
Expand Down
2 changes: 1 addition & 1 deletion examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"

using json = nlohmann::json;
using json = nlohmann::ordered_json;

// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
enum error_type {
Expand Down
101 changes: 53 additions & 48 deletions tests/test-json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase

test({
FAILURE,
"invalid type type",
"invalid type",
R"""({
"type": 123
})""",
Expand Down Expand Up @@ -193,21 +193,27 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
});

test({
FAILURE,
SUCCESS,
"non-string const",
R"""({
"const": 123
})""",
""
R"""(
root ::= "123"
space ::= " "?
)"""
});

test({
FAILURE,
SUCCESS,
"non-string enum",
R"""({
"enum": [123]
"enum": ["red", "amber", "green", null, 42, ["foo"]]
})""",
""
R"""(
root ::= "\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]"
space ::= " "?
)"""
});

test({
Expand Down Expand Up @@ -378,28 +384,27 @@ static void test_all(const std::string & lang, std::function<void(const TestCase

test({
SUCCESS,
"required props",
"required props in original order",
R"""({
"type": "object",
"properties": {
"a": {
"type": "string"
},
"b": {
"type": "string"
}
"b": {"type": "string"},
"c": {"type": "string"},
"a": {"type": "string"}
},
"required": [
"a",
"b"
"b",
"c"
],
"additionalProperties": false,
"definitions": {}
})""",
R"""(
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
root ::= "{" space a-kv "," space b-kv "}" space
c-kv ::= "\"c\"" space ":" space string
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
Expand Down Expand Up @@ -458,13 +463,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase

test({
SUCCESS,
"required + optional props",
"required + optional props each in original order",
R"""({
"properties": {
"a": {"type": "string"},
"b": {"type": "string"},
"c": {"type": "string"},
"d": {"type": "string"}
"a": {"type": "string"},
"d": {"type": "string"},
"c": {"type": "string"}
},
"required": ["a", "b"],
"additionalProperties": false
Expand All @@ -473,14 +478,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
a-kv ::= "\"a\"" space ":" space string
b-kv ::= "\"b\"" space ":" space string
c-kv ::= "\"c\"" space ":" space string
c-rest ::= ( "," space d-kv )?
d-kv ::= "\"d\"" space ":" space string
root ::= "{" space a-kv "," space b-kv ( "," space ( c-kv c-rest | d-kv ) )? "}" space
d-rest ::= ( "," space c-kv )?
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
)* "\"" space
)"""
});

Expand Down Expand Up @@ -648,16 +653,16 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"$ref": "#/definitions/MyType",
"definitions": {
"MyType": {
"type": "object",
"properties": {
"a": {
"type": "string"
}
},
"required": [
"a"
],
"additionalProperties": false
"type": "object",
"properties": {
"a": {
"type": "string"
}
},
"required": [
"a"
],
"additionalProperties": false
}
}
})""",
Expand All @@ -683,10 +688,10 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
],
"definitions": {
"foo": {
"properties": {"a": {"type": "number"}}
"properties": {"a": {"type": "number"}}
},
"bar": {
"properties": {"b": {"type": "number"}}
"properties": {"b": {"type": "number"}}
}
},
"type": "object"
Expand Down Expand Up @@ -720,16 +725,16 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
],
"definitions": {
"foo": {
"properties": {"a": {"type": "number"}}
"properties": {"a": {"type": "number"}}
},
"bar": {
"properties": {"b": {"type": "number"}}
"properties": {"b": {"type": "number"}}
},
"bam": {
"properties": {"c": {"type": "number"}}
"properties": {"c": {"type": "number"}}
},
"baz": {
"properties": {"d": {"type": "number"}}
"properties": {"d": {"type": "number"}}
}
},
"type": "object"
Expand Down Expand Up @@ -757,15 +762,15 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
"properties": {
"number": {
"type": "object",
"properties": {
"root": {
"type": "number"
}
},
"required": [
"root"
],
"additionalProperties": false
"properties": {
"root": {
"type": "number"
}
},
"required": [
"root"
],
"additionalProperties": false
}
},
"required": [
Expand Down Expand Up @@ -796,7 +801,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
int main() {
test_all("C++", [](const TestCase & tc) {
try {
tc.verify(json_schema_to_grammar(nlohmann::json::parse(tc.schema)));
tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
tc.verify_status(SUCCESS);
} catch (const std::runtime_error & ex) {
fprintf(stderr, "Error: %s\n", ex.what());
Expand Down
Loading