Skip to content

Commit f30dbf9

Browse files
committed
ggml : speed-up q4_2
- 4 threads: ~100ms -> ~90ms - 8 threads: ~55ms -> ~50ms
1 parent 99092f2 commit f30dbf9

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

ggml.c

+24-11
Original file line numberDiff line numberDiff line change
@@ -3058,8 +3058,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
30583058
float sumf = 0.0;
30593059

30603060
#if defined(__ARM_NEON)
3061-
float sum0 = 0.0f;
3062-
float sum1 = 0.0f;
3061+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
3062+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
30633063

30643064
for (int i = 0; i < nb; i += 2) {
30653065
const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0];
@@ -3100,10 +3100,21 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
31003100
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
31013101

31023102
#if defined(__ARM_FEATURE_DOTPROD)
3103-
sum0 += (GGML_FP16_TO_FP32(x0_0->d)*y0->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l));
3104-
sum0 += (GGML_FP16_TO_FP32(x0_1->d)*y0->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h));
3105-
sum1 += (GGML_FP16_TO_FP32(x1_0->d)*y1->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l));
3106-
sum1 += (GGML_FP16_TO_FP32(x1_1->d)*y1->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h));
3103+
const float32x4_t x0_0d = vdupq_n_f32(GGML_FP16_TO_FP32(x0_0->d));
3104+
const float32x4_t x0_1d = vdupq_n_f32(GGML_FP16_TO_FP32(x0_1->d));
3105+
const float32x4_t x1_0d = vdupq_n_f32(GGML_FP16_TO_FP32(x1_0->d));
3106+
const float32x4_t x1_1d = vdupq_n_f32(GGML_FP16_TO_FP32(x1_1->d));
3107+
3108+
const float32x4_t y0d = vdupq_n_f32(y0->d);
3109+
const float32x4_t y1d = vdupq_n_f32(y1->d);
3110+
3111+
sumv0 = vaddq_f32(sumv0, vmulq_f32(y0d, vaddq_f32(
3112+
vmulq_f32(x0_0d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l))),
3113+
vmulq_f32(x0_1d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h))))));
3114+
3115+
sumv1 = vaddq_f32(sumv1, vmulq_f32(y1d, vaddq_f32(
3116+
vmulq_f32(x1_0d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l))),
3117+
vmulq_f32(x1_1d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h))))));
31073118
#else
31083119
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
31093120
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
@@ -3120,14 +3131,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
31203131
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
31213132
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
31223133

3123-
sum0 += (GGML_FP16_TO_FP32(x0_0->d)*y0->d)*vaddvq_s32(pl0);
3124-
sum0 += (GGML_FP16_TO_FP32(x0_1->d)*y0->d)*vaddvq_s32(ph0);
3125-
sum1 += (GGML_FP16_TO_FP32(x1_0->d)*y1->d)*vaddvq_s32(pl1);
3126-
sum1 += (GGML_FP16_TO_FP32(x1_1->d)*y1->d)*vaddvq_s32(ph1);
3134+
sumv0 = vaddq_f32(sumv0, vmulq_f32(vdupq_n_f32(y0->d), vaddq_f32(
3135+
vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x0_0->d)), vcvtq_f32_s32(pl0)),
3136+
vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x0_1->d)), vcvtq_f32_s32(ph0)))));
3137+
sumv1 = vaddq_f32(sumv1, vmulq_f32(vdupq_n_f32(y1->d), vaddq_f32(
3138+
vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x1_0->d)), vcvtq_f32_s32(pl1)),
3139+
vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x1_1->d)), vcvtq_f32_s32(ph1)))));
31273140
#endif
31283141
}
31293142

3130-
sumf = sum0 + sum1;
3143+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
31313144
#else
31323145
// scalar
31333146
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)