Skip to content

Commit faa0e69

Browse files
authored
ggml: aarch64: SVE kernels for q8_0_q8_0, q4_0_q8_0 vector dot (ggml-org#7433)
* Add SVE support for q4_0_q8_0 q8_0_q8_0 * remove ifdef
1 parent 9791f40 commit faa0e69

File tree

7 files changed

+85
-2
lines changed

7 files changed

+85
-2
lines changed

CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ else()
7272
set(INS_ENB ON)
7373
endif()
7474

75+
option(LLAMA_SVE "llama: enable SVE" OFF)
7576
option(LLAMA_AVX "llama: enable AVX" ${INS_ENB})
7677
option(LLAMA_AVX2 "llama: enable AVX2" ${INS_ENB})
7778
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
@@ -1040,6 +1041,9 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
10401041
# Raspberry Pi 3, 4, Zero 2 (32-bit)
10411042
list(APPEND ARCH_FLAGS -mno-unaligned-access)
10421043
endif()
1044+
if (LLAMA_SVE)
1045+
list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
1046+
endif()
10431047
endif()
10441048
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
10451049
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND

common/common.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -2844,6 +2844,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
28442844
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
28452845
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
28462846
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
2847+
fprintf(stream, "cpu_has_sve: %s\n", ggml_cpu_has_sve() ? "true" : "false");
28472848
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
28482849
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
28492850
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");

ggml-impl.h

+4
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ extern "C" {
144144
#endif
145145
#endif
146146

147+
#if defined(__ARM_FEATURE_SVE)
148+
#include <arm_sve.h>
149+
#endif
150+
147151
// 16-bit float
148152
// on Arm, we use __fp16
149153
// on x86, we use uint16_t

ggml-quants.c

+64-2
Original file line numberDiff line numberDiff line change
@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
38133813
return;
38143814
}
38153815
#endif
3816-
#if defined(__ARM_NEON)
3816+
#if defined(__ARM_FEATURE_SVE)
3817+
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
3818+
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
3819+
3820+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
3821+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
3822+
3823+
assert(nb % 2 == 0); // TODO: handle odd nb
3824+
3825+
for (int i = 0; i < nb; i += 2) {
3826+
const block_q4_0 * restrict x0 = &x[i + 0];
3827+
const block_q4_0 * restrict x1 = &x[i + 1];
3828+
const block_q8_0 * restrict y0 = &y[i + 0];
3829+
const block_q8_0 * restrict y1 = &y[i + 1];
3830+
3831+
// load x
3832+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
3833+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
3834+
3835+
// 4-bit -> 8-bit
3836+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
3837+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
3838+
3839+
// sub 8
3840+
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
3841+
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
3842+
3843+
// load y
3844+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
3845+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
3846+
3847+
// dot product
3848+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3849+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3850+
}
3851+
3852+
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
3853+
#elif defined(__ARM_NEON)
38173854
float32x4_t sumv0 = vdupq_n_f32(0.0f);
38183855
float32x4_t sumv1 = vdupq_n_f32(0.0f);
38193856

@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
53845421
return;
53855422
}
53865423
#endif
5387-
#if defined(__ARM_NEON)
5424+
#if defined(__ARM_FEATURE_SVE)
5425+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
5426+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
5427+
5428+
assert(nb % 2 == 0); // TODO: handle odd nb
5429+
5430+
for (int i = 0; i < nb; i += 2) {
5431+
const block_q8_0 * restrict x0 = &x[i + 0];
5432+
const block_q8_0 * restrict x1 = &x[i + 1];
5433+
const block_q8_0 * restrict y0 = &y[i + 0];
5434+
const block_q8_0 * restrict y1 = &y[i + 1];
5435+
5436+
// load x
5437+
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5438+
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5439+
5440+
// load y
5441+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5442+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5443+
5444+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5445+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5446+
}
5447+
5448+
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5449+
#elif defined(__ARM_NEON)
53885450
float32x4_t sumv0 = vdupq_n_f32(0.0f);
53895451
float32x4_t sumv1 = vdupq_n_f32(0.0f);
53905452

ggml.c

+10
Original file line numberDiff line numberDiff line change
@@ -22742,6 +22742,16 @@ int ggml_cpu_has_neon(void) {
2274222742
#endif
2274322743
}
2274422744

22745+
int ggml_cpu_has_sve(void) {
22746+
#if defined(__ARM_FEATURE_SVE)
22747+
// TODO: Currently, SVE 256 bit is only supported.
22748+
GGML_ASSERT(svcntb() == QK8_0);
22749+
return 1;
22750+
#else
22751+
return 0;
22752+
#endif
22753+
}
22754+
2274522755
int ggml_cpu_has_arm_fma(void) {
2274622756
#if defined(__ARM_FEATURE_FMA)
2274722757
return 1;

ggml.h

+1
Original file line numberDiff line numberDiff line change
@@ -2404,6 +2404,7 @@ extern "C" {
24042404
GGML_API int ggml_cpu_has_avx512_bf16(void);
24052405
GGML_API int ggml_cpu_has_fma (void);
24062406
GGML_API int ggml_cpu_has_neon (void);
2407+
GGML_API int ggml_cpu_has_sve (void);
24072408
GGML_API int ggml_cpu_has_arm_fma (void);
24082409
GGML_API int ggml_cpu_has_metal (void);
24092410
GGML_API int ggml_cpu_has_f16c (void);

llama.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -18337,6 +18337,7 @@ const char * llama_print_system_info(void) {
1833718337
s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
1833818338
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
1833918339
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
18340+
s += "SVE = " + std::to_string(ggml_cpu_has_sve()) + " | ";
1834018341
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
1834118342
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
1834218343
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";

0 commit comments

Comments
 (0)