Skip to content

Commit d96ca7d

Browse files
authored
server : fix crash when prompt exceeds context size (#3996)
1 parent 34b0a08 commit d96ca7d

File tree

1 file changed

+29
-29
lines changed

1 file changed

+29
-29
lines changed

examples/server/server.cpp

+29-29
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,35 @@ struct llama_server_context
15571557

15581558
slot.num_prompt_tokens = prompt_tokens.size();
15591559

1560+
if (slot.params.n_keep < 0)
1561+
{
1562+
slot.params.n_keep = slot.num_prompt_tokens;
1563+
}
1564+
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
1565+
1566+
// if input prompt is too big, truncate it
1567+
if (slot.num_prompt_tokens >= slot.n_ctx)
1568+
{
1569+
const int n_left = slot.n_ctx - slot.params.n_keep;
1570+
const int n_block_size = n_left / 2;
1571+
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
1572+
1573+
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
1574+
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
1575+
1576+
LOG_VERBOSE("input truncated", {
1577+
{"n_ctx", slot.n_ctx},
1578+
{"n_keep", slot.params.n_keep},
1579+
{"n_left", n_left},
1580+
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
1581+
});
1582+
slot.truncated = true;
1583+
prompt_tokens = new_tokens;
1584+
1585+
slot.num_prompt_tokens = prompt_tokens.size();
1586+
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
1587+
}
1588+
15601589
if (!slot.params.cache_prompt)
15611590
{
15621591
llama_sampling_reset(slot.ctx_sampling);
@@ -1566,35 +1595,6 @@ struct llama_server_context
15661595
}
15671596
else
15681597
{
1569-
if (slot.params.n_keep < 0)
1570-
{
1571-
slot.params.n_keep = slot.num_prompt_tokens;
1572-
}
1573-
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
1574-
1575-
// if input prompt is too big, truncate it
1576-
if (slot.num_prompt_tokens >= slot.n_ctx)
1577-
{
1578-
const int n_left = slot.n_ctx - slot.params.n_keep;
1579-
const int n_block_size = n_left / 2;
1580-
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
1581-
1582-
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
1583-
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
1584-
1585-
LOG_VERBOSE("input truncated", {
1586-
{"n_ctx", slot.n_ctx},
1587-
{"n_keep", slot.params.n_keep},
1588-
{"n_left", n_left},
1589-
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
1590-
});
1591-
slot.truncated = true;
1592-
prompt_tokens = new_tokens;
1593-
1594-
slot.num_prompt_tokens = prompt_tokens.size();
1595-
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
1596-
}
1597-
15981598
// push the prompt into the sampling context (do not apply grammar)
15991599
for (auto &token : prompt_tokens)
16001600
{

0 commit comments

Comments
 (0)