@@ -477,6 +477,14 @@ typedef struct {
477
477
} block_q6_K;
478
478
static_assert (sizeof (block_q6_K) == sizeof(ggml_fp16_t ) + 13*QK_K/16, "wrong q6_K block size/padding");
479
479
480
+ #define QR2_XXS 8
481
+ #define QI2_XXS (QK_K / (4 *QR2_XXS))
482
+ typedef struct {
483
+ half d;
484
+ uint16_t qs[QK_K/8 ];
485
+ } block_iq2_xxs;
486
+ static_assert (sizeof (block_iq2_xxs) == sizeof(ggml_fp16_t ) + QK_K/8*sizeof(uint16_t ), "wrong iq2_xxs block size/padding");
487
+
480
488
#define WARP_SIZE 32
481
489
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
482
490
@@ -1292,6 +1300,128 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
1292
1300
#endif
1293
1301
}
1294
1302
1303
+ static const __device__ uint64_t kgrid_iq2xxs[256 ] = {
1304
+ 0x0808080808080808 , 0x080808080808082b , 0x0808080808081919 , 0x0808080808082b08 ,
1305
+ 0x0808080808082b2b , 0x0808080808190819 , 0x0808080808191908 , 0x08080808082b0808 ,
1306
+ 0x08080808082b082b , 0x08080808082b2b08 , 0x08080808082b2b2b , 0x0808080819080819 ,
1307
+ 0x0808080819081908 , 0x0808080819190808 , 0x0808080819192b08 , 0x08080808192b0819 ,
1308
+ 0x08080808192b1908 , 0x080808082b080808 , 0x080808082b08082b , 0x080808082b082b2b ,
1309
+ 0x080808082b2b082b , 0x0808081908080819 , 0x0808081908081908 , 0x0808081908190808 ,
1310
+ 0x0808081908191919 , 0x0808081919080808 , 0x080808192b081908 , 0x080808192b192b08 ,
1311
+ 0x0808082b08080808 , 0x0808082b0808082b , 0x0808082b082b082b , 0x0808082b2b08082b ,
1312
+ 0x0808190808080819 , 0x0808190808081908 , 0x0808190808190808 , 0x08081908082b0819 ,
1313
+ 0x08081908082b1908 , 0x0808190819080808 , 0x080819081908082b , 0x0808190819082b08 ,
1314
+ 0x08081908192b0808 , 0x080819082b080819 , 0x080819082b081908 , 0x080819082b190808 ,
1315
+ 0x080819082b2b1908 , 0x0808191908080808 , 0x080819190808082b , 0x0808191908082b08 ,
1316
+ 0x08081919082b0808 , 0x080819191908192b , 0x08081919192b2b19 , 0x080819192b080808 ,
1317
+ 0x080819192b190819 , 0x0808192b08082b19 , 0x0808192b08190808 , 0x0808192b19080808 ,
1318
+ 0x0808192b2b081908 , 0x0808192b2b2b1908 , 0x08082b0808080808 , 0x08082b0808081919 ,
1319
+ 0x08082b0808082b08 , 0x08082b0808191908 , 0x08082b08082b2b08 , 0x08082b0819080819 ,
1320
+ 0x08082b0819081908 , 0x08082b0819190808 , 0x08082b081919082b , 0x08082b082b082b08 ,
1321
+ 0x08082b1908081908 , 0x08082b1919080808 , 0x08082b2b0808082b , 0x08082b2b08191908 ,
1322
+ 0x0819080808080819 , 0x0819080808081908 , 0x0819080808190808 , 0x08190808082b0819 ,
1323
+ 0x0819080819080808 , 0x08190808192b0808 , 0x081908082b081908 , 0x081908082b190808 ,
1324
+ 0x081908082b191919 , 0x0819081908080808 , 0x0819081908082b08 , 0x08190819082b0808 ,
1325
+ 0x0819081919190808 , 0x0819081919192b2b , 0x081908192b080808 , 0x0819082b082b1908 ,
1326
+ 0x0819082b19081919 , 0x0819190808080808 , 0x0819190808082b08 , 0x08191908082b0808 ,
1327
+ 0x08191908082b1919 , 0x0819190819082b19 , 0x081919082b080808 , 0x0819191908192b08 ,
1328
+ 0x08191919192b082b , 0x0819192b08080808 , 0x0819192b0819192b , 0x08192b0808080819 ,
1329
+ 0x08192b0808081908 , 0x08192b0808190808 , 0x08192b0819080808 , 0x08192b082b080819 ,
1330
+ 0x08192b1908080808 , 0x08192b1908081919 , 0x08192b192b2b0808 , 0x08192b2b19190819 ,
1331
+ 0x082b080808080808 , 0x082b08080808082b , 0x082b080808082b2b , 0x082b080819081908 ,
1332
+ 0x082b0808192b0819 , 0x082b08082b080808 , 0x082b08082b08082b , 0x082b0819082b2b19 ,
1333
+ 0x082b081919082b08 , 0x082b082b08080808 , 0x082b082b0808082b , 0x082b190808080819 ,
1334
+ 0x082b190808081908 , 0x082b190808190808 , 0x082b190819080808 , 0x082b19081919192b ,
1335
+ 0x082b191908080808 , 0x082b191919080819 , 0x082b1919192b1908 , 0x082b192b2b190808 ,
1336
+ 0x082b2b0808082b08 , 0x082b2b08082b0808 , 0x082b2b082b191908 , 0x082b2b2b19081908 ,
1337
+ 0x1908080808080819 , 0x1908080808081908 , 0x1908080808190808 , 0x1908080808192b08 ,
1338
+ 0x19080808082b0819 , 0x19080808082b1908 , 0x1908080819080808 , 0x1908080819082b08 ,
1339
+ 0x190808081919192b , 0x19080808192b0808 , 0x190808082b080819 , 0x190808082b081908 ,
1340
+ 0x190808082b190808 , 0x1908081908080808 , 0x19080819082b0808 , 0x19080819192b0819 ,
1341
+ 0x190808192b080808 , 0x190808192b081919 , 0x1908082b08080819 , 0x1908082b08190808 ,
1342
+ 0x1908082b19082b08 , 0x1908082b1919192b , 0x1908082b192b2b08 , 0x1908190808080808 ,
1343
+ 0x1908190808082b08 , 0x19081908082b0808 , 0x190819082b080808 , 0x190819082b192b19 ,
1344
+ 0x190819190819082b , 0x19081919082b1908 , 0x1908192b08080808 , 0x19082b0808080819 ,
1345
+ 0x19082b0808081908 , 0x19082b0808190808 , 0x19082b0819080808 , 0x19082b0819081919 ,
1346
+ 0x19082b1908080808 , 0x19082b1919192b08 , 0x19082b19192b0819 , 0x19082b192b08082b ,
1347
+ 0x19082b2b19081919 , 0x19082b2b2b190808 , 0x1919080808080808 , 0x1919080808082b08 ,
1348
+ 0x1919080808190819 , 0x1919080808192b19 , 0x19190808082b0808 , 0x191908082b080808 ,
1349
+ 0x191908082b082b08 , 0x1919081908081908 , 0x191908191908082b , 0x191908192b2b1908 ,
1350
+ 0x1919082b2b190819 , 0x191919082b190808 , 0x191919082b19082b , 0x1919191908082b2b ,
1351
+ 0x1919192b08080819 , 0x1919192b19191908 , 0x19192b0808080808 , 0x19192b0808190819 ,
1352
+ 0x19192b0808192b19 , 0x19192b08192b1908 , 0x19192b1919080808 , 0x19192b2b08082b08 ,
1353
+ 0x192b080808081908 , 0x192b080808190808 , 0x192b080819080808 , 0x192b0808192b2b08 ,
1354
+ 0x192b081908080808 , 0x192b081919191919 , 0x192b082b08192b08 , 0x192b082b192b0808 ,
1355
+ 0x192b190808080808 , 0x192b190808081919 , 0x192b191908190808 , 0x192b19190819082b ,
1356
+ 0x192b19192b081908 , 0x192b2b081908082b , 0x2b08080808080808 , 0x2b0808080808082b ,
1357
+ 0x2b08080808082b2b , 0x2b08080819080819 , 0x2b0808082b08082b , 0x2b08081908081908 ,
1358
+ 0x2b08081908192b08 , 0x2b08081919080808 , 0x2b08082b08190819 , 0x2b08190808080819 ,
1359
+ 0x2b08190808081908 , 0x2b08190808190808 , 0x2b08190808191919 , 0x2b08190819080808 ,
1360
+ 0x2b081908192b0808 , 0x2b08191908080808 , 0x2b0819191908192b , 0x2b0819192b191908 ,
1361
+ 0x2b08192b08082b19 , 0x2b08192b19080808 , 0x2b08192b192b0808 , 0x2b082b080808082b ,
1362
+ 0x2b082b1908081908 , 0x2b082b2b08190819 , 0x2b19080808081908 , 0x2b19080808190808 ,
1363
+ 0x2b190808082b1908 , 0x2b19080819080808 , 0x2b1908082b2b0819 , 0x2b1908190819192b ,
1364
+ 0x2b1908192b080808 , 0x2b19082b19081919 , 0x2b19190808080808 , 0x2b191908082b082b ,
1365
+ 0x2b19190819081908 , 0x2b19191919190819 , 0x2b192b082b080819 , 0x2b192b19082b0808 ,
1366
+ 0x2b2b08080808082b , 0x2b2b080819190808 , 0x2b2b08082b081919 , 0x2b2b081908082b19 ,
1367
+ 0x2b2b082b08080808 , 0x2b2b190808192b08 , 0x2b2b2b0819190808 , 0x2b2b2b1908081908 ,
1368
+ };
1369
+
1370
+ static const __device__ uint8_t ksigns_iq2xs[128 ] = {
1371
+ 0 , 129 , 130 , 3 , 132 , 5 , 6 , 135 , 136 , 9 , 10 , 139 , 12 , 141 , 142 , 15 ,
1372
+ 144 , 17 , 18 , 147 , 20 , 149 , 150 , 23 , 24 , 153 , 154 , 27 , 156 , 29 , 30 , 159 ,
1373
+ 160 , 33 , 34 , 163 , 36 , 165 , 166 , 39 , 40 , 169 , 170 , 43 , 172 , 45 , 46 , 175 ,
1374
+ 48 , 177 , 178 , 51 , 180 , 53 , 54 , 183 , 184 , 57 , 58 , 187 , 60 , 189 , 190 , 63 ,
1375
+ 192 , 65 , 66 , 195 , 68 , 197 , 198 , 71 , 72 , 201 , 202 , 75 , 204 , 77 , 78 , 207 ,
1376
+ 80 , 209 , 210 , 83 , 212 , 85 , 86 , 215 , 216 , 89 , 90 , 219 , 92 , 221 , 222 , 95 ,
1377
+ 96 , 225 , 226 , 99 , 228 , 101 , 102 , 231 , 232 , 105 , 106 , 235 , 108 , 237 , 238 , 111 ,
1378
+ 240 , 113 , 114 , 243 , 116 , 245 , 246 , 119 , 120 , 249 , 250 , 123 , 252 , 125 , 126 , 255 ,
1379
+ };
1380
+
1381
+ static const __device__ uint8_t kmask_iq2xs[8 ] = {1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 };
1382
+
1383
+ inline bool ggml_cuda_supports_mmq (enum ggml_type type) {
1384
+ switch (type) {
1385
+ case GGML_TYPE_Q4_0:
1386
+ case GGML_TYPE_Q4_1:
1387
+ case GGML_TYPE_Q5_0:
1388
+ case GGML_TYPE_Q5_1:
1389
+ case GGML_TYPE_Q8_0:
1390
+ case GGML_TYPE_Q2_K:
1391
+ case GGML_TYPE_Q3_K:
1392
+ case GGML_TYPE_Q4_K:
1393
+ case GGML_TYPE_Q5_K:
1394
+ case GGML_TYPE_Q6_K:
1395
+ return true ;
1396
+ default :
1397
+ return false ;
1398
+ }
1399
+ }
1400
+
1401
+ template <typename dst_t >
1402
+ static __global__ void dequantize_block_iq2_xxs (const void * __restrict__ vx, dst_t * __restrict__ yy) {
1403
+
1404
+ const int i = blockIdx .x ;
1405
+ const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
1406
+
1407
+ const int tid = threadIdx .x ;
1408
+ #if QK_K == 256
1409
+ const int il = tid/8 ; // 0...3
1410
+ const int ib = tid%8 ; // 0...7
1411
+ dst_t * y = yy + i*QK_K + 32 *ib + 8 *il;
1412
+ const uint16_t * q2 = x[i].qs + 4 *ib;
1413
+ const uint8_t * aux8 = (const uint8_t *)q2;
1414
+ const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[il]);
1415
+ const uint32_t aux32 = q2[2 ] | (q2[3 ] << 16 );
1416
+ const float d = (float )x[i].d * (0 .5f + (aux32 >> 28 )) * 0 .25f ;
1417
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7 *il) & 127 ];
1418
+ for (int j = 0 ; j < 8 ; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1 .f : 1 .f );
1419
+ #else
1420
+ assert (false );
1421
+ #endif
1422
+
1423
+ }
1424
+
1295
1425
static __global__ void dequantize_mul_mat_vec_q2_k (const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
1296
1426
1297
1427
static_assert (16 %K_QUANTS_PER_ITERATION == 0 , " 16 must be divisible by K_QUANTS_PER_ITERATION" );
@@ -3825,6 +3955,55 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
3825
3955
return vec_dot_q6_K_q8_1_impl_mmq (&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
3826
3956
}
3827
3957
3958
+ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1 (
3959
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
3960
+ #if QK_K == 256
3961
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
3962
+
3963
+ #if QR2_XXS == 8
3964
+ const int ib32 = iqs;
3965
+ const uint16_t * q2 = bq2->qs + 4 *ib32;
3966
+ const uint8_t * aux8 = (const uint8_t *)q2;
3967
+ const int8_t * q8 = bq8_1[ib32].qs ;
3968
+ uint32_t aux32 = q2[2 ] | (q2[3 ] << 16 );
3969
+ int sumi = 0 ;
3970
+ for (int l = 0 ; l < 4 ; ++l) {
3971
+ const uint8_t * grid = (const uint8_t *)(kgrid_iq2xxs + aux8[l]);
3972
+ const uint8_t signs = ksigns_iq2xs[aux32 & 127 ];
3973
+ for (int j = 0 ; j < 8 ; ++j) {
3974
+ sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1 );
3975
+ }
3976
+ q8 += 8 ;
3977
+ aux32 >>= 7 ;
3978
+ }
3979
+ const float d = (float )bq2->d * (0 .5f + aux32) * (float )bq8_1[ib32].ds .x * 0 .25f ;
3980
+ return d * sumi;
3981
+ #else
3982
+ // iqs is 0...15
3983
+ const int ib32 = iqs/2 ;
3984
+ const int il = iqs%2 ;
3985
+ const uint16_t * q2 = bq2->qs + 4 *ib32;
3986
+ const uint8_t * aux8 = (const uint8_t *)q2;
3987
+ const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2 *il+0 ]);
3988
+ const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2 *il+1 ]);
3989
+ const uint32_t aux32 = q2[2 ] | (q2[3 ] << 16 );
3990
+ const float d = (float )bq2->d * (0 .5f + (aux32 >> 28 )) * (float )bq8_1[ib32].ds .x * 0 .25f ;
3991
+ const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14 *il) & 127 ];
3992
+ const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14 *il + 7 )) & 127 ];
3993
+ const int8_t * q8 = bq8_1[ib32].qs + 16 *il;
3994
+ int sumi1 = 0 , sumi2 = 0 ;
3995
+ for (int j = 0 ; j < 8 ; ++j) {
3996
+ sumi1 += q8[j+0 ] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1 );
3997
+ sumi2 += q8[j+8 ] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1 );
3998
+ }
3999
+ return d * (sumi1 + sumi2);
4000
+ #endif
4001
+ #else
4002
+ assert (false );
4003
+ return 0 .f ;
4004
+ #endif
4005
+ }
4006
+
3828
4007
template <int qk, int qr, int qi, bool need_sum, typename block_q_t , int mmq_x, int mmq_y, int nwarps,
3829
4008
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
3830
4009
static __device__ __forceinline__ void mul_mat_q (
@@ -5664,6 +5843,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
5664
5843
#endif
5665
5844
}
5666
5845
5846
+ template <typename dst_t >
5847
+ static void dequantize_row_iq2_xxs_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
5848
+ const int nb = k / QK_K;
5849
+ dequantize_block_iq2_xxs<<<nb, 32 , 0 , stream>>> (vx, y);
5850
+ }
5851
+
5667
5852
template <typename src_t , typename dst_t >
5668
5853
static void convert_unary_cuda (const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
5669
5854
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@@ -5692,6 +5877,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
5692
5877
return dequantize_row_q5_K_cuda;
5693
5878
case GGML_TYPE_Q6_K:
5694
5879
return dequantize_row_q6_K_cuda;
5880
+ case GGML_TYPE_IQ2_XXS:
5881
+ return dequantize_row_iq2_xxs_cuda;
5695
5882
case GGML_TYPE_F32:
5696
5883
return convert_unary_cuda<float >;
5697
5884
default :
@@ -5721,6 +5908,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
5721
5908
return dequantize_row_q5_K_cuda;
5722
5909
case GGML_TYPE_Q6_K:
5723
5910
return dequantize_row_q6_K_cuda;
5911
+ case GGML_TYPE_IQ2_XXS:
5912
+ return dequantize_row_iq2_xxs_cuda;
5724
5913
case GGML_TYPE_F16:
5725
5914
return convert_unary_cuda<half>;
5726
5915
default :
@@ -5915,6 +6104,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
5915
6104
<<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
5916
6105
}
5917
6106
6107
+ static void mul_mat_vec_iq2_xxs_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6108
+ GGML_ASSERT (ncols % QK_K == 0 );
6109
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6110
+ const dim3 block_nums (block_num_y, 1 , 1 );
6111
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6112
+ mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1 , vec_dot_iq2_xxs_q8_1>
6113
+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6114
+ }
6115
+
5918
6116
static void ggml_mul_mat_q4_0_q8_1_cuda (
5919
6117
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
5920
6118
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -7407,6 +7605,7 @@ static int64_t get_row_rounding(ggml_type type) {
7407
7605
case GGML_TYPE_Q4_K:
7408
7606
case GGML_TYPE_Q5_K:
7409
7607
case GGML_TYPE_Q6_K:
7608
+ case GGML_TYPE_IQ2_XXS:
7410
7609
return max_compute_capability >= CC_RDNA2 ? 128 : 64 ;
7411
7610
default :
7412
7611
GGML_ASSERT (false );
@@ -7427,6 +7626,7 @@ static int64_t get_row_rounding(ggml_type type) {
7427
7626
case GGML_TYPE_Q3_K:
7428
7627
case GGML_TYPE_Q4_K:
7429
7628
case GGML_TYPE_Q5_K:
7629
+ case GGML_TYPE_IQ2_XXS:
7430
7630
return max_compute_capability >= CC_VOLTA ? 128 : 64 ;
7431
7631
case GGML_TYPE_Q6_K:
7432
7632
return 64 ;
@@ -7477,6 +7677,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
7477
7677
case GGML_TYPE_Q6_K:
7478
7678
mul_mat_vec_q6_K_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7479
7679
break ;
7680
+ case GGML_TYPE_IQ2_XXS:
7681
+ mul_mat_vec_iq2_xxs_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7682
+ break ;
7480
7683
default :
7481
7684
GGML_ASSERT (false );
7482
7685
break ;
@@ -8693,6 +8896,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
8693
8896
8694
8897
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8695
8898
8899
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq (src0->type );
8900
+
8696
8901
// debug helpers
8697
8902
// printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
8698
8903
// printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
0 commit comments