6
6
#include < string>
7
7
#include < iterator>
8
8
#include < algorithm>
9
+ #include < sstream>
10
+ #include < iostream>
9
11
10
12
#if defined (_WIN32)
11
13
#include < fcntl.h>
@@ -114,6 +116,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
114
116
break ;
115
117
}
116
118
params.temp = std::stof (argv[i]);
119
+ } else if (arg == " --tfs" ) {
120
+ if (++i >= argc) {
121
+ invalid_param = true ;
122
+ break ;
123
+ }
124
+ params.tfs_z = std::stof (argv[i]);
125
+ } else if (arg == " --typical" ) {
126
+ if (++i >= argc) {
127
+ invalid_param = true ;
128
+ break ;
129
+ }
130
+ params.typical_p = std::stof (argv[i]);
117
131
} else if (arg == " --repeat_last_n" ) {
118
132
if (++i >= argc) {
119
133
invalid_param = true ;
@@ -126,6 +140,36 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
126
140
break ;
127
141
}
128
142
params.repeat_penalty = std::stof (argv[i]);
143
+ } else if (arg == " --frequency_penalty" ) {
144
+ if (++i >= argc) {
145
+ invalid_param = true ;
146
+ break ;
147
+ }
148
+ params.frequency_penalty = std::stof (argv[i]);
149
+ } else if (arg == " --presence_penalty" ) {
150
+ if (++i >= argc) {
151
+ invalid_param = true ;
152
+ break ;
153
+ }
154
+ params.presence_penalty = std::stof (argv[i]);
155
+ } else if (arg == " --mirostat" ) {
156
+ if (++i >= argc) {
157
+ invalid_param = true ;
158
+ break ;
159
+ }
160
+ params.mirostat = std::stoi (argv[i]);
161
+ } else if (arg == " --mirostat_lr" ) {
162
+ if (++i >= argc) {
163
+ invalid_param = true ;
164
+ break ;
165
+ }
166
+ params.mirostat_eta = std::stof (argv[i]);
167
+ } else if (arg == " --mirostat_ent" ) {
168
+ if (++i >= argc) {
169
+ invalid_param = true ;
170
+ break ;
171
+ }
172
+ params.mirostat_tau = std::stof (argv[i]);
129
173
} else if (arg == " -b" || arg == " --batch_size" ) {
130
174
if (++i >= argc) {
131
175
invalid_param = true ;
@@ -185,7 +229,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
185
229
} else if (arg == " --perplexity" ) {
186
230
params.perplexity = true ;
187
231
} else if (arg == " --ignore-eos" ) {
188
- params.ignore_eos = true ;
232
+ params.logit_bias [llama_token_eos ()] = -INFINITY;
233
+ } else if (arg == " --no-penalize-nl" ) {
234
+ params.penalize_nl = false ;
235
+ } else if (arg == " -l" || arg == " --logit-bias" ) {
236
+ if (++i >= argc) {
237
+ invalid_param = true ;
238
+ break ;
239
+ }
240
+ std::stringstream ss (argv[i]);
241
+ llama_token key;
242
+ char sign;
243
+ std::string value_str;
244
+ try {
245
+ if (ss >> key && ss >> sign && std::getline (ss, value_str) && (sign == ' +' || sign == ' -' )) {
246
+ params.logit_bias [key] = std::stof (value_str) * ((sign == ' -' ) ? -1 .0f : 1 .0f );
247
+ } else {
248
+ throw std::exception ();
249
+ }
250
+ } catch (const std::exception &e) {
251
+ invalid_param = true ;
252
+ break ;
253
+ }
189
254
} else if (arg == " --n_parts" ) {
190
255
if (++i >= argc) {
191
256
invalid_param = true ;
@@ -240,12 +305,26 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
240
305
fprintf (stderr, " -f FNAME, --file FNAME\n " );
241
306
fprintf (stderr, " prompt file to start generation.\n " );
242
307
fprintf (stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n " , params.n_predict );
243
- fprintf (stderr, " --top_k N top-k sampling (default: %d)\n " , params.top_k );
244
- fprintf (stderr, " --top_p N top-p sampling (default: %.1f)\n " , (double )params.top_p );
245
- fprintf (stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n " , params.repeat_last_n );
246
- fprintf (stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n " , (double )params.repeat_penalty );
308
+ fprintf (stderr, " --top_k N top-k sampling (default: %d, 0 = disabled)\n " , params.top_k );
309
+ fprintf (stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n " , (double )params.top_p );
310
+ fprintf (stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n " , (double )params.tfs_z );
311
+ fprintf (stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n " , (double )params.typical_p );
312
+ fprintf (stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n " , params.repeat_last_n );
313
+ fprintf (stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n " , (double )params.repeat_penalty );
314
+ fprintf (stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n " , (double )params.presence_penalty );
315
+ fprintf (stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n " , (double )params.frequency_penalty );
316
+ fprintf (stderr, " --mirostat N use Mirostat sampling.\n " );
317
+ fprintf (stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n " );
318
+ fprintf (stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n " , params.mirostat );
319
+ fprintf (stderr, " --mirostat_lr N Mirostat learning rate, parameter eta (default: %.1f)\n " , (double )params.mirostat_eta );
320
+ fprintf (stderr, " --mirostat_ent N Mirostat target entropy, parameter tau (default: %.1f)\n " , (double )params.mirostat_tau );
321
+ fprintf (stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n " );
322
+ fprintf (stderr, " modifies the likelihood of token appearing in the completion,\n " );
323
+ fprintf (stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n " );
324
+ fprintf (stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n " );
247
325
fprintf (stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n " , params.n_ctx );
248
- fprintf (stderr, " --ignore-eos ignore end of stream token and continue generating\n " );
326
+ fprintf (stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n " );
327
+ fprintf (stderr, " --no-penalize-nl do not penalize newline token\n " );
249
328
fprintf (stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n " );
250
329
fprintf (stderr, " --temp N temperature (default: %.1f)\n " , (double )params.temp );
251
330
fprintf (stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n " );
0 commit comments