@@ -227,7 +227,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
227
227
@given (
228
228
B_T = st .sampled_from ([0 , 2048 , 4096 ]),
229
229
D = st .sampled_from ([128 , 256 ]),
230
- HD_L = st .sampled_from ([256 , 512 ]),
230
+ HD_L = st .sampled_from ([256 , 512 , 4096 , 8192 ]),
231
231
Mode = st .sampled_from (
232
232
["rowwise" , "blockwise" ]
233
233
+ (["tensorwise_broadcast" , "tensorwise" ] if torch .version .cuda else [])
@@ -236,6 +236,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
236
236
Bias = st .sampled_from ([True , False ]),
237
237
CudaGraph = st .sampled_from ([True , False ]),
238
238
UseTriton = st .sampled_from ([False ] + ([True ] if torch .version .cuda else [])),
239
+ UseFastAccum = st .booleans (),
239
240
InputMultiDim = st .booleans (),
240
241
)
241
242
def test_quantize_fp8_matmul (
@@ -248,8 +249,13 @@ def test_quantize_fp8_matmul(
248
249
Bias : bool ,
249
250
CudaGraph : bool ,
250
251
UseTriton : bool ,
252
+ UseFastAccum : bool ,
251
253
InputMultiDim : bool ,
252
254
) -> None :
255
+ # Fast accumulation is only supported on Nvidia.
256
+ if torch .version .hip :
257
+ UseFastAccum = False
258
+ # Setup input shapes.
253
259
if InputMultiDim and not torch .version .hip :
254
260
x = torch .randn (size = (3 , B_T , D ), dtype = torch .bfloat16 , device = "cuda" ) * 0.1
255
261
else :
@@ -285,12 +291,16 @@ def test_quantize_fp8_matmul(
285
291
if CudaGraph :
286
292
g = torch .cuda .CUDAGraph ()
287
293
with torch .cuda .graph (g ):
288
- zq = torch .ops .fbgemm .f8f8bf16_tensorwise (xq , wq , x_scale * w_scale )
294
+ zq = torch .ops .fbgemm .f8f8bf16_tensorwise (
295
+ xq , wq , x_scale * w_scale , use_fast_accum = UseFastAccum
296
+ )
289
297
if bias is not None :
290
298
zq += bias
291
299
g .replay ()
292
300
else :
293
- zq = torch .ops .fbgemm .f8f8bf16_tensorwise (xq , wq , x_scale * w_scale )
301
+ zq = torch .ops .fbgemm .f8f8bf16_tensorwise (
302
+ xq , wq , x_scale * w_scale , use_fast_accum = UseFastAccum
303
+ )
294
304
if bias is not None :
295
305
zq += bias
296
306
elif Mode == "rowwise" :
@@ -299,7 +309,9 @@ def test_quantize_fp8_matmul(
299
309
xq , x_scale = quantize_fp8_row (x )
300
310
wq , w_scale = quantize_fp8_row (w )
301
311
if UseTriton and torch .version .cuda :
302
- zq = matmul_fp8_row (xq , wq , x_scale , w_scale )
312
+ zq = matmul_fp8_row (
313
+ xq , wq , x_scale , w_scale , fp8_fast_accum = UseFastAccum
314
+ )
303
315
g = torch .cuda .CUDAGraph ()
304
316
with torch .cuda .graph (g ):
305
317
if torch .version .cuda :
@@ -321,6 +333,7 @@ def test_quantize_fp8_matmul(
321
333
x_scale ,
322
334
w_scale ,
323
335
bias = bias if torch .version .cuda else None ,
336
+ use_fast_accum = UseFastAccum ,
324
337
)
325
338
# Bias fusion not yet supported on AMD.
326
339
if bias is not None and torch .version .hip :
@@ -336,7 +349,9 @@ def test_quantize_fp8_matmul(
336
349
xq , x_scale = quantize_fp8_row (x )
337
350
wq , w_scale = quantize_fp8_row (w )
338
351
if UseTriton and torch .version .cuda :
339
- zq = matmul_fp8_row (xq , wq , x_scale , w_scale )
352
+ zq = matmul_fp8_row (
353
+ xq , wq , x_scale , w_scale , fp8_fast_accum = UseFastAccum
354
+ )
340
355
if bias is not None :
341
356
zq += bias
342
357
else :
@@ -346,6 +361,7 @@ def test_quantize_fp8_matmul(
346
361
x_scale ,
347
362
w_scale ,
348
363
bias = bias if torch .version .cuda else None ,
364
+ use_fast_accum = UseFastAccum ,
349
365
)
350
366
# Bias fusion not yet supported on AMD.
351
367
if bias is not None and torch .version .hip :
@@ -369,7 +385,7 @@ def test_quantize_fp8_matmul(
369
385
block_m ,
370
386
block_n ,
371
387
block_k ,
372
- fp8_fast_accum = True ,
388
+ fp8_fast_accum = UseFastAccum ,
373
389
)
374
390
else :
375
391
zq = torch .ops .fbgemm .f8f8bf16_blockwise (
@@ -393,7 +409,7 @@ def test_quantize_fp8_matmul(
393
409
block_m ,
394
410
block_n ,
395
411
block_k ,
396
- fp8_fast_accum = True ,
412
+ fp8_fast_accum = UseFastAccum ,
397
413
)
398
414
else :
399
415
zq = torch .ops .fbgemm .f8f8bf16_blockwise (
@@ -416,7 +432,7 @@ def test_quantize_fp8_matmul(
416
432
block_m ,
417
433
block_n ,
418
434
block_k ,
419
- fp8_fast_accum = True ,
435
+ fp8_fast_accum = UseFastAccum ,
420
436
)
421
437
else :
422
438
zq = torch .ops .fbgemm .f8f8bf16_blockwise (
0 commit comments