Skip to content

Commit 45eac04

Browse files
authored
fix: update bmm fp8 test (#487)
fp8 scale per tensor ref sgl-project/sglang#1285
1 parent 77bff3f commit 45eac04

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

python/flashinfer/gemm.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,9 @@ def bmm_fp8(
241241
>>> import flashinfer
242242
>>> def to_float8(x, dtype=torch.float8_e4m3fn):
243243
... finfo = torch.finfo(dtype)
244-
... abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12)
245-
... scale = finfo.max / abs_max
244+
... min_val, max_val = x.aminmax()
245+
... amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
246+
... scale = finfo.max / amax
246247
... x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
247248
... return x_scl_sat.to(dtype), scale.float().reciprocal()
248249
>>>

python/tests/test_bmm_fp8.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import pytest
22
import torch
33
import torch.nn.functional as F
4+
45
from flashinfer import bmm_fp8
56

67

78
def to_float8(x, dtype=torch.float8_e4m3fn):
89
finfo = torch.finfo(dtype)
9-
abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12)
10-
scale = finfo.max / abs_max
10+
min_val, max_val = x.aminmax()
11+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
12+
scale = finfo.max / amax
1113
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
1214
return x_scl_sat.to(dtype), scale.float().reciprocal()
1315

@@ -32,9 +34,8 @@ def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype):
3234
bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res)
3335

3436
reference = torch.bmm(input, mat2)
35-
3637
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
37-
assert cos_sim > 0.98
38+
assert cos_sim > 0.99
3839

3940

4041
if __name__ == "__main__":

0 commit comments

Comments
 (0)