Skip to content

Commit 40b0a47

Browse files
committed
cuda : reuse ggml-common (cont)
ggml-ci
1 parent 3656c76 commit 40b0a47

File tree

3 files changed

+158
-117
lines changed

3 files changed

+158
-117
lines changed

ggml-common.h

+111-71
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,29 @@
33
#if defined(GGML_COMMON_DECL_C)
44
#include <stdint.h>
55

6-
typedef uint16_t ggml_fp16_t;
6+
typedef uint16_t ggml_half;
7+
typedef uint32_t ggml_half2;
8+
9+
#define GGML_COMMON_AGGR
710

811
#define GGML_COMMON_DECL
912
#elif defined(GGML_COMMON_DECL_METAL)
1013
#include <metal_stdlib>
1114

12-
typedef half ggml_fp16_t;
15+
typedef half ggml_half;
16+
typedef half2 ggml_half2;
17+
18+
#define GGML_COMMON_AGGR
1319

1420
#define GGML_COMMON_DECL
1521
#elif defined(GGML_COMMON_DECL_CUDA)
22+
#include <cuda_fp16.h>
1623
#include <cstdint>
1724

18-
typedef half ggml_fp16_t;
25+
typedef half ggml_half;
26+
typedef half2 ggml_half2;
27+
28+
#define GGML_COMMON_AGGR data
1929

2030
#define GGML_COMMON_DECL
2131
#endif
@@ -40,60 +50,75 @@ typedef half ggml_fp16_t;
4050
#define QI4_0 (QK4_0 / (4 * QR4_0))
4151
#define QR4_0 2
4252
typedef struct {
43-
ggml_fp16_t d; // delta
44-
uint8_t qs[QK4_0 / 2]; // nibbles / quants
53+
ggml_half d; // delta
54+
uint8_t qs[QK4_0 / 2]; // nibbles / quants
4555
} block_q4_0;
46-
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
56+
static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 block size/padding");
4757

4858
#define QK4_1 32
4959
#define QI4_1 (QK4_1 / (4 * QR4_1))
5060
#define QR4_1 2
5161
typedef struct {
52-
ggml_fp16_t d; // delta
53-
ggml_fp16_t m; // min
54-
uint8_t qs[QK4_1 / 2]; // nibbles / quants
62+
union {
63+
struct {
64+
ggml_half d; // delta
65+
ggml_half m; // min
66+
} GGML_COMMON_AGGR;
67+
ggml_half2 dm;
68+
};
69+
uint8_t qs[QK4_1 / 2]; // nibbles / quants
5570
} block_q4_1;
56-
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
71+
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
5772

5873
#define QK5_0 32
5974
#define QI5_0 (QK5_0 / (4 * QR5_0))
6075
#define QR5_0 2
6176
typedef struct {
62-
ggml_fp16_t d; // delta
77+
ggml_half d; // delta
6378
uint8_t qh[4]; // 5-th bit of quants
6479
uint8_t qs[QK5_0 / 2]; // nibbles / quants
6580
} block_q5_0;
66-
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
81+
static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
6782

6883
#define QK5_1 32
6984
#define QI5_1 (QK5_1 / (4 * QR5_1))
7085
#define QR5_1 2
7186
typedef struct {
72-
ggml_fp16_t d; // delta
73-
ggml_fp16_t m; // min
87+
union {
88+
struct {
89+
ggml_half d; // delta
90+
ggml_half m; // min
91+
} GGML_COMMON_AGGR;
92+
ggml_half2 dm;
93+
};
7494
uint8_t qh[4]; // 5-th bit of quants
7595
uint8_t qs[QK5_1 / 2]; // nibbles / quants
7696
} block_q5_1;
77-
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
97+
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
7898

7999
#define QK8_0 32
80100
#define QI8_0 (QK8_0 / (4 * QR8_0))
81101
#define QR8_0 1
82102
typedef struct {
83-
ggml_fp16_t d; // delta
84-
int8_t qs[QK8_0]; // quants
103+
ggml_half d; // delta
104+
int8_t qs[QK8_0]; // quants
85105
} block_q8_0;
86-
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
106+
static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block size/padding");
87107

88108
#define QK8_1 32
89109
#define QI8_1 (QK8_1 / (4 * QR8_1))
90110
#define QR8_1 1
91111
typedef struct {
92-
float d; // delta
93-
float s; // d * sum(qs[i])
94-
int8_t qs[QK8_1]; // quants
112+
union {
113+
struct {
114+
ggml_half xxxd; // delta
115+
ggml_half xxxs; // d * sum(qs[i])
116+
} GGML_COMMON_AGGR;
117+
ggml_half2 ds;
118+
};
119+
int8_t qs[QK8_1]; // quants
95120
} block_q8_1;
96-
static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
121+
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
97122

98123
//
99124
// Super-block quantization structures
@@ -117,10 +142,15 @@ static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block s
117142
typedef struct {
118143
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
119144
uint8_t qs[QK_K/4]; // quants
120-
ggml_fp16_t d; // super-block scale for quantized scales
121-
ggml_fp16_t dmin; // super-block scale for quantized mins
145+
union {
146+
struct {
147+
ggml_half d; // super-block scale for quantized scales
148+
ggml_half dmin; // super-block scale for quantized mins
149+
} GGML_COMMON_AGGR;
150+
ggml_half2 dm;
151+
};
122152
} block_q2_K;
123-
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
153+
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
124154

125155
// 3-bit quantization
126156
// weight is represented as x = a * q
@@ -130,20 +160,20 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w
130160
#define QR3_K 4
131161
#ifdef GGML_QKK_64
132162
typedef struct {
133-
uint8_t hmask[QK_K/8]; // quants - high bit
134-
uint8_t qs[QK_K/4]; // quants - low 2 bits
163+
uint8_t hmask[QK_K/8]; // quants - high bit
164+
uint8_t qs[QK_K/4]; // quants - low 2 bits
135165
uint8_t scales[2];
136-
ggml_fp16_t d; // super-block scale
166+
ggml_half d; // super-block scale
137167
} block_q3_K;
138-
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
168+
static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
139169
#else
140170
typedef struct {
141-
uint8_t hmask[QK_K/8]; // quants - high bit
142-
uint8_t qs[QK_K/4]; // quants - low 2 bits
143-
uint8_t scales[12]; // scales, quantized with 6 bits
144-
ggml_fp16_t d; // super-block scale
171+
uint8_t hmask[QK_K/8]; // quants - high bit
172+
uint8_t qs[QK_K/4]; // quants - low 2 bits
173+
uint8_t scales[12]; // scales, quantized with 6 bits
174+
ggml_half d; // super-block scale
145175
} block_q3_K;
146-
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
176+
static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
147177
#endif
148178

149179
// 4-bit quantization
@@ -154,19 +184,24 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 +
154184
#define QR4_K 2
155185
#ifdef GGML_QKK_64
156186
typedef struct {
157-
ggml_fp16_t d[2]; // super-block scales/mins
158-
uint8_t scales[2]; // 4-bit block scales/mins
159-
uint8_t qs[QK_K/2]; // 4--bit quants
187+
ggml_half d[2]; // super-block scales/mins
188+
uint8_t scales[2]; // 4-bit block scales/mins
189+
uint8_t qs[QK_K/2]; // 4--bit quants
160190
} block_q4_K;
161-
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
191+
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + QK_K/2 + 2, "wrong q4_K block size/padding");
162192
#else
163193
typedef struct {
164-
ggml_fp16_t d; // super-block scale for quantized scales
165-
ggml_fp16_t dmin; // super-block scale for quantized mins
194+
union {
195+
struct {
196+
ggml_half d; // super-block scale for quantized scales
197+
ggml_half dmin; // super-block scale for quantized mins
198+
} GGML_COMMON_AGGR;
199+
ggml_half2 dm;
200+
};
166201
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
167-
uint8_t qs[QK_K/2]; // 4--bit quants
202+
uint8_t qs[QK_K/2]; // 4--bit quants
168203
} block_q4_K;
169-
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
204+
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
170205
#endif
171206

172207
// 5-bit quantization
@@ -177,21 +212,26 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/
177212
#define QR5_K 2
178213
#ifdef GGML_QKK_64
179214
typedef struct {
180-
ggml_fp16_t d; // super-block scale
181-
int8_t scales[QK_K/16]; // 8-bit block scales
182-
uint8_t qh[QK_K/8]; // quants, high bit
183-
uint8_t qs[QK_K/2]; // quants, low 4 bits
215+
ggml_half d; // super-block scale
216+
int8_t scales[QK_K/16]; // 8-bit block scales
217+
uint8_t qh[QK_K/8]; // quants, high bit
218+
uint8_t qs[QK_K/2]; // quants, low 4 bits
184219
} block_q5_K;
185-
static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
220+
static_assert(sizeof(block_q5_K) == sizeof(ggml_half) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
186221
#else
187222
typedef struct {
188-
ggml_fp16_t d; // super-block scale for quantized scales
189-
ggml_fp16_t dmin; // super-block scale for quantized mins
190-
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
191-
uint8_t qh[QK_K/8]; // quants, high bit
192-
uint8_t qs[QK_K/2]; // quants, low 4 bits
223+
union {
224+
struct {
225+
ggml_half d; // super-block scale for quantized scales
226+
ggml_half dmin; // super-block scale for quantized mins
227+
} GGML_COMMON_AGGR;
228+
ggml_half2 dm;
229+
};
230+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
231+
uint8_t qh[QK_K/8]; // quants, high bit
232+
uint8_t qs[QK_K/2]; // quants, low 4 bits
193233
} block_q5_K;
194-
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
234+
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
195235
#endif
196236

197237
// 6-bit quantization
@@ -204,9 +244,9 @@ typedef struct {
204244
uint8_t ql[QK_K/2]; // quants, lower 4 bits
205245
uint8_t qh[QK_K/4]; // quants, upper 2 bits
206246
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
207-
ggml_fp16_t d; // super-block scale
247+
ggml_half d; // super-block scale
208248
} block_q6_K;
209-
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
249+
static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
210250

211251
// This is only used for intermediate quantization and dot products
212252
typedef struct {
@@ -222,42 +262,42 @@ static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_
222262
#define QI2_XXS (QK_K / (4*QR2_XXS))
223263
#define QR2_XXS 8
224264
typedef struct {
225-
ggml_fp16_t d;
265+
ggml_half d;
226266
uint16_t qs[QK_K/8];
227267
} block_iq2_xxs;
228-
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
268+
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
229269

230270
// 2.3125 bpw quants
231271
#define QI2_XS (QK_K / (4*QR2_XS))
232272
#define QR2_XS 8
233273
typedef struct {
234-
ggml_fp16_t d;
274+
ggml_half d;
235275
uint16_t qs[QK_K/8];
236276
uint8_t scales[QK_K/32];
237277
} block_iq2_xs;
238-
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
278+
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
239279

240280
// 2.5625 bpw quants
241281
#define QI2_S (QK_K / (4*QR2_S))
242282
#define QR2_S 8
243283
typedef struct {
244-
ggml_fp16_t d;
284+
ggml_half d;
245285
uint8_t qs[QK_K/4];
246286
uint8_t qh[QK_K/32];
247287
uint8_t scales[QK_K/32];
248288
} block_iq2_s;
249-
static_assert(sizeof(block_iq2_s) == sizeof(ggml_fp16_t) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
289+
static_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
250290

251291
// (Almost) "true" 3-bit quantization.
252292
// Due to the need to use blocks as per ggml design, it ends up using
253293
// 3.0625 bpw because of the 16-bit scale for each block of 256.
254294
#define QI3_XXS (QK_K / (4*QR3_XXS))
255295
#define QR3_XXS 8
256296
typedef struct {
257-
ggml_fp16_t d;
297+
ggml_half d;
258298
uint8_t qs[3*QK_K/8];
259299
} block_iq3_xxs;
260-
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
300+
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
261301

262302
// 3.4375 bpw
263303
#if QK_K == 64
@@ -268,32 +308,32 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
268308
#define QI3_XS (QK_K / (4*QR3_XS))
269309
#define QR3_XS 8
270310
typedef struct {
271-
ggml_fp16_t d;
311+
ggml_half d;
272312
uint8_t qs[QK_K/4];
273313
uint8_t qh[QK_K/32];
274314
uint8_t signs[QK_K/8];
275315
uint8_t scales[IQ3S_N_SCALE];
276316
} block_iq3_s;
277-
static_assert(sizeof(block_iq3_s) == sizeof(ggml_fp16_t) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
317+
static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
278318

279319
#define QI1_S (QK_K / (4*QR1_S))
280320
#define QR1_S 8
281321
typedef struct {
282-
ggml_fp16_t d;
322+
ggml_half d;
283323
uint8_t qs[QK_K/8];
284324
uint8_t scales[QK_K/16];
285325
} block_iq1_s;
286-
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
326+
static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
287327

288328
// Non-linear quants
289329
#define QK4_NL 32
290330
#define QI4_NL (QK4_NL / (4*QR4_NL))
291331
#define QR4_NL 2
292332
typedef struct {
293-
ggml_fp16_t d;
333+
ggml_half d;
294334
uint8_t qs[QK4_NL/2];
295335
} block_iq4_nl;
296-
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
336+
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding");
297337

298338
#if QK_K == 64
299339
#define block_iq4_xs block_iq4_nl
@@ -304,12 +344,12 @@ static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4
304344
#define QI4_XS (QK_K / (4*QR4_XS))
305345
#define QR4_XS 8
306346
typedef struct {
307-
ggml_fp16_t d;
347+
ggml_half d;
308348
uint16_t scales_h;
309349
uint8_t scales_l[QK_K/64];
310350
uint8_t qs[QK_K/2];
311351
} block_iq4_xs;
312-
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
352+
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
313353
#endif
314354

315355
#endif // GGML_COMMON_DECL

0 commit comments

Comments
 (0)