Skip to content

Commit 2179a5c

Browse files
dsharletgxnnpack-bot
authored andcommitted
Fix missing RVV gemm kernel descriptions and regenerate tests/benchmarks
PiperOrigin-RevId: 731115584
1 parent 0bd7180 commit 2179a5c

10 files changed

+1803
-7
lines changed

bench/qs8-qc8w-gemm-fp32.cc

+91
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,97 @@
1919
#include "src/xnnpack/packw.h"
2020

2121

22+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
23+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv(benchmark::State& state, const char* net) {
24+
GEMMBenchmark(state,
25+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv,
26+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
27+
xnn_pack_qs8_gemm_goi_w,
28+
/*mr=*/1, /*nr=*/4 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
29+
benchmark::utils::CheckRVV);
30+
}
31+
32+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv)
33+
34+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_2x4v__rvv(benchmark::State& state, const char* net) {
35+
GEMMBenchmark(state,
36+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4v__rvv,
37+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
38+
xnn_pack_qs8_gemm_goi_w,
39+
/*mr=*/2, /*nr=*/4 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
40+
benchmark::utils::CheckRVV);
41+
}
42+
43+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_2x4v__rvv)
44+
45+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_3x4v__rvv(benchmark::State& state, const char* net) {
46+
GEMMBenchmark(state,
47+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_3x4v__rvv,
48+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
49+
xnn_pack_qs8_gemm_goi_w,
50+
/*mr=*/3, /*nr=*/4 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
51+
benchmark::utils::CheckRVV);
52+
}
53+
54+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_3x4v__rvv)
55+
56+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_4x2v__rvv(benchmark::State& state, const char* net) {
57+
GEMMBenchmark(state,
58+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_4x2v__rvv,
59+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
60+
xnn_pack_qs8_gemm_goi_w,
61+
/*mr=*/4, /*nr=*/2 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
62+
benchmark::utils::CheckRVV);
63+
}
64+
65+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_4x2v__rvv)
66+
67+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_5x4v__rvv(benchmark::State& state, const char* net) {
68+
GEMMBenchmark(state,
69+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_5x4v__rvv,
70+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
71+
xnn_pack_qs8_gemm_goi_w,
72+
/*mr=*/5, /*nr=*/4 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
73+
benchmark::utils::CheckRVV);
74+
}
75+
76+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_5x4v__rvv)
77+
78+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_6x4v__rvv(benchmark::State& state, const char* net) {
79+
GEMMBenchmark(state,
80+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_6x4v__rvv,
81+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
82+
xnn_pack_qs8_gemm_goi_w,
83+
/*mr=*/6, /*nr=*/4 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
84+
benchmark::utils::CheckRVV);
85+
}
86+
87+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_6x4v__rvv)
88+
89+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_7x4v__rvv(benchmark::State& state, const char* net) {
90+
GEMMBenchmark(state,
91+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_7x4v__rvv,
92+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
93+
xnn_pack_qs8_gemm_goi_w,
94+
/*mr=*/7, /*nr=*/4 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
95+
benchmark::utils::CheckRVV);
96+
}
97+
98+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_7x4v__rvv)
99+
100+
static void qs8_qc8w_gemm_minmax_fp32_ukernel_8x4v__rvv(benchmark::State& state, const char* net) {
101+
GEMMBenchmark(state,
102+
xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_8x4v__rvv,
103+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
104+
xnn_pack_qs8_gemm_goi_w,
105+
/*mr=*/8, /*nr=*/4 * xnn_init_hardware_config()->vlenb / sizeof(int32_t), /*kr=*/1, /*sr=*/1,
106+
benchmark::utils::CheckRVV);
107+
}
108+
109+
BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_8x4v__rvv)
110+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
111+
112+
22113
#if XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
23114
static void qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c2__wasmsimd_dot16x2_ld64(benchmark::State& state, const char* net) {
24115
GEMMBenchmark(state,

test/qs8-qc8w-gemm-minmax-fp32-2.cc

+263
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,244 @@ std::vector<GemmTestParams> CreateTests1(
267267
return gemm_tests;
268268
}
269269

270+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
271+
std::vector<GemmTestParams> CreateTests2(
272+
size_t k_block, size_t adj_k_block,
273+
size_t mr, size_t nr, size_t kr, size_t sr,
274+
bool is_igemm,
275+
bool unsigned_inputs,
276+
uint8_t planes,
277+
std::function<void(GemmMicrokernelTester& tester)> test_func,
278+
std::function<void()> isa_check = nullptr) {
279+
std::string kbs = std::to_string(k_block);
280+
std::string kb2s = std::to_string(k_block * 2);
281+
std::string akbs = std::to_string(adj_k_block);
282+
nr = nr * xnn_init_hardware_config()->vlenb / sizeof(int32_t);
283+
std::string nrs = std::to_string(nr);
284+
285+
const GemmMicrokernelTester tester = GemmMicrokernelTester()
286+
.mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs).planes(planes);
287+
288+
std::vector<GemmTestParams> gemm_tests;
289+
gemm_tests.reserve(42);
290+
291+
gemm_tests.push_back(GemmTestParams(
292+
"k_eq_" + kbs,
293+
tester.clone()
294+
.m(mr).n(nr).k(k_block)
295+
, test_func, isa_check));
296+
if (!is_igemm) {
297+
gemm_tests.push_back(GemmTestParams(
298+
"k_eq_" + kbs + "_strided_a",
299+
tester.clone()
300+
.m(mr).n(nr).k(k_block)
301+
.a_stride(xnnpack::NextPrime(k_block + 1))
302+
, test_func, isa_check));
303+
}
304+
gemm_tests.push_back(GemmTestParams(
305+
"k_eq_" + kbs + "_subtile",
306+
tester.clone()
307+
.k(k_block)
308+
, test_func, isa_check)
309+
.loop_n(1, nr)
310+
.loop_m(1, mr));
311+
gemm_tests.push_back(GemmTestParams(
312+
"k_eq_" + kbs + "_subtile_m",
313+
tester.clone()
314+
.n(nr).k(k_block)
315+
, test_func, isa_check)
316+
.loop_m(1, mr));
317+
gemm_tests.push_back(GemmTestParams(
318+
"k_eq_" + kbs + "_subtile_n",
319+
tester.clone()
320+
.m(mr).k(k_block)
321+
, test_func, isa_check)
322+
.loop_n(1, nr));
323+
if (k_block > 1) {
324+
gemm_tests.push_back(GemmTestParams(
325+
"k_lt_" + akbs,
326+
tester.clone()
327+
.m(mr).n(nr)
328+
, test_func, isa_check)
329+
.loop_k(1, adj_k_block - 1));
330+
if (!is_igemm) {
331+
gemm_tests.push_back(GemmTestParams(
332+
"k_lt_" + akbs + "_strided_a",
333+
tester.clone()
334+
.m(mr).n(nr)
335+
.a_stride(xnnpack::NextPrime(adj_k_block + 1))
336+
, test_func, isa_check)
337+
.loop_k(1, adj_k_block - 1));
338+
}
339+
gemm_tests.push_back(GemmTestParams(
340+
"k_lt_" + akbs + "_subtile",
341+
tester.clone()
342+
, test_func, isa_check)
343+
.loop_k(1, adj_k_block - 1)
344+
.loop_n(1, nr)
345+
.loop_m(1, mr));
346+
}
347+
gemm_tests.push_back(GemmTestParams(
348+
"k_gt_" + akbs,
349+
tester.clone()
350+
.m(mr).n(nr)
351+
, test_func, isa_check)
352+
.loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block));
353+
if (is_igemm) {
354+
gemm_tests.push_back(GemmTestParams(
355+
"k_gt_" + akbs + "_strided_a",
356+
tester.clone()
357+
.m(mr).n(nr)
358+
.a_stride(xnnpack::NextPrime(adj_k_block * 2 + 1))
359+
, test_func, isa_check)
360+
.loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block));
361+
}
362+
gemm_tests.push_back(GemmTestParams(
363+
"k_gt_" + akbs + "_subtile",
364+
tester.clone()
365+
, test_func, isa_check)
366+
.loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)
367+
.loop_n(1, nr)
368+
.loop_m(1, mr));
369+
if (k_block > 1) {
370+
gemm_tests.push_back(GemmTestParams(
371+
"k_div_" + kbs,
372+
tester.clone()
373+
.m(mr).n(nr)
374+
, test_func, isa_check)
375+
.loop_k(adj_k_block + k_block, k_block * 5, k_block));
376+
if (is_igemm) {
377+
gemm_tests.push_back(GemmTestParams(
378+
"k_div_" + kbs + "_strided_a",
379+
tester.clone()
380+
.m(mr).n(nr)
381+
.a_stride(xnnpack::NextPrime(k_block * 3 + 1))
382+
, test_func, isa_check)
383+
.loop_k(adj_k_block + k_block, k_block * 3, k_block));
384+
}
385+
gemm_tests.push_back(GemmTestParams(
386+
"k_div_" + kbs + "_subtile",
387+
tester.clone()
388+
, test_func, isa_check)
389+
.loop_k(adj_k_block + k_block, k_block * 5, k_block)
390+
.loop_n(1, nr)
391+
.loop_m(1, mr));
392+
}
393+
gemm_tests.push_back(GemmTestParams(
394+
"n_gt_" + nrs,
395+
tester.clone()
396+
.m(mr)
397+
, test_func, isa_check)
398+
.loop_n(nr + 1, nr * 2 - 1, 4)
399+
.loop_k(1, k_block * 3, k_block + 1));
400+
if (!is_igemm) {
401+
gemm_tests.push_back(GemmTestParams(
402+
"n_gt_" + nrs + "_strided_a",
403+
tester.clone()
404+
.m(mr)
405+
.a_stride(xnnpack::NextPrime(k_block * 3 + 1))
406+
, test_func, isa_check)
407+
.loop_n(nr + 1, nr * 2 - 1, 4)
408+
.loop_k(1, k_block * 3, k_block));
409+
}
410+
gemm_tests.push_back(GemmTestParams(
411+
"n_gt_" + nrs + "_subtile",
412+
tester.clone()
413+
, test_func, isa_check)
414+
.loop_n(nr + 1, nr * 2 - 1, 4)
415+
.loop_k(1, k_block * 3, k_block + 1)
416+
.loop_m(1, mr));
417+
gemm_tests.push_back(GemmTestParams(
418+
"n_div_" + nrs,
419+
tester.clone()
420+
.m(mr)
421+
, test_func, isa_check)
422+
.loop_n(nr * 2, nr * 3, nr)
423+
.loop_k(1, k_block * 3, k_block + 1));
424+
if (!is_igemm) {
425+
gemm_tests.push_back(GemmTestParams(
426+
"n_div_" + nrs + "_strided_a",
427+
tester.clone()
428+
.m(mr)
429+
.a_stride(xnnpack::NextPrime(k_block * 3 + 1))
430+
, test_func, isa_check)
431+
.loop_n(nr * 2, nr * 3, nr)
432+
.loop_k(1, k_block * 3, k_block));
433+
}
434+
gemm_tests.push_back(GemmTestParams(
435+
"n_div_" + nrs + "_subtile",
436+
tester.clone()
437+
, test_func, isa_check)
438+
.loop_n(nr * 2, nr * 3, nr)
439+
.loop_k(1, k_block * 3, k_block + 1)
440+
.loop_m(1, mr));
441+
if (is_igemm) {
442+
gemm_tests.push_back(GemmTestParams(
443+
"small_kernel",
444+
tester.clone()
445+
.m(mr).n(nr).ks(3)
446+
, test_func, isa_check)
447+
.loop_k(1, k_block * 3, k_block + 1));
448+
gemm_tests.push_back(GemmTestParams(
449+
"small_kernel_subtile",
450+
tester.clone()
451+
.ks(3)
452+
, test_func, isa_check)
453+
.loop_k(1, k_block * 3, k_block + 1)
454+
.loop_n(1, nr)
455+
.loop_m(1, mr));
456+
gemm_tests.push_back(GemmTestParams(
457+
"n_gt_" + nrs + "_small_kernel",
458+
tester.clone()
459+
.m(mr).ks(3)
460+
, test_func, isa_check)
461+
.loop_n(nr + 1, nr * 2 - 1, 4)
462+
.loop_k(1, k_block * 3, k_block + 1));
463+
gemm_tests.push_back(GemmTestParams(
464+
"n_div_" + nrs + "_small_kernel",
465+
tester.clone()
466+
.m(mr).ks(3)
467+
, test_func, isa_check)
468+
.loop_n(nr * 2, nr * 3, nr)
469+
.loop_k(1, k_block * 3, k_block + 1));
470+
}
471+
gemm_tests.push_back(GemmTestParams(
472+
"strided_cm_subtile",
473+
tester.clone()
474+
.mr(mr).nr(nr).kr(kr).sr(sr)
475+
.cm_stride(xnnpack::NextPrime(nr + 1))
476+
, test_func, isa_check)
477+
.loop_k(1, k_block * 3, k_block + 1)
478+
.loop_n(1, nr)
479+
.loop_m(1, mr));
480+
if (is_igemm) {
481+
gemm_tests.push_back(GemmTestParams(
482+
"a_offset",
483+
tester.clone()
484+
.m(mr).n(nr).ks(3)
485+
.a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1))
486+
, test_func, isa_check)
487+
.loop_k(1, k_block * 3, k_block + 1));
488+
gemm_tests.push_back(GemmTestParams(
489+
"zero",
490+
tester.clone()
491+
.m(mr).n(nr).ks(3)
492+
.a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1))
493+
, test_func, isa_check)
494+
.loop_k(1, k_block * 3, k_block + 1)
495+
.loop_zi(0, mr - 1));
496+
}
497+
gemm_tests.push_back(GemmTestParams(
498+
"strided_cm",
499+
tester.clone()
500+
.m(mr).n(nr).k(k_block)
501+
.cm_stride(xnnpack::NextPrime(nr + 1))
502+
, test_func, isa_check));
503+
504+
return gemm_tests;
505+
}
506+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
507+
270508
} // namespace
271509

272510

@@ -3010,3 +3248,28 @@ INSTANTIATE_TEST_SUITE_P(
30103248
return info.param.test_name;
30113249
});
30123250

3251+
3252+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
3253+
INSTANTIATE_TEST_SUITE_P(
3254+
QS8_QC8W_GEMM_MINMAX_FP32_1X4V__RVV, GemmTest,
3255+
testing::ValuesIn(CreateTests2(
3256+
/*k_block=*/1,
3257+
/*adj_k_block=*/1,
3258+
/*mr=*/1, /*nr=*/4, /*kr=*/1, /*sr=*/1,
3259+
/*is_igemm=*/false,
3260+
/*unsigned_inputs=*/false,
3261+
/*planes=*/1,
3262+
[](GemmMicrokernelTester& tester) {
3263+
tester.Test(xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv,
3264+
xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
3265+
xnn_pack_qs8_gemm_goi_w,
3266+
xnn_qs8_requantize_fp32);
3267+
},
3268+
[]() {
3269+
TEST_REQUIRES_RISCV_VECTOR;
3270+
})),
3271+
[](const testing::TestParamInfo<GemmTest::ParamType>& info) {
3272+
return info.param.test_name;
3273+
});
3274+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
3275+

0 commit comments

Comments
 (0)