Skip to content

llama : add option for greedy sampling with probs #3813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,13 @@ llama_token llama_sampling_sample(
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}

if (temp <= 0) {
// greedy sampling
if (temp < 0.0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result is the same in either case right? I'm not entirely sure it's worth special casing this instead of just changing greedy sampling to do:

        llama_sample_softmax(ctx_main, &cur_p);
        id = cur_p.data[0].id;

But if you did go that way, you'd probably also want to change the common args parsing stuff to clamp the user-specified temperature to 0.0 so if they pass a negative value it's the same.

It's only internal stuff that would care about probs generated vs no probs unless I'm misunderstanding.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same result yes. The probs are not used only internally - we are using them in speculative. Before this PR, we had to do the hack with temp = 0.01f; to get probs. Now we get them with temp = 0.0f;

The user specified input should not be affected by this change. Technically, the user would normally want to pass temp = -1.0f to save the extra softmax compute, but it's probably not something that would affect performance in measurable way.

Copy link
Collaborator

@KerfuffleV2 KerfuffleV2 Oct 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, "internal" was a poor choice of words. I meant it's not something someone calling the application and passing --temp on the commandline would care about. So if they do --temp -1 for an example that doesn't care about probs then it's kind of weird/unnecessary to turn on generating probs in that case.

So what I'm proposing is that the argument handling stuff would do something like:

params.sparams.temp = std::max(0.0f, atof(blah));

when parsing the commandline arguments, so even if the user does --temp -1 it's still just 0.0. Then something like speculative which cares about probs in the greedy sampling case can do:

if (params.sparams.temp == 0.0f) {
    params.sparams.temp = -1.0f;
}

edit: Actually, you'd need to reverse the logic for the softmax case a bit also: so 0.0 = greedy sampling, no softmax. < 0.0 = greedy sampling with softmax.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Should be good now

// greedy sampling, no probs
id = llama_sample_token_greedy(ctx_main, &cur_p);
} else if (temp == 0.0) {
// greedy sampling, with probs
llama_sample_softmax(ctx_main, &cur_p);
id = cur_p.data[0].id;
} else {
if (mirostat == 1) {
const int mirostat_m = 100;
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
std::vector<seq_draft> drafts(n_seq_dft);

params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
params.sparams.temp = std::max(0.01f, params.sparams.temp);
params.sparams.temp = 0.0f;

for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
Expand Down