@@ -277,6 +277,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
277
277
params.use_color = true ;
278
278
} else if (arg == " --mlock" ) {
279
279
params.use_mlock = true ;
280
+ } else if (arg == " --gpu-layers" || arg == " -ngl" || arg == " --n-gpu-layers" ) {
281
+ if (++i >= argc) {
282
+ invalid_param = true ;
283
+ break ;
284
+ }
285
+ params.n_gpu_layers = std::stoi (argv[i]);
280
286
} else if (arg == " --no-mmap" ) {
281
287
params.use_mmap = false ;
282
288
} else if (arg == " --mtest" ) {
@@ -421,6 +427,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
421
427
if (llama_mmap_supported ()) {
422
428
fprintf (stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n " );
423
429
}
430
+ fprintf (stderr, " -ngl N, --n-gpu-layers N\n " );
431
+ fprintf (stderr, " number of layers to store in VRAM\n " );
424
432
fprintf (stderr, " --mtest compute maximum memory usage\n " );
425
433
fprintf (stderr, " --verbose-prompt print prompt before generation\n " );
426
434
fprintf (stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n " );
@@ -463,14 +471,15 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
463
471
struct llama_context * llama_init_from_gpt_params (const gpt_params & params) {
464
472
auto lparams = llama_context_default_params ();
465
473
466
- lparams.n_ctx = params.n_ctx ;
467
- lparams.n_parts = params.n_parts ;
468
- lparams.seed = params.seed ;
469
- lparams.f16_kv = params.memory_f16 ;
470
- lparams.use_mmap = params.use_mmap ;
471
- lparams.use_mlock = params.use_mlock ;
472
- lparams.logits_all = params.perplexity ;
473
- lparams.embedding = params.embedding ;
474
+ lparams.n_ctx = params.n_ctx ;
475
+ lparams.n_parts = params.n_parts ;
476
+ lparams.n_gpu_layers = params.n_gpu_layers ;
477
+ lparams.seed = params.seed ;
478
+ lparams.f16_kv = params.memory_f16 ;
479
+ lparams.use_mmap = params.use_mmap ;
480
+ lparams.use_mlock = params.use_mlock ;
481
+ lparams.logits_all = params.perplexity ;
482
+ lparams.embedding = params.embedding ;
474
483
475
484
llama_context * lctx = llama_init_from_file (params.model .c_str (), lparams);
476
485
0 commit comments