-
Notifications
You must be signed in to change notification settings - Fork 252
Fix cuda compile error with bf16 #2122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary: T222166791 Differential Revision: D73562284
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2122
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 66d86a6 with merge base 31d17c0 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D73562284 |
@@ -70,9 +70,9 @@ constexpr float power_of_two(int n) { | |||
return (n == 0) ? 1.0f : 2.0f * power_of_two(n - 1); | |||
} | |||
|
|||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this look reversed to me? should this be if defined and >= 800?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that was weird as well. But I see it defined in this way in multiple files in torchao (they could all be wrong).
But what I'm doing in this PR is matching the if/else macro on the import (lines 27-29):
#include <cuda_fp16.h>
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
with this function.
__nv_bfloat16 is defined in cuda_bf16.h, but that is currently guarded by #if !defined(CUDA_ARCH) || CUDA_ARCH >= 800.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, this does look weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @gau-nernst can you take a look, is this intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also cc @tobiasvanderwerff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the original error? (i.e. why is this PR needed?)
__CUDA_ARCH__
is only defined for device (CUDA) code. #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
means it is true if it is host code OR CUDA code with sm>=80. There are some rules about how we should use it https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-arch, otherwise it will lead to undefined behavior.
For the example you mentioned above, we need to also include cuda_bf16.h
header in host code, otherwise host code won't have access to BF16 typedef. Similarly, __global__
functions (CUDA kernels) must be visible in both host code and device code (and have the same signature). Hence, in many cases, it's easier to have all the functions defined, and leave the implementation empty if __CUDA_ARCH__ < xxx
.
Summary: T222166791
Differential Revision: D73562284