-
Notifications
You must be signed in to change notification settings - Fork 11.5k
llama : add option for greedy sampling with probs #3813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e274fe3
to
4aa1fb0
Compare
@@ -167,9 +167,13 @@ llama_token llama_sampling_sample( | |||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); | |||
} | |||
|
|||
if (temp <= 0) { | |||
// greedy sampling | |||
if (temp < 0.0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result is the same in either case right? I'm not entirely sure it's worth special casing this instead of just changing greedy sampling to do:
llama_sample_softmax(ctx_main, &cur_p);
id = cur_p.data[0].id;
But if you did go that way, you'd probably also want to change the common args parsing stuff to clamp the user-specified temperature to 0.0
so if they pass a negative value it's the same.
It's only internal stuff that would care about probs generated vs no probs unless I'm misunderstanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the same result yes. The probs are not used only internally - we are using them in speculative
. Before this PR, we had to do the hack with temp = 0.01f;
to get probs. Now we get them with temp = 0.0f;
The user specified input should not be affected by this change. Technically, the user would normally want to pass temp = -1.0f
to save the extra softmax compute, but it's probably not something that would affect performance in measurable way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, "internal" was a poor choice of words. I meant it's not something someone calling the application and passing --temp
on the commandline would care about. So if they do --temp -1
for an example that doesn't care about probs then it's kind of weird/unnecessary to turn on generating probs in that case.
So what I'm proposing is that the argument handling stuff would do something like:
params.sparams.temp = std::max(0.0f, atof(blah));
when parsing the commandline arguments, so even if the user does --temp -1
it's still just 0.0
. Then something like speculative
which cares about probs in the greedy sampling case can do:
if (params.sparams.temp == 0.0f) {
params.sparams.temp = -1.0f;
}
edit: Actually, you'd need to reverse the logic for the softmax case a bit also: so 0.0 = greedy sampling, no softmax. < 0.0 = greedy sampling with softmax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Should be good now
* llama : add option for greedy sampling with probs * llama : add comment about llama_sample_token_greedy() missing probs * sampling : temp == 0.0 -> no probs, temp < 0.0 -> probs
* llama : add option for greedy sampling with probs * llama : add comment about llama_sample_token_greedy() missing probs * sampling : temp == 0.0 -> no probs, temp < 0.0 -> probs
* llama : add option for greedy sampling with probs * llama : add comment about llama_sample_token_greedy() missing probs * sampling : temp == 0.0 -> no probs, temp < 0.0 -> probs
On
master
when using temp <= 0.0, we get greedy sampling but we don't have the probs of the tokens.This PR adds an option when using temp == 0.0, to do greedy sampling but also apply softmax so we get the probs.