From 59b389f123b10d5542c6495c9d5c82f0b00fe006 Mon Sep 17 00:00:00 2001 From: Julius Arkenberg Date: Thu, 21 Mar 2024 13:44:59 +0000 Subject: [PATCH 1/6] Add support for Grok model architecture --- convert-hf-to-gguf.py | 13 +- gguf-py/gguf/constants.py | 24 +++ gguf-py/gguf/tensor_mapping.py | 25 ++- llama.cpp | 299 +++++++++++++++++++++++++++++++++ 4 files changed, 354 insertions(+), 7 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1e49d56c19514..f42d16c5472f6 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -19,8 +19,8 @@ if TYPE_CHECKING: from torch import Tensor -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +#if 'NO_LOCAL_GGUF' not in os.environ: +sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf from convert import HfVocab @@ -53,7 +53,7 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") self.part_names = self._get_part_names() self.hparams = Model.load_hparams(self.dir_model) - self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False) + self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=True) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) @property @@ -1051,6 +1051,13 @@ def set_vocab(self): self._set_vocab_sentencepiece() +@Model.register("GrokForCausalLM") +class GrokModel(Model): + model_arch = gguf.MODEL_ARCH.GROK + + def set_vocab(self): + self._set_vocab_sentencepiece() + @Model.register("MiniCPMForCausalLM") class MiniCPMModel(Model): model_arch = gguf.MODEL_ARCH.MINICPM diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 4a4facb06ea14..e47896e2a9d3e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -100,6 +100,7 @@ class MODEL_ARCH(IntEnum): LLAMA = auto() FALCON = auto() BAICHUAN = auto() + GROK = auto() GPT2 = auto() GPTJ = auto() GPTNEOX = auto() @@ -167,6 +168,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GROK: "grok", MODEL_ARCH.GPT2: "gpt2", MODEL_ARCH.GPTJ: "gptj", MODEL_ARCH.GPTNEOX: "gptneox", @@ -251,6 +253,28 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.GROK: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.LAYER_OUT_NORM, + ], MODEL_ARCH.GPTNEOX: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index ed89955d8970f..7f482dd77c3dc 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -23,6 +23,7 @@ class TensorNameMap: "model.embedding", # mamba-qbert "backbone.embedding", # mamba "backbone.embeddings", # mamba-hf + "transformer.in_out_embed", # Grok ), # Token type embeddings @@ -66,6 +67,7 @@ class TensorNameMap: "lm_head.ln", # phi2 "model.norm_f", # mamba-qbert "backbone.norm_f", # mamba + "transformer.rms_norm", # Grok ), # Rope frequencies @@ -93,6 +95,7 @@ class TensorNameMap: "model.layers.{bid}.attention_norm", # internlm2 "model.layers.{bid}.norm", # mamba-qbert "backbone.layers.{bid}.norm", # mamba + "transformer.decoder_layer.{bid}.rms_norm", # Grok ), # Attention norm 2 @@ -121,7 +124,8 @@ class TensorNameMap: "encoder.layer.{bid}.attention.self.query", # bert "transformer.h.{bid}.attn.q_proj", # gpt-j "model.layers.layers.{bid}.self_attn.q_proj", # plamo - "model.layers.{bid}.attention.wq" # internlm2 + "model.layers.{bid}.attention.wq", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok ), # Attention key @@ -131,7 +135,8 @@ class TensorNameMap: "encoder.layer.{bid}.attention.self.key", # bert "transformer.h.{bid}.attn.k_proj", # gpt-j "model.layers.layers.{bid}.self_attn.k_proj", # plamo - "model.layers.{bid}.attention.wk" # internlm2 + "model.layers.{bid}.attention.wk", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok ), # Attention value @@ -141,7 +146,8 @@ class TensorNameMap: "encoder.layer.{bid}.attention.self.value", # bert "transformer.h.{bid}.attn.v_proj", # gpt-j "model.layers.layers.{bid}.self_attn.v_proj", # plamo - "model.layers.{bid}.attention.wv" # internlm2 + "model.layers.{bid}.attention.wv", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok ), # Attention output @@ -162,12 +168,14 @@ class TensorNameMap: "model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert + "transformer.decoder_layer.{bid}.multi_head_attention.linear" # Grok ), # Attention output norm MODEL_TENSOR.ATTN_OUT_NORM: ( "encoder.layer.{bid}.attention.output.LayerNorm", # bert "encoder.layers.{bid}.norm1", # nomic-bert + "transformer.decoder_layer.{bid}.rms_norm_1", # Grok ), # Rotary embeddings @@ -190,11 +198,15 @@ class TensorNameMap: "model.layers.{bid}.ln2", # yi "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 + + "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + ), MODEL_TENSOR.FFN_GATE_INP: ( "layers.{bid}.feed_forward.gate", # mixtral "model.layers.{bid}.block_sparse_moe.gate", # mixtral + "transformer.decoder_layer.{bid}.router" # Grok ), # Feed-forward up @@ -222,7 +234,8 @@ class TensorNameMap: MODEL_TENSOR.FFN_UP_EXP: ( "layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral - "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral + "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral + "transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok ), # AWQ-activation gate @@ -243,6 +256,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_EXP: ( "layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral + "transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok ), # Feed-forward down @@ -270,6 +284,8 @@ class TensorNameMap: MODEL_TENSOR.FFN_DOWN_EXP: ( "layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral + "transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok + ), MODEL_TENSOR.ATTN_Q_NORM: ( @@ -289,6 +305,7 @@ class TensorNameMap: MODEL_TENSOR.LAYER_OUT_NORM: ( "encoder.layer.{bid}.output.LayerNorm", # bert "encoder.layers.{bid}.norm2", # nomic-bert + "transformer.decoder_layer.{bid}.rms_norm_3", # Grok ), MODEL_TENSOR.SSM_IN: ( diff --git a/llama.cpp b/llama.cpp index 1a9fe0c4d2cea..9129cd799b50e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -192,6 +192,7 @@ enum llm_arch { LLM_ARCH_LLAMA, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, + LLM_ARCH_GROK, LLM_ARCH_GPT2, LLM_ARCH_GPTJ, LLM_ARCH_GPTNEOX, @@ -221,6 +222,7 @@ enum llm_arch { static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, { LLM_ARCH_GPT2, "gpt2" }, { LLM_ARCH_GPTJ, "gptj" }, { LLM_ARCH_GPTNEOX, "gptneox" }, @@ -483,6 +485,28 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_GROK, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + }, + }, { LLM_ARCH_GPT2, { @@ -4265,6 +4289,57 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_GROK: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, false); + + + GGML_ASSERT(hparams.n_expert > 0); + GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + for (uint32_t x = 0; x < hparams.n_expert; ++x) { + layer.ffn_gate_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), {n_embd, n_ff}); + layer.ffn_down_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}); + layer.ffn_up_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}); + } + + + + layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); + } + } break; case LLM_ARCH_BAICHUAN: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -5451,6 +5526,20 @@ static struct ggml_tensor * llm_build_kqv( ggml_mul_mat_set_prec(kq, GGML_PREC_F32); } + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx, kq, 30); + } + #if defined(GGML_USE_KOMPUTE) #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") @@ -6225,6 +6314,211 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_grok() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + + // multiply by embedding_multiplier_scale of 78.38367176906169 + inpL = ggml_scale(ctx0, inpL, 78.38367176906169f); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + //for (int il = 0; il < 1; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + } + + // Grok + // if attn_out_norm is present then apply it before adding the input + if (model.layers[il].attn_out_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].attn_out_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_out_norm", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] + cb(probs, "ffn_moe_probs", il); + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + + ggml_tensor * weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); + cb(weights, "ffn_moe_weights", il); + + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] + + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] + cb(weights, "ffn_moe_weights_norm", il); + + // compute expert outputs + ggml_tensor * moe_out = nullptr; + + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert; + + ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exp, n_expert, selected_experts, i, cur); + cb(cur_up, "ffn_moe_up", il); + + ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur); + cb(cur_gate, "ffn_moe_gate", il); + + //cur_gate = ggml_silu(ctx0, cur_gate); + //cb(cur_gate, "ffn_moe_silu", il); + + //GeLU + cur_gate = ggml_gelu(ctx0, cur_gate); + cb(cur_gate, "ffn_moe_gelu", il); + + cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] + cb(cur_expert, "ffn_moe_gate_par", il); + + cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exp, n_expert, selected_experts, i, cur_expert); // [n_tokens, n_embd] + cb(cur_expert, "ffn_moe_down", il); + + cur_expert = ggml_mul(ctx0, cur_expert, + ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); + cb(cur_expert, "ffn_moe_weighted", il); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx0, moe_out, cur_expert); + cb(moe_out, "ffn_moe_out", il); + } + } + + cur = moe_out; + + + // Grok + // if layer_out_norm is present then apply it before adding the input + // Idea: maybe ffn_out_norm is a better name + if (model.layers[il].layer_out_norm) { + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].layer_out_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "layer_out_norm", il); + } + + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + + // Grok + // multiply logits by output_multiplier_scale of 0.5773502691896257 + + cur = ggml_scale(ctx0, cur, 0.5773502691896257f); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_starcoder() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -8648,6 +8942,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_falcon(); } break; + case LLM_ARCH_GROK: + { + result = llm.build_grok(); + } break; case LLM_ARCH_STARCODER: { result = llm.build_starcoder(); @@ -13373,6 +13671,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_GROK: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: case LLM_ARCH_PLAMO: From 95612548a000852c33f8ede89011a47afc3c9b9d Mon Sep 17 00:00:00 2001 From: Julius Arkenberg Date: Thu, 21 Mar 2024 15:34:38 +0100 Subject: [PATCH 2/6] Revert convert-hf-to-gguf to default options --- convert-hf-to-gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index f42d16c5472f6..0107efe649e15 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -19,8 +19,8 @@ if TYPE_CHECKING: from torch import Tensor -#if 'NO_LOCAL_GGUF' not in os.environ: -sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf from convert import HfVocab @@ -53,7 +53,7 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") self.part_names = self._get_part_names() self.hparams = Model.load_hparams(self.dir_model) - self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=True) + self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) @property From 6052e3b3a7616f6cc506771bfcf93bb21a019ad2 Mon Sep 17 00:00:00 2001 From: Julius Arkenberg Date: Thu, 21 Mar 2024 16:58:51 +0000 Subject: [PATCH 3/6] Fixed f_norm_rms_eps bug --- convert-hf-to-gguf.py | 21 +++++++++++++++++++++ llama.cpp | 11 +++++++++++ 2 files changed, 32 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 0107efe649e15..c9cca34fa4106 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -93,31 +93,42 @@ def set_gguf_parameters(self): if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: self.gguf_writer.add_context_length(n_ctx) + print(f"gguf: context length = {n_ctx}") n_embd = self.find_hparam(["hidden_size", "n_embd"]) self.gguf_writer.add_embedding_length(n_embd) + print(f"gguf: embedding length = {n_embd}") if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None: self.gguf_writer.add_feed_forward_length(n_ff) + print(f"gguf: feed forward length = {n_ff}") n_head = self.find_hparam(["num_attention_heads", "n_head"]) self.gguf_writer.add_head_count(n_head) + print(f"gguf: head count = {n_head}") if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: self.gguf_writer.add_head_count_kv(n_head_kv) + print(f"gguf: key-value head count = {n_head_kv}") if (rope_theta := self.hparams.get("rope_theta")) is not None: self.gguf_writer.add_rope_freq_base(rope_theta) + print(f"gguf: rope theta = {rope_theta}") if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) + print(f"gguf: rms norm epsilon = {f_rms_eps}") if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: self.gguf_writer.add_layer_norm_eps(f_norm_eps) + print(f"gguf: layer norm epsilon = {f_norm_eps}") if (n_experts := self.hparams.get("num_local_experts")) is not None: self.gguf_writer.add_expert_count(n_experts) + print(f"gguf: expert count = {n_experts}") if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) + print(f"gguf: experts used count = {n_experts_used}") self.gguf_writer.add_file_type(self.ftype) + print(f"gguf: file type = {self.ftype}") def write_tensors(self): block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) @@ -1057,6 +1068,16 @@ class GrokModel(Model): def set_vocab(self): self._set_vocab_sentencepiece() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_name("Grok") + + + @Model.register("MiniCPMForCausalLM") class MiniCPMModel(Model): diff --git a/llama.cpp b/llama.cpp index 9129cd799b50e..a015d67b1a987 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1645,6 +1645,7 @@ enum e_model { MODEL_40B, MODEL_65B, MODEL_70B, + MODEL_314B, MODEL_SMALL, MODEL_MEDIUM, MODEL_LARGE, @@ -3314,6 +3315,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_40B: return "40B"; case MODEL_65B: return "65B"; case MODEL_70B: return "70B"; + case MODEL_314B: return "314B"; case MODEL_SMALL: return "0.1B"; case MODEL_MEDIUM: return "0.4B"; case MODEL_LARGE: return "0.8B"; @@ -3452,6 +3454,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GROK: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 64: model.type = e_model::MODEL_314B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_FALCON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); From 81ce9df3ee03d68f857d53a200b83b7269d695b1 Mon Sep 17 00:00:00 2001 From: Julius Arkenberg Date: Thu, 21 Mar 2024 19:59:15 +0000 Subject: [PATCH 4/6] Fix whitespaces --- convert-hf-to-gguf.py | 8 ++--- gguf-py/gguf/tensor_mapping.py | 66 +++++++++++++++++----------------- llama.cpp | 8 ++--- 3 files changed, 38 insertions(+), 44 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index c9cca34fa4106..723ea18e34c65 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1068,16 +1068,14 @@ class GrokModel(Model): def set_vocab(self): self._set_vocab_sentencepiece() - + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_name("Grok") - - - + @Model.register("MiniCPMForCausalLM") class MiniCPMModel(Model): diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7f482dd77c3dc..11fd34b8b9103 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -23,7 +23,7 @@ class TensorNameMap: "model.embedding", # mamba-qbert "backbone.embedding", # mamba "backbone.embeddings", # mamba-hf - "transformer.in_out_embed", # Grok + "transformer.in_out_embed", # Grok ), # Token type embeddings @@ -67,7 +67,7 @@ class TensorNameMap: "lm_head.ln", # phi2 "model.norm_f", # mamba-qbert "backbone.norm_f", # mamba - "transformer.rms_norm", # Grok + "transformer.rms_norm", # Grok ), # Rope frequencies @@ -95,7 +95,7 @@ class TensorNameMap: "model.layers.{bid}.attention_norm", # internlm2 "model.layers.{bid}.norm", # mamba-qbert "backbone.layers.{bid}.norm", # mamba - "transformer.decoder_layer.{bid}.rms_norm", # Grok + "transformer.decoder_layer.{bid}.rms_norm", # Grok ), # Attention norm 2 @@ -119,34 +119,34 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf - "layers.{bid}.attention.wq", # llama-pth - "encoder.layer.{bid}.attention.self.query", # bert - "transformer.h.{bid}.attn.q_proj", # gpt-j - "model.layers.layers.{bid}.self_attn.q_proj", # plamo - "model.layers.{bid}.attention.wq", # internlm2 + "model.layers.{bid}.self_attn.q_proj", # llama-hf + "layers.{bid}.attention.wq", # llama-pth + "encoder.layer.{bid}.attention.self.query", # bert + "transformer.h.{bid}.attn.q_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.q_proj", # plamo + "model.layers.{bid}.attention.wq", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok ), # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf - "layers.{bid}.attention.wk", # llama-pth - "encoder.layer.{bid}.attention.self.key", # bert - "transformer.h.{bid}.attn.k_proj", # gpt-j - "model.layers.layers.{bid}.self_attn.k_proj", # plamo - "model.layers.{bid}.attention.wk", # internlm2 + "model.layers.{bid}.self_attn.k_proj", # llama-hf + "layers.{bid}.attention.wk", # llama-pth + "encoder.layer.{bid}.attention.self.key", # bert + "transformer.h.{bid}.attn.k_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.k_proj", # plamo + "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok ), # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf - "layers.{bid}.attention.wv", # llama-pth - "encoder.layer.{bid}.attention.self.value", # bert - "transformer.h.{bid}.attn.v_proj", # gpt-j - "model.layers.layers.{bid}.self_attn.v_proj", # plamo - "model.layers.{bid}.attention.wv", # internlm2 + "model.layers.{bid}.self_attn.v_proj", # llama-hf + "layers.{bid}.attention.wv", # llama-pth + "encoder.layer.{bid}.attention.self.value", # bert + "transformer.h.{bid}.attn.v_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.v_proj", # plamo + "model.layers.{bid}.attention.wv", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok ), @@ -168,14 +168,14 @@ class TensorNameMap: "model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert - "transformer.decoder_layer.{bid}.multi_head_attention.linear" # Grok + "transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok ), # Attention output norm MODEL_TENSOR.ATTN_OUT_NORM: ( "encoder.layer.{bid}.attention.output.LayerNorm", # bert "encoder.layers.{bid}.norm1", # nomic-bert - "transformer.decoder_layer.{bid}.rms_norm_1", # Grok + "transformer.decoder_layer.{bid}.rms_norm_1", # Grok ), # Rotary embeddings @@ -198,15 +198,13 @@ class TensorNameMap: "model.layers.{bid}.ln2", # yi "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 - - "transformer.decoder_layer.{bid}.rms_norm_2", # Grok - + "transformer.decoder_layer.{bid}.rms_norm_2", # Grok ), MODEL_TENSOR.FFN_GATE_INP: ( "layers.{bid}.feed_forward.gate", # mixtral "model.layers.{bid}.block_sparse_moe.gate", # mixtral - "transformer.decoder_layer.{bid}.router" # Grok + "transformer.decoder_layer.{bid}.router" # Grok ), # Feed-forward up @@ -234,8 +232,8 @@ class TensorNameMap: MODEL_TENSOR.FFN_UP_EXP: ( "layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral - "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral - "transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok + "model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", # mixtral + "transformer.decoder_layer.{bid}.moe.{xid}.linear_v", # Grok ), # AWQ-activation gate @@ -256,7 +254,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_GATE_EXP: ( "layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", # mixtral - "transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok + "transformer.decoder_layer.{bid}.moe.{xid}.linear" # Grok ), # Feed-forward down @@ -284,7 +282,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_DOWN_EXP: ( "layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral "model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", # mixtral - "transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok + "transformer.decoder_layer.{bid}.moe.{xid}.linear_1", # Grok ), @@ -303,9 +301,9 @@ class TensorNameMap: ), MODEL_TENSOR.LAYER_OUT_NORM: ( - "encoder.layer.{bid}.output.LayerNorm", # bert - "encoder.layers.{bid}.norm2", # nomic-bert - "transformer.decoder_layer.{bid}.rms_norm_3", # Grok + "encoder.layer.{bid}.output.LayerNorm", # bert + "encoder.layers.{bid}.norm2", # nomic-bert + "transformer.decoder_layer.{bid}.rms_norm_3", # Grok ), MODEL_TENSOR.SSM_IN: ( diff --git a/llama.cpp b/llama.cpp index a015d67b1a987..5ad82e69b727b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4330,7 +4330,7 @@ static bool llm_load_tensors( layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); - + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, false); @@ -4345,8 +4345,6 @@ static bool llm_load_tensors( layer.ffn_down_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd}); layer.ffn_up_exp[x] = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXP, "weight", i, x), {n_embd, n_ff}); } - - layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); } @@ -6480,7 +6478,7 @@ struct llm_build_context { } cur = moe_out; - + // Grok // if layer_out_norm is present then apply it before adding the input @@ -6515,7 +6513,7 @@ struct llm_build_context { // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); - + // Grok // multiply logits by output_multiplier_scale of 0.5773502691896257 From abdc8ea34a8b6523117ae9bb1b073a2a5e91994e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 22 Mar 2024 22:18:47 +0200 Subject: [PATCH 5/6] llama : fix grok rope type --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 5ad82e69b727b..d81d4067df261 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13680,7 +13680,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: - case LLM_ARCH_GROK: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: case LLM_ARCH_PLAMO: @@ -13693,6 +13692,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // the pairs of head values are offset by n_rot/2 case LLM_ARCH_FALCON: + case LLM_ARCH_GROK: case LLM_ARCH_PERSIMMON: case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: From 9a9e6cde66d5f659a8166e40068a549d478206c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 23 Mar 2024 18:41:10 +0200 Subject: [PATCH 6/6] llama : minor --- llama.cpp | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/llama.cpp b/llama.cpp index d81d4067df261..0e64f38ba1319 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4307,7 +4307,7 @@ static bool llm_load_tensors( // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); // if output is NULL, init from the input tok embed if (model.output == NULL) { model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -4333,8 +4333,7 @@ static bool llm_load_tensors( layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}, false); - + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd}); GGML_ASSERT(hparams.n_expert > 0); GGML_ASSERT(hparams.n_expert_used > 0); @@ -6335,7 +6334,6 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); - // multiply by embedding_multiplier_scale of 78.38367176906169 inpL = ggml_scale(ctx0, inpL, 78.38367176906169f); @@ -6346,7 +6344,6 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); for (int il = 0; il < n_layer; ++il) { - //for (int il = 0; il < 1; ++il) { struct ggml_tensor * inpSA = inpL; // norm @@ -6452,9 +6449,6 @@ struct llm_build_context { ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exp, n_expert, selected_experts, i, cur); cb(cur_gate, "ffn_moe_gate", il); - //cur_gate = ggml_silu(ctx0, cur_gate); - //cb(cur_gate, "ffn_moe_silu", il); - //GeLU cur_gate = ggml_gelu(ctx0, cur_gate); cb(cur_gate, "ffn_moe_gelu", il); @@ -6479,7 +6473,6 @@ struct llm_build_context { cur = moe_out; - // Grok // if layer_out_norm is present then apply it before adding the input // Idea: maybe ffn_out_norm is a better name @@ -6514,7 +6507,6 @@ struct llm_build_context { // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); - // Grok // multiply logits by output_multiplier_scale of 0.5773502691896257 @@ -6527,7 +6519,6 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_starcoder() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);