@@ -1360,34 +1360,20 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
1360
1360
const int8x16_t v1_1hs = vsubq_s8 (v1_1h , s8b );
1361
1361
1362
1362
// dot product into int16x8_t
1363
- const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0ls ), vget_low_s8 (v1_0ls ));
1364
- const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0ls ), vget_high_s8 (v1_0ls ));
1363
+ // assume that vdotq_s32 is always available, if not, should check for __ARM_FEATURE_DOTPROD
1364
+ int32x4_t p_0 = vdotq_s32 (vdupq_n_s32 (0 ), v0_0ls , v1_0ls );
1365
+ int32x4_t p_1 = vdotq_s32 (vdupq_n_s32 (0 ), v0_1ls , v1_1ls );
1365
1366
1366
- const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hs ), vget_low_s8 (v1_0hs ));
1367
- const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hs ), vget_high_s8 (v1_0hs ));
1368
-
1369
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1ls ), vget_low_s8 (v1_1ls ));
1370
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1ls ), vget_high_s8 (v1_1ls ));
1371
-
1372
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hs ), vget_low_s8 (v1_1hs ));
1373
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hs ), vget_high_s8 (v1_1hs ));
1374
-
1375
- const int16x8_t pl_0 = vaddq_s16 (pl0l , pl0h );
1376
- const int16x8_t ph_0 = vaddq_s16 (ph0l , ph0h );
1377
-
1378
- const int16x8_t pl_1 = vaddq_s16 (pl1l , pl1h );
1379
- const int16x8_t ph_1 = vaddq_s16 (ph1l , ph1h );
1380
-
1381
- const int16x8_t p_0 = vaddq_s16 (pl_0 , ph_0 );
1382
- const int16x8_t p_1 = vaddq_s16 (pl_1 , ph_1 );
1367
+ p_0 = vdotq_s32 (p_0 , v0_0hs , v1_0hs );
1368
+ p_1 = vdotq_s32 (p_1 , v0_1hs , v1_1hs );
1383
1369
1384
1370
// scalar
1385
1371
#if defined(__ARM_FEATURE_QRDMX )
1386
- sum0 += d0_0 * d1_0 * vaddvq_s16 (p_0 );
1387
- sum1 += d0_1 * d1_1 * vaddvq_s16 (p_1 );
1372
+ sum0 += d0_0 * d1_0 * vaddvq_s32 (p_0 );
1373
+ sum1 += d0_1 * d1_1 * vaddvq_s32 (p_1 );
1388
1374
#else
1389
- sum0 += d0_0 * d1_0 * (vgetq_lane_s16 (p_0 , 0 ) + vgetq_lane_s16 (p_0 , 1 ) + vgetq_lane_s16 (p_0 , 2 ) + vgetq_lane_s16 (p_0 , 3 ) + vgetq_lane_s16 ( p_0 , 4 ) + vgetq_lane_s16 ( p_0 , 5 ) + vgetq_lane_s16 ( p_0 , 6 ) + vgetq_lane_s16 ( p_0 , 7 ));
1390
- sum1 += d0_1 * d1_1 * (vgetq_lane_s16 (p_1 , 0 ) + vgetq_lane_s16 (p_1 , 1 ) + vgetq_lane_s16 (p_1 , 2 ) + vgetq_lane_s16 (p_1 , 3 ) + vgetq_lane_s16 ( p_1 , 4 ) + vgetq_lane_s16 ( p_1 , 5 ) + vgetq_lane_s16 ( p_1 , 6 ) + vgetq_lane_s16 ( p_1 , 7 ));
1375
+ sum0 += d0_0 * d1_0 * (vgetq_lane_s32 (p_0 , 0 ) + vgetq_lane_s32 (p_0 , 1 ) + vgetq_lane_s32 (p_0 , 2 ) + vgetq_lane_s32 (p_0 , 3 ));
1376
+ sum1 += d0_1 * d1_1 * (vgetq_lane_s32 (p_1 , 0 ) + vgetq_lane_s32 (p_1 , 1 ) + vgetq_lane_s32 (p_1 , 2 ) + vgetq_lane_s32 (p_1 , 3 ));
1391
1377
#endif
1392
1378
}
1393
1379
0 commit comments