@@ -6,16 +6,34 @@ enum LlamaError: Error {
6
6
case couldNotInitializeContext
7
7
}
8
8
9
+ func llama_batch_clear( _ batch: inout llama_batch ) {
10
+ batch. n_tokens = 0
11
+ }
12
+
13
+ func llama_batch_add( _ batch: inout llama_batch , _ id: llama_token , _ pos: llama_pos , _ seq_ids: [ llama_seq_id ] , _ logits: Bool ) {
14
+ batch. token [ Int ( batch. n_tokens) ] = id
15
+ batch. pos [ Int ( batch. n_tokens) ] = pos
16
+ batch. n_seq_id [ Int ( batch. n_tokens) ] = Int32 ( seq_ids. count)
17
+ for i in 0 ..< seq_ids. count {
18
+ batch. seq_id [ Int ( batch. n_tokens) ] ![ Int ( i) ] = seq_ids [ i]
19
+ }
20
+ batch. logits [ Int ( batch. n_tokens) ] = logits ? 1 : 0
21
+
22
+ batch. n_tokens += 1
23
+ }
24
+
9
25
actor LlamaContext {
10
26
private var model : OpaquePointer
11
27
private var context : OpaquePointer
12
28
private var batch : llama_batch
13
29
private var tokens_list : [ llama_token ]
30
+
14
31
/// This variable is used to store temporarily invalid cchars
15
32
private var temporary_invalid_cchars : [ CChar ]
16
33
17
- var n_len : Int32 = 512
34
+ var n_len : Int32 = 64
18
35
var n_cur : Int32 = 0
36
+
19
37
var n_decode : Int32 = 0
20
38
21
39
init ( model: OpaquePointer , context: OpaquePointer ) {
@@ -27,25 +45,34 @@ actor LlamaContext {
27
45
}
28
46
29
47
deinit {
48
+ llama_batch_free ( batch)
30
49
llama_free ( context)
31
50
llama_free_model ( model)
32
51
llama_backend_free ( )
33
52
}
34
53
35
- static func createContext ( path: String ) throws -> LlamaContext {
54
+ static func create_context ( path: String ) throws -> LlamaContext {
36
55
llama_backend_init ( false )
37
- let model_params = llama_model_default_params ( )
56
+ var model_params = llama_model_default_params ( )
38
57
58
+ #if targetEnvironment(simulator)
59
+ model_params. n_gpu_layers = 0
60
+ print ( " Running on simulator, force use n_gpu_layers = 0 " )
61
+ #endif
39
62
let model = llama_load_model_from_file ( path, model_params)
40
63
guard let model else {
41
64
print ( " Could not load model at \( path) " )
42
65
throw LlamaError . couldNotInitializeContext
43
66
}
67
+
68
+ let n_threads = max ( 1 , min ( 8 , ProcessInfo . processInfo. processorCount - 2 ) )
69
+ print ( " Using \( n_threads) threads " )
70
+
44
71
var ctx_params = llama_context_default_params ( )
45
- ctx_params. seed = 1234
72
+ ctx_params. seed = 1234
46
73
ctx_params. n_ctx = 2048
47
- ctx_params. n_threads = 8
48
- ctx_params. n_threads_batch = 8
74
+ ctx_params. n_threads = UInt32 ( n_threads )
75
+ ctx_params. n_threads_batch = UInt32 ( n_threads )
49
76
50
77
let context = llama_new_context_with_model ( model, ctx_params)
51
78
guard let context else {
@@ -56,6 +83,26 @@ actor LlamaContext {
56
83
return LlamaContext ( model: model, context: context)
57
84
}
58
85
86
+ func model_info( ) -> String {
87
+ let result = UnsafeMutablePointer< Int8> . allocate( capacity: 256 )
88
+ result. initialize ( repeating: Int8 ( 0 ) , count: 256 )
89
+ defer {
90
+ result. deallocate ( )
91
+ }
92
+
93
+ // TODO: this is probably very stupid way to get the string from C
94
+
95
+ let nChars = llama_model_desc ( model, result, 256 )
96
+ let bufferPointer = UnsafeBufferPointer ( start: result, count: Int ( nChars) )
97
+
98
+ var SwiftString = " "
99
+ for char in bufferPointer {
100
+ SwiftString . append ( Character ( UnicodeScalar ( UInt8 ( char) ) ) )
101
+ }
102
+
103
+ return SwiftString
104
+ }
105
+
59
106
func get_n_tokens( ) -> Int32 {
60
107
return batch. n_tokens;
61
108
}
@@ -79,16 +126,11 @@ actor LlamaContext {
79
126
print ( String ( cString: token_to_piece ( token: id) + [ 0 ] ) )
80
127
}
81
128
82
- // batch = llama_batch_init(512, 0) // done in init()
83
- batch. n_tokens = Int32 ( tokens_list. count)
129
+ llama_batch_clear ( & batch)
84
130
85
- for i1 in 0 ..< batch . n_tokens {
131
+ for i1 in 0 ..< tokens_list . count {
86
132
let i = Int ( i1)
87
- batch. token [ i] = tokens_list [ i]
88
- batch. pos [ i] = i1
89
- batch. n_seq_id [ Int ( i) ] = 1
90
- batch. seq_id [ Int ( i) ] ![ 0 ] = 0
91
- batch. logits [ i] = 0
133
+ llama_batch_add ( & batch, tokens_list [ i] , Int32 ( i) , [ 0 ] , false )
92
134
}
93
135
batch. logits [ Int ( batch. n_tokens) - 1 ] = 1 // true
94
136
@@ -141,18 +183,11 @@ actor LlamaContext {
141
183
print ( new_token_str)
142
184
// tokens_list.append(new_token_id)
143
185
144
- batch. n_tokens = 0
145
-
146
- batch. token [ Int ( batch. n_tokens) ] = new_token_id
147
- batch. pos [ Int ( batch. n_tokens) ] = n_cur
148
- batch. n_seq_id [ Int ( batch. n_tokens) ] = 1
149
- batch. seq_id [ Int ( batch. n_tokens) ] ![ 0 ] = 0
150
- batch. logits [ Int ( batch. n_tokens) ] = 1 // true
151
- batch. n_tokens += 1
186
+ llama_batch_clear ( & batch)
187
+ llama_batch_add ( & batch, new_token_id, n_cur, [ 0 ] , true )
152
188
153
189
n_decode += 1
154
-
155
- n_cur += 1
190
+ n_cur += 1
156
191
157
192
if llama_decode ( context, batch) != 0 {
158
193
print ( " failed to evaluate llama! " )
@@ -161,14 +196,111 @@ actor LlamaContext {
161
196
return new_token_str
162
197
}
163
198
199
+ func bench( pp: Int , tg: Int , pl: Int , nr: Int = 1 ) -> String {
200
+ var pp_avg : Double = 0
201
+ var tg_avg : Double = 0
202
+
203
+ var pp_std : Double = 0
204
+ var tg_std : Double = 0
205
+
206
+ for r in 0 ..< nr {
207
+ // bench prompt processing
208
+
209
+ llama_batch_clear ( & batch)
210
+
211
+ let n_tokens = pp
212
+
213
+ for i in 0 ..< n_tokens {
214
+ llama_batch_add ( & batch, 0 , Int32 ( i) , [ 0 ] , false )
215
+ }
216
+ batch. logits [ Int ( batch. n_tokens) - 1 ] = 1 // true
217
+
218
+ llama_kv_cache_clear ( context)
219
+
220
+ let t_pp_start = ggml_time_us ( )
221
+
222
+ if llama_decode ( context, batch) != 0 {
223
+ print ( " llama_decode() failed during prompt " )
224
+ }
225
+
226
+ let t_pp_end = ggml_time_us ( )
227
+
228
+ // bench text generation
229
+
230
+ llama_kv_cache_clear ( context)
231
+
232
+ let t_tg_start = ggml_time_us ( )
233
+
234
+ for i in 0 ..< tg {
235
+ llama_batch_clear ( & batch)
236
+
237
+ for j in 0 ..< pl {
238
+ llama_batch_add ( & batch, 0 , Int32 ( i) , [ Int32 ( j) ] , true )
239
+ }
240
+
241
+ if llama_decode ( context, batch) != 0 {
242
+ print ( " llama_decode() failed during text generation " )
243
+ }
244
+ }
245
+
246
+ let t_tg_end = ggml_time_us ( )
247
+
248
+ llama_kv_cache_clear ( context)
249
+
250
+ let t_pp = Double ( t_pp_end - t_pp_start) / 1000000.0
251
+ let t_tg = Double ( t_tg_end - t_tg_start) / 1000000.0
252
+
253
+ let speed_pp = Double ( pp) / t_pp
254
+ let speed_tg = Double ( pl*tg) / t_tg
255
+
256
+ pp_avg += speed_pp
257
+ tg_avg += speed_tg
258
+
259
+ pp_std += speed_pp * speed_pp
260
+ tg_std += speed_tg * speed_tg
261
+
262
+ print ( " pp \( speed_pp) t/s, tg \( speed_tg) t/s " )
263
+ }
264
+
265
+ pp_avg /= Double ( nr)
266
+ tg_avg /= Double ( nr)
267
+
268
+ if nr > 1 {
269
+ pp_std = sqrt ( pp_std / Double( nr - 1 ) - pp_avg * pp_avg * Double( nr) / Double( nr - 1 ) )
270
+ tg_std = sqrt ( tg_std / Double( nr - 1 ) - tg_avg * tg_avg * Double( nr) / Double( nr - 1 ) )
271
+ } else {
272
+ pp_std = 0
273
+ tg_std = 0
274
+ }
275
+
276
+ let model_desc = model_info ( ) ;
277
+ let model_size = String ( format: " %.2f GiB " , Double ( llama_model_size ( model) ) / 1024.0 / 1024.0 / 1024.0 ) ;
278
+ let model_n_params = String ( format: " %.2f B " , Double ( llama_model_n_params ( model) ) / 1e9 ) ;
279
+ let backend = " Metal " ;
280
+ let pp_avg_str = String ( format: " %.2f " , pp_avg) ;
281
+ let tg_avg_str = String ( format: " %.2f " , tg_avg) ;
282
+ let pp_std_str = String ( format: " %.2f " , pp_std) ;
283
+ let tg_std_str = String ( format: " %.2f " , tg_std) ;
284
+
285
+ var result = " "
286
+
287
+ result += String ( " | model | size | params | backend | test | t/s | \n " )
288
+ result += String ( " | --- | --- | --- | --- | --- | --- | \n " )
289
+ result += String ( " | \( model_desc) | \( model_size) | \( model_n_params) | \( backend) | pp \( pp) | \( pp_avg_str) ± \( pp_std_str) | \n " )
290
+ result += String ( " | \( model_desc) | \( model_size) | \( model_n_params) | \( backend) | tg \( tg) | \( tg_avg_str) ± \( tg_std_str) | \n " )
291
+
292
+ return result;
293
+ }
294
+
164
295
func clear( ) {
165
296
tokens_list. removeAll ( )
166
297
temporary_invalid_cchars. removeAll ( )
298
+ llama_kv_cache_clear ( context)
167
299
}
168
300
169
301
private func tokenize( text: String , add_bos: Bool ) -> [ llama_token ] {
170
302
let utf8Count = text. utf8. count
171
- let n_tokens = utf8Count + ( add_bos ? 1 : 0 )
303
+ let n_tokens = utf8Count + ( add_bos ? 1 : 0 ) + 1
172
304
let tokens = UnsafeMutablePointer< llama_token> . allocate( capacity: n_tokens)
173
305
let tokenCount = llama_tokenize ( model, text, Int32 ( utf8Count) , tokens, Int32 ( n_tokens) , add_bos, false )
174
306
0 commit comments