Skip to content

Commit ec7d0d6

Browse files
mmoskalNeoZhangJianyu
authored andcommitted
sampling : support for llguidance grammars (ggml-org#10224)
* initial porting of previous LLG patch * update for new APIs * build: integrate llguidance as an external project * use '%llguidance' as marker to enable llg lark syntax * add some docs * clarify docs * code style fixes * remove llguidance.h from .gitignore * fix tests when llg is enabled * pass vocab not model to llama_sampler_init_llg() * copy test-grammar-integration.cpp to test-llguidance.cpp * clang fmt * fix ref-count bug * build and run test * gbnf -> lark syntax * conditionally include llguidance test based on LLAMA_LLGUIDANCE flag * rename llguidance test file to test-grammar-llguidance.cpp * add gh action for llg test * align tests with LLG grammar syntax and JSON Schema spec * llama_tokenizer() in fact requires valid utf8 * update llg * format file * add $LLGUIDANCE_LOG_LEVEL support * fix whitespace * fix warning * include <cmath> for INFINITY * add final newline * fail llama_sampler_init_llg() at runtime * Link gbnf_to_lark.py script; fix links; refer to llg docs for lexemes * simplify #includes * improve doc string for LLAMA_LLGUIDANCE * typo in merge * bump llguidance to 0.6.12
1 parent be4e8dd commit ec7d0d6

13 files changed

+1555
-9
lines changed

.github/workflows/build.yml

+30
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,36 @@ jobs:
302302
cd build
303303
ctest -L main --verbose --timeout 900
304304
305+
ubuntu-latest-llguidance:
306+
runs-on: ubuntu-latest
307+
308+
steps:
309+
- name: Clone
310+
id: checkout
311+
uses: actions/checkout@v4
312+
313+
- name: Dependencies
314+
id: depends
315+
run: |
316+
sudo apt-get update
317+
sudo apt-get install build-essential
318+
319+
- name: Build
320+
id: cmake_build
321+
run: |
322+
mkdir build
323+
cd build
324+
cmake .. \
325+
-DLLAMA_FATAL_WARNINGS=ON \
326+
-DLLAMA_LLGUIDANCE=ON
327+
cmake --build . --config Release -j $(nproc)
328+
329+
- name: Test
330+
id: cmake_test
331+
run: |
332+
cd build
333+
ctest -L main --verbose --timeout 900
334+
305335
ubuntu-latest-cmake-rpc:
306336
runs-on: ubuntu-latest
307337

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
8080

8181
# 3rd party libs
8282
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
83+
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
8384

8485
# Required for relocatable CMake package
8586
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)

common/CMakeLists.txt

+28
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ add_library(${TARGET} STATIC
6565
console.h
6666
json-schema-to-grammar.cpp
6767
json.hpp
68+
llguidance.cpp
6869
log.cpp
6970
log.h
7071
minja.hpp
@@ -91,6 +92,33 @@ if (LLAMA_CURL)
9192
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
9293
endif ()
9394

95+
if (LLAMA_LLGUIDANCE)
96+
include(ExternalProject)
97+
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
98+
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
99+
ExternalProject_Add(llguidance_ext
100+
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
101+
# v0.6.12:
102+
GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09
103+
PREFIX ${CMAKE_BINARY_DIR}/llguidance
104+
SOURCE_DIR ${LLGUIDANCE_SRC}
105+
BUILD_IN_SOURCE TRUE
106+
CONFIGURE_COMMAND ""
107+
BUILD_COMMAND cargo build --release
108+
INSTALL_COMMAND ""
109+
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h
110+
UPDATE_COMMAND ""
111+
)
112+
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
113+
114+
add_library(llguidance STATIC IMPORTED)
115+
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a)
116+
add_dependencies(llguidance llguidance_ext)
117+
118+
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
119+
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance)
120+
endif ()
121+
94122
target_include_directories(${TARGET} PUBLIC .)
95123
target_compile_features (${TARGET} PUBLIC cxx_std_17)
96124
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)

common/json-schema-to-grammar.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,14 @@ class SchemaConverter {
991991
}
992992
};
993993

994-
std::string json_schema_to_grammar(const json & schema) {
994+
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
995+
#ifdef LLAMA_USE_LLGUIDANCE
996+
if (!force_gbnf) {
997+
return "%llguidance {}\nstart: %json " + schema.dump();
998+
}
999+
#else
1000+
(void)force_gbnf;
1001+
#endif // LLAMA_USE_LLGUIDANCE
9951002
return build_grammar([&](const common_grammar_builder & callbacks) {
9961003
auto copy = schema;
9971004
callbacks.resolve_refs(copy);

common/json-schema-to-grammar.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
#define JSON_ASSERT GGML_ASSERT
66
#include "json.hpp"
77

8-
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
8+
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
9+
bool force_gbnf = false);
910

1011
struct common_grammar_builder {
1112
std::function<std::string(const std::string &, const std::string &)> add_rule;

common/llguidance.cpp

+270
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
#include "sampling.h"
2+
#include "log.h"
3+
4+
#ifdef LLAMA_USE_LLGUIDANCE
5+
6+
# include "llguidance.h"
7+
# include <cmath>
8+
9+
struct llama_sampler_llg {
10+
const llama_vocab * vocab;
11+
std::string grammar_kind;
12+
std::string grammar_data;
13+
LlgTokenizer * tokenizer;
14+
LlgConstraint * grammar;
15+
LlgMaskResult llg_res;
16+
bool has_llg_res;
17+
};
18+
19+
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
20+
const char * grammar_data) {
21+
LlgConstraintInit cinit;
22+
llg_constraint_init_set_defaults(&cinit, tokenizer);
23+
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
24+
if (log_level && *log_level) {
25+
cinit.log_stderr_level = atoi(log_level);
26+
}
27+
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
28+
if (llg_get_error(c)) {
29+
LOG_ERR("llg error: %s\n", llg_get_error(c));
30+
llg_free_constraint(c);
31+
return nullptr;
32+
}
33+
return c;
34+
}
35+
36+
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
37+
return "llguidance";
38+
}
39+
40+
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
41+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
42+
if (ctx->grammar) {
43+
LlgCommitResult res;
44+
llg_commit_token(ctx->grammar, token, &res);
45+
ctx->has_llg_res = false;
46+
}
47+
}
48+
49+
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
50+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
51+
if (ctx->grammar) {
52+
if (!ctx->has_llg_res) {
53+
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
54+
ctx->has_llg_res = true;
55+
} else {
56+
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
57+
llg_free_constraint(ctx->grammar);
58+
ctx->grammar = nullptr;
59+
}
60+
}
61+
if (ctx->has_llg_res) {
62+
if (ctx->llg_res.is_stop) {
63+
for (size_t i = 0; i < cur_p->size; ++i) {
64+
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
65+
cur_p->data[i].logit = -INFINITY;
66+
}
67+
}
68+
} else {
69+
const uint32_t * mask = ctx->llg_res.sample_mask;
70+
for (size_t i = 0; i < cur_p->size; ++i) {
71+
auto token = cur_p->data[i].id;
72+
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
73+
cur_p->data[i].logit = -INFINITY;
74+
}
75+
}
76+
}
77+
}
78+
}
79+
}
80+
81+
static void llama_sampler_llg_reset(llama_sampler * smpl) {
82+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
83+
if (!ctx->grammar) {
84+
return;
85+
}
86+
87+
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
88+
llg_free_constraint(ctx->grammar);
89+
ctx->grammar = grammar_new;
90+
ctx->has_llg_res = false;
91+
}
92+
93+
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
94+
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
95+
96+
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
97+
98+
// copy the state
99+
{
100+
auto * result_ctx = (llama_sampler_llg *) result->ctx;
101+
102+
if (ctx->grammar) {
103+
result_ctx->grammar_kind = ctx->grammar_kind;
104+
result_ctx->grammar_data = ctx->grammar_data;
105+
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
106+
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
107+
}
108+
}
109+
110+
return result;
111+
}
112+
113+
static void llama_sampler_llg_free(llama_sampler * smpl) {
114+
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
115+
116+
if (ctx->grammar) {
117+
llg_free_constraint(ctx->grammar);
118+
llg_free_tokenizer(ctx->tokenizer);
119+
}
120+
121+
delete ctx;
122+
}
123+
124+
static llama_sampler_i llama_sampler_llg_i = {
125+
/* .name = */ llama_sampler_llg_name,
126+
/* .accept = */ llama_sampler_llg_accept_impl,
127+
/* .apply = */ llama_sampler_llg_apply,
128+
/* .reset = */ llama_sampler_llg_reset,
129+
/* .clone = */ llama_sampler_llg_clone,
130+
/* .free = */ llama_sampler_llg_free,
131+
};
132+
133+
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
134+
uint32_t * output_tokens, size_t output_tokens_len) {
135+
const llama_vocab * vocab = (const llama_vocab *) user_data;
136+
int r = 0;
137+
try {
138+
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
139+
true);
140+
} catch (const std::exception & e) {
141+
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
142+
}
143+
if (r < 0) {
144+
return -r;
145+
}
146+
return r;
147+
}
148+
149+
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
150+
// TODO store the tokenizer in the vocab somehow
151+
static const llama_vocab * vocab_cache;
152+
static LlgTokenizer * tokenizer_cache;
153+
154+
if (vocab_cache == vocab) {
155+
return llg_clone_tokenizer(tokenizer_cache);
156+
}
157+
158+
auto tok_eos = llama_vocab_eot(vocab);
159+
if (tok_eos == LLAMA_TOKEN_NULL) {
160+
tok_eos = llama_vocab_eos(vocab);
161+
}
162+
163+
size_t vocab_size = llama_vocab_n_tokens(vocab);
164+
165+
auto token_lens = new uint32_t[vocab_size];
166+
// we typically have ~7 bytes per token; let's go on the safe side here
167+
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
168+
auto token_bytes = new uint8_t[token_bytes_size];
169+
170+
size_t offset = 0;
171+
for (size_t i = 0; i < vocab_size; i++) {
172+
size_t max_token = 1024;
173+
if (token_bytes_size - offset < max_token) {
174+
GGML_ABORT("token_bytes buffer too small\n");
175+
}
176+
177+
llama_token token = i;
178+
auto dp = (char *) token_bytes + offset;
179+
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
180+
if (size < 0) {
181+
GGML_ABORT("llama_detokenize failed\n");
182+
}
183+
if (size == 0) {
184+
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
185+
if (size < 0) {
186+
GGML_ABORT("llama_detokenize failed\n");
187+
}
188+
if (size != 0) {
189+
*dp = '\xff'; // special token prefix marker
190+
size += 1;
191+
}
192+
}
193+
194+
token_lens[i] = size;
195+
offset += size;
196+
}
197+
198+
LlgTokenizerInit tinit = {
199+
/* .vocab_size = */ (uint32_t) vocab_size,
200+
/* .tok_eos = */ (uint32_t) tok_eos,
201+
/* .token_lens = */ token_lens,
202+
/* .token_bytes = */ token_bytes,
203+
/* .tokenizer_json = */ nullptr,
204+
/* .tokenize_assumes_string = */ true,
205+
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
206+
/* .use_approximate_greedy_tokenize_fn = */ false,
207+
/* .tokenize_user_data = */ vocab,
208+
};
209+
210+
char error_buffer[1024];
211+
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
212+
213+
delete[] token_bytes;
214+
delete[] token_lens;
215+
216+
if (tokenizer == nullptr) {
217+
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
218+
return tokenizer;
219+
}
220+
221+
if (tokenizer_cache) {
222+
llg_free_tokenizer(tokenizer_cache);
223+
}
224+
vocab_cache = vocab;
225+
tokenizer_cache = tokenizer;
226+
227+
return llg_clone_tokenizer(tokenizer_cache);
228+
}
229+
230+
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
231+
const char * grammar_data) {
232+
auto * ctx = new llama_sampler_llg;
233+
234+
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
235+
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
236+
*ctx = {
237+
/* .vocab = */ vocab,
238+
/* .grammar_kind = */ grammar_kind,
239+
/* .grammar_data = */ grammar_data,
240+
/* .tokenizer = */ tokenizer,
241+
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
242+
/* .llg_res = */ {},
243+
/* .has_llg_res = */ false,
244+
};
245+
} else {
246+
*ctx = {
247+
/* .vocab = */ vocab,
248+
/* .grammar_kind = */ {},
249+
/* .grammar_data = */ {},
250+
/* .tokenizer = */ nullptr,
251+
/* .grammar = */ nullptr,
252+
/* .llg_res = */ {},
253+
/* .has_llg_res = */ false,
254+
};
255+
}
256+
257+
return new llama_sampler{
258+
/* .iface = */ &llama_sampler_llg_i,
259+
/* .ctx = */ ctx,
260+
};
261+
}
262+
263+
#else
264+
265+
llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
266+
LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
267+
return nullptr;
268+
}
269+
270+
#endif // LLAMA_USE_LLGUIDANCE

0 commit comments

Comments
 (0)