Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7cd7a8d

Browse files
ggerganovklosaxslaren
authored andcommittedAug 29, 2023
llm : add Falcon support (ggml-org#2717)
* llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (ggml-org#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <[email protected]> Co-authored-by: slaren <[email protected]>
1 parent e782ea0 commit 7cd7a8d

18 files changed

+1635
-707
lines changed
 

‎common/common.cpp

-32
Original file line numberDiff line numberDiff line change
@@ -744,35 +744,3 @@ std::string llama_token_to_str(const struct llama_context * ctx, llama_token tok
744744

745745
return std::string(result.data(), result.size());
746746
}
747-
748-
std::vector<llama_token> llama_tokenize_bpe(
749-
struct llama_context * ctx,
750-
const std::string & text,
751-
bool add_bos) {
752-
int n_tokens = text.length() + add_bos;
753-
std::vector<llama_token> result(n_tokens);
754-
n_tokens = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
755-
if (n_tokens < 0) {
756-
result.resize(-n_tokens);
757-
int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
758-
GGML_ASSERT(check == -n_tokens);
759-
} else {
760-
result.resize(n_tokens);
761-
}
762-
return result;
763-
}
764-
765-
std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) {
766-
std::vector<char> result(8, 0);
767-
const int n_tokens = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
768-
if (n_tokens < 0) {
769-
result.resize(-n_tokens);
770-
const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
771-
GGML_ASSERT(check == -n_tokens);
772-
} else {
773-
result.resize(n_tokens);
774-
}
775-
776-
return std::string(result.data(), result.size());
777-
}
778-

‎common/common.h

-9
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,6 @@ std::vector<llama_token> llama_tokenize(
120120
const std::string & text,
121121
bool add_bos);
122122

123-
std::vector<llama_token> llama_tokenize_bpe(
124-
struct llama_context * ctx,
125-
const std::string & text,
126-
bool add_bos);
127-
128123
std::string llama_token_to_str(
129124
const struct llama_context * ctx,
130125
llama_token token);
131-
132-
std::string llama_token_to_str_bpe(
133-
const struct llama_context * ctx,
134-
llama_token token);

‎convert-falcon-hf-to-gguf.py

+25-30
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,26 @@ def count_model_parts(dir_model: str) -> int:
9595

9696
block_count = hparams["n_layer"]
9797

98-
gguf_writer.add_name(last_dir)
98+
gguf_writer.add_name("Falcon")
9999
gguf_writer.add_context_length(2048) # not in config.json
100100
gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
101101
gguf_writer.add_embedding_length(hparams["hidden_size"])
102102
gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
103103
gguf_writer.add_block_count(block_count)
104104
gguf_writer.add_head_count(hparams["n_head"])
105-
if "n_head_kv" in hparams: gguf_writer.add_head_count_kv(hparams["n_head_kv"])
105+
if "n_head_kv" in hparams:
106+
gguf_writer.add_head_count_kv(hparams["n_head_kv"])
107+
else:
108+
gguf_writer.add_head_count_kv(1)
106109
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
107110

108111
# TOKENIZATION
109112

110113
print("gguf: get tokenizer metadata")
111114

112115
tokens: List[str] = []
116+
scores: List[float] = []
117+
toktypes: List[int] = []
113118
merges: List[str] = []
114119

115120

@@ -153,50 +158,40 @@ def count_model_parts(dir_model: str) -> int:
153158
text = bytearray(pad_token)
154159

155160
tokens.append(text)
161+
scores.append(0.0) # dymmy
162+
toktypes.append(gguf.TokenType.NORMAL) # dummy
156163

157164
gguf_writer.add_token_list(tokens)
165+
gguf_writer.add_token_scores(scores)
166+
gguf_writer.add_token_types(toktypes)
158167

159-
if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file():
160-
print("gguf: get special token ids")
161-
162-
with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
163-
tokenizer_config = json.load(f)
168+
print("gguf: get special token ids")
169+
# Look for special tokens in config.json
164170

165-
# find special token ids
171+
if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
172+
gguf_writer.add_bos_token_id(hparams["bos_token_id"])
166173

167-
if "bos_token" in tokenizer_config:
168-
for key in tokenizer_json["added_tokens"]:
169-
if key["content"] == tokenizer_config["bos_token"]:
170-
gguf_writer.add_bos_token_id(key["id"])
174+
if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
175+
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
171176

172-
if "eos_token" in tokenizer_config:
173-
for key in tokenizer_json["added_tokens"]:
174-
if key["content"] == tokenizer_config["eos_token"]:
175-
gguf_writer.add_eos_token_id(key["id"])
177+
if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
178+
gguf_writer.add_unk_token_id(hparams["unk_token_id"])
176179

177-
if "unk_token" in tokenizer_config:
178-
for key in tokenizer_json["added_tokens"]:
179-
if key["content"] == tokenizer_config["unk_token"]:
180-
gguf_writer.add_unk_token_id(key["id"])
180+
if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
181+
gguf_writer.add_sep_token_id(hparams["sep_token_id"])
181182

182-
if "sep_token" in tokenizer_config:
183-
for key in tokenizer_json["added_tokens"]:
184-
if key["content"] == tokenizer_config["sep_token"]:
185-
gguf_writer.add_sep_token_id(key["id"])
186-
187-
if "pad_token" in tokenizer_config:
188-
for key in tokenizer_json["added_tokens"]:
189-
if key["content"] == tokenizer_config["pad_token"]:
190-
gguf_writer.add_pad_token_id(key["id"])
183+
if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
184+
gguf_writer.add_pad_token_id(hparams["pad_token_id"])
191185

192186

193187
# TENSORS
194188

195189
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
196190

197191
# params for qkv transform
198-
n_head = hparams["n_head"]
192+
n_head = hparams["n_head"]
199193
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
194+
200195
head_dim = hparams["hidden_size"] // n_head
201196

202197
# tensor info

‎convert.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,11 @@ def __init__(self, fname_out: Path) -> None:
733733
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
734734

735735
def add_meta_arch(self, params: Params) -> None:
736-
self.gguf.add_name ("LLaMA")
736+
ver = None
737+
if (params.n_ctx == 4096):
738+
ver = "v2"
739+
740+
self.gguf.add_name ("LLaMA" if ver == None else "LLaMA " + ver)
737741
self.gguf.add_context_length (params.n_ctx)
738742
self.gguf.add_embedding_length (params.n_embd)
739743
self.gguf.add_block_count (params.n_layer)

‎examples/main/main.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static bool is_interacting = false;
4343
void sigint_handler(int signo) {
4444
if (signo == SIGINT) {
4545
if (!is_interacting) {
46-
is_interacting=true;
46+
is_interacting = true;
4747
} else {
4848
console::cleanup();
4949
printf("\n");
@@ -189,10 +189,12 @@ int main(int argc, char ** argv) {
189189
}
190190
}
191191

192+
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
193+
192194
// tokenize the prompt
193195
std::vector<llama_token> embd_inp;
194196
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
195-
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
197+
embd_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
196198
} else {
197199
embd_inp = session_tokens;
198200
}
@@ -208,9 +210,9 @@ int main(int argc, char ** argv) {
208210
int original_prompt_len = 0;
209211
if (ctx_guidance) {
210212
params.cfg_negative_prompt.insert(0, 1, ' ');
211-
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
213+
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, is_spm);
212214

213-
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
215+
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
214216
original_prompt_len = original_inp.size();
215217
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
216218
}
@@ -257,8 +259,8 @@ int main(int argc, char ** argv) {
257259
}
258260

259261
// prefix & suffix for instruct mode
260-
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
261-
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
262+
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", is_spm);
263+
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
262264

263265
// in instruct mode, we inject a prefix and a suffix to each input by the user
264266
if (params.instruct) {

‎examples/perplexity/perplexity.cpp

+21-10
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ std::vector<float> softmax(const std::vector<float>& logits) {
2828
}
2929

3030
void perplexity_v2(llama_context * ctx, const gpt_params & params) {
31-
3231
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
3332
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
3433
// Output: `perplexity: 13.5106 [114/114]`
@@ -38,7 +37,13 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
3837
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
3938
return;
4039
}
41-
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
40+
41+
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
42+
const bool add_bos = is_spm;
43+
44+
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
45+
46+
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
4247

4348
const int calc_chunk = params.n_ctx;
4449

@@ -86,7 +91,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
8691
const auto token_org = tokens[batch_start];
8792

8893
// add BOS token for the first batch of each chunk
89-
if (j == 0) {
94+
if (add_bos && j == 0) {
9095
tokens[batch_start] = llama_token_bos(ctx);
9196
}
9297

@@ -136,7 +141,6 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
136141
}
137142

138143
void perplexity(llama_context * ctx, const gpt_params & params) {
139-
140144
if (params.ppl_stride > 0) {
141145
perplexity_v2(ctx, params);
142146
return;
@@ -146,7 +150,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
146150
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
147151
// Output: `perplexity: 13.5106 [114/114]`
148152
// BOS tokens will be added for each chunk before eval
149-
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
153+
154+
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
155+
const bool add_bos = is_spm;
156+
157+
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
158+
159+
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
150160

151161
const int n_chunk_max = tokens.size() / params.n_ctx;
152162

@@ -177,7 +187,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
177187
const auto token_org = tokens[batch_start];
178188

179189
// add BOS token for the first batch of each chunk
180-
if (j == 0) {
190+
if (add_bos && j == 0) {
181191
tokens[batch_start] = llama_token_bos(ctx);
182192
}
183193

@@ -295,8 +305,10 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
295305
size_t hs_task_count = prompt_lines.size()/6;
296306
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
297307

308+
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
309+
298310
// This is needed as usual for LLaMA models
299-
bool prepend_bos = true;
311+
const bool add_bos = is_spm;
300312

301313
// Number of tasks to use when computing the score
302314
if ( params.hellaswag_tasks < hs_task_count ) {
@@ -352,14 +364,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
352364
std::vector<float> tok_logits(n_vocab);
353365

354366
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
355-
356367
// Tokenize the context to count tokens
357-
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
368+
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
358369
size_t context_size = context_embd.size();
359370

360371
// Do the 1st ending
361372
// In this case we include the context when evaluating
362-
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
373+
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
363374
auto query_size = query_embd.size();
364375
//printf("First query: %d\n",(int)query_size);
365376

‎ggml-alloc.c

+2-2
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
238238
alloc->n_free_blocks++;
239239
}
240240

241-
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) {
241+
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
242242
int pos = 0;
243243
for (int i = 0; i < n; i++) {
244244
if (list[i] != -1) {
@@ -547,7 +547,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
547547
struct ggml_tensor * view_src = get_view_source(parent);
548548
struct hash_node * view_src_hn = hash_get(ht, view_src);
549549
view_src_hn->n_views -= 1;
550-
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
550+
AT_PRINTF("view_src %s\n", view_src->name);
551551
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
552552
ggml_allocator_free_tensor(alloc, view_src);
553553
}

‎ggml-alloc.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
1212

1313
// tell the allocator to parse nodes following the order described in the list
1414
// you should call this if your graph are optimized to execute out-of-order
15-
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);
15+
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
1616

1717
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
1818
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);

‎ggml-cuda.cu

+28-1
Original file line numberDiff line numberDiff line change
@@ -3907,6 +3907,29 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
39073907
dst[i + 1] = x0*sin_theta + x1*cos_theta;
39083908
}
39093909

3910+
// TODO: this implementation is wrong!
3911+
//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
3912+
// const float p_delta, const int p_delta_rows, const float theta_scale) {
3913+
// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
3914+
//
3915+
// if (col >= ncols) {
3916+
// return;
3917+
// }
3918+
//
3919+
// const int row = blockDim.x*blockIdx.x + threadIdx.x;
3920+
// const int i = row*ncols + col/2;
3921+
//
3922+
// const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
3923+
// const float sin_theta = sinf(theta);
3924+
// const float cos_theta = cosf(theta);
3925+
//
3926+
// const float x0 = x[i + 0];
3927+
// const float x1 = x[i + ncols/2];
3928+
//
3929+
// dst[i + 0] = x0*cos_theta - x1*sin_theta;
3930+
// dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
3931+
//}
3932+
39103933
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
39113934
const int col = blockDim.x*blockIdx.x + threadIdx.x;
39123935
const int half_n_dims = ncols/4;
@@ -5515,14 +5538,18 @@ inline void ggml_cuda_op_rope(
55155538

55165539
const float theta_scale = powf(freq_base, -2.0f/n_dims);
55175540

5518-
const bool is_glm = mode & 4;
5541+
const bool is_neox = mode & 2;
5542+
const bool is_glm = mode & 4;
55195543

55205544
// compute
55215545
if (is_glm) {
55225546
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
55235547
const float id_p = min(p, n_ctx - 2.f);
55245548
const float block_p = max(p - (n_ctx - 2.f), 0.f);
55255549
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
5550+
} else if (is_neox) {
5551+
GGML_ASSERT(false && "RoPE NeoX not implemented yet");
5552+
#pragma message("TODO: implement RoPE NeoX for CUDA")
55265553
} else {
55275554
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
55285555
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);

0 commit comments

Comments
 (0)
Please sign in to comment.