@@ -94,13 +94,6 @@ typedef struct {
94
94
} block_q4_1;
95
95
static_assert (sizeof (block_q4_1) == sizeof(float ) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
96
96
97
- #define QK4_2 16
98
- typedef struct {
99
- half d; // delta
100
- uint8_t qs[QK4_2 / 2 ]; // nibbles / quants
101
- } block_q4_2;
102
- static_assert (sizeof (block_q4_2) == sizeof(ggml_fp16_t ) + QK4_2 / 2, "wrong q4_2 block size/padding");
103
-
104
97
#define QK5_0 32
105
98
typedef struct {
106
99
half d; // delta
@@ -126,147 +119,102 @@ typedef struct {
126
119
static_assert (sizeof (block_q8_0) == sizeof(float ) + QK8_0, "wrong q8_0 block size/padding");
127
120
128
121
static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
122
+ static const int qk = QK4_0;
123
+
129
124
const block_q4_0 * x = (const block_q4_0 *) vx;
130
125
131
126
const int i = blockIdx .x ;
132
127
133
128
const float d = x[i].d ;
134
129
135
- const uint8_t * pp = x[i].qs ;
136
-
137
- for (int l = 0 ; l < QK4_0; l += 2 ) {
138
- const uint8_t vi = pp[l/2 ];
139
-
140
- const int8_t vi0 = vi & 0xf ;
141
- const int8_t vi1 = vi >> 4 ;
130
+ for (int j = 0 ; j < qk/2 ; ++j) {
131
+ const int x0 = (x[i].qs [j] & 0xf ) - 8 ;
132
+ const int x1 = (x[i].qs [j] >> 4 ) - 8 ;
142
133
143
- const float v0 = (vi0 - 8 )*d;
144
- const float v1 = (vi1 - 8 )*d;
145
-
146
- y[i*QK4_0 + l + 0 ] = v0;
147
- y[i*QK4_0 + l + 1 ] = v1;
134
+ y[i*qk + j + 0 ] = x0*d;
135
+ y[i*qk + j + qk/2 ] = x1*d;
148
136
}
149
137
}
150
138
151
139
static __global__ void dequantize_block_q4_1 (const void * vx, float * y) {
140
+ static const int qk = QK4_1;
141
+
152
142
const block_q4_1 * x = (const block_q4_1 *) vx;
153
143
154
144
const int i = blockIdx .x ;
155
145
156
146
const float d = x[i].d ;
157
147
const float m = x[i].m ;
158
148
159
- const uint8_t * pp = x[i].qs ;
160
-
161
- for (int l = 0 ; l < QK4_1; l += 2 ) {
162
- const uint8_t vi = pp[l/2 ];
163
-
164
- const int8_t vi0 = vi & 0xf ;
165
- const int8_t vi1 = vi >> 4 ;
149
+ for (int j = 0 ; j < qk/2 ; ++j) {
150
+ const int x0 = (x[i].qs [j] & 0xf );
151
+ const int x1 = (x[i].qs [j] >> 4 );
166
152
167
- const float v0 = vi0*d + m;
168
- const float v1 = vi1*d + m;
169
-
170
- y[i*QK4_1 + l + 0 ] = v0;
171
- y[i*QK4_1 + l + 1 ] = v1;
172
- }
173
- }
174
-
175
- static __global__ void dequantize_block_q4_2 (const void * vx, float * y) {
176
- const block_q4_2 * x = (const block_q4_2 *) vx;
177
-
178
- const int i = blockIdx .x ;
179
-
180
- const float d = x[i].d ;
181
-
182
- const uint8_t * pp = x[i].qs ;
183
-
184
- for (int l = 0 ; l < QK4_2; l += 2 ) {
185
- const uint8_t vi = pp[l/2 ];
186
-
187
- const int8_t vi0 = vi & 0xf ;
188
- const int8_t vi1 = vi >> 4 ;
189
-
190
- const float v0 = (vi0 - 8 )*d;
191
- const float v1 = (vi1 - 8 )*d;
192
-
193
- y[i*QK4_2 + l + 0 ] = v0;
194
- y[i*QK4_2 + l + 1 ] = v1;
153
+ y[i*qk + j + 0 ] = x0*d + m;
154
+ y[i*qk + j + qk/2 ] = x1*d + m;
195
155
}
196
156
}
197
157
198
158
static __global__ void dequantize_block_q5_0 (const void * vx, float * y) {
159
+ static const int qk = QK5_0;
160
+
199
161
const block_q5_0 * x = (const block_q5_0 *) vx;
200
162
201
163
const int i = blockIdx .x ;
202
164
203
165
const float d = x[i].d ;
204
166
205
- const uint8_t * pp = x[i].qs ;
206
-
207
167
uint32_t qh;
208
168
memcpy (&qh, x[i].qh , sizeof (qh));
209
169
210
- for (int l = 0 ; l < QK5_0; l += 2 ) {
211
- const uint8_t vi = pp[l/2 ];
212
-
213
- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
214
- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
170
+ for (int j = 0 ; j < qk/2 ; ++j) {
171
+ const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
172
+ const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
215
173
216
- const int8_t vi0 = ((vi & 0xf ) | vh0) ;
217
- const int8_t vi1 = ((vi >> 4 ) | vh1) ;
174
+ const int32_t x0 = ((x[i]. qs [j] & 0xf ) | xh_0) - 16 ;
175
+ const int32_t x1 = ((x[i]. qs [j] >> 4 ) | xh_1) - 16 ;
218
176
219
- const float v0 = (vi0 - 16 )*d;
220
- const float v1 = (vi1 - 16 )*d;
221
-
222
- y[i*QK5_0 + l + 0 ] = v0;
223
- y[i*QK5_0 + l + 1 ] = v1;
177
+ y[i*qk + j + 0 ] = x0*d;
178
+ y[i*qk + j + qk/2 ] = x1*d;
224
179
}
225
180
}
226
181
227
182
static __global__ void dequantize_block_q5_1 (const void * vx, float * y) {
183
+ static const int qk = QK5_1;
184
+
228
185
const block_q5_1 * x = (const block_q5_1 *) vx;
229
186
230
187
const int i = blockIdx .x ;
231
188
232
189
const float d = x[i].d ;
233
190
const float m = x[i].m ;
234
191
235
- const uint8_t * pp = x[i].qs ;
236
-
237
192
uint32_t qh;
238
193
memcpy (&qh, x[i].qh , sizeof (qh));
239
194
240
- for (int l = 0 ; l < QK5_1; l += 2 ) {
241
- const uint8_t vi = pp[l/2 ];
242
-
243
- const int8_t vh0 = ((qh & (1 << (l + 0 ))) >> (l + 0 )) << 4 ;
244
- const int8_t vh1 = ((qh & (1 << (l + 1 ))) >> (l + 1 )) << 4 ;
195
+ for (int j = 0 ; j < qk/2 ; ++j) {
196
+ const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
197
+ const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
245
198
246
- const int8_t vi0 = (vi & 0xf ) | vh0 ;
247
- const int8_t vi1 = (vi >> 4 ) | vh1 ;
199
+ const int x0 = (x[i]. qs [j] & 0xf ) | xh_0 ;
200
+ const int x1 = (x[i]. qs [j] >> 4 ) | xh_1 ;
248
201
249
- const float v0 = vi0*d + m;
250
- const float v1 = vi1*d + m;
251
-
252
- y[i*QK5_1 + l + 0 ] = v0;
253
- y[i*QK5_1 + l + 1 ] = v1;
202
+ y[i*qk + j + 0 ] = x0*d + m;
203
+ y[i*qk + j + qk/2 ] = x1*d + m;
254
204
}
255
205
}
256
206
257
207
static __global__ void dequantize_block_q8_0 (const void * vx, float * y) {
208
+ static const int qk = QK8_0;
209
+
258
210
const block_q8_0 * x = (const block_q8_0 *) vx;
259
211
260
212
const int i = blockIdx .x ;
261
213
262
214
const float d = x[i].d ;
263
215
264
- const int8_t * pp = x[i].qs ;
265
-
266
- for (int l = 0 ; l < QK8_0; l++) {
267
- const int8_t vi = pp[l];
268
-
269
- y[i*QK8_0 + l] = vi*d;
216
+ for (int j = 0 ; j < qk; ++j) {
217
+ y[i*qk + j] = x[i].qs [j]*d;
270
218
}
271
219
}
272
220
@@ -280,11 +228,6 @@ static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStre
280
228
dequantize_block_q4_1<<<nb, 1 , 0 , stream>>> (vx, y);
281
229
}
282
230
283
- static void dequantize_row_q4_2_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
284
- const int nb = k / QK4_2;
285
- dequantize_block_q4_2<<<nb, 1 , 0 , stream>>> (vx, y);
286
- }
287
-
288
231
static void dequantize_row_q5_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
289
232
const int nb = k / QK5_0;
290
233
dequantize_block_q5_0<<<nb, 1 , 0 , stream>>> (vx, y);
@@ -319,8 +262,6 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
319
262
return dequantize_row_q4_0_cuda;
320
263
case GGML_TYPE_Q4_1:
321
264
return dequantize_row_q4_1_cuda;
322
- case GGML_TYPE_Q4_2:
323
- return dequantize_row_q4_2_cuda;
324
265
case GGML_TYPE_Q5_0:
325
266
return dequantize_row_q5_0_cuda;
326
267
case GGML_TYPE_Q5_1:
0 commit comments