Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending grammar integration tests #6644

Merged
merged 10 commits into from
Apr 29, 2024
348 changes: 212 additions & 136 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
@@ -10,15 +10,10 @@
#include "unicode.h"
#include <cassert>
#include <string>
#include <vector>

static void test_simple_grammar() {
// Test case for a simple grammar
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
static llama_grammar* build_grammar(const std::string & grammar_str) {
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
@@ -30,168 +25,78 @@ number ::= [0-9]+)""";
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));

std::string input = "123+456";
return grammar;
}

static bool match_string(const std::string & input, llama_grammar* grammar) {
auto decoded = decode_utf8(input, {});

const auto & code_points = decoded.first;

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty());
if (grammar->stacks.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
}

bool completed_grammar = false;

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
// An empty stack means that the grammar has been completed
return true;
}
}

assert(completed_grammar);

// Clean up allocated memory
llama_grammar_free(grammar);
return false;
}

static void test_complex_grammar() {
// Test case for a more complex grammar, with both failure strings and success strings
const std::string grammar_str = R"""(root ::= expression
expression ::= term ws (("+"|"-") ws term)*
term ::= factor ws (("*"|"/") ws factor)*
factor ::= number | variable | "(" expression ")" | function-call
number ::= [0-9]+
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
ws ::= [ \t\n\r]?)""";

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());

// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
static void test_grammar(const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
fprintf(stderr, "⚪ Testing grammar: %s\n", grammar_str.c_str());
fflush(stderr);

std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
auto grammar = build_grammar(grammar_str);

// Save the original grammar stacks so that we can reset after every new string we want to test
auto original_stacks = grammar->stacks;

// Test a few strings
std::vector<std::string> test_strings_pass = {
"42",
"1*2*3*4*5",
"x",
"x+10",
"x1+y2",
"(a+b)*(c-d)",
"func()",
"func(x,y+2)",
"a*(b+c)-d/e",
"f(g(x),h(y,z))",
"x + 10",
"x1 + y2",
"(a + b) * (c - d)",
"func()",
"func(x, y + 2)",
"a * (b + c) - d / e",
"f(g(x), h(y, z))",
"123+456",
"123*456*789-123/456+789*123",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
};

std::vector<std::string> test_strings_fail = {
"+",
"/ 3x",
"x + + y",
"a * / b",
"func(,)",
"func(x y)",
"(a + b",
"x + y)",
"a + b * (c - d",
"42 +",
"x +",
"x + 10 +",
"(a + b) * (c - d",
"func(",
"func(x, y + 2",
"a * (b + c) - d /",
"f(g(x), h(y, z)",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
};
fprintf(stderr, " Checking valid strings:\n");

// Passing strings
for (const auto & test_string : test_strings_pass) {
auto decoded = decode_utf8(test_string, {});

const auto & code_points = decoded.first;

int pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);

// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
fprintf(stdout, "Error at position %d\n", pos);
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
}
assert(!grammar->stacks.empty());
}
for (const auto & test_string : passing_strings) {
fprintf(stderr, " \"%s\" ", test_string.c_str());
fflush(stderr);

bool completed_grammar = false;
bool matched = match_string(test_string, grammar);

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
if (!matched) {
fprintf(stderr, "❌ (failed to match)\n");
} else {
fprintf(stdout, "✅︎\n");
}

assert(completed_grammar);
assert(matched);

// Reset the grammar stacks
grammar->stacks = original_stacks;
}

fprintf(stderr, " Checking invalid strings:\n");

// Failing strings
for (const auto & test_string : test_strings_fail) {
auto decoded = decode_utf8(test_string, {});

const auto & code_points = decoded.first;
bool parse_failed = false;

for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
}
assert(!grammar->stacks.empty());
}
for (const auto & test_string : failing_strings) {
fprintf(stderr, " \"%s\" ", test_string.c_str());
fflush(stderr);

bool completed_grammar = false;
bool matched = match_string(test_string, grammar);

for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
if (matched) {
fprintf(stderr, "❌ (incorrectly matched)\n");
} else {
fprintf(stdout, "✅︎\n");
}

// Ensure that the grammar is not completed, or that each string failed to match as-expected
assert((!completed_grammar) || parse_failed);
assert(!matched);

// Reset the grammar stacks
grammar->stacks = original_stacks;
@@ -201,7 +106,170 @@ ws ::= [ \t\n\r]?)""";
llama_grammar_free(grammar);
}

static void test_simple_grammar() {
// Test case for a simple grammar
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";

auto grammar = build_grammar(grammar_str);

bool matched = match_string("123+456", grammar);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot this one I think :-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could still call test_grammar from this one I think


assert(matched);

// Clean up allocated memory
llama_grammar_free(grammar);
}

static void test_complex_grammar() {
// Test case for a more complex grammar, with both failure strings and success strings
test_grammar(
// Grammar
R"""(
root ::= expression
expression ::= term ws (("+"|"-") ws term)*
term ::= factor ws (("*"|"/") ws factor)*
factor ::= number | variable | "(" expression ")" | function-call
number ::= [0-9]+
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
ws ::= [ \t\n\r]?)""",
// Passing strings
{
"42",
"1*2*3*4*5",
"x",
"x+10",
"x1+y2",
"(a+b)*(c-d)",
"func()",
"func(x,y+2)",
"a*(b+c)-d/e",
"f(g(x),h(y,z))",
"x + 10",
"x1 + y2",
"(a + b) * (c - d)",
"func()",
"func(x, y + 2)",
"a * (b + c) - d / e",
"f(g(x), h(y, z))",
"123+456",
"123*456*789-123/456+789*123",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
},
// Failing strings
{
"+",
"/ 3x",
"x + + y",
"a * / b",
"func(,)",
"func(x y)",
"(a + b",
"x + y)",
"a + b * (c - d",
"42 +",
"x +",
"x + 10 +",
"(a + b) * (c - d",
"func(",
"func(x, y + 2",
"a * (b + c) - d /",
"f(g(x), h(y, z)",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
}
);
}

static void test_quantifiers() {
// A collection of tests to exercise * + and ? quantifiers

test_grammar(
// Grammar
R"""(root ::= "a"*)""",
// Passing strings
{
"",
"a",
"aaaaa",
"aaaaaaaaaaaaaaaaaa",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
},
// Failing strings
{
"b",
"ab",
"aab",
"ba",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
}
);
test_grammar(
// Grammar
R"""(root ::= "a"+)""",
// Passing strings
{
"a",
"aaaaa",
"aaaaaaaaaaaaaaaaaa",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
},
// Failing strings
{
"",
"b",
"ab",
"aab",
"ba",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
}
);
test_grammar(
// Grammar
R"""(root ::= "a"?)""",
// Passing strings
{
"",
"a"
},
// Failing strings
{
"b",
"ab",
"aa",
"ba",
}
);
test_grammar(
// Grammar
R"""(
root ::= cons+ vowel* cons? (vowel cons)*
vowel ::= [aeiouy]
cons ::= [bcdfghjklmnpqrstvwxyz]
)""",
// Passing strings
{
"yes",
"no",
"noyes",
"crwth",
"four",
"bryyyy",
},
// Failing strings
{
"yess",
"yesno",
"forty",
"catyyy",
}
);
}

static void test_failure_missing_root() {
fprintf(stderr, "🟢 Testing for missing root node:\n");
// Test case for a grammar that is missing a root rule
const std::string grammar_str = R"""(rot ::= expr
expr ::= term ("+" term)*
@@ -215,29 +283,37 @@ number ::= [0-9]+)""";

// Ensure we do NOT have a root node
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
fprintf(stderr, " ✅︎ Passed\n");
}

static void test_failure_missing_reference() {
fprintf(stderr, "🟢 Testing for missing reference node:\n");

// Test case for a grammar that is missing a referenced rule
const std::string grammar_str = R"""(root ::= expr
const std::string grammar_str =
R"""(root ::= expr
expr ::= term ("+" term)*
term ::= numero
number ::= [0-9]+)""";

fprintf(stderr, "Expected error: ");
fprintf(stderr, " Expected error: ");

grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());

// Ensure we did NOT parsed correctly
assert(parsed_grammar.rules.empty());

fprintf(stderr, "End of expected error. Test successful.\n");
fprintf(stderr, " End of expected error.\n");
fprintf(stderr, " ✅︎ Passed\n");
}

int main() {
fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar();
test_complex_grammar();
test_quantifiers();
test_failure_missing_root();
test_failure_missing_reference();
fprintf(stdout, "All tests passed.\n");
return 0;
}