Skip to content

Commit e63e578

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Fix templates for FP8 Rowwise Slow Accumulation (#4037)
Summary: X-link: facebookresearch/FBGEMM#1122 It turns out there are a few tile / cluster configurations for FP8 Rowwise Matmul that work fine for fast accumulation but produce bad outputs when used for slow accumcumulation. Spefically tile sizes of [128, 256, 128] seem to be problematic. This would not affect any production use-cases since slow accumulation only is used for debugging. Differential Revision: D73805710
1 parent eeee38e commit e63e578

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ at::Tensor dispatch_fp8_rowwise_kernel(
8787
} else if (N <= 2048) {
8888
return f8f8bf16_rowwise_64_128_128_1_1_1_f_f(
8989
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
90-
} else if (N <= 4096) {
90+
} else if (N <= 4096 || use_fast_accum == false) {
9191
return f8f8bf16_rowwise_64_256_128_1_1_1_f_f(
9292
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
9393
} else {
@@ -98,7 +98,7 @@ at::Tensor dispatch_fp8_rowwise_kernel(
9898
if (N <= 1024) {
9999
return f8f8bf16_rowwise_64_128_128_1_1_1_f_f(
100100
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
101-
} else if (N <= 2048) {
101+
} else if (N <= 2048 || use_fast_accum == false) {
102102
return f8f8bf16_rowwise_64_256_128_1_1_1_f_f(
103103
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
104104
} else {
@@ -109,7 +109,7 @@ at::Tensor dispatch_fp8_rowwise_kernel(
109109
if (M <= 2048 && N <= 1024) {
110110
return f8f8bf16_rowwise_64_256_128_2_1_1_f_f(
111111
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
112-
} else if (K <= 4096) {
112+
} else if (K <= 4096 || use_fast_accum == false) {
113113
return f8f8bf16_rowwise_128_128_128_2_1_1_t_f(
114114
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
115115
} else if (M > 8192 && N > 8192) {

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
227227
@given(
228228
B_T=st.sampled_from([0, 2048, 4096]),
229229
D=st.sampled_from([128, 256]),
230-
HD_L=st.sampled_from([256, 512]),
230+
HD_L=st.sampled_from([256, 512, 4096, 8192]),
231231
Mode=st.sampled_from(
232232
["rowwise", "blockwise"]
233233
+ (["tensorwise_broadcast", "tensorwise"] if torch.version.cuda else [])
@@ -236,6 +236,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
236236
Bias=st.sampled_from([True, False]),
237237
CudaGraph=st.sampled_from([True, False]),
238238
UseTriton=st.sampled_from([False] + ([True] if torch.version.cuda else [])),
239+
UseFastAccum=st.booleans(),
239240
InputMultiDim=st.booleans(),
240241
)
241242
def test_quantize_fp8_matmul(
@@ -248,8 +249,13 @@ def test_quantize_fp8_matmul(
248249
Bias: bool,
249250
CudaGraph: bool,
250251
UseTriton: bool,
252+
UseFastAccum: bool,
251253
InputMultiDim: bool,
252254
) -> None:
255+
# Fast accumulation is only supported on Nvidia.
256+
if torch.version.hip:
257+
UseFastAccum = False
258+
# Setup input shapes.
253259
if InputMultiDim and not torch.version.hip:
254260
x = torch.randn(size=(3, B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1
255261
else:
@@ -285,12 +291,16 @@ def test_quantize_fp8_matmul(
285291
if CudaGraph:
286292
g = torch.cuda.CUDAGraph()
287293
with torch.cuda.graph(g):
288-
zq = torch.ops.fbgemm.f8f8bf16_tensorwise(xq, wq, x_scale * w_scale)
294+
zq = torch.ops.fbgemm.f8f8bf16_tensorwise(
295+
xq, wq, x_scale * w_scale, use_fast_accum=UseFastAccum
296+
)
289297
if bias is not None:
290298
zq += bias
291299
g.replay()
292300
else:
293-
zq = torch.ops.fbgemm.f8f8bf16_tensorwise(xq, wq, x_scale * w_scale)
301+
zq = torch.ops.fbgemm.f8f8bf16_tensorwise(
302+
xq, wq, x_scale * w_scale, use_fast_accum=UseFastAccum
303+
)
294304
if bias is not None:
295305
zq += bias
296306
elif Mode == "rowwise":
@@ -299,7 +309,9 @@ def test_quantize_fp8_matmul(
299309
xq, x_scale = quantize_fp8_row(x)
300310
wq, w_scale = quantize_fp8_row(w)
301311
if UseTriton and torch.version.cuda:
302-
zq = matmul_fp8_row(xq, wq, x_scale, w_scale)
312+
zq = matmul_fp8_row(
313+
xq, wq, x_scale, w_scale, fp8_fast_accum=UseFastAccum
314+
)
303315
g = torch.cuda.CUDAGraph()
304316
with torch.cuda.graph(g):
305317
if torch.version.cuda:
@@ -321,6 +333,7 @@ def test_quantize_fp8_matmul(
321333
x_scale,
322334
w_scale,
323335
bias=bias if torch.version.cuda else None,
336+
use_fast_accum=UseFastAccum,
324337
)
325338
# Bias fusion not yet supported on AMD.
326339
if bias is not None and torch.version.hip:
@@ -336,7 +349,9 @@ def test_quantize_fp8_matmul(
336349
xq, x_scale = quantize_fp8_row(x)
337350
wq, w_scale = quantize_fp8_row(w)
338351
if UseTriton and torch.version.cuda:
339-
zq = matmul_fp8_row(xq, wq, x_scale, w_scale)
352+
zq = matmul_fp8_row(
353+
xq, wq, x_scale, w_scale, fp8_fast_accum=UseFastAccum
354+
)
340355
if bias is not None:
341356
zq += bias
342357
else:
@@ -346,6 +361,7 @@ def test_quantize_fp8_matmul(
346361
x_scale,
347362
w_scale,
348363
bias=bias if torch.version.cuda else None,
364+
use_fast_accum=UseFastAccum,
349365
)
350366
# Bias fusion not yet supported on AMD.
351367
if bias is not None and torch.version.hip:
@@ -369,7 +385,7 @@ def test_quantize_fp8_matmul(
369385
block_m,
370386
block_n,
371387
block_k,
372-
fp8_fast_accum=True,
388+
fp8_fast_accum=UseFastAccum,
373389
)
374390
else:
375391
zq = torch.ops.fbgemm.f8f8bf16_blockwise(
@@ -393,7 +409,7 @@ def test_quantize_fp8_matmul(
393409
block_m,
394410
block_n,
395411
block_k,
396-
fp8_fast_accum=True,
412+
fp8_fast_accum=UseFastAccum,
397413
)
398414
else:
399415
zq = torch.ops.fbgemm.f8f8bf16_blockwise(
@@ -416,7 +432,7 @@ def test_quantize_fp8_matmul(
416432
block_m,
417433
block_n,
418434
block_k,
419-
fp8_fast_accum=True,
435+
fp8_fast_accum=UseFastAccum,
420436
)
421437
else:
422438
zq = torch.ops.fbgemm.f8f8bf16_blockwise(

0 commit comments

Comments
 (0)