|
10 | 10 | import sys
|
11 | 11 | from enum import IntEnum
|
12 | 12 | from pathlib import Path
|
13 |
| -from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast |
| 13 | +from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional |
14 | 14 |
|
15 | 15 | import numpy as np
|
16 | 16 | import torch
|
@@ -168,6 +168,8 @@ def from_model_architecture(model_architecture):
|
168 | 168 | return PersimmonModel
|
169 | 169 | if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
170 | 170 | return StableLMModel
|
| 171 | + if model_architecture == "QWenLMHeadModel": |
| 172 | + return QwenModel |
171 | 173 | return Model
|
172 | 174 |
|
173 | 175 | def _is_model_safetensors(self) -> bool:
|
@@ -203,6 +205,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
|
203 | 205 | return gguf.MODEL_ARCH.PERSIMMON
|
204 | 206 | if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
205 | 207 | return gguf.MODEL_ARCH.STABLELM
|
| 208 | + if arch == "QWenLMHeadModel": |
| 209 | + return gguf.MODEL_ARCH.QWEN |
206 | 210 |
|
207 | 211 | raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
208 | 212 |
|
@@ -832,6 +836,131 @@ def set_gguf_parameters(self):
|
832 | 836 | self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
|
833 | 837 | self.gguf_writer.add_layer_norm_eps(1e-5)
|
834 | 838 |
|
| 839 | + |
| 840 | +class QwenModel(Model): |
| 841 | + @staticmethod |
| 842 | + def token_bytes_to_string(b): |
| 843 | + from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode |
| 844 | + byte_encoder = bytes_to_unicode() |
| 845 | + return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) |
| 846 | + |
| 847 | + @staticmethod |
| 848 | + def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None) -> list[bytes]: |
| 849 | + parts = [bytes([b]) for b in token] |
| 850 | + while True: |
| 851 | + min_idx = None |
| 852 | + min_rank = None |
| 853 | + for i, pair in enumerate(zip(parts[:-1], parts[1:])): |
| 854 | + rank = mergeable_ranks.get(pair[0] + pair[1]) |
| 855 | + if rank is not None and (min_rank is None or rank < min_rank): |
| 856 | + min_idx = i |
| 857 | + min_rank = rank |
| 858 | + if min_rank is None or (max_rank is not None and min_rank >= max_rank): |
| 859 | + break |
| 860 | + assert min_idx is not None |
| 861 | + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] |
| 862 | + return parts |
| 863 | + |
| 864 | + def set_vocab(self): |
| 865 | + dir_model = self.dir_model |
| 866 | + hparams = self.hparams |
| 867 | + tokens: list[bytearray] = [] |
| 868 | + toktypes: list[int] = [] |
| 869 | + |
| 870 | + from transformers import AutoTokenizer # type: ignore[attr-defined] |
| 871 | + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) |
| 872 | + vocab_size = hparams["vocab_size"] |
| 873 | + assert max(tokenizer.get_vocab().values()) < vocab_size |
| 874 | + |
| 875 | + merges = [] |
| 876 | + vocab = {} |
| 877 | + mergeable_ranks = tokenizer.mergeable_ranks |
| 878 | + for token, rank in mergeable_ranks.items(): |
| 879 | + vocab[self.token_bytes_to_string(token)] = rank |
| 880 | + if len(token) == 1: |
| 881 | + continue |
| 882 | + merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) |
| 883 | + assert len(merged) == 2 |
| 884 | + merges.append(' '.join(map(self.token_bytes_to_string, merged))) |
| 885 | + |
| 886 | + reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in vocab.items()} |
| 887 | + added_vocab = tokenizer.special_tokens |
| 888 | + |
| 889 | + for i in range(vocab_size): |
| 890 | + if i not in reverse_vocab: |
| 891 | + pad_token = f"[PAD{i}]".encode("utf-8") |
| 892 | + tokens.append(bytearray(pad_token)) |
| 893 | + toktypes.append(gguf.TokenType.USER_DEFINED) |
| 894 | + elif reverse_vocab[i] in added_vocab: |
| 895 | + tokens.append(reverse_vocab[i]) |
| 896 | + toktypes.append(gguf.TokenType.CONTROL) |
| 897 | + else: |
| 898 | + tokens.append(reverse_vocab[i]) |
| 899 | + toktypes.append(gguf.TokenType.NORMAL) |
| 900 | + |
| 901 | + self.gguf_writer.add_tokenizer_model("gpt2") |
| 902 | + self.gguf_writer.add_token_list(tokens) |
| 903 | + self.gguf_writer.add_token_types(toktypes) |
| 904 | + |
| 905 | + special_vocab = gguf.SpecialVocab(dir_model, load_merges=False) |
| 906 | + special_vocab.merges = merges |
| 907 | + special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"]) |
| 908 | + special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"]) |
| 909 | + special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"]) |
| 910 | + special_vocab.add_to_gguf(self.gguf_writer) |
| 911 | + |
| 912 | + def set_gguf_parameters(self): |
| 913 | + self.gguf_writer.add_name("Qwen") |
| 914 | + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) |
| 915 | + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) |
| 916 | + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) |
| 917 | + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) |
| 918 | + self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) |
| 919 | + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) |
| 920 | + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) |
| 921 | + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) |
| 922 | + |
| 923 | + def write_tensors(self): |
| 924 | + block_count = self.hparams["num_hidden_layers"] |
| 925 | + model_kv = dict(self.get_tensors()) |
| 926 | + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) |
| 927 | + for name, data_torch in model_kv.items(): |
| 928 | + # we don't need these |
| 929 | + if name.endswith(".rotary_emb.inv_freq"): |
| 930 | + continue |
| 931 | + |
| 932 | + old_dtype = data_torch.dtype |
| 933 | + |
| 934 | + # convert any unsupported data types to float32 |
| 935 | + if data_torch.dtype not in (torch.float16, torch.float32): |
| 936 | + data_torch = data_torch.to(torch.float32) |
| 937 | + |
| 938 | + data = data_torch.squeeze().numpy() |
| 939 | + |
| 940 | + # map tensor names |
| 941 | + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) |
| 942 | + if new_name is None: |
| 943 | + print(f"Can not map tensor {name!r}") |
| 944 | + sys.exit() |
| 945 | + |
| 946 | + n_dims = len(data.shape) |
| 947 | + data_dtype = data.dtype |
| 948 | + |
| 949 | + # if f32 desired, convert any float16 to float32 |
| 950 | + if self.ftype == 0 and data_dtype == np.float16: |
| 951 | + data = data.astype(np.float32) |
| 952 | + |
| 953 | + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 |
| 954 | + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: |
| 955 | + data = data.astype(np.float32) |
| 956 | + |
| 957 | + # if f16 desired, convert any float32 2-dim weight tensors to float16 |
| 958 | + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: |
| 959 | + data = data.astype(np.float16) |
| 960 | + |
| 961 | + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") |
| 962 | + self.gguf_writer.add_tensor(new_name, data) |
| 963 | + |
835 | 964 | ###### CONVERSION LOGIC ######
|
836 | 965 |
|
837 | 966 |
|
|
0 commit comments