You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Recently, while experimenting with the w4a8 functionality in torchao, I noticed that the marlin_qqq check function requires
input_tensor.dtype == torch.float16
This seems potentially problematic, as most modern models typically use bf16 or fp32 for activation values. Forcing a conversion to float16 might introduce precision loss or even NaN issues in some cases.
Could you clarify if this dtype check is strictly necessary? Are there specific constraints or optimizations that depend on float16 here?
Thank you for your insights!
The text was updated successfully, but these errors were encountered:
Hi torchao developers,
Recently, while experimenting with the w4a8 functionality in torchao, I noticed that the marlin_qqq check function requires
This seems potentially problematic, as most modern models typically use bf16 or fp32 for activation values. Forcing a conversion to float16 might introduce precision loss or even NaN issues in some cases.
Could you clarify if this dtype check is strictly necessary? Are there specific constraints or optimizations that depend on float16 here?
Thank you for your insights!
The text was updated successfully, but these errors were encountered: