Skip to content

Commit cef3b8f

Browse files
ochafikMagnusS0
authored andcommitted
json: support integer minimum, maximum, exclusiveMinimum, exclusiveMaximum (ggml-org#7797)
* json: support minimum for positive integer values * json: fix min 0 * json: min + max integer constraints * json: handle negative min / max integer bounds * json: fix missing paren min/max bug * json: proper paren fix * json: integration test for schemas * json: fix bounds tests * Update json-schema-to-grammar.cpp * json: fix negative max * json: fix negative min (w/ more than 1 digit) * Update test-grammar-integration.cpp * json: nit: move string rules together * json: port min/max integer support to Python & JS * nit: move + rename _build_min_max_int * fix min in [1, 9] * Update test-grammar-integration.cpp * add C++11-compatible replacement for std::string_view * add min/max constrained int field to pydantic json schema example * fix merge * json: add integration tests for min/max bounds * reshuffle/merge min/max integ test cases * nits / cleanups * defensive code against string out of bounds (apparently different behaviour of libstdc++ vs. clang's libc++, can't read final NULL char w/ former)
1 parent 0655b4c commit cef3b8f

6 files changed

+1150
-3
lines changed

common/json-schema-to-grammar.cpp

+245-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,233 @@ static std::string build_repetition(const std::string & item_rule, int min_items
4040
return result;
4141
}
4242

43+
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
44+
class string_view {
45+
const std::string & _str;
46+
const size_t _start;
47+
const size_t _end;
48+
public:
49+
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
50+
51+
size_t size() const {
52+
return _end - _start;
53+
}
54+
55+
size_t length() const {
56+
return size();
57+
}
58+
59+
operator std::string() const {
60+
return str();
61+
}
62+
63+
std::string str() const {
64+
return _str.substr(_start, _end - _start);
65+
}
66+
67+
string_view substr(size_t pos, size_t len = std::string::npos) const {
68+
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
69+
}
70+
71+
char operator[](size_t pos) const {
72+
auto index = _start + pos;
73+
if (index >= _end) {
74+
throw std::out_of_range("string_view index out of range");
75+
}
76+
return _str[_start + pos];
77+
}
78+
79+
bool operator==(const string_view & other) const {
80+
std::string this_str = *this;
81+
std::string other_str = other;
82+
return this_str == other_str;
83+
}
84+
};
85+
86+
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
87+
auto has_min = min_value != std::numeric_limits<int>::min();
88+
auto has_max = max_value != std::numeric_limits<int>::max();
89+
90+
auto digit_range = [&](char from, char to) {
91+
out << "[";
92+
if (from == to) {
93+
out << from;
94+
} else {
95+
out << from << "-" << to;
96+
}
97+
out << "]";
98+
};
99+
auto more_digits = [&](int min_digits, int max_digits) {
100+
out << "[0-9]";
101+
if (min_digits == max_digits && min_digits == 1) {
102+
return;
103+
}
104+
out << "{";
105+
out << min_digits;
106+
if (max_digits != min_digits) {
107+
out << ",";
108+
if (max_digits != std::numeric_limits<int>::max()) {
109+
out << max_digits;
110+
}
111+
}
112+
out << "}";
113+
};
114+
std::function<void(const string_view &, const string_view &)> uniform_range =
115+
[&](const string_view & from, const string_view & to) {
116+
size_t i = 0;
117+
while (i < from.length() && i < to.length() && from[i] == to[i]) {
118+
i++;
119+
}
120+
if (i > 0) {
121+
out << "\"" << from.substr(0, i).str() << "\"";
122+
}
123+
if (i < from.length() && i < to.length()) {
124+
if (i > 0) {
125+
out << " ";
126+
}
127+
auto sub_len = from.length() - i - 1;
128+
if (sub_len > 0) {
129+
auto from_sub = from.substr(i + 1);
130+
auto to_sub = to.substr(i + 1);
131+
auto sub_zeros = repeat("0", sub_len);
132+
auto sub_nines = repeat("9", sub_len);
133+
134+
auto to_reached = false;
135+
out << "(";
136+
if (from_sub == sub_zeros) {
137+
digit_range(from[i], to[i] - 1);
138+
out << " ";
139+
more_digits(sub_len, sub_len);
140+
} else {
141+
out << "[" << from[i] << "] ";
142+
out << "(";
143+
uniform_range(from_sub, sub_nines);
144+
out << ")";
145+
if (from[i] < to[i] - 1) {
146+
out << " | ";
147+
if (to_sub == sub_nines) {
148+
digit_range(from[i] + 1, to[i]);
149+
to_reached = true;
150+
} else {
151+
digit_range(from[i] + 1, to[i] - 1);
152+
}
153+
out << " ";
154+
more_digits(sub_len, sub_len);
155+
}
156+
}
157+
if (!to_reached) {
158+
out << " | ";
159+
digit_range(to[i], to[i]);
160+
out << " ";
161+
uniform_range(sub_zeros, to_sub);
162+
}
163+
out << ")";
164+
} else {
165+
out << "[" << from[i] << "-" << to[i] << "]";
166+
}
167+
}
168+
};
169+
170+
if (has_min && has_max) {
171+
if (min_value < 0 && max_value < 0) {
172+
out << "\"-\" (";
173+
_build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
174+
out << ")";
175+
return;
176+
}
177+
178+
if (min_value < 0) {
179+
out << "\"-\" (";
180+
_build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
181+
out << ") | ";
182+
min_value = 0;
183+
}
184+
185+
auto min_s = std::to_string(min_value);
186+
auto max_s = std::to_string(max_value);
187+
auto min_digits = min_s.length();
188+
auto max_digits = max_s.length();
189+
190+
for (auto digits = min_digits; digits < max_digits; digits++) {
191+
uniform_range(min_s, repeat("9", digits));
192+
min_s = "1" + repeat("0", digits);
193+
out << " | ";
194+
}
195+
uniform_range(min_s, max_s);
196+
return;
197+
}
198+
199+
auto less_decimals = std::max(decimals_left - 1, 1);
200+
201+
if (has_min) {
202+
if (min_value < 0) {
203+
out << "\"-\" (";
204+
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
205+
out << ") | [0] | [1-9] ";
206+
more_digits(0, decimals_left - 1);
207+
} else if (min_value == 0) {
208+
if (top_level) {
209+
out << "[0] | [1-9] ";
210+
more_digits(0, less_decimals);
211+
} else {
212+
more_digits(1, decimals_left);
213+
}
214+
} else if (min_value <= 9) {
215+
char c = '0' + min_value;
216+
auto range_start = top_level ? '1' : '0';
217+
if (c > range_start) {
218+
digit_range(range_start, c - 1);
219+
out << " ";
220+
more_digits(1, less_decimals);
221+
out << " | ";
222+
}
223+
digit_range(c, '9');
224+
out << " ";
225+
more_digits(0, less_decimals);
226+
} else {
227+
auto min_s = std::to_string(min_value);
228+
auto len = min_s.length();
229+
auto c = min_s[0];
230+
231+
if (c > '1') {
232+
digit_range(top_level ? '1' : '0', c - 1);
233+
out << " ";
234+
more_digits(len, less_decimals);
235+
out << " | ";
236+
}
237+
digit_range(c, c);
238+
out << " (";
239+
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
240+
out << ")";
241+
if (c < '9') {
242+
out << " | ";
243+
digit_range(c + 1, '9');
244+
out << " ";
245+
more_digits(len - 1, less_decimals);
246+
}
247+
}
248+
return;
249+
}
250+
251+
if (has_max) {
252+
if (max_value >= 0) {
253+
if (top_level) {
254+
out << "\"-\" [1-9] ";
255+
more_digits(0, less_decimals);
256+
out << " | ";
257+
}
258+
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
259+
} else {
260+
out << "\"-\" (";
261+
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
262+
out << ")";
263+
}
264+
return;
265+
}
266+
267+
throw std::runtime_error("At least one of min_value or max_value must be set");
268+
}
269+
43270
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
44271

45272
struct BuiltinRule {
@@ -160,7 +387,6 @@ static std::string format_literal(const std::string & literal) {
160387
return "\"" + escaped + "\"";
161388
}
162389

163-
164390
class SchemaConverter {
165391
private:
166392
std::function<json(const std::string &)> _fetch_json;
@@ -686,6 +912,24 @@ class SchemaConverter {
686912
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
687913
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
688914
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
915+
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
916+
int min_value = std::numeric_limits<int>::min();
917+
int max_value = std::numeric_limits<int>::max();
918+
if (schema.contains("minimum")) {
919+
min_value = schema["minimum"].get<int>();
920+
} else if (schema.contains("exclusiveMinimum")) {
921+
min_value = schema["exclusiveMinimum"].get<int>() + 1;
922+
}
923+
if (schema.contains("maximum")) {
924+
max_value = schema["maximum"].get<int>();
925+
} else if (schema.contains("exclusiveMaximum")) {
926+
max_value = schema["exclusiveMaximum"].get<int>() - 1;
927+
}
928+
std::stringstream out;
929+
out << "(";
930+
_build_min_max_int(min_value, max_value, out);
931+
out << ") space";
932+
return _add_rule(rule_name, out.str());
689933
} else if (schema.empty() || schema_type == "object") {
690934
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
691935
} else {

examples/json-schema-pydantic-example.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class QAPair(BaseModel):
5353
question: str
5454
concise_answer: str
5555
justification: str
56+
stars: Annotated[int, Field(ge=1, le=5)]
5657

5758
class PyramidalSummary(BaseModel):
5859
title: str

0 commit comments

Comments
 (0)