-
Notifications
You must be signed in to change notification settings - Fork 568
avoid propagation of NaN #3723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
avoid propagation of NaN #3723
Conversation
This pull request was exported from Phabricator. Differential Revision: D69522001 |
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
Summary: X-link: facebookresearch/FBGEMM#806 as title Introduce padding in dequantization kernel to avoid passing of NaNs to the output of FA3 in prefill stage. Reviewed By: jianyuh Differential Revision: D69522001
4afe692
to
6984077
Compare
This pull request was exported from Phabricator. Differential Revision: D69522001 |
6984077
to
481519e
Compare
Summary: X-link: facebookresearch/FBGEMM#806 as title Introduce padding in dequantization kernel to avoid passing of NaNs to the output of FA3 in prefill stage. Reviewed By: jianyuh Differential Revision: D69522001
This pull request was exported from Phabricator. Differential Revision: D69522001 |
Summary: X-link: facebookresearch/FBGEMM#806 as title Introduce padding in dequantization kernel to avoid passing of NaNs to the output of FA3 in prefill stage. Reviewed By: jianyuh Differential Revision: D69522001
481519e
to
9e31670
Compare
This pull request was exported from Phabricator. Differential Revision: D69522001 |
@@ -1919,23 +1928,27 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache( | |||
block_tables_b_stride = block_tables.value().stride(0); | |||
} | |||
|
|||
constexpr int32_t kMaxBlocks = 256; | |||
constexpr int32_t kMaxBlocks = 512; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change for better performance (increased parallelism)?
// each thread writes 4 elements of type bf16 | ||
*reinterpret_cast<uint2*>(&row_k_dq[4 * threadIdx.x]) = | ||
*reinterpret_cast<uint2*>(&kv_dq.vals[0]); | ||
*reinterpret_cast<uint2*>(&row_v_dq[4 * threadIdx.x]) = | ||
*reinterpret_cast<uint2*>(&kv_dq.vals[2]); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, maybe add some comments to explain why we do the padding of the last tile?
or not HAS_XFORMERS, | ||
"Skip when H100 is not available or MI300 is not available", | ||
"Skip when H100 is not available", | ||
) | ||
def test_fp8_kv_cache(self, MAX_T: int, N_KVH_L: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, update the test to check the padding logic?
This pull request has been merged in e97b388. |
Summary: X-link: pytorch#3723 Pull Request resolved: facebookresearch/FBGEMM#806 as title Introduce padding in dequantization kernel to avoid passing of NaNs to the output of FA3 in prefill stage. Reviewed By: jianyuh Differential Revision: D69522001 fbshipit-source-id: 9ce8c1840be75c78727e952feb1fbb962c57543a
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/806
as title
Introduce padding in dequantization kernel to avoid passing of NaNs to the output of FA3 in prefill stage.
Reviewed By: jianyuh
Differential Revision: D69522001