-
Notifications
You must be signed in to change notification settings - Fork 11.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hotfix prompt caching introduced in #1169, fixes #1257 #1260
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2701,7 +2701,7 @@ size_t llama_load_session_file(struct llama_context * ctx, const char * path_ses | |
const uint32_t magic = file.read_u32(); | ||
const uint32_t version = file.read_u32(); | ||
|
||
if (!(magic == 'ggsn' && version == 0)) { | ||
if (!(magic == 'ggsn' && version == 1)) { | ||
fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); | ||
return 0; | ||
} | ||
|
@@ -2724,6 +2724,7 @@ size_t llama_load_session_file(struct llama_context * ctx, const char * path_ses | |
const size_t n_orig_state_size = llama_get_state_size(ctx); | ||
if (n_state_size != n_orig_state_size) { | ||
fprintf(stderr, "%s : failed to validate state size\n", __func__); | ||
return 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! Just noticed this, oops... |
||
} | ||
std::unique_ptr<uint8_t[]> state_data(new uint8_t[n_state_size]); | ||
file.read_raw(state_data.get(), n_state_size); | ||
|
@@ -2739,7 +2740,7 @@ size_t llama_save_session_file(struct llama_context * ctx, const char * path_ses | |
llama_copy_state_data(ctx, state_data.get()); | ||
|
||
file.write_u32('ggsn'); // magic | ||
file.write_u32(0); // version | ||
file.write_u32(1); // version | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: it's harmless but I don't believe it's necessary here to change the version? The binary format isn't affected by the decision of how many tokens to store. |
||
file.write_raw(&ctx->model.hparams, sizeof(llama_hparams)); | ||
|
||
file.write_u32((uint32_t) n_token_count); // REVIEW | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could actually be the right approach in general. It will just have the effect of forcing eval one token earlier, which will of course have effectively the same performance. That said, ideally I'm wondering if we do this on the read side, like maybe try decrementing
n_matching_session_tokens
by one in the startup code above? The reason I say that is eventually I envision sessions being used to restore the full transcript vs. caching the prompt, and in this case the approach here would clip the last token. In addition, I believe that there's an (however unlikely) edge case here with empty session_tokens.I'd actually been wondering about if we should do this anyway, because of the flow of state from eval to sampling, specifically the logits vector. In other words, if we should force an eval to ensure up-to-date logits. This corresponds to the fact that, in
main
at least, sampling is never done without first eval-ing. I haven't deeply explored this. Perhaps this could be related to your issue.