@@ -267,6 +267,244 @@ std::vector<GemmTestParams> CreateTests1(
267
267
return gemm_tests;
268
268
}
269
269
270
+ #if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
271
+ std::vector<GemmTestParams> CreateTests2 (
272
+ size_t k_block, size_t adj_k_block,
273
+ size_t mr, size_t nr, size_t kr, size_t sr,
274
+ bool is_igemm,
275
+ bool unsigned_inputs,
276
+ uint8_t planes,
277
+ std::function<void (GemmMicrokernelTester& tester)> test_func,
278
+ std::function<void()> isa_check = nullptr) {
279
+ std::string kbs = std::to_string (k_block);
280
+ std::string kb2s = std::to_string (k_block * 2 );
281
+ std::string akbs = std::to_string (adj_k_block);
282
+ nr = nr * xnn_init_hardware_config ()->vlenb / sizeof (int32_t );
283
+ std::string nrs = std::to_string (nr);
284
+
285
+ const GemmMicrokernelTester tester = GemmMicrokernelTester ()
286
+ .mr (mr).nr (nr).kr (kr).sr (sr).unsigned_inputs (unsigned_inputs).planes (planes);
287
+
288
+ std::vector<GemmTestParams> gemm_tests;
289
+ gemm_tests.reserve (42 );
290
+
291
+ gemm_tests.push_back (GemmTestParams (
292
+ " k_eq_" + kbs,
293
+ tester.clone ()
294
+ .m (mr).n (nr).k (k_block)
295
+ , test_func, isa_check));
296
+ if (!is_igemm) {
297
+ gemm_tests.push_back (GemmTestParams (
298
+ " k_eq_" + kbs + " _strided_a" ,
299
+ tester.clone ()
300
+ .m (mr).n (nr).k (k_block)
301
+ .a_stride (xnnpack::NextPrime (k_block + 1 ))
302
+ , test_func, isa_check));
303
+ }
304
+ gemm_tests.push_back (GemmTestParams (
305
+ " k_eq_" + kbs + " _subtile" ,
306
+ tester.clone ()
307
+ .k (k_block)
308
+ , test_func, isa_check)
309
+ .loop_n (1 , nr)
310
+ .loop_m (1 , mr));
311
+ gemm_tests.push_back (GemmTestParams (
312
+ " k_eq_" + kbs + " _subtile_m" ,
313
+ tester.clone ()
314
+ .n (nr).k (k_block)
315
+ , test_func, isa_check)
316
+ .loop_m (1 , mr));
317
+ gemm_tests.push_back (GemmTestParams (
318
+ " k_eq_" + kbs + " _subtile_n" ,
319
+ tester.clone ()
320
+ .m (mr).k (k_block)
321
+ , test_func, isa_check)
322
+ .loop_n (1 , nr));
323
+ if (k_block > 1 ) {
324
+ gemm_tests.push_back (GemmTestParams (
325
+ " k_lt_" + akbs,
326
+ tester.clone ()
327
+ .m (mr).n (nr)
328
+ , test_func, isa_check)
329
+ .loop_k (1 , adj_k_block - 1 ));
330
+ if (!is_igemm) {
331
+ gemm_tests.push_back (GemmTestParams (
332
+ " k_lt_" + akbs + " _strided_a" ,
333
+ tester.clone ()
334
+ .m (mr).n (nr)
335
+ .a_stride (xnnpack::NextPrime (adj_k_block + 1 ))
336
+ , test_func, isa_check)
337
+ .loop_k (1 , adj_k_block - 1 ));
338
+ }
339
+ gemm_tests.push_back (GemmTestParams (
340
+ " k_lt_" + akbs + " _subtile" ,
341
+ tester.clone ()
342
+ , test_func, isa_check)
343
+ .loop_k (1 , adj_k_block - 1 )
344
+ .loop_n (1 , nr)
345
+ .loop_m (1 , mr));
346
+ }
347
+ gemm_tests.push_back (GemmTestParams (
348
+ " k_gt_" + akbs,
349
+ tester.clone ()
350
+ .m (mr).n (nr)
351
+ , test_func, isa_check)
352
+ .loop_k (adj_k_block + 1 , adj_k_block * 2 - 1 , k_block));
353
+ if (is_igemm) {
354
+ gemm_tests.push_back (GemmTestParams (
355
+ " k_gt_" + akbs + " _strided_a" ,
356
+ tester.clone ()
357
+ .m (mr).n (nr)
358
+ .a_stride (xnnpack::NextPrime (adj_k_block * 2 + 1 ))
359
+ , test_func, isa_check)
360
+ .loop_k (adj_k_block + 1 , adj_k_block * 2 - 1 , k_block));
361
+ }
362
+ gemm_tests.push_back (GemmTestParams (
363
+ " k_gt_" + akbs + " _subtile" ,
364
+ tester.clone ()
365
+ , test_func, isa_check)
366
+ .loop_k (adj_k_block + 1 , adj_k_block * 2 - 1 , k_block)
367
+ .loop_n (1 , nr)
368
+ .loop_m (1 , mr));
369
+ if (k_block > 1 ) {
370
+ gemm_tests.push_back (GemmTestParams (
371
+ " k_div_" + kbs,
372
+ tester.clone ()
373
+ .m (mr).n (nr)
374
+ , test_func, isa_check)
375
+ .loop_k (adj_k_block + k_block, k_block * 5 , k_block));
376
+ if (is_igemm) {
377
+ gemm_tests.push_back (GemmTestParams (
378
+ " k_div_" + kbs + " _strided_a" ,
379
+ tester.clone ()
380
+ .m (mr).n (nr)
381
+ .a_stride (xnnpack::NextPrime (k_block * 3 + 1 ))
382
+ , test_func, isa_check)
383
+ .loop_k (adj_k_block + k_block, k_block * 3 , k_block));
384
+ }
385
+ gemm_tests.push_back (GemmTestParams (
386
+ " k_div_" + kbs + " _subtile" ,
387
+ tester.clone ()
388
+ , test_func, isa_check)
389
+ .loop_k (adj_k_block + k_block, k_block * 5 , k_block)
390
+ .loop_n (1 , nr)
391
+ .loop_m (1 , mr));
392
+ }
393
+ gemm_tests.push_back (GemmTestParams (
394
+ " n_gt_" + nrs,
395
+ tester.clone ()
396
+ .m (mr)
397
+ , test_func, isa_check)
398
+ .loop_n (nr + 1 , nr * 2 - 1 , 4 )
399
+ .loop_k (1 , k_block * 3 , k_block + 1 ));
400
+ if (!is_igemm) {
401
+ gemm_tests.push_back (GemmTestParams (
402
+ " n_gt_" + nrs + " _strided_a" ,
403
+ tester.clone ()
404
+ .m (mr)
405
+ .a_stride (xnnpack::NextPrime (k_block * 3 + 1 ))
406
+ , test_func, isa_check)
407
+ .loop_n (nr + 1 , nr * 2 - 1 , 4 )
408
+ .loop_k (1 , k_block * 3 , k_block));
409
+ }
410
+ gemm_tests.push_back (GemmTestParams (
411
+ " n_gt_" + nrs + " _subtile" ,
412
+ tester.clone ()
413
+ , test_func, isa_check)
414
+ .loop_n (nr + 1 , nr * 2 - 1 , 4 )
415
+ .loop_k (1 , k_block * 3 , k_block + 1 )
416
+ .loop_m (1 , mr));
417
+ gemm_tests.push_back (GemmTestParams (
418
+ " n_div_" + nrs,
419
+ tester.clone ()
420
+ .m (mr)
421
+ , test_func, isa_check)
422
+ .loop_n (nr * 2 , nr * 3 , nr)
423
+ .loop_k (1 , k_block * 3 , k_block + 1 ));
424
+ if (!is_igemm) {
425
+ gemm_tests.push_back (GemmTestParams (
426
+ " n_div_" + nrs + " _strided_a" ,
427
+ tester.clone ()
428
+ .m (mr)
429
+ .a_stride (xnnpack::NextPrime (k_block * 3 + 1 ))
430
+ , test_func, isa_check)
431
+ .loop_n (nr * 2 , nr * 3 , nr)
432
+ .loop_k (1 , k_block * 3 , k_block));
433
+ }
434
+ gemm_tests.push_back (GemmTestParams (
435
+ " n_div_" + nrs + " _subtile" ,
436
+ tester.clone ()
437
+ , test_func, isa_check)
438
+ .loop_n (nr * 2 , nr * 3 , nr)
439
+ .loop_k (1 , k_block * 3 , k_block + 1 )
440
+ .loop_m (1 , mr));
441
+ if (is_igemm) {
442
+ gemm_tests.push_back (GemmTestParams (
443
+ " small_kernel" ,
444
+ tester.clone ()
445
+ .m (mr).n (nr).ks (3 )
446
+ , test_func, isa_check)
447
+ .loop_k (1 , k_block * 3 , k_block + 1 ));
448
+ gemm_tests.push_back (GemmTestParams (
449
+ " small_kernel_subtile" ,
450
+ tester.clone ()
451
+ .ks (3 )
452
+ , test_func, isa_check)
453
+ .loop_k (1 , k_block * 3 , k_block + 1 )
454
+ .loop_n (1 , nr)
455
+ .loop_m (1 , mr));
456
+ gemm_tests.push_back (GemmTestParams (
457
+ " n_gt_" + nrs + " _small_kernel" ,
458
+ tester.clone ()
459
+ .m (mr).ks (3 )
460
+ , test_func, isa_check)
461
+ .loop_n (nr + 1 , nr * 2 - 1 , 4 )
462
+ .loop_k (1 , k_block * 3 , k_block + 1 ));
463
+ gemm_tests.push_back (GemmTestParams (
464
+ " n_div_" + nrs + " _small_kernel" ,
465
+ tester.clone ()
466
+ .m (mr).ks (3 )
467
+ , test_func, isa_check)
468
+ .loop_n (nr * 2 , nr * 3 , nr)
469
+ .loop_k (1 , k_block * 3 , k_block + 1 ));
470
+ }
471
+ gemm_tests.push_back (GemmTestParams (
472
+ " strided_cm_subtile" ,
473
+ tester.clone ()
474
+ .mr (mr).nr (nr).kr (kr).sr (sr)
475
+ .cm_stride (xnnpack::NextPrime (nr + 1 ))
476
+ , test_func, isa_check)
477
+ .loop_k (1 , k_block * 3 , k_block + 1 )
478
+ .loop_n (1 , nr)
479
+ .loop_m (1 , mr));
480
+ if (is_igemm) {
481
+ gemm_tests.push_back (GemmTestParams (
482
+ " a_offset" ,
483
+ tester.clone ()
484
+ .m (mr).n (nr).ks (3 )
485
+ .a_offset (xnnpack::NextPrime (mr * k_block * 3 + 1 ))
486
+ , test_func, isa_check)
487
+ .loop_k (1 , k_block * 3 , k_block + 1 ));
488
+ gemm_tests.push_back (GemmTestParams (
489
+ " zero" ,
490
+ tester.clone ()
491
+ .m (mr).n (nr).ks (3 )
492
+ .a_offset (xnnpack::NextPrime (mr * k_block * 3 + 1 ))
493
+ , test_func, isa_check)
494
+ .loop_k (1 , k_block * 3 , k_block + 1 )
495
+ .loop_zi (0 , mr - 1 ));
496
+ }
497
+ gemm_tests.push_back (GemmTestParams (
498
+ " strided_cm" ,
499
+ tester.clone ()
500
+ .m (mr).n (nr).k (k_block)
501
+ .cm_stride (xnnpack::NextPrime (nr + 1 ))
502
+ , test_func, isa_check));
503
+
504
+ return gemm_tests;
505
+ }
506
+ #endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
507
+
270
508
} // namespace
271
509
272
510
@@ -3010,3 +3248,28 @@ INSTANTIATE_TEST_SUITE_P(
3010
3248
return info.param .test_name ;
3011
3249
});
3012
3250
3251
+
3252
+ #if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
3253
+ INSTANTIATE_TEST_SUITE_P (
3254
+ QS8_QC8W_GEMM_MINMAX_FP32_1X4V__RVV, GemmTest,
3255
+ testing::ValuesIn (CreateTests2(
3256
+ /* k_block=*/ 1 ,
3257
+ /* adj_k_block=*/ 1 ,
3258
+ /* mr=*/ 1 , /* nr=*/ 4 , /* kr=*/ 1 , /* sr=*/ 1 ,
3259
+ /* is_igemm=*/ false ,
3260
+ /* unsigned_inputs=*/ false ,
3261
+ /* planes=*/ 1 ,
3262
+ [](GemmMicrokernelTester& tester) {
3263
+ tester.Test (xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4v__rvv,
3264
+ xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params,
3265
+ xnn_pack_qs8_gemm_goi_w,
3266
+ xnn_qs8_requantize_fp32);
3267
+ },
3268
+ []() {
3269
+ TEST_REQUIRES_RISCV_VECTOR;
3270
+ })),
3271
+ [](const testing::TestParamInfo<GemmTest::ParamType>& info) {
3272
+ return info.param .test_name ;
3273
+ });
3274
+ #endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
3275
+
0 commit comments