@@ -54,6 +54,10 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
54
54
#endif
55
55
#endif
56
56
57
+ #ifdef __riscv_v_intrinsic
58
+ #include <riscv_vector.h>
59
+ #endif
60
+
57
61
#undef MIN
58
62
#undef MAX
59
63
#define MIN (a , b ) ((a) < (b) ? (a) : (b))
@@ -1582,6 +1586,90 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1582
1586
1583
1587
* s = hsum_float_8 (acc );
1584
1588
1589
+ #elif defined __riscv_v_intrinsic
1590
+
1591
+ float sumf = 0 ;
1592
+ uint8_t temp_01 [32 ] = {0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
1593
+ 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 };
1594
+
1595
+ for (int i = 0 ; i < nb ; ++ i ) {
1596
+
1597
+ const uint8_t * q2 = x [i ].qs ;
1598
+ const int8_t * q8 = y [i ].qs ;
1599
+ const uint8_t * sc = x [i ].scales ;
1600
+
1601
+ const float dall = y [i ].d * ggml_fp16_to_fp32 (x [i ].d );
1602
+ const float dmin = - y [i ].d * ggml_fp16_to_fp32 (x [i ].dmin );
1603
+
1604
+ size_t vl = 16 ;
1605
+
1606
+ vuint8m1_t scales = __riscv_vle8_v_u8m1 (sc , vl );
1607
+ vuint8m1_t aux = __riscv_vand_vx_u8m1 (scales , 0x0F , vl );
1608
+
1609
+ vint16m1_t q8sums = __riscv_vle16_v_i16m1 (y [i ].bsums , vl );
1610
+
1611
+ vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2 (sc , vl );
1612
+ vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2 (scales_2 , 0x4 , vl );
1613
+ vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1 (__riscv_vzext_vf2_u16m1 (mins8 , vl ));
1614
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2 (q8sums , mins , vl );
1615
+ vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1 (prod , __riscv_vmv_v_x_i32m1 (0 , 1 ), vl );
1616
+
1617
+ sumf += dmin * __riscv_vmv_x_s_i32m1_i32 (vsums );
1618
+
1619
+ vl = 32 ;
1620
+
1621
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1 (0 , 1 );
1622
+ vuint8m1_t v_b = __riscv_vle8_v_u8m1 (temp_01 , vl );
1623
+
1624
+ uint8_t is = 0 ;
1625
+ int isum = 0 ;
1626
+
1627
+ for (int j = 0 ; j < QK_K /128 ; ++ j ) {
1628
+ // load Q2
1629
+ vuint8m1_t q2_x = __riscv_vle8_v_u8m1 (q2 , vl );
1630
+
1631
+ vuint8m1_t q2_0 = __riscv_vand_vx_u8m1 (q2_x , 0x03 , vl );
1632
+ vuint8m1_t q2_1 = __riscv_vand_vx_u8m1 (__riscv_vsrl_vx_u8m1 (q2_x , 0x2 , vl ), 0x03 , vl );
1633
+ vuint8m1_t q2_2 = __riscv_vand_vx_u8m1 (__riscv_vsrl_vx_u8m1 (q2_x , 0x4 , vl ), 0x03 , vl );
1634
+ vuint8m1_t q2_3 = __riscv_vand_vx_u8m1 (__riscv_vsrl_vx_u8m1 (q2_x , 0x6 , vl ), 0x03 , vl );
1635
+
1636
+ // duplicate scale elements for product
1637
+ vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1 (aux , __riscv_vadd_vx_u8m1 (v_b , 0 + is , vl ), vl );
1638
+ vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1 (aux , __riscv_vadd_vx_u8m1 (v_b , 2 + is , vl ), vl );
1639
+ vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1 (aux , __riscv_vadd_vx_u8m1 (v_b , 4 + is , vl ), vl );
1640
+ vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1 (aux , __riscv_vadd_vx_u8m1 (v_b , 6 + is , vl ), vl );
1641
+
1642
+ vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2 (__riscv_vwmulu_vv_u16m2 (q2_0 , sc0 , vl ));
1643
+ vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2 (__riscv_vwmulu_vv_u16m2 (q2_1 , sc1 , vl ));
1644
+ vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2 (__riscv_vwmulu_vv_u16m2 (q2_2 , sc2 , vl ));
1645
+ vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2 (__riscv_vwmulu_vv_u16m2 (q2_3 , sc3 , vl ));
1646
+
1647
+ // load Q8
1648
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1 (q8 , vl );
1649
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1 (q8 + 32 , vl );
1650
+ vint8m1_t q8_2 = __riscv_vle8_v_i8m1 (q8 + 64 , vl );
1651
+ vint8m1_t q8_3 = __riscv_vle8_v_i8m1 (q8 + 96 , vl );
1652
+
1653
+ vint32m4_t s0 = __riscv_vwmul_vv_i32m4 (p0 , __riscv_vwcvt_x_x_v_i16m2 (q8_0 , vl ), vl );
1654
+ vint32m4_t s1 = __riscv_vwmul_vv_i32m4 (p1 , __riscv_vwcvt_x_x_v_i16m2 (q8_1 , vl ), vl );
1655
+ vint32m4_t s2 = __riscv_vwmul_vv_i32m4 (p2 , __riscv_vwcvt_x_x_v_i16m2 (q8_2 , vl ), vl );
1656
+ vint32m4_t s3 = __riscv_vwmul_vv_i32m4 (p3 , __riscv_vwcvt_x_x_v_i16m2 (q8_3 , vl ), vl );
1657
+
1658
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1 (__riscv_vadd_vv_i32m4 (s0 , s1 , vl ), vzero , vl );
1659
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1 (__riscv_vadd_vv_i32m4 (s2 , s3 , vl ), isum0 , vl );
1660
+
1661
+ isum += __riscv_vmv_x_s_i32m1_i32 (isum1 );
1662
+
1663
+ q2 += 32 ; q8 += 128 ; is = 8 ;
1664
+
1665
+ }
1666
+
1667
+ sumf += dall * isum ;
1668
+
1669
+ }
1670
+
1671
+ * s = sumf ;
1672
+
1585
1673
#else
1586
1674
1587
1675
float sumf = 0 ;
@@ -1807,6 +1895,64 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
1807
1895
1808
1896
* s = hsum_float_8 (acc ) + summs ;
1809
1897
1898
+ #elif defined __riscv_v_intrinsic
1899
+
1900
+ uint32_t aux32 [2 ];
1901
+ const uint8_t * scales = (const uint8_t * )aux32 ;
1902
+
1903
+ float sumf = 0 ;
1904
+
1905
+ for (int i = 0 ; i < nb ; ++ i ) {
1906
+
1907
+ const float d = y [i ].d * (float )x [i ].d ;
1908
+ const float dmin = - y [i ].d * (float )x [i ].dmin ;
1909
+
1910
+ const uint8_t * restrict q2 = x [i ].qs ;
1911
+ const int8_t * restrict q8 = y [i ].qs ;
1912
+ const uint32_t * restrict sc = (const uint32_t * )x [i ].scales ;
1913
+
1914
+ aux32 [0 ] = sc [0 ] & 0x0f0f0f0f ;
1915
+ aux32 [1 ] = (sc [0 ] >> 4 ) & 0x0f0f0f0f ;
1916
+
1917
+ sumf += dmin * (scales [4 ] * y [i ].bsums [0 ] + scales [5 ] * y [i ].bsums [1 ] + scales [6 ] * y [i ].bsums [2 ] + scales [7 ] * y [i ].bsums [3 ]);
1918
+
1919
+ int isum1 = 0 ;
1920
+ int isum2 = 0 ;
1921
+
1922
+ size_t vl = 16 ;
1923
+
1924
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1 (0 , 1 );
1925
+
1926
+ // load Q2
1927
+ vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2 (q2 , vl );
1928
+
1929
+ vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2 (__riscv_vand_vx_u8mf2 (q2_x , 0x03 , vl ));
1930
+ vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2 (__riscv_vand_vx_u8mf2 (__riscv_vsrl_vx_u8mf2 (q2_x , 0x2 , vl ), 0x03 , vl ));
1931
+ vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2 (__riscv_vand_vx_u8mf2 (__riscv_vsrl_vx_u8mf2 (q2_x , 0x4 , vl ), 0x03 , vl ));
1932
+ vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2 (__riscv_vand_vx_u8mf2 (__riscv_vsrl_vx_u8mf2 (q2_x , 0x6 , vl ), 0x03 , vl ));
1933
+
1934
+ // load Q8, and take product with Q2
1935
+ vint16m1_t p0 = __riscv_vwmul_vv_i16m1 (q2_0 , __riscv_vle8_v_i8mf2 (q8 , vl ), vl );
1936
+ vint16m1_t p1 = __riscv_vwmul_vv_i16m1 (q2_1 , __riscv_vle8_v_i8mf2 (q8 + 16 , vl ), vl );
1937
+ vint16m1_t p2 = __riscv_vwmul_vv_i16m1 (q2_2 , __riscv_vle8_v_i8mf2 (q8 + 32 , vl ), vl );
1938
+ vint16m1_t p3 = __riscv_vwmul_vv_i16m1 (q2_3 , __riscv_vle8_v_i8mf2 (q8 + 48 , vl ), vl );
1939
+
1940
+ vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1 (p0 , vzero , vl );
1941
+ vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1 (p1 , vzero , vl );
1942
+ vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1 (p2 , vzero , vl );
1943
+ vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1 (p3 , vzero , vl );
1944
+
1945
+ isum1 += __riscv_vmv_x_s_i16m1_i16 (vs_0 ) * scales [0 ];
1946
+ isum2 += __riscv_vmv_x_s_i16m1_i16 (vs_1 ) * scales [1 ];
1947
+ isum1 += __riscv_vmv_x_s_i16m1_i16 (vs_2 ) * scales [2 ];
1948
+ isum2 += __riscv_vmv_x_s_i16m1_i16 (vs_3 ) * scales [3 ];
1949
+
1950
+ sumf += d * (isum1 + isum2 );
1951
+
1952
+ }
1953
+
1954
+ * s = sumf ;
1955
+
1810
1956
#else
1811
1957
1812
1958
float sumf = 0 ;
@@ -2220,6 +2366,106 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2220
2366
2221
2367
* s = hsum_float_8 (acc );
2222
2368
2369
+ #elif defined __riscv_v_intrinsic
2370
+
2371
+ uint32_t aux [3 ];
2372
+ uint32_t utmp [4 ];
2373
+
2374
+ float sumf = 0 ;
2375
+ for (int i = 0 ; i < nb ; ++ i ) {
2376
+
2377
+ const uint8_t * restrict q3 = x [i ].qs ;
2378
+ const uint8_t * restrict qh = x [i ].hmask ;
2379
+ const int8_t * restrict q8 = y [i ].qs ;
2380
+
2381
+ memcpy (aux , x [i ].scales , 12 );
2382
+ utmp [3 ] = ((aux [1 ] >> 4 ) & kmask2 ) | (((aux [2 ] >> 6 ) & kmask1 ) << 4 );
2383
+ utmp [2 ] = ((aux [0 ] >> 4 ) & kmask2 ) | (((aux [2 ] >> 4 ) & kmask1 ) << 4 );
2384
+ utmp [1 ] = (aux [1 ] & kmask2 ) | (((aux [2 ] >> 2 ) & kmask1 ) << 4 );
2385
+ utmp [0 ] = (aux [0 ] & kmask2 ) | (((aux [2 ] >> 0 ) & kmask1 ) << 4 );
2386
+
2387
+ int8_t * scale = (int8_t * )utmp ;
2388
+ for (int j = 0 ; j < 16 ; ++ j ) scale [j ] -= 32 ;
2389
+
2390
+
2391
+ size_t vl = 32 ;
2392
+ uint8_t m = 1 ;
2393
+
2394
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1 (0 , 1 );
2395
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1 (qh , vl );
2396
+
2397
+ int sum_t = 0 ;
2398
+
2399
+ for (int j = 0 ; j < QK_K ; j += 128 ) {
2400
+
2401
+ vl = 32 ;
2402
+
2403
+ // load Q3
2404
+ vuint8m1_t q3_x = __riscv_vle8_v_u8m1 (q3 , vl );
2405
+
2406
+ vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vand_vx_u8m1 (q3_x , 0x03 , vl ));
2407
+ vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vand_vx_u8m1 (__riscv_vsrl_vx_u8m1 (q3_x , 0x2 , vl ), 0x03 , vl ));
2408
+ vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vand_vx_u8m1 (__riscv_vsrl_vx_u8m1 (q3_x , 0x4 , vl ), 0x03 , vl ));
2409
+ vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vand_vx_u8m1 (__riscv_vsrl_vx_u8m1 (q3_x , 0x6 , vl ), 0x03 , vl ));
2410
+
2411
+ // compute mask for subtraction
2412
+ vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1 (vqh , m , vl );
2413
+ vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8 (qh_m0 , 0 , vl );
2414
+ vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m (vmask_0 , q3_0 , 0x4 , vl );
2415
+ m <<= 1 ;
2416
+
2417
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1 (vqh , m , vl );
2418
+ vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8 (qh_m1 , 0 , vl );
2419
+ vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m (vmask_1 , q3_1 , 0x4 , vl );
2420
+ m <<= 1 ;
2421
+
2422
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1 (vqh , m , vl );
2423
+ vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8 (qh_m2 , 0 , vl );
2424
+ vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m (vmask_2 , q3_2 , 0x4 , vl );
2425
+ m <<= 1 ;
2426
+
2427
+ vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1 (vqh , m , vl );
2428
+ vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8 (qh_m3 , 0 , vl );
2429
+ vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m (vmask_3 , q3_3 , 0x4 , vl );
2430
+ m <<= 1 ;
2431
+
2432
+ // load Q8 and take product with Q3
2433
+ vint16m2_t a0 = __riscv_vwmul_vv_i16m2 (q3_m0 , __riscv_vle8_v_i8m1 (q8 , vl ), vl );
2434
+ vint16m2_t a1 = __riscv_vwmul_vv_i16m2 (q3_m1 , __riscv_vle8_v_i8m1 (q8 + 32 , vl ), vl );
2435
+ vint16m2_t a2 = __riscv_vwmul_vv_i16m2 (q3_m2 , __riscv_vle8_v_i8m1 (q8 + 64 , vl ), vl );
2436
+ vint16m2_t a3 = __riscv_vwmul_vv_i16m2 (q3_m3 , __riscv_vle8_v_i8m1 (q8 + 96 , vl ), vl );
2437
+
2438
+ vl = 16 ;
2439
+
2440
+ // retreive lane to multiply with scale
2441
+ vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (a0 , 0 ), (scale [0 ]), vl );
2442
+ vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (a0 , 1 ), (scale [1 ]), vl );
2443
+ vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (a1 , 0 ), (scale [2 ]), vl );
2444
+ vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (a1 , 1 ), (scale [3 ]), vl );
2445
+ vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (a2 , 0 ), (scale [4 ]), vl );
2446
+ vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (a2 , 1 ), (scale [5 ]), vl );
2447
+ vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (a3 , 0 ), (scale [6 ]), vl );
2448
+ vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (a3 , 1 ), (scale [7 ]), vl );
2449
+
2450
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1 (__riscv_vadd_vv_i32m2 (aux0_0 , aux0_1 , vl ), vzero , vl );
2451
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1 (__riscv_vadd_vv_i32m2 (aux1_0 , aux1_1 , vl ), isum0 , vl );
2452
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1 (__riscv_vadd_vv_i32m2 (aux2_0 , aux2_1 , vl ), isum1 , vl );
2453
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1 (__riscv_vadd_vv_i32m2 (aux3_0 , aux3_1 , vl ), isum2 , vl );
2454
+
2455
+ sum_t += __riscv_vmv_x_s_i32m1_i32 (isum3 );
2456
+
2457
+ q3 += 32 ; q8 += 128 ; scale += 8 ;
2458
+
2459
+ }
2460
+
2461
+ const float d = ggml_fp16_to_fp32 (x [i ].d ) * y [i ].d ;
2462
+
2463
+ sumf += d * sum_t ;
2464
+
2465
+ }
2466
+
2467
+ * s = sumf ;
2468
+
2223
2469
#else
2224
2470
// scalar version
2225
2471
// This function is written like this so the compiler can manage to vectorize most of it
@@ -2523,6 +2769,79 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
2523
2769
2524
2770
* s = hsum_float_8 (acc );
2525
2771
2772
+ #elif defined __riscv_v_intrinsic
2773
+
2774
+ uint16_t aux16 [2 ];
2775
+ int8_t * scales = (int8_t * )aux16 ;
2776
+
2777
+ float sumf = 0 ;
2778
+
2779
+ for (int i = 0 ; i < nb ; ++ i ) {
2780
+
2781
+ const uint8_t * restrict q3 = x [i ].qs ;
2782
+ const int8_t * restrict q8 = y [i ].qs ;
2783
+
2784
+ const uint16_t a = * (const uint16_t * )x [i ].scales ;
2785
+ aux16 [0 ] = a & 0x0f0f ;
2786
+ aux16 [1 ] = (a >> 4 ) & 0x0f0f ;
2787
+
2788
+ for (int j = 0 ; j < 4 ; ++ j ) scales [j ] -= 8 ;
2789
+
2790
+ int32_t isum = -4 * (scales [0 ] * y [i ].bsums [0 ] + scales [2 ] * y [i ].bsums [1 ] + scales [1 ] * y [i ].bsums [2 ] + scales [3 ] * y [i ].bsums [3 ]);
2791
+
2792
+ const float d = y [i ].d * (float )x [i ].d ;
2793
+
2794
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1 (0 , 1 );
2795
+
2796
+ // load qh
2797
+ vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4 (x [i ].hmask , 8 );
2798
+ vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2 (__riscv_vsrl_vx_u8mf4 (qh_x1 , 1 , 8 ));
2799
+
2800
+ size_t vl = 16 ;
2801
+
2802
+ // extend and combine both qh_x1 and qh_x2
2803
+ vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2 (__riscv_vlmul_ext_v_u8mf4_u8mf2 (qh_x1 ), qh_x2 , vl /2 , vl );
2804
+
2805
+ vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2 (__riscv_vsll_vx_u8mf2 (qh_x , 0x2 , vl ), 0x4 , vl );
2806
+ vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2 (qh_x , 0x4 , vl );
2807
+ vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2 (__riscv_vsrl_vx_u8mf2 (qh_x , 0x2 , vl ), 0x4 , vl );
2808
+ vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2 (__riscv_vsrl_vx_u8mf2 (qh_x , 0x4 , vl ), 0x4 , vl );
2809
+
2810
+ // load Q3
2811
+ vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2 (q3 , vl );
2812
+
2813
+ vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2 (__riscv_vand_vx_u8mf2 (q3_x , 0x3 , vl ), qh_0 , vl );
2814
+ vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2 (__riscv_vand_vx_u8mf2 (__riscv_vsrl_vx_u8mf2 (q3_x , 2 , vl ), 0x3 , vl ), qh_1 , vl );
2815
+ vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2 (__riscv_vand_vx_u8mf2 (__riscv_vsrl_vx_u8mf2 (q3_x , 4 , vl ), 0x3 , vl ), qh_2 , vl );
2816
+ vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2 (__riscv_vsrl_vx_u8mf2 (q3_x , 0x6 , vl ), qh_3 , vl );
2817
+
2818
+ vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2 (q3h_0 );
2819
+ vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2 (q3h_1 );
2820
+ vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2 (q3h_2 );
2821
+ vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2 (q3h_3 );
2822
+
2823
+ // load Q8 and take product with Q3
2824
+ vint16m1_t p0 = __riscv_vwmul_vv_i16m1 (q3_0 , __riscv_vle8_v_i8mf2 (q8 , vl ), vl );
2825
+ vint16m1_t p1 = __riscv_vwmul_vv_i16m1 (q3_1 , __riscv_vle8_v_i8mf2 (q8 + 16 , vl ), vl );
2826
+ vint16m1_t p2 = __riscv_vwmul_vv_i16m1 (q3_2 , __riscv_vle8_v_i8mf2 (q8 + 32 , vl ), vl );
2827
+ vint16m1_t p3 = __riscv_vwmul_vv_i16m1 (q3_3 , __riscv_vle8_v_i8mf2 (q8 + 48 , vl ), vl );
2828
+
2829
+ vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1 (p0 , vzero , vl );
2830
+ vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1 (p1 , vzero , vl );
2831
+ vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1 (p2 , vzero , vl );
2832
+ vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1 (p3 , vzero , vl );
2833
+
2834
+ isum += __riscv_vmv_x_s_i32m1_i32 (vs_0 ) * scales [0 ];
2835
+ isum += __riscv_vmv_x_s_i32m1_i32 (vs_1 ) * scales [2 ];
2836
+ isum += __riscv_vmv_x_s_i32m1_i32 (vs_2 ) * scales [1 ];
2837
+ isum += __riscv_vmv_x_s_i32m1_i32 (vs_3 ) * scales [3 ];
2838
+
2839
+ sumf += d * isum ;
2840
+
2841
+ }
2842
+
2843
+ * s = sumf ;
2844
+
2526
2845
#else
2527
2846
2528
2847
int8_t aux8 [QK_K ];
@@ -2823,6 +3142,78 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
2823
3142
2824
3143
* s = hsum_float_8 (acc ) + _mm_cvtss_f32 (acc_m );
2825
3144
3145
+ #elif defined __riscv_v_intrinsic
3146
+
3147
+ const uint8_t * scales = (const uint8_t * )& utmp [0 ];
3148
+ const uint8_t * mins = (const uint8_t * )& utmp [2 ];
3149
+
3150
+ float sumf = 0 ;
3151
+
3152
+ for (int i = 0 ; i < nb ; ++ i ) {
3153
+
3154
+ size_t vl = 8 ;
3155
+
3156
+ const float d = y [i ].d * ggml_fp16_to_fp32 (x [i ].d );
3157
+ const float dmin = y [i ].d * ggml_fp16_to_fp32 (x [i ].dmin );
3158
+
3159
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2 (y [i ].bsums , 4 , vl );
3160
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2 (y [i ].bsums + 1 , 4 , vl );
3161
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2 (q8sums_0 , q8sums_1 , vl );
3162
+
3163
+ memcpy (utmp , x [i ].scales , 12 );
3164
+ utmp [3 ] = ((utmp [2 ] >> 4 ) & kmask2 ) | (((utmp [1 ] >> 6 ) & kmask3 ) << 4 );
3165
+ const uint32_t uaux = utmp [1 ] & kmask1 ;
3166
+ utmp [1 ] = (utmp [2 ] & kmask2 ) | (((utmp [0 ] >> 6 ) & kmask3 ) << 4 );
3167
+ utmp [2 ] = uaux ;
3168
+ utmp [0 ] &= kmask1 ;
3169
+
3170
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4 (mins , vl );
3171
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2 (__riscv_vzext_vf2_u16mf2 (mins8 , vl ));
3172
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1 (q8sums , v_mins , vl );
3173
+
3174
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1 (prod , __riscv_vmv_v_x_i32m1 (0 , 1 ), vl );
3175
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32 (sumi );
3176
+
3177
+ const uint8_t * restrict q4 = x [i ].qs ;
3178
+ const int8_t * restrict q8 = y [i ].qs ;
3179
+
3180
+ vl = 32 ;
3181
+
3182
+ int32_t sum_1 = 0 ;
3183
+ int32_t sum_2 = 0 ;
3184
+
3185
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1 (0 , 1 );
3186
+
3187
+ for (int j = 0 ; j < QK_K /64 ; ++ j ) {
3188
+ // load Q4
3189
+ vuint8m1_t q4_x = __riscv_vle8_v_u8m1 (q4 , vl );
3190
+
3191
+ // load Q8 and multiply it with lower Q4 nibble
3192
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1 (q8 , vl );
3193
+ vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vand_vx_u8m1 (q4_x , 0x0F , vl ));
3194
+ vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2 (q4_0 , q8_0 , vl );
3195
+ vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1 (qv_0 , vzero , vl );
3196
+
3197
+ sum_1 += __riscv_vmv_x_s_i16m1_i16 (vs_0 ) * scales [2 * j + 0 ];
3198
+
3199
+ // load Q8 and multiply it with upper Q4 nibble
3200
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1 (q8 + 32 , vl );
3201
+ vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vsrl_vx_u8m1 (q4_x , 0x04 , vl ));
3202
+ vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2 (q4_1 , q8_1 , vl );
3203
+ vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1 (qv_1 , vzero , vl );
3204
+
3205
+ sum_2 += __riscv_vmv_x_s_i16m1_i16 (vs_1 ) * scales [2 * j + 1 ];
3206
+
3207
+ q4 += 32 ; q8 += 64 ;
3208
+
3209
+ }
3210
+
3211
+ sumf += d * (sum_1 + sum_2 );
3212
+
3213
+ }
3214
+
3215
+ * s = sumf ;
3216
+
2826
3217
#else
2827
3218
2828
3219
@@ -3064,6 +3455,50 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
3064
3455
3065
3456
* s = hsum_float_8 (acc ) - summs ;
3066
3457
3458
+ #elif defined __riscv_v_intrinsic
3459
+
3460
+ uint16_t s16 [2 ];
3461
+ const uint8_t * restrict scales = (const uint8_t * )s16 ;
3462
+
3463
+ float sumf = 0 ;
3464
+
3465
+ for (int i = 0 ; i < nb ; ++ i ) {
3466
+
3467
+ const uint8_t * restrict q4 = x [i ].qs ;
3468
+ const int8_t * restrict q8 = y [i ].qs ;
3469
+
3470
+ const uint16_t * restrict b = (const uint16_t * )x [i ].scales ;
3471
+ s16 [0 ] = b [0 ] & 0x0f0f ;
3472
+ s16 [1 ] = (b [0 ] >> 4 ) & 0x0f0f ;
3473
+
3474
+ sumf -= y [i ].d * ggml_fp16_to_fp32 (x [i ].d [1 ]) * (scales [2 ] * (y [i ].bsums [0 ] + y [i ].bsums [1 ]) + scales [3 ] * (y [i ].bsums [2 ] + y [i ].bsums [3 ]));
3475
+ const float d = y [i ].d * ggml_fp16_to_fp32 (x [i ].d [0 ]);
3476
+
3477
+ size_t vl = 32 ;
3478
+
3479
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1 (0 , 1 );
3480
+
3481
+ // load Q4
3482
+ vuint8m1_t q4_x = __riscv_vle8_v_u8m1 (q4 , vl );
3483
+
3484
+ // load Q8 and multiply it with lower Q4 nibble
3485
+ vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vand_vx_u8m1 (q4_x , 0x0F , vl ));
3486
+ vint16m2_t va_0 = __riscv_vwmul_vv_i16m2 (q4_a , __riscv_vle8_v_i8m1 (q8 , vl ), vl );
3487
+ vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1 (va_0 , vzero , vl );
3488
+
3489
+ sumf += d * scales [0 ]* __riscv_vmv_x_s_i16m1_i16 (aux1 );
3490
+
3491
+ // load Q8 and multiply it with upper Q4 nibble
3492
+ vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vsrl_vx_u8m1 (q4_x , 0x04 , vl ));
3493
+ vint16m2_t va_1 = __riscv_vwmul_vv_i16m2 (q4_s , __riscv_vle8_v_i8m1 (q8 + 32 , vl ), vl );
3494
+ vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1 (va_1 , vzero , vl );
3495
+
3496
+ sumf += d * scales [1 ]* __riscv_vmv_x_s_i16m1_i16 (aux2 );
3497
+
3498
+ }
3499
+
3500
+ * s = sumf ;
3501
+
3067
3502
#else
3068
3503
3069
3504
uint8_t aux8 [QK_K ];
@@ -3394,6 +3829,93 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3394
3829
3395
3830
* s = hsum_float_8 (acc ) + summs ;
3396
3831
3832
+ #elif defined __riscv_v_intrinsic
3833
+
3834
+ const uint8_t * scales = (const uint8_t * )& utmp [0 ];
3835
+ const uint8_t * mins = (const uint8_t * )& utmp [2 ];
3836
+
3837
+ float sumf = 0 ;
3838
+ float sums = 0.0 ;
3839
+
3840
+ size_t vl ;
3841
+
3842
+ for (int i = 0 ; i < nb ; ++ i ) {
3843
+
3844
+ vl = 8 ;
3845
+
3846
+ const uint8_t * restrict q5 = x [i ].qs ;
3847
+ const uint8_t * restrict hm = x [i ].qh ;
3848
+ const int8_t * restrict q8 = y [i ].qs ;
3849
+
3850
+ const float d = ggml_fp16_to_fp32 (x [i ].d ) * y [i ].d ;
3851
+ const float dmin = ggml_fp16_to_fp32 (x [i ].dmin ) * y [i ].d ;
3852
+
3853
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2 (y [i ].bsums , 4 , vl );
3854
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2 (y [i ].bsums + 1 , 4 , vl );
3855
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2 (q8sums_0 , q8sums_1 , vl );
3856
+
3857
+ memcpy (utmp , x [i ].scales , 12 );
3858
+ utmp [3 ] = ((utmp [2 ] >> 4 ) & kmask2 ) | (((utmp [1 ] >> 6 ) & kmask3 ) << 4 );
3859
+ const uint32_t uaux = utmp [1 ] & kmask1 ;
3860
+ utmp [1 ] = (utmp [2 ] & kmask2 ) | (((utmp [0 ] >> 6 ) & kmask3 ) << 4 );
3861
+ utmp [2 ] = uaux ;
3862
+ utmp [0 ] &= kmask1 ;
3863
+
3864
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4 (mins , vl );
3865
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2 (__riscv_vzext_vf2_u16mf2 (mins8 , vl ));
3866
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1 (q8sums , v_mins , vl );
3867
+
3868
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1 (prod , __riscv_vmv_v_x_i32m1 (0 , 1 ), vl );
3869
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32 (sumi );
3870
+
3871
+ vl = 32 ;
3872
+ int32_t aux32 = 0 ;
3873
+ int is = 0 ;
3874
+
3875
+ uint8_t m = 1 ;
3876
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1 (0 , 1 );
3877
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1 (hm , vl );
3878
+
3879
+ for (int j = 0 ; j < QK_K /64 ; ++ j ) {
3880
+ // load Q5 and Q8
3881
+ vuint8m1_t q5_x = __riscv_vle8_v_u8m1 (q5 , vl );
3882
+ vint8m1_t q8_y1 = __riscv_vle8_v_i8m1 (q8 , vl );
3883
+ vint8m1_t q8_y2 = __riscv_vle8_v_i8m1 (q8 + 32 , vl );
3884
+
3885
+ // compute mask for addition
3886
+ vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vand_vx_u8m1 (q5_x , 0x0F , vl ));
3887
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1 (vqh , m , vl );
3888
+ vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8 (qh_m1 , 0 , vl );
3889
+ vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m (vmask_1 , q5_a , 16 , vl );
3890
+ m <<= 1 ;
3891
+
3892
+ vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1 (__riscv_vsrl_vx_u8m1 (q5_x , 0x04 , vl ));
3893
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1 (vqh , m , vl );
3894
+ vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8 (qh_m2 , 0 , vl );
3895
+ vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m (vmask_2 , q5_l , 16 , vl );
3896
+ m <<= 1 ;
3897
+
3898
+ vint16m2_t v0 = __riscv_vwmul_vv_i16m2 (q5_m1 , q8_y1 , vl );
3899
+ vint16m2_t v1 = __riscv_vwmul_vv_i16m2 (q5_m2 , q8_y2 , vl );
3900
+
3901
+ vint32m4_t vs1 = __riscv_vwmul_vx_i32m4 (v0 , scales [is ++ ], vl );
3902
+ vint32m4_t vs2 = __riscv_vwmul_vx_i32m4 (v1 , scales [is ++ ], vl );
3903
+
3904
+ vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1 (vs1 , vzero , vl );
3905
+ vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1 (vs2 , vzero , vl );
3906
+
3907
+ aux32 += __riscv_vmv_x_s_i32m1_i32 (vacc1 ) + __riscv_vmv_x_s_i32m1_i32 (vacc2 );
3908
+ q5 += 32 ; q8 += 64 ;
3909
+
3910
+ }
3911
+
3912
+ vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1 (__riscv_vfmv_v_f_f32m1 (aux32 , 1 ), d , 1 );
3913
+ sums += __riscv_vfmv_f_s_f32m1_f32 (vaux );
3914
+
3915
+ }
3916
+
3917
+ * s = sumf + sums ;
3918
+
3397
3919
#else
3398
3920
3399
3921
const uint8_t * scales = (const uint8_t * )& utmp [0 ];
@@ -3639,6 +4161,76 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
3639
4161
3640
4162
* s = hsum_float_8 (acc );
3641
4163
4164
+ #elif defined __riscv_v_intrinsic
4165
+
4166
+ float sumf = 0 ;
4167
+
4168
+ for (int i = 0 ; i < nb ; ++ i ) {
4169
+
4170
+ const float d = y [i ].d * (float )x [i ].d ;
4171
+ const int8_t * sc = x [i ].scales ;
4172
+
4173
+ const uint8_t * restrict q5 = x [i ].qs ;
4174
+ const uint8_t * restrict qh = x [i ].qh ;
4175
+ const int8_t * restrict q8 = y [i ].qs ;
4176
+
4177
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1 (0 , 1 );
4178
+
4179
+ // load qh
4180
+ vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4 (qh , 8 );
4181
+ vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2 (__riscv_vsrl_vx_u8mf4 (qh_x1 , 1 , 8 ));
4182
+
4183
+ size_t vl = 16 ;
4184
+
4185
+ // combine both qh_1 and qh_2
4186
+ vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2 (__riscv_vlmul_ext_v_u8mf4_u8mf2 (qh_x1 ), qh_x2 , vl /2 , vl );
4187
+
4188
+ vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2 (__riscv_vnot_v_u8mf2 (__riscv_vsll_vx_u8mf2 (qh_x , 0x4 , vl ), vl ), 16 , vl );
4189
+ vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2 (__riscv_vnot_v_u8mf2 (__riscv_vsll_vx_u8mf2 (qh_x , 0x2 , vl ), vl ), 16 , vl );
4190
+ vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2 (__riscv_vnot_v_u8mf2 (qh_x , vl ), 16 , vl );
4191
+ vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2 (__riscv_vnot_v_u8mf2 (__riscv_vsrl_vx_u8mf2 (qh_x , 0x4 , vl ), vl ), 16 , vl );
4192
+
4193
+ vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2 (qh_h0 );
4194
+ vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2 (qh_h1 );
4195
+ vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2 (qh_h2 );
4196
+ vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2 (qh_h3 );
4197
+
4198
+ // load q5
4199
+ vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2 (q5 , vl );
4200
+ vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2 (q5 + 16 , vl );
4201
+
4202
+ vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2 (__riscv_vand_vx_u8mf2 (q5_x1 , 0xF , vl ));
4203
+ vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2 (__riscv_vand_vx_u8mf2 (q5_x2 , 0xF , vl ));
4204
+ vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2 (__riscv_vsrl_vx_u8mf2 (q5_x1 , 0x4 , vl ));
4205
+ vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2 (__riscv_vsrl_vx_u8mf2 (q5_x2 , 0x4 , vl ));
4206
+
4207
+ vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2 (q5s_0 , qh_0 , vl );
4208
+ vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2 (q5s_1 , qh_1 , vl );
4209
+ vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2 (q5s_2 , qh_2 , vl );
4210
+ vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2 (q5s_3 , qh_3 , vl );
4211
+
4212
+ // load Q8 and multiply it with Q5
4213
+ vint16m1_t p0 = __riscv_vwmul_vv_i16m1 (q5_0 , __riscv_vle8_v_i8mf2 (q8 , vl ), vl );
4214
+ vint16m1_t p1 = __riscv_vwmul_vv_i16m1 (q5_1 , __riscv_vle8_v_i8mf2 (q8 + 16 , vl ), vl );
4215
+ vint16m1_t p2 = __riscv_vwmul_vv_i16m1 (q5_2 , __riscv_vle8_v_i8mf2 (q8 + 32 , vl ), vl );
4216
+ vint16m1_t p3 = __riscv_vwmul_vv_i16m1 (q5_3 , __riscv_vle8_v_i8mf2 (q8 + 48 , vl ), vl );
4217
+
4218
+ vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1 (p0 , vzero , vl );
4219
+ vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1 (p1 , vzero , vl );
4220
+ vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1 (p2 , vzero , vl );
4221
+ vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1 (p3 , vzero , vl );
4222
+
4223
+ int32_t sumi1 = sc [0 ] * __riscv_vmv_x_s_i32m1_i32 (vs_0 );
4224
+ int32_t sumi2 = sc [1 ] * __riscv_vmv_x_s_i32m1_i32 (vs_1 );
4225
+ int32_t sumi3 = sc [2 ] * __riscv_vmv_x_s_i32m1_i32 (vs_2 );
4226
+ int32_t sumi4 = sc [3 ] * __riscv_vmv_x_s_i32m1_i32 (vs_3 );
4227
+
4228
+ sumf += d * (sumi1 + sumi2 + sumi3 + sumi4 );
4229
+
4230
+ }
4231
+
4232
+ * s = sumf ;
4233
+
3642
4234
#else
3643
4235
3644
4236
int8_t aux8 [QK_K ];
@@ -4023,6 +4615,91 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4023
4615
4024
4616
* s = hsum_float_8 (acc );
4025
4617
4618
+ #elif defined __riscv_v_intrinsic
4619
+
4620
+ float sumf = 0 ;
4621
+ for (int i = 0 ; i < nb ; ++ i ) {
4622
+
4623
+ const float d = ggml_fp16_to_fp32 (x [i ].d ) * y [i ].d ;
4624
+
4625
+ const uint8_t * restrict q6 = x [i ].ql ;
4626
+ const uint8_t * restrict qh = x [i ].qh ;
4627
+ const int8_t * restrict q8 = y [i ].qs ;
4628
+
4629
+ const int8_t * restrict scale = x [i ].scales ;
4630
+
4631
+ size_t vl ;
4632
+
4633
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1 (0 , 1 );
4634
+
4635
+ int sum_t = 0 ;
4636
+ int is = 0 ;
4637
+
4638
+ for (int j = 0 ; j < QK_K /128 ; ++ j ) {
4639
+
4640
+ vl = 32 ;
4641
+
4642
+ // load qh
4643
+ vuint8m1_t qh_x = __riscv_vle8_v_u8m1 (qh , vl );
4644
+
4645
+ // load Q6
4646
+ vuint8m1_t q6_0 = __riscv_vle8_v_u8m1 (q6 , vl );
4647
+ vuint8m1_t q6_1 = __riscv_vle8_v_u8m1 (q6 + 32 , vl );
4648
+
4649
+ vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1 (q6_0 , 0x0F , vl );
4650
+ vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1 (q6_1 , 0x0F , vl );
4651
+ vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1 (q6_0 , 0x04 , vl );
4652
+ vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1 (q6_1 , 0x04 , vl );
4653
+
4654
+ vuint8m1_t qh_0 = __riscv_vand_vx_u8m1 (qh_x , 0x03 , vl );
4655
+ vuint8m1_t qh_1 = __riscv_vand_vx_u8m1 (__riscv_vsrl_vx_u8m1 (qh_x , 0x2 , vl ), 0x03 , vl );
4656
+ vuint8m1_t qh_2 = __riscv_vand_vx_u8m1 (__riscv_vsrl_vx_u8m1 (qh_x , 0x4 , vl ), 0x03 , vl );
4657
+ vuint8m1_t qh_3 = __riscv_vand_vx_u8m1 (__riscv_vsrl_vx_u8m1 (qh_x , 0x6 , vl ), 0x03 , vl );
4658
+
4659
+ vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1 (q6a_0 , __riscv_vsll_vx_u8m1 (qh_0 , 0x04 , vl ), vl );
4660
+ vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1 (q6a_1 , __riscv_vsll_vx_u8m1 (qh_1 , 0x04 , vl ), vl );
4661
+ vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1 (q6s_0 , __riscv_vsll_vx_u8m1 (qh_2 , 0x04 , vl ), vl );
4662
+ vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1 (q6s_1 , __riscv_vsll_vx_u8m1 (qh_3 , 0x04 , vl ), vl );
4663
+
4664
+ vint8m1_t a_0 = __riscv_vsub_vx_i8m1 (__riscv_vreinterpret_v_u8m1_i8m1 (qhi_0 ), 32 , vl );
4665
+ vint8m1_t a_1 = __riscv_vsub_vx_i8m1 (__riscv_vreinterpret_v_u8m1_i8m1 (qhi_1 ), 32 , vl );
4666
+ vint8m1_t a_2 = __riscv_vsub_vx_i8m1 (__riscv_vreinterpret_v_u8m1_i8m1 (qhi_2 ), 32 , vl );
4667
+ vint8m1_t a_3 = __riscv_vsub_vx_i8m1 (__riscv_vreinterpret_v_u8m1_i8m1 (qhi_3 ), 32 , vl );
4668
+
4669
+ // load Q8 and take product
4670
+ vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2 (a_0 , __riscv_vle8_v_i8m1 (q8 , vl ), vl );
4671
+ vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2 (a_1 , __riscv_vle8_v_i8m1 (q8 + 32 , vl ), vl );
4672
+ vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2 (a_2 , __riscv_vle8_v_i8m1 (q8 + 64 , vl ), vl );
4673
+ vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2 (a_3 , __riscv_vle8_v_i8m1 (q8 + 96 , vl ), vl );
4674
+
4675
+ vl = 16 ;
4676
+
4677
+ vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (va_q_0 , 0 ), scale [is + 0 ], vl );
4678
+ vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (va_q_0 , 1 ), scale [is + 1 ], vl );
4679
+ vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (va_q_1 , 0 ), scale [is + 2 ], vl );
4680
+ vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (va_q_1 , 1 ), scale [is + 3 ], vl );
4681
+ vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (va_q_2 , 0 ), scale [is + 4 ], vl );
4682
+ vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (va_q_2 , 1 ), scale [is + 5 ], vl );
4683
+ vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (va_q_3 , 0 ), scale [is + 6 ], vl );
4684
+ vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2 (__riscv_vget_v_i16m2_i16m1 (va_q_3 , 1 ), scale [is + 7 ], vl );
4685
+
4686
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1 (__riscv_vadd_vv_i32m2 (vaux_0 , vaux_1 , vl ), vzero , vl );
4687
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1 (__riscv_vadd_vv_i32m2 (vaux_2 , vaux_3 , vl ), isum0 , vl );
4688
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1 (__riscv_vadd_vv_i32m2 (vaux_4 , vaux_5 , vl ), isum1 , vl );
4689
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1 (__riscv_vadd_vv_i32m2 (vaux_6 , vaux_7 , vl ), isum2 , vl );
4690
+
4691
+ sum_t += __riscv_vmv_x_s_i32m1_i32 (isum3 );
4692
+
4693
+ q6 += 64 ; qh += 32 ; q8 += 128 ; is = 8 ;
4694
+
4695
+ }
4696
+
4697
+ sumf += d * sum_t ;
4698
+
4699
+ }
4700
+
4701
+ * s = sumf ;
4702
+
4026
4703
#else
4027
4704
4028
4705
int8_t aux8 [QK_K ];
@@ -4276,6 +4953,73 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
4276
4953
4277
4954
* s = hsum_float_8 (acc );
4278
4955
4956
+ #elif defined __riscv_v_intrinsic
4957
+
4958
+ float sumf = 0 ;
4959
+
4960
+ for (int i = 0 ; i < nb ; ++ i ) {
4961
+
4962
+ const float d_all = (float )x [i ].d ;
4963
+
4964
+ const uint8_t * restrict q6 = x [i ].ql ;
4965
+ const uint8_t * restrict qh = x [i ].qh ;
4966
+ const int8_t * restrict q8 = y [i ].qs ;
4967
+
4968
+ const int8_t * restrict scale = x [i ].scales ;
4969
+
4970
+ int32_t isum = 0 ;
4971
+
4972
+ size_t vl = 16 ;
4973
+
4974
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1 (0 , 1 );
4975
+
4976
+ // load Q6
4977
+ vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2 (q6 , vl );
4978
+ vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2 (q6 + 16 , vl );
4979
+
4980
+ // load qh
4981
+ vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2 (qh , vl );
4982
+
4983
+ vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2 (__riscv_vand_vx_u8mf2 (qh_x , 0x3 , vl ), 0x4 , vl );
4984
+ qh_x = __riscv_vsrl_vx_u8mf2 (qh_x , 0x2 , vl );
4985
+ vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2 (__riscv_vand_vx_u8mf2 (qh_x , 0x3 , vl ), 0x4 , vl );
4986
+ qh_x = __riscv_vsrl_vx_u8mf2 (qh_x , 0x2 , vl );
4987
+ vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2 (__riscv_vand_vx_u8mf2 (qh_x , 0x3 , vl ), 0x4 , vl );
4988
+ qh_x = __riscv_vsrl_vx_u8mf2 (qh_x , 0x2 , vl );
4989
+ vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2 (__riscv_vand_vx_u8mf2 (qh_x , 0x3 , vl ), 0x4 , vl );
4990
+
4991
+ vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2 (__riscv_vand_vx_u8mf2 (q6_0 , 0xF , vl ), qh0 , vl );
4992
+ vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2 (__riscv_vand_vx_u8mf2 (q6_1 , 0xF , vl ), qh1 , vl );
4993
+ vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2 (__riscv_vsrl_vx_u8mf2 (q6_0 , 0x4 , vl ), qh2 , vl );
4994
+ vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2 (__riscv_vsrl_vx_u8mf2 (q6_1 , 0x4 , vl ), qh3 , vl );
4995
+
4996
+ vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2 (__riscv_vreinterpret_v_u8mf2_i8mf2 (q6h_0 ), 32 , vl );
4997
+ vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2 (__riscv_vreinterpret_v_u8mf2_i8mf2 (q6h_1 ), 32 , vl );
4998
+ vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2 (__riscv_vreinterpret_v_u8mf2_i8mf2 (q6h_2 ), 32 , vl );
4999
+ vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2 (__riscv_vreinterpret_v_u8mf2_i8mf2 (q6h_3 ), 32 , vl );
5000
+
5001
+ // load Q8 and take product
5002
+ vint16m1_t p0 = __riscv_vwmul_vv_i16m1 (q6v_0 , __riscv_vle8_v_i8mf2 (q8 , vl ), vl );
5003
+ vint16m1_t p1 = __riscv_vwmul_vv_i16m1 (q6v_1 , __riscv_vle8_v_i8mf2 (q8 + 16 , vl ), vl );
5004
+ vint16m1_t p2 = __riscv_vwmul_vv_i16m1 (q6v_2 , __riscv_vle8_v_i8mf2 (q8 + 32 , vl ), vl );
5005
+ vint16m1_t p3 = __riscv_vwmul_vv_i16m1 (q6v_3 , __riscv_vle8_v_i8mf2 (q8 + 48 , vl ), vl );
5006
+
5007
+ vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1 (p0 , vzero , vl );
5008
+ vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1 (p1 , vzero , vl );
5009
+ vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1 (p2 , vzero , vl );
5010
+ vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1 (p3 , vzero , vl );
5011
+
5012
+ isum += __riscv_vmv_x_s_i32m1_i32 (vs_0 ) * scale [0 ];
5013
+ isum += __riscv_vmv_x_s_i32m1_i32 (vs_1 ) * scale [1 ];
5014
+ isum += __riscv_vmv_x_s_i32m1_i32 (vs_2 ) * scale [2 ];
5015
+ isum += __riscv_vmv_x_s_i32m1_i32 (vs_3 ) * scale [3 ];
5016
+
5017
+ sumf += isum * d_all * y [i ].d ;
5018
+
5019
+ }
5020
+
5021
+ * s = sumf ;
5022
+
4279
5023
#else
4280
5024
4281
5025
int8_t aux8 [QK_K ];
0 commit comments