Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6b026a9

Browse files
hxer7963willhewillhe
authored andcommittedApr 17, 2024
[Model] Add support for xverse (ggml-org#6301)
* Support xverse model convert to gguf format. * 1. Convert xverse models to gguf; 2. Add LLM_ARCH_XVERSE inference in llama.cpp; 3. Add xverse item in Supported models in README.md; * * gguf-py: remove redundant logs * llama: remove the init_mapping_prefetch custom parameter * llama.cpp: Include the changes from ggml-org#6122 to exclude the unused outputs of the last layers. * - Fix format issues - Remove duplicate set kqv_out to llm_build_kv * Update llama.cpp --------- Co-authored-by: willhe <[email protected]> Co-authored-by: willhe <[email protected]>
1 parent 4be177c commit 6b026a9

File tree

4 files changed

+329
-1
lines changed

4 files changed

+329
-1
lines changed
 

‎README.md

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ Typically finetunes of the base models below are supported as well.
115115
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
116116
- [x] [Gemma](https://ai.google.dev/gemma)
117117
- [x] [Mamba](https://github.com/state-spaces/mamba)
118+
- [x] [Xverse](https://huggingface.co/models?search=xverse)
118119
- [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
119120

120121
**Multimodal models:**

‎convert-hf-to-gguf.py

+142
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,148 @@ def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor:
773773
return weights[r * n_part:r * n_part + r, ...]
774774

775775

776+
@Model.register("XverseForCausalLM")
777+
class XverseModel(Model):
778+
model_arch = gguf.MODEL_ARCH.XVERSE
779+
780+
def set_vocab(self):
781+
assert (self.dir_model / "tokenizer.json").is_file()
782+
dir_model = self.dir_model
783+
hparams = self.hparams
784+
785+
tokens: list[bytearray] = []
786+
toktypes: list[int] = []
787+
788+
from transformers import AutoTokenizer
789+
tokenizer = AutoTokenizer.from_pretrained(dir_model)
790+
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
791+
assert max(tokenizer.vocab.values()) < vocab_size
792+
793+
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
794+
added_vocab = tokenizer.get_added_vocab()
795+
796+
for token_id in range(vocab_size):
797+
token_text = reverse_vocab[token_id].encode('utf-8')
798+
# replace "\x00" to string with length > 0
799+
if token_text == b"\x00":
800+
toktype = gguf.TokenType.BYTE # special
801+
token_text = f"<{token_text}>".encode('utf-8')
802+
elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
803+
toktype = gguf.TokenType.BYTE # special
804+
elif reverse_vocab[token_id] in added_vocab:
805+
if tokenizer.added_tokens_decoder[token_id].special:
806+
toktype = gguf.TokenType.CONTROL
807+
else:
808+
toktype = gguf.TokenType.USER_DEFINED
809+
else:
810+
toktype = gguf.TokenType.NORMAL
811+
812+
tokens.append(token_text)
813+
toktypes.append(toktype)
814+
815+
self.gguf_writer.add_tokenizer_model("llama")
816+
self.gguf_writer.add_token_list(tokens)
817+
self.gguf_writer.add_token_types(toktypes)
818+
819+
special_vocab = gguf.SpecialVocab(dir_model, n_vocab=len(tokens))
820+
special_vocab.add_to_gguf(self.gguf_writer)
821+
822+
def set_gguf_parameters(self):
823+
block_count = self.hparams["num_hidden_layers"]
824+
head_count = self.hparams["num_attention_heads"]
825+
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
826+
hf_repo = self.hparams.get("_name_or_path", "")
827+
828+
ctx_length = 0
829+
if "max_sequence_length" in self.hparams:
830+
ctx_length = self.hparams["max_sequence_length"]
831+
elif "max_position_embeddings" in self.hparams:
832+
ctx_length = self.hparams["max_position_embeddings"]
833+
elif "model_max_length" in self.hparams:
834+
ctx_length = self.hparams["model_max_length"]
835+
else:
836+
print("gguf: can not find ctx length parameter.")
837+
sys.exit()
838+
839+
self.gguf_writer.add_name(self.dir_model.name)
840+
self.gguf_writer.add_source_hf_repo(hf_repo)
841+
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
842+
self.gguf_writer.add_context_length(ctx_length)
843+
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
844+
self.gguf_writer.add_block_count(block_count)
845+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
846+
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
847+
self.gguf_writer.add_head_count(head_count)
848+
self.gguf_writer.add_head_count_kv(head_count_kv)
849+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
850+
851+
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
852+
if self.hparams["rope_scaling"].get("type") == "linear":
853+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
854+
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
855+
856+
def write_tensors(self):
857+
# Collect tensors from generator object
858+
model_kv = dict(self.get_tensors())
859+
block_count = self.hparams["num_hidden_layers"]
860+
head_count = self.hparams["num_attention_heads"]
861+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
862+
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
863+
864+
for name, data_torch in model_kv.items():
865+
# we don't need these
866+
if name.endswith(".rotary_emb.inv_freq"):
867+
continue
868+
869+
old_dtype = data_torch.dtype
870+
871+
# convert any unsupported data types to float32
872+
if data_torch.dtype not in (torch.float16, torch.float32):
873+
data_torch = data_torch.to(torch.float32)
874+
875+
# HF models permute some of the tensors, so we need to undo that
876+
if name.endswith(("q_proj.weight")):
877+
data_torch = self._reverse_hf_permute(data_torch, head_count, head_count)
878+
if name.endswith(("k_proj.weight")):
879+
data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv)
880+
881+
data = data_torch.squeeze().numpy()
882+
883+
# map tensor names
884+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
885+
if new_name is None:
886+
print(f"Can not map tensor {name!r}")
887+
sys.exit()
888+
889+
n_dims = len(data.shape)
890+
data_dtype = data.dtype
891+
892+
# if f32 desired, convert any float16 to float32
893+
if self.ftype == 0 and data_dtype == np.float16:
894+
data = data.astype(np.float32)
895+
896+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
897+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
898+
data = data.astype(np.float32)
899+
900+
# if f16 desired, convert any float32 2-dim weight tensors to float16
901+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
902+
data = data.astype(np.float16)
903+
904+
print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
905+
self.gguf_writer.add_tensor(new_name, data)
906+
907+
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
908+
if n_kv_head is not None and n_head != n_kv_head:
909+
n_head //= n_kv_head
910+
911+
return (
912+
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
913+
.swapaxes(1, 2)
914+
.reshape(weights.shape)
915+
)
916+
917+
776918
@Model.register("FalconForCausalLM", "RWForCausalLM")
777919
class FalconModel(Model):
778920
model_arch = gguf.MODEL_ARCH.FALCON

‎gguf-py/gguf/constants.py

+22
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class MODEL_ARCH(IntEnum):
123123
GEMMA = auto()
124124
STARCODER2 = auto()
125125
MAMBA = auto()
126+
XVERSE = auto()
126127
COMMAND_R = auto()
127128

128129

@@ -191,6 +192,7 @@ class MODEL_TENSOR(IntEnum):
191192
MODEL_ARCH.GEMMA: "gemma",
192193
MODEL_ARCH.STARCODER2: "starcoder2",
193194
MODEL_ARCH.MAMBA: "mamba",
195+
MODEL_ARCH.XVERSE: "xverse",
194196
MODEL_ARCH.COMMAND_R: "command-r",
195197
}
196198

@@ -606,6 +608,22 @@ class MODEL_TENSOR(IntEnum):
606608
MODEL_TENSOR.SSM_D,
607609
MODEL_TENSOR.SSM_OUT,
608610
],
611+
MODEL_ARCH.XVERSE: [
612+
MODEL_TENSOR.TOKEN_EMBD,
613+
MODEL_TENSOR.OUTPUT_NORM,
614+
MODEL_TENSOR.OUTPUT,
615+
MODEL_TENSOR.ROPE_FREQS,
616+
MODEL_TENSOR.ATTN_NORM,
617+
MODEL_TENSOR.ATTN_Q,
618+
MODEL_TENSOR.ATTN_K,
619+
MODEL_TENSOR.ATTN_V,
620+
MODEL_TENSOR.ATTN_OUT,
621+
MODEL_TENSOR.ATTN_ROT_EMBD,
622+
MODEL_TENSOR.FFN_NORM,
623+
MODEL_TENSOR.FFN_GATE,
624+
MODEL_TENSOR.FFN_DOWN,
625+
MODEL_TENSOR.FFN_UP,
626+
],
609627
MODEL_ARCH.COMMAND_R: [
610628
MODEL_TENSOR.TOKEN_EMBD,
611629
MODEL_TENSOR.OUTPUT_NORM,
@@ -650,6 +668,10 @@ class MODEL_TENSOR(IntEnum):
650668
MODEL_TENSOR.ROPE_FREQS,
651669
MODEL_TENSOR.ATTN_ROT_EMBD,
652670
],
671+
MODEL_ARCH.XVERSE: [
672+
MODEL_TENSOR.ROPE_FREQS,
673+
MODEL_TENSOR.ATTN_ROT_EMBD,
674+
],
653675
}
654676

655677
#

0 commit comments

Comments
 (0)
Please sign in to comment.