Skip to content

Commit 79f34ab

Browse files
authoredOct 3, 2023
ggml : add RISC-V Vector Support for K-Quants and improved the existing intrinsics (#3453)
* Added RVV intrinsics support for Q8 quantize row and also improved the existing dot product function for risc-v. The RVV intrinsics is added for the following quantize row functions quantize_row_q8_0 quantize_row_q8_1 The following dot product functions have also been optimized by using LMUL = 1/2 instead of LMUL = 1 ggml_vec_dot_q4_0_q8_0 ggml_vec_dot_q4_1_q8_1 ggml_vec_dot_q5_0_q8_0 ggml_vec_dot_q5_1_q8_1 And vector initialization in Q5 by temporary array is also replaced by the vid intrinsics Signed-off-by: Ahmad Tameem <[email protected]> * Added RVV intrinsics support for k_quants This adds RISC-V Vector intrinsics support for the following K_quants functions for both QKK = 256 and QKK = 64 ggml_vec_dot_q2_K_q8_K ggml_vec_dot_q3_K_q8_K ggml_vec_dot_q4_K_q8_K ggml_vec_dot_q5_K_q8_K ggml_vec_dot_q6_K_q8_K Signed-off-by: Ahmad Tameem <[email protected]> --------- Signed-off-by: Ahmad Tameem <[email protected]>
1 parent 8186242 commit 79f34ab

File tree

2 files changed

+897
-97
lines changed

2 files changed

+897
-97
lines changed
 

‎ggml.c

+153-97
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,33 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
12721272
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
12731273
#endif
12741274
}
1275+
#elif defined(__riscv_v_intrinsic)
1276+
1277+
size_t vl = __riscv_vsetvl_e32m4(QK8_0);
1278+
1279+
for (int i = 0; i < nb; i++) {
1280+
// load elements
1281+
vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
1282+
1283+
vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1284+
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
1285+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1286+
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1287+
1288+
const float d = amax / ((1 << 7) - 1);
1289+
const float id = d ? 1.0f/d : 0.0f;
1290+
1291+
y[i].d = GGML_FP32_TO_FP16(d);
1292+
1293+
vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1294+
1295+
// convert to integer
1296+
vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1297+
vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1298+
1299+
// store result
1300+
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1301+
}
12751302
#else
12761303
// scalar
12771304
quantize_row_q8_0_reference(x, y, k);
@@ -1490,6 +1517,41 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
14901517
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
14911518
#endif
14921519
}
1520+
#elif defined(__riscv_v_intrinsic)
1521+
1522+
size_t vl = __riscv_vsetvl_e32m4(QK8_1);
1523+
1524+
for (int i = 0; i < nb; i++) {
1525+
// load elements
1526+
vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
1527+
1528+
vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1529+
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
1530+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1531+
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1532+
1533+
const float d = amax / ((1 << 7) - 1);
1534+
const float id = d ? 1.0f/d : 0.0f;
1535+
1536+
y[i].d = d;
1537+
1538+
vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1539+
1540+
// convert to integer
1541+
vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1542+
vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1543+
1544+
// store result
1545+
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1546+
1547+
// compute sum for y[i].s
1548+
vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
1549+
vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
1550+
1551+
// set y[i].s
1552+
int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
1553+
y[i].s = sum*d;
1554+
}
14931555
#else
14941556
// scalar
14951557
quantize_row_q8_1_reference(x, y, k);
@@ -2662,30 +2724,32 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
26622724
size_t vl = __riscv_vsetvl_e8m1(qk/2);
26632725

26642726
for (int i = 0; i < nb; i++) {
2665-
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2727+
// load elements
2728+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
26662729

2667-
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2668-
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2730+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2731+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
26692732

2670-
vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2671-
vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2733+
// mask and store lower part of x, and then upper part
2734+
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2735+
vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
26722736

2673-
vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2674-
vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2737+
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2738+
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
26752739

2676-
vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
2677-
vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
2740+
// subtract offset
2741+
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2742+
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
26782743

2679-
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2680-
vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2744+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2745+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
26812746

26822747
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
26832748

2684-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2685-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2749+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2750+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
26862751

2687-
int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2688-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2752+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
26892753

26902754
sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
26912755
}
@@ -2823,27 +2887,28 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
28232887
size_t vl = __riscv_vsetvl_e8m1(qk/2);
28242888

28252889
for (int i = 0; i < nb; i++) {
2826-
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
2890+
// load elements
2891+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
28272892

2828-
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
2829-
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
2893+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
2894+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
28302895

2831-
vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2832-
vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2896+
// mask and store lower part of x, and then upper part
2897+
vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2898+
vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
28332899

2834-
vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2835-
vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2900+
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2901+
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
28362902

2837-
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2838-
vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
2903+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2904+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
28392905

28402906
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
28412907

2842-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
2843-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2908+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2909+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
28442910

2845-
int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
2846-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
2911+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
28472912

28482913
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
28492914
}
@@ -3088,66 +3153,61 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
30883153

30893154
uint32_t qh;
30903155

3091-
// These temp values are for masking and shift operations
3092-
uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3093-
uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80,
3094-
0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
3095-
30963156
size_t vl = __riscv_vsetvl_e8m1(qk/2);
30973157

3158+
// These tempory registers are for masking and shift operations
3159+
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3160+
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3161+
3162+
vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3163+
vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3164+
30983165
for (int i = 0; i < nb; i++) {
30993166
memcpy(&qh, x[i].qh, sizeof(uint32_t));
31003167

3101-
// temporary registers
3102-
vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
3103-
vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
3104-
vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
3105-
vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
3106-
31073168
// ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3108-
vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl);
3109-
vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl);
3110-
vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3169+
vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3170+
vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3171+
vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
31113172

31123173
// ((qh & (1u << (j + 16))) >> (j + 12));
3113-
vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl);
3114-
vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl);
3174+
vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3175+
vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
31153176

31163177
// narrowing
3117-
vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl);
3118-
vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3178+
vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3179+
vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
31193180

3120-
vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl);
3121-
vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3181+
vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3182+
vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
31223183

31233184
// load
3124-
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3185+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
31253186

3126-
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3127-
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3187+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3188+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
31283189

3129-
vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3130-
vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3190+
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3191+
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
31313192

3132-
vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3133-
vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3193+
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3194+
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
31343195

3135-
vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3136-
vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3196+
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3197+
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
31373198

3138-
vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl);
3139-
vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl);
3199+
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3200+
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
31403201

3141-
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3142-
vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3202+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3203+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
31433204

31443205
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
31453206

3146-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3147-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3207+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3208+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
31483209

3149-
int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3150-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3210+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
31513211

31523212
sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
31533213
}
@@ -3414,62 +3474,58 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
34143474

34153475
uint32_t qh;
34163476

3417-
// These temp values are for shift operations
3418-
uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
3419-
34203477
size_t vl = __riscv_vsetvl_e8m1(qk/2);
34213478

3479+
// temporary registers for shift operations
3480+
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3481+
vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3482+
34223483
for (int i = 0; i < nb; i++) {
34233484
memcpy(&qh, x[i].qh, sizeof(uint32_t));
34243485

3425-
// temporary registers
3426-
vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
3427-
vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
3428-
34293486
// load qh
3430-
vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl);
3487+
vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
34313488

34323489
// ((qh >> (j + 0)) << 4) & 0x10;
3433-
vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl);
3434-
vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
3435-
vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl);
3490+
vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3491+
vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3492+
vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
34363493

34373494
// ((qh >> (j + 12)) ) & 0x10;
3438-
vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl);
3439-
vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl);
3495+
vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3496+
vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
34403497

34413498
// narrowing
3442-
vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl);
3443-
vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
3499+
vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3500+
vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
34443501

3445-
vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl);
3446-
vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
3502+
vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3503+
vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
34473504

34483505
// load
3449-
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
3506+
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
34503507

3451-
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
3452-
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
3508+
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
3509+
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
34533510

3454-
vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
3455-
vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
3511+
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3512+
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
34563513

3457-
vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
3458-
vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
3514+
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3515+
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
34593516

3460-
vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
3461-
vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
3517+
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3518+
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
34623519

3463-
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
3464-
vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
3520+
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3521+
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
34653522

34663523
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
34673524

3468-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
3469-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
3525+
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3526+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
34703527

3471-
int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
3472-
sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
3528+
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
34733529

34743530
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
34753531
}

‎k_quants.c

+744
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
5454
#endif
5555
#endif
5656

57+
#ifdef __riscv_v_intrinsic
58+
#include <riscv_vector.h>
59+
#endif
60+
5761
#undef MIN
5862
#undef MAX
5963
#define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -1582,6 +1586,90 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
15821586

15831587
*s = hsum_float_8(acc);
15841588

1589+
#elif defined __riscv_v_intrinsic
1590+
1591+
float sumf = 0;
1592+
uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1593+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
1594+
1595+
for (int i = 0; i < nb; ++i) {
1596+
1597+
const uint8_t * q2 = x[i].qs;
1598+
const int8_t * q8 = y[i].qs;
1599+
const uint8_t * sc = x[i].scales;
1600+
1601+
const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
1602+
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
1603+
1604+
size_t vl = 16;
1605+
1606+
vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
1607+
vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
1608+
1609+
vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
1610+
1611+
vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
1612+
vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
1613+
vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
1614+
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
1615+
vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1616+
1617+
sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
1618+
1619+
vl = 32;
1620+
1621+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1622+
vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
1623+
1624+
uint8_t is=0;
1625+
int isum=0;
1626+
1627+
for (int j = 0; j < QK_K/128; ++j) {
1628+
// load Q2
1629+
vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
1630+
1631+
vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
1632+
vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
1633+
vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
1634+
vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
1635+
1636+
// duplicate scale elements for product
1637+
vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
1638+
vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
1639+
vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
1640+
vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
1641+
1642+
vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
1643+
vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
1644+
vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
1645+
vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
1646+
1647+
// load Q8
1648+
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
1649+
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
1650+
vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
1651+
vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
1652+
1653+
vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
1654+
vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
1655+
vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
1656+
vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
1657+
1658+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
1659+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
1660+
1661+
isum += __riscv_vmv_x_s_i32m1_i32(isum1);
1662+
1663+
q2+=32; q8+=128; is=8;
1664+
1665+
}
1666+
1667+
sumf += dall * isum;
1668+
1669+
}
1670+
1671+
*s = sumf;
1672+
15851673
#else
15861674

15871675
float sumf = 0;
@@ -1807,6 +1895,64 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
18071895

18081896
*s = hsum_float_8(acc) + summs;
18091897

1898+
#elif defined __riscv_v_intrinsic
1899+
1900+
uint32_t aux32[2];
1901+
const uint8_t * scales = (const uint8_t *)aux32;
1902+
1903+
float sumf = 0;
1904+
1905+
for (int i = 0; i < nb; ++i) {
1906+
1907+
const float d = y[i].d * (float)x[i].d;
1908+
const float dmin = -y[i].d * (float)x[i].dmin;
1909+
1910+
const uint8_t * restrict q2 = x[i].qs;
1911+
const int8_t * restrict q8 = y[i].qs;
1912+
const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
1913+
1914+
aux32[0] = sc[0] & 0x0f0f0f0f;
1915+
aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
1916+
1917+
sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
1918+
1919+
int isum1 = 0;
1920+
int isum2 = 0;
1921+
1922+
size_t vl = 16;
1923+
1924+
vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
1925+
1926+
// load Q2
1927+
vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl);
1928+
1929+
vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl));
1930+
vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl));
1931+
vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl));
1932+
vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl));
1933+
1934+
// load Q8, and take product with Q2
1935+
vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
1936+
vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
1937+
vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
1938+
vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
1939+
1940+
vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl);
1941+
vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl);
1942+
vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl);
1943+
vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl);
1944+
1945+
isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0];
1946+
isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1];
1947+
isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2];
1948+
isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3];
1949+
1950+
sumf += d * (isum1 + isum2);
1951+
1952+
}
1953+
1954+
*s = sumf;
1955+
18101956
#else
18111957

18121958
float sumf = 0;
@@ -2220,6 +2366,106 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
22202366

22212367
*s = hsum_float_8(acc);
22222368

2369+
#elif defined __riscv_v_intrinsic
2370+
2371+
uint32_t aux[3];
2372+
uint32_t utmp[4];
2373+
2374+
float sumf = 0;
2375+
for (int i = 0; i < nb; ++i) {
2376+
2377+
const uint8_t * restrict q3 = x[i].qs;
2378+
const uint8_t * restrict qh = x[i].hmask;
2379+
const int8_t * restrict q8 = y[i].qs;
2380+
2381+
memcpy(aux, x[i].scales, 12);
2382+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
2383+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
2384+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
2385+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
2386+
2387+
int8_t * scale = (int8_t *)utmp;
2388+
for (int j = 0; j < 16; ++j) scale[j] -= 32;
2389+
2390+
2391+
size_t vl = 32;
2392+
uint8_t m = 1;
2393+
2394+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
2395+
vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
2396+
2397+
int sum_t = 0;
2398+
2399+
for (int j = 0; j < QK_K; j += 128) {
2400+
2401+
vl = 32;
2402+
2403+
// load Q3
2404+
vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
2405+
2406+
vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
2407+
vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
2408+
vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
2409+
vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
2410+
2411+
// compute mask for subtraction
2412+
vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
2413+
vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
2414+
vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl);
2415+
m <<= 1;
2416+
2417+
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
2418+
vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
2419+
vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl);
2420+
m <<= 1;
2421+
2422+
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
2423+
vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
2424+
vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl);
2425+
m <<= 1;
2426+
2427+
vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
2428+
vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
2429+
vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl);
2430+
m <<= 1;
2431+
2432+
// load Q8 and take product with Q3
2433+
vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
2434+
vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
2435+
vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
2436+
vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
2437+
2438+
vl = 16;
2439+
2440+
// retreive lane to multiply with scale
2441+
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
2442+
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
2443+
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
2444+
vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
2445+
vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
2446+
vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
2447+
vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
2448+
vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
2449+
2450+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
2451+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
2452+
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
2453+
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
2454+
2455+
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
2456+
2457+
q3 += 32; q8 += 128; scale += 8;
2458+
2459+
}
2460+
2461+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
2462+
2463+
sumf += d*sum_t;
2464+
2465+
}
2466+
2467+
*s = sumf;
2468+
22232469
#else
22242470
// scalar version
22252471
// This function is written like this so the compiler can manage to vectorize most of it
@@ -2523,6 +2769,79 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
25232769

25242770
*s = hsum_float_8(acc);
25252771

2772+
#elif defined __riscv_v_intrinsic
2773+
2774+
uint16_t aux16[2];
2775+
int8_t * scales = (int8_t *)aux16;
2776+
2777+
float sumf = 0;
2778+
2779+
for (int i = 0; i < nb; ++i) {
2780+
2781+
const uint8_t * restrict q3 = x[i].qs;
2782+
const int8_t * restrict q8 = y[i].qs;
2783+
2784+
const uint16_t a = *(const uint16_t *)x[i].scales;
2785+
aux16[0] = a & 0x0f0f;
2786+
aux16[1] = (a >> 4) & 0x0f0f;
2787+
2788+
for (int j = 0; j < 4; ++j) scales[j] -= 8;
2789+
2790+
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
2791+
2792+
const float d = y[i].d * (float)x[i].d;
2793+
2794+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
2795+
2796+
// load qh
2797+
vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8);
2798+
vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
2799+
2800+
size_t vl = 16;
2801+
2802+
// extend and combine both qh_x1 and qh_x2
2803+
vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
2804+
2805+
vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
2806+
vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl);
2807+
vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl);
2808+
vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl);
2809+
2810+
// load Q3
2811+
vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl);
2812+
2813+
vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl);
2814+
vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl);
2815+
vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl);
2816+
vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl);
2817+
2818+
vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0);
2819+
vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1);
2820+
vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2);
2821+
vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3);
2822+
2823+
// load Q8 and take product with Q3
2824+
vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
2825+
vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
2826+
vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
2827+
vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
2828+
2829+
vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
2830+
vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
2831+
vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
2832+
vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
2833+
2834+
isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0];
2835+
isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2];
2836+
isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1];
2837+
isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3];
2838+
2839+
sumf += d * isum;
2840+
2841+
}
2842+
2843+
*s = sumf;
2844+
25262845
#else
25272846

25282847
int8_t aux8[QK_K];
@@ -2823,6 +3142,78 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
28233142

28243143
*s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
28253144

3145+
#elif defined __riscv_v_intrinsic
3146+
3147+
const uint8_t * scales = (const uint8_t*)&utmp[0];
3148+
const uint8_t * mins = (const uint8_t*)&utmp[2];
3149+
3150+
float sumf = 0;
3151+
3152+
for (int i = 0; i < nb; ++i) {
3153+
3154+
size_t vl = 8;
3155+
3156+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
3157+
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
3158+
3159+
vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
3160+
vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
3161+
vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
3162+
3163+
memcpy(utmp, x[i].scales, 12);
3164+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
3165+
const uint32_t uaux = utmp[1] & kmask1;
3166+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
3167+
utmp[2] = uaux;
3168+
utmp[0] &= kmask1;
3169+
3170+
vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
3171+
vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
3172+
vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
3173+
3174+
vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
3175+
sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
3176+
3177+
const uint8_t * restrict q4 = x[i].qs;
3178+
const int8_t * restrict q8 = y[i].qs;
3179+
3180+
vl = 32;
3181+
3182+
int32_t sum_1 = 0;
3183+
int32_t sum_2 = 0;
3184+
3185+
vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
3186+
3187+
for (int j = 0; j < QK_K/64; ++j) {
3188+
// load Q4
3189+
vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
3190+
3191+
// load Q8 and multiply it with lower Q4 nibble
3192+
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
3193+
vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
3194+
vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
3195+
vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
3196+
3197+
sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
3198+
3199+
// load Q8 and multiply it with upper Q4 nibble
3200+
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
3201+
vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
3202+
vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
3203+
vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
3204+
3205+
sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
3206+
3207+
q4 += 32; q8 += 64;
3208+
3209+
}
3210+
3211+
sumf += d*(sum_1 + sum_2);
3212+
3213+
}
3214+
3215+
*s = sumf;
3216+
28263217
#else
28273218

28283219

@@ -3064,6 +3455,50 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
30643455

30653456
*s = hsum_float_8(acc) - summs;
30663457

3458+
#elif defined __riscv_v_intrinsic
3459+
3460+
uint16_t s16[2];
3461+
const uint8_t * restrict scales = (const uint8_t *)s16;
3462+
3463+
float sumf = 0;
3464+
3465+
for (int i = 0; i < nb; ++i) {
3466+
3467+
const uint8_t * restrict q4 = x[i].qs;
3468+
const int8_t * restrict q8 = y[i].qs;
3469+
3470+
const uint16_t * restrict b = (const uint16_t *)x[i].scales;
3471+
s16[0] = b[0] & 0x0f0f;
3472+
s16[1] = (b[0] >> 4) & 0x0f0f;
3473+
3474+
sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]));
3475+
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]);
3476+
3477+
size_t vl = 32;
3478+
3479+
vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
3480+
3481+
// load Q4
3482+
vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
3483+
3484+
// load Q8 and multiply it with lower Q4 nibble
3485+
vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
3486+
vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl);
3487+
vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl);
3488+
3489+
sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1);
3490+
3491+
// load Q8 and multiply it with upper Q4 nibble
3492+
vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
3493+
vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl);
3494+
vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl);
3495+
3496+
sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2);
3497+
3498+
}
3499+
3500+
*s = sumf;
3501+
30673502
#else
30683503

30693504
uint8_t aux8[QK_K];
@@ -3394,6 +3829,93 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
33943829

33953830
*s = hsum_float_8(acc) + summs;
33963831

3832+
#elif defined __riscv_v_intrinsic
3833+
3834+
const uint8_t * scales = (const uint8_t*)&utmp[0];
3835+
const uint8_t * mins = (const uint8_t*)&utmp[2];
3836+
3837+
float sumf = 0;
3838+
float sums = 0.0;
3839+
3840+
size_t vl;
3841+
3842+
for (int i = 0; i < nb; ++i) {
3843+
3844+
vl = 8;
3845+
3846+
const uint8_t * restrict q5 = x[i].qs;
3847+
const uint8_t * restrict hm = x[i].qh;
3848+
const int8_t * restrict q8 = y[i].qs;
3849+
3850+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
3851+
const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
3852+
3853+
vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
3854+
vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
3855+
vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
3856+
3857+
memcpy(utmp, x[i].scales, 12);
3858+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
3859+
const uint32_t uaux = utmp[1] & kmask1;
3860+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
3861+
utmp[2] = uaux;
3862+
utmp[0] &= kmask1;
3863+
3864+
vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
3865+
vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
3866+
vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
3867+
3868+
vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
3869+
sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
3870+
3871+
vl = 32;
3872+
int32_t aux32 = 0;
3873+
int is = 0;
3874+
3875+
uint8_t m = 1;
3876+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
3877+
vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
3878+
3879+
for (int j = 0; j < QK_K/64; ++j) {
3880+
// load Q5 and Q8
3881+
vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
3882+
vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
3883+
vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
3884+
3885+
// compute mask for addition
3886+
vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
3887+
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
3888+
vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
3889+
vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl);
3890+
m <<= 1;
3891+
3892+
vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
3893+
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
3894+
vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
3895+
vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl);
3896+
m <<= 1;
3897+
3898+
vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
3899+
vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
3900+
3901+
vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
3902+
vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
3903+
3904+
vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
3905+
vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
3906+
3907+
aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
3908+
q5 += 32; q8 += 64;
3909+
3910+
}
3911+
3912+
vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
3913+
sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
3914+
3915+
}
3916+
3917+
*s = sumf+sums;
3918+
33973919
#else
33983920

33993921
const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -3639,6 +4161,76 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
36394161

36404162
*s = hsum_float_8(acc);
36414163

4164+
#elif defined __riscv_v_intrinsic
4165+
4166+
float sumf = 0;
4167+
4168+
for (int i = 0; i < nb; ++i) {
4169+
4170+
const float d = y[i].d * (float)x[i].d;
4171+
const int8_t * sc = x[i].scales;
4172+
4173+
const uint8_t * restrict q5 = x[i].qs;
4174+
const uint8_t * restrict qh = x[i].qh;
4175+
const int8_t * restrict q8 = y[i].qs;
4176+
4177+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
4178+
4179+
// load qh
4180+
vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8);
4181+
vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8));
4182+
4183+
size_t vl = 16;
4184+
4185+
// combine both qh_1 and qh_2
4186+
vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl);
4187+
4188+
vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
4189+
vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl);
4190+
vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl);
4191+
vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl);
4192+
4193+
vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0);
4194+
vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1);
4195+
vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2);
4196+
vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3);
4197+
4198+
// load q5
4199+
vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl);
4200+
vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl);
4201+
4202+
vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl));
4203+
vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl));
4204+
vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl));
4205+
vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl));
4206+
4207+
vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl);
4208+
vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl);
4209+
vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl);
4210+
vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl);
4211+
4212+
// load Q8 and multiply it with Q5
4213+
vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
4214+
vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
4215+
vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
4216+
vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
4217+
4218+
vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
4219+
vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
4220+
vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
4221+
vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
4222+
4223+
int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0);
4224+
int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1);
4225+
int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2);
4226+
int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3);
4227+
4228+
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
4229+
4230+
}
4231+
4232+
*s = sumf;
4233+
36424234
#else
36434235

36444236
int8_t aux8[QK_K];
@@ -4023,6 +4615,91 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
40234615

40244616
*s = hsum_float_8(acc);
40254617

4618+
#elif defined __riscv_v_intrinsic
4619+
4620+
float sumf = 0;
4621+
for (int i = 0; i < nb; ++i) {
4622+
4623+
const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
4624+
4625+
const uint8_t * restrict q6 = x[i].ql;
4626+
const uint8_t * restrict qh = x[i].qh;
4627+
const int8_t * restrict q8 = y[i].qs;
4628+
4629+
const int8_t * restrict scale = x[i].scales;
4630+
4631+
size_t vl;
4632+
4633+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
4634+
4635+
int sum_t = 0;
4636+
int is = 0;
4637+
4638+
for (int j = 0; j < QK_K/128; ++j) {
4639+
4640+
vl = 32;
4641+
4642+
// load qh
4643+
vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
4644+
4645+
// load Q6
4646+
vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
4647+
vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
4648+
4649+
vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
4650+
vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
4651+
vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
4652+
vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
4653+
4654+
vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
4655+
vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
4656+
vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
4657+
vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
4658+
4659+
vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
4660+
vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
4661+
vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
4662+
vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
4663+
4664+
vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
4665+
vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
4666+
vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
4667+
vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
4668+
4669+
// load Q8 and take product
4670+
vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
4671+
vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
4672+
vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
4673+
vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
4674+
4675+
vl = 16;
4676+
4677+
vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
4678+
vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
4679+
vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
4680+
vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
4681+
vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
4682+
vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
4683+
vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
4684+
vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
4685+
4686+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
4687+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
4688+
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
4689+
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
4690+
4691+
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
4692+
4693+
q6 += 64; qh += 32; q8 += 128; is=8;
4694+
4695+
}
4696+
4697+
sumf += d * sum_t;
4698+
4699+
}
4700+
4701+
*s = sumf;
4702+
40264703
#else
40274704

40284705
int8_t aux8[QK_K];
@@ -4276,6 +4953,73 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
42764953

42774954
*s = hsum_float_8(acc);
42784955

4956+
#elif defined __riscv_v_intrinsic
4957+
4958+
float sumf = 0;
4959+
4960+
for (int i = 0; i < nb; ++i) {
4961+
4962+
const float d_all = (float)x[i].d;
4963+
4964+
const uint8_t * restrict q6 = x[i].ql;
4965+
const uint8_t * restrict qh = x[i].qh;
4966+
const int8_t * restrict q8 = y[i].qs;
4967+
4968+
const int8_t * restrict scale = x[i].scales;
4969+
4970+
int32_t isum = 0;
4971+
4972+
size_t vl = 16;
4973+
4974+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
4975+
4976+
// load Q6
4977+
vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl);
4978+
vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl);
4979+
4980+
// load qh
4981+
vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl);
4982+
4983+
vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
4984+
qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
4985+
vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
4986+
qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
4987+
vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
4988+
qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl);
4989+
vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl);
4990+
4991+
vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl);
4992+
vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl);
4993+
vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl);
4994+
vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl);
4995+
4996+
vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl);
4997+
vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl);
4998+
vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl);
4999+
vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl);
5000+
5001+
// load Q8 and take product
5002+
vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl);
5003+
vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl);
5004+
vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl);
5005+
vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl);
5006+
5007+
vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl);
5008+
vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl);
5009+
vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl);
5010+
vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl);
5011+
5012+
isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0];
5013+
isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1];
5014+
isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2];
5015+
isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3];
5016+
5017+
sumf += isum * d_all * y[i].d;
5018+
5019+
}
5020+
5021+
*s = sumf;
5022+
42795023
#else
42805024

42815025
int8_t aux8[QK_K];

0 commit comments

Comments
 (0)
Please sign in to comment.