Skip to content

Question about dtype check in marlin_qqq validation for w4a8 functionality #2115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
xxw11 opened this issue Apr 23, 2025 · 1 comment
Open

Comments

@xxw11
Copy link

xxw11 commented Apr 23, 2025

Hi torchao developers,

Recently, while experimenting with the w4a8 functionality in torchao, I noticed that the marlin_qqq check function requires

input_tensor.dtype == torch.float16

This seems potentially problematic, as most modern models typically use bf16 or fp32 for activation values. Forcing a conversion to float16 might introduce precision loss or even NaN issues in some cases.

Could you clarify if this dtype check is strictly necessary? Are there specific constraints or optimizations that depend on float16 here?

Thank you for your insights!

Image

@jerryzh168
Copy link
Contributor

cc @HandH1998 can Marlin QQQ kernel be extended to support bfloat16 as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants