Skip to content

Commit 7923b70

Browse files
committed
llama : add llm_build_inp_embd helper
1 parent 2073347 commit 7923b70

File tree

1 file changed

+50
-111
lines changed

1 file changed

+50
-111
lines changed

Diff for: llama.cpp

+50-111
Original file line numberDiff line numberDiff line change
@@ -1228,8 +1228,8 @@ struct llama_model {
12281228
llama_hparams hparams = {};
12291229
llama_vocab vocab;
12301230

1231-
struct ggml_tensor * tok_embeddings;
1232-
struct ggml_tensor * pos_embeddings;
1231+
struct ggml_tensor * tok_embd;
1232+
struct ggml_tensor * pos_embd;
12331233
struct ggml_tensor * tok_norm;
12341234
struct ggml_tensor * tok_norm_b;
12351235

@@ -2484,7 +2484,7 @@ static void llm_load_tensors(
24842484
case LLM_ARCH_LLAMA:
24852485
case LLM_ARCH_REFACT:
24862486
{
2487-
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2487+
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
24882488

24892489
// output
24902490
{
@@ -2552,7 +2552,7 @@ static void llm_load_tensors(
25522552
} break;
25532553
case LLM_ARCH_BAICHUAN:
25542554
{
2555-
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2555+
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
25562556
{
25572557
ggml_backend_type backend_norm;
25582558
ggml_backend_type backend_output;
@@ -2620,7 +2620,7 @@ static void llm_load_tensors(
26202620
{
26212621
// TODO: CPU-only for now
26222622

2623-
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2623+
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
26242624

26252625
// output
26262626
{
@@ -2696,8 +2696,8 @@ static void llm_load_tensors(
26962696
} break;
26972697
case LLM_ARCH_STARCODER:
26982698
{
2699-
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2700-
model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
2699+
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2700+
model.pos_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU);
27012701

27022702
// output
27032703
{
@@ -2775,7 +2775,7 @@ static void llm_load_tensors(
27752775
} break;
27762776
case LLM_ARCH_PERSIMMON:
27772777
{
2778-
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2778+
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
27792779

27802780
{
27812781
ggml_backend_type backend_norm;
@@ -2838,9 +2838,9 @@ static void llm_load_tensors(
28382838
{
28392839
// TODO: CPU-only for now
28402840

2841-
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2842-
model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU);
2843-
model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU);
2841+
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2842+
model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU);
2843+
model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU);
28442844

28452845
// output
28462846
{
@@ -2918,7 +2918,7 @@ static void llm_load_tensors(
29182918
} break;
29192919
case LLM_ARCH_MPT:
29202920
{
2921-
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
2921+
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
29222922

29232923
// output
29242924
{
@@ -3099,6 +3099,31 @@ enum llm_rope_type {
30993099
LLM_ROPE_GLM,
31003100
};
31013101

3102+
static struct ggml_tensor * llm_build_inp_embd(
3103+
struct ggml_context * ctx,
3104+
const llama_batch & batch,
3105+
struct ggml_tensor * tok_embd,
3106+
int64_t n_embd,
3107+
int32_t n_tokens,
3108+
const llm_build_cb & cb) {
3109+
struct ggml_tensor * inpL;
3110+
3111+
if (batch.token) {
3112+
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens);
3113+
cb(inp_tokens, "inp_tokens", -1);
3114+
3115+
inpL = ggml_get_rows(ctx, tok_embd, inp_tokens);
3116+
} else {
3117+
#ifdef GGML_USE_MPI
3118+
GGML_ASSERT(false && "not implemented");
3119+
#endif
3120+
3121+
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
3122+
}
3123+
3124+
return inpL;
3125+
}
3126+
31023127
// Persimmon: n_rot = n_embd_head/2
31033128
// Other: n_rot = n_embd_head
31043129
static void llm_build_k_shift(
@@ -3463,18 +3488,7 @@ static struct ggml_cgraph * llm_build_llama(
34633488
struct ggml_tensor * cur;
34643489
struct ggml_tensor * inpL;
34653490

3466-
if (batch.token) {
3467-
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
3468-
cb(inp_tokens, "inp_tokens", -1);
3469-
3470-
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
3471-
} else {
3472-
#ifdef GGML_USE_MPI
3473-
GGML_ASSERT(false && "not implemented");
3474-
#endif
3475-
3476-
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
3477-
}
3491+
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
34783492
cb(inpL, "inp_embd", -1);
34793493

34803494
// inp_pos - contains the positions
@@ -3619,18 +3633,7 @@ static struct ggml_cgraph * llm_build_baichaun(
36193633
struct ggml_tensor * cur;
36203634
struct ggml_tensor * inpL;
36213635

3622-
if (batch.token) {
3623-
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
3624-
cb(inp_tokens, "inp_tokens", -1);
3625-
3626-
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
3627-
} else {
3628-
#ifdef GGML_USE_MPI
3629-
GGML_ASSERT(false && "not implemented");
3630-
#endif
3631-
3632-
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
3633-
}
3636+
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
36343637
cb(inpL, "inp_embd", -1);
36353638

36363639
// inp_pos - contains the positions
@@ -3789,18 +3792,7 @@ static struct ggml_cgraph * llm_build_falcon(
37893792
struct ggml_tensor * cur;
37903793
struct ggml_tensor * inpL;
37913794

3792-
if (batch.token) {
3793-
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
3794-
cb(inp_tokens, "inp_tokens", -1);
3795-
3796-
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
3797-
} else {
3798-
#ifdef GGML_USE_MPI
3799-
GGML_ASSERT(false && "not implemented");
3800-
#endif
3801-
3802-
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
3803-
}
3795+
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
38043796
cb(inpL, "inp_embd", -1);
38053797

38063798
// inp_pos - contains the positions
@@ -3953,23 +3945,11 @@ static struct ggml_cgraph * llm_build_starcoder(
39533945
ggml_cgraph * gf = ggml_new_graph(ctx0);
39543946

39553947
struct ggml_tensor * cur;
3956-
struct ggml_tensor * embd;
39573948
struct ggml_tensor * pos;
39583949
struct ggml_tensor * inpL;
39593950

3960-
if (batch.token) {
3961-
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
3962-
cb(inp_tokens, "inp_tokens", -1);
3963-
3964-
embd = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
3965-
} else {
3966-
#ifdef GGML_USE_MPI
3967-
GGML_ASSERT(false && "not implemented");
3968-
#endif
3969-
3970-
embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
3971-
}
3972-
cb(embd, "inp_embd", -1);
3951+
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
3952+
cb(inpL, "inp_embd", -1);
39733953

39743954
// inp_pos - contains the positions
39753955
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
@@ -3983,10 +3963,10 @@ static struct ggml_cgraph * llm_build_starcoder(
39833963
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
39843964
cb(KQ_mask, "KQ_mask", -1);
39853965

3986-
pos = ggml_get_rows(ctx0, model.pos_embeddings, inp_pos);
3966+
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
39873967
cb(pos, "pos_embd", -1);
39883968

3989-
inpL = ggml_add(ctx0, embd, pos);
3969+
inpL = ggml_add(ctx0, inpL, pos);
39903970
cb(inpL, "inpL", -1);
39913971

39923972
for (int il = 0; il < n_layer; ++il) {
@@ -4108,14 +4088,7 @@ static struct ggml_cgraph * llm_build_persimmon(
41084088
struct ggml_tensor * cur;
41094089
struct ggml_tensor * inpL;
41104090

4111-
if (batch.token) {
4112-
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
4113-
cb(inp_tokens, "inp_tokens", -1);
4114-
4115-
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
4116-
} else {
4117-
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
4118-
}
4091+
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
41194092
cb(inpL, "imp_embd", -1);
41204093

41214094
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
@@ -4358,18 +4331,7 @@ static struct ggml_cgraph * llm_build_refact(
43584331
struct ggml_tensor * cur;
43594332
struct ggml_tensor * inpL;
43604333

4361-
if (batch.token) {
4362-
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
4363-
cb(inp_tokens, "inp_tokens", -1);
4364-
4365-
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
4366-
} else {
4367-
#ifdef GGML_USE_MPI
4368-
GGML_ASSERT(false && "not implemented");
4369-
#endif
4370-
4371-
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
4372-
}
4334+
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
43734335
cb(inpL, "inp_embd", -1);
43744336

43754337
// KQ_scale
@@ -4499,22 +4461,10 @@ static struct ggml_cgraph * llm_build_bloom(
44994461
ggml_cgraph * gf = ggml_new_graph(ctx0);
45004462

45014463
struct ggml_tensor * cur;
4502-
struct ggml_tensor * embd;
45034464
struct ggml_tensor * inpL;
45044465

4505-
if (batch.token) {
4506-
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
4507-
cb(inp_tokens, "inp_tokens", -1);
4508-
4509-
embd = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
4510-
} else {
4511-
#ifdef GGML_USE_MPI
4512-
GGML_ASSERT(false && "not implemented");
4513-
#endif
4514-
4515-
embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
4516-
}
4517-
cb(embd, "inp_embd", -1);
4466+
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
4467+
cb(inpL, "inp_embd", -1);
45184468

45194469
// KQ_scale
45204470
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
@@ -4524,7 +4474,7 @@ static struct ggml_cgraph * llm_build_bloom(
45244474
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
45254475
cb(KQ_mask, "KQ_mask", -1);
45264476

4527-
inpL = llm_build_norm(ctx0, embd,
4477+
inpL = llm_build_norm(ctx0, inpL,
45284478
model.tok_norm,
45294479
model.tok_norm_b,
45304480
LLM_NORM, norm_eps, cb, -1);
@@ -4648,18 +4598,7 @@ static struct ggml_cgraph * llm_build_mpt(
46484598
struct ggml_tensor * cur;
46494599
struct ggml_tensor * inpL;
46504600

4651-
if (batch.token) {
4652-
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
4653-
cb(inp_tokens, "inp_tokens", -1);
4654-
4655-
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
4656-
} else {
4657-
#ifdef GGML_USE_MPI
4658-
GGML_ASSERT(false && "not implemented");
4659-
#endif
4660-
4661-
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
4662-
}
4601+
inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb);
46634602
cb(inpL, "inp_embd", -1);
46644603

46654604
// KQ_scale

0 commit comments

Comments
 (0)