Skip to content

Commit 9ebf302

Browse files
ikawrakowKawrakow
authored andcommitted
SOTA 2-bit quants (ggml-org#4773)
* iq2_xxs: basics * iq2_xxs: scalar and AVX2 dot products Needed to change Q8_K to have quants in the -127...127 range, else the IQ2_XXS AVX implementation becomes very awkward. The alternative would have been to use Q8_0 instead. Perhaps I'll change later, for now this is what we have. * iq2_xxs: ARM_NEON dot product Somehow strangely slow (112 ms/token). * iq2_xxs: WIP Metal Dequantize works, something is still wrong with the dot product. * iq2_xxs: Metal dot product now works We have PP-512 = 475 t/s TG-128 = 47.3 t/s Not the greatest performance, but not complete garbage either. * iq2_xxs: slighty faster dot product TG-128 is now 48.4 t/s * iq2_xxs: slighty faster dot product TG-128 is now 50.9 t/s * iq2_xxs: even faster Metal dot product TG-128 is now 54.1 t/s. Strangely enough, putting the signs lookup table into shared memory has a bigger impact than the grid values being in shared memory. * iq2_xxs: dequantize CUDA kernel - fix conflict with master * iq2_xxs: quantized CUDA dot product (MMVQ) We get TG-128 = 153.1 t/s * iq2_xxs: slightly faster CUDA dot product TG-128 is now at 155.1 t/s. * iq2_xxs: add to llama ftype enum * iq2_xxs: fix MoE on Metal * Fix missing MMQ ops when on hipBLAS I had put the ggml_supports_mmq call at the wrong place. * Fix bug in qequantize_row_iq2_xxs The 0.25f factor was missing. Great detective work by @ggerganov! * Fixing tests * PR suggestion --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 6d1fae7 commit 9ebf302

10 files changed

+902
-1
lines changed

ggml-cuda.cu

+205
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,14 @@ typedef struct {
477477
} block_q6_K;
478478
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
479479

480+
#define QR2_XXS 8
481+
#define QI2_XXS (QK_K / (4*QR2_XXS))
482+
typedef struct {
483+
half d;
484+
uint16_t qs[QK_K/8];
485+
} block_iq2_xxs;
486+
static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
487+
480488
#define WARP_SIZE 32
481489
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
482490

@@ -1292,6 +1300,128 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
12921300
#endif
12931301
}
12941302

1303+
static const __device__ uint64_t kgrid_iq2xxs[256] = {
1304+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
1305+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
1306+
0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
1307+
0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
1308+
0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
1309+
0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
1310+
0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
1311+
0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
1312+
0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
1313+
0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
1314+
0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
1315+
0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
1316+
0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
1317+
0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
1318+
0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
1319+
0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
1320+
0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
1321+
0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
1322+
0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
1323+
0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
1324+
0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
1325+
0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
1326+
0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
1327+
0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
1328+
0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
1329+
0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
1330+
0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
1331+
0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
1332+
0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
1333+
0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
1334+
0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
1335+
0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
1336+
0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
1337+
0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
1338+
0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
1339+
0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
1340+
0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
1341+
0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
1342+
0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
1343+
0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
1344+
0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
1345+
0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
1346+
0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
1347+
0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
1348+
0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
1349+
0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
1350+
0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
1351+
0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
1352+
0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
1353+
0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
1354+
0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
1355+
0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
1356+
0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
1357+
0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
1358+
0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
1359+
0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
1360+
0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
1361+
0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
1362+
0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
1363+
0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
1364+
0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
1365+
0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
1366+
0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
1367+
0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
1368+
};
1369+
1370+
static const __device__ uint8_t ksigns_iq2xs[128] = {
1371+
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
1372+
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
1373+
160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
1374+
48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
1375+
192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
1376+
80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
1377+
96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
1378+
240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
1379+
};
1380+
1381+
static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
1382+
1383+
inline bool ggml_cuda_supports_mmq(enum ggml_type type) {
1384+
switch (type) {
1385+
case GGML_TYPE_Q4_0:
1386+
case GGML_TYPE_Q4_1:
1387+
case GGML_TYPE_Q5_0:
1388+
case GGML_TYPE_Q5_1:
1389+
case GGML_TYPE_Q8_0:
1390+
case GGML_TYPE_Q2_K:
1391+
case GGML_TYPE_Q3_K:
1392+
case GGML_TYPE_Q4_K:
1393+
case GGML_TYPE_Q5_K:
1394+
case GGML_TYPE_Q6_K:
1395+
return true;
1396+
default:
1397+
return false;
1398+
}
1399+
}
1400+
1401+
template<typename dst_t>
1402+
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
1403+
1404+
const int i = blockIdx.x;
1405+
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
1406+
1407+
const int tid = threadIdx.x;
1408+
#if QK_K == 256
1409+
const int il = tid/8; // 0...3
1410+
const int ib = tid%8; // 0...7
1411+
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
1412+
const uint16_t * q2 = x[i].qs + 4*ib;
1413+
const uint8_t * aux8 = (const uint8_t *)q2;
1414+
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]);
1415+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
1416+
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
1417+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
1418+
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
1419+
#else
1420+
assert(false);
1421+
#endif
1422+
1423+
}
1424+
12951425
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
12961426

12971427
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@@ -3825,6 +3955,55 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
38253955
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
38263956
}
38273957

3958+
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
3959+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
3960+
#if QK_K == 256
3961+
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
3962+
3963+
#if QR2_XXS == 8
3964+
const int ib32 = iqs;
3965+
const uint16_t * q2 = bq2->qs + 4*ib32;
3966+
const uint8_t * aux8 = (const uint8_t *)q2;
3967+
const int8_t * q8 = bq8_1[ib32].qs;
3968+
uint32_t aux32 = q2[2] | (q2[3] << 16);
3969+
int sumi = 0;
3970+
for (int l = 0; l < 4; ++l) {
3971+
const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]);
3972+
const uint8_t signs = ksigns_iq2xs[aux32 & 127];
3973+
for (int j = 0; j < 8; ++j) {
3974+
sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
3975+
}
3976+
q8 += 8;
3977+
aux32 >>= 7;
3978+
}
3979+
const float d = (float)bq2->d * (0.5f + aux32) * (float)bq8_1[ib32].ds.x * 0.25f;
3980+
return d * sumi;
3981+
#else
3982+
// iqs is 0...15
3983+
const int ib32 = iqs/2;
3984+
const int il = iqs%2;
3985+
const uint16_t * q2 = bq2->qs + 4*ib32;
3986+
const uint8_t * aux8 = (const uint8_t *)q2;
3987+
const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+0]);
3988+
const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2*il+1]);
3989+
const uint32_t aux32 = q2[2] | (q2[3] << 16);
3990+
const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * (float)bq8_1[ib32].ds.x * 0.25f;
3991+
const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127];
3992+
const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127];
3993+
const int8_t * q8 = bq8_1[ib32].qs + 16*il;
3994+
int sumi1 = 0, sumi2 = 0;
3995+
for (int j = 0; j < 8; ++j) {
3996+
sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1);
3997+
sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1);
3998+
}
3999+
return d * (sumi1 + sumi2);
4000+
#endif
4001+
#else
4002+
assert(false);
4003+
return 0.f;
4004+
#endif
4005+
}
4006+
38284007
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
38294008
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
38304009
static __device__ __forceinline__ void mul_mat_q(
@@ -5664,6 +5843,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
56645843
#endif
56655844
}
56665845

5846+
template<typename dst_t>
5847+
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
5848+
const int nb = k / QK_K;
5849+
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
5850+
}
5851+
56675852
template <typename src_t, typename dst_t>
56685853
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
56695854
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -5692,6 +5877,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
56925877
return dequantize_row_q5_K_cuda;
56935878
case GGML_TYPE_Q6_K:
56945879
return dequantize_row_q6_K_cuda;
5880+
case GGML_TYPE_IQ2_XXS:
5881+
return dequantize_row_iq2_xxs_cuda;
56955882
case GGML_TYPE_F32:
56965883
return convert_unary_cuda<float>;
56975884
default:
@@ -5721,6 +5908,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
57215908
return dequantize_row_q5_K_cuda;
57225909
case GGML_TYPE_Q6_K:
57235910
return dequantize_row_q6_K_cuda;
5911+
case GGML_TYPE_IQ2_XXS:
5912+
return dequantize_row_iq2_xxs_cuda;
57245913
case GGML_TYPE_F16:
57255914
return convert_unary_cuda<half>;
57265915
default:
@@ -5915,6 +6104,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
59156104
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
59166105
}
59176106

6107+
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6108+
GGML_ASSERT(ncols % QK_K == 0);
6109+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
6110+
const dim3 block_nums(block_num_y, 1, 1);
6111+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
6112+
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
6113+
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
6114+
}
6115+
59186116
static void ggml_mul_mat_q4_0_q8_1_cuda(
59196117
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
59206118
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -7407,6 +7605,7 @@ static int64_t get_row_rounding(ggml_type type) {
74077605
case GGML_TYPE_Q4_K:
74087606
case GGML_TYPE_Q5_K:
74097607
case GGML_TYPE_Q6_K:
7608+
case GGML_TYPE_IQ2_XXS:
74107609
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
74117610
default:
74127611
GGML_ASSERT(false);
@@ -7427,6 +7626,7 @@ static int64_t get_row_rounding(ggml_type type) {
74277626
case GGML_TYPE_Q3_K:
74287627
case GGML_TYPE_Q4_K:
74297628
case GGML_TYPE_Q5_K:
7629+
case GGML_TYPE_IQ2_XXS:
74307630
return max_compute_capability >= CC_VOLTA ? 128 : 64;
74317631
case GGML_TYPE_Q6_K:
74327632
return 64;
@@ -7477,6 +7677,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
74777677
case GGML_TYPE_Q6_K:
74787678
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
74797679
break;
7680+
case GGML_TYPE_IQ2_XXS:
7681+
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7682+
break;
74807683
default:
74817684
GGML_ASSERT(false);
74827685
break;
@@ -8693,6 +8896,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
86938896

86948897
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
86958898

8899+
use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type);
8900+
86968901
// debug helpers
86978902
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
86988903
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);

0 commit comments

Comments
 (0)