We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Describe the bug Goal is to figure out why it can compile but fail when running.
Got "Got cutlass error: Error Internal at: ", even though compilation is successful.
config is cutlass3x_sm90_tensorop_s64x96x16gemm_bf16_bf16_f32_void_bf16_256x96x64_4x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
When I turn on the debug flags, I see
cutlass error: Error Internal at: 247 cutlass/gemm/device/gemm_universal_adapter.h:264 GemmUniversal::maximum_active_blocks() cutlass/gemm/device/gemm_universal_adapter.h:271 Setting smem size to 197632 cutlass/gemm/device/gemm_universal_adapter.h:300 max_active_blocks: 1 cutlass/gemm/device/gemm_universal_adapter.h:312 GemmUniversal::initialize() - workspace 0, stream: null cutlass/gemm/kernel/tile_scheduler_params.h:300 get_grid_shape(): Proposed GridDims by the scheduler using heuristics = (4, 33, 1) cutlass/gemm/kernel/tile_scheduler_params.h:300 get_grid_shape(): Proposed GridDims by the scheduler using heuristics = (4, 33, 1) cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:183 to_underlying_arguments(): cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:200 to_underlying_arguments(): Setting persistent grid SM count to 132 cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp:206 WARNING: Arguments do not include a valid max cluster count. For optimal performance, populate the arguments KernelHardwareInfo struct with the max_active_clusters. cutlass/gemm/kernel/tile_scheduler_params.h:300 get_grid_shape(): Proposed GridDims by the scheduler using heuristics = (4, 33, 1) cutlass/gemm/kernel/tile_scheduler_params.h:300 get_grid_shape(): Proposed GridDims by the scheduler using heuristics = (4, 33, 1) cutlass/gemm/device/gemm_universal_adapter.h:336 Setting smem size to 197632 cutlass/gemm/device/gemm_universal_adapter.h:372 GemmUniversal::run() cutlass/gemm/kernel/tile_scheduler_params.h:300 get_grid_shape(): Proposed GridDims by the scheduler using heuristics = (4, 33, 1) cutlass/cluster_launch.hpp:177 ClusterLauncher: Setting ClusterDims = (4, 2, 1) cutlass/cluster_launch.hpp:108 ClusterLauncher: Invalid cluster configuration -- aborting launch. cutlass/cluster_launch.hpp:231 ClusterLauncher: check_cluster_dims() failed. Aborting. cutlass/gemm/device/gemm_universal_adapter.h:555 Kernel launch failed. Reason: no error
A clear and concise description of what the bug is.
Steps/Code to reproduce bug compile and run the following https://gist.github.com/henrylhtsang/fd34c103db83cba379980210bb88658b
Expected behavior Figure out why we can't run it. Either we modify the config or the code.
Environment details (please complete the following information): cuda 12.4 cutlass 3.8
Additional context NA
The text was updated successfully, but these errors were encountered:
Thanks for bringing this up. CUTLASS should not be using a grid shape of (4,33,1):
cutlass/gemm/kernel/tile_scheduler_params.h:300 get_grid_shape(): Proposed GridDims by the scheduler using heuristics = (4, 33, 1)
when grid size is 4x2x1.
4x2x1
I'll dig in further.
Sorry, something went wrong.
No branches or pull requests
Describe the bug
Goal is to figure out why it can compile but fail when running.
Got "Got cutlass error: Error Internal at: ", even though compilation is successful.
config is cutlass3x_sm90_tensorop_s64x96x16gemm_bf16_bf16_f32_void_bf16_256x96x64_4x2x1_0_ttn_align8_stream_k_warpspecialized_cooperative_epi_tma
When I turn on the debug flags, I see
A clear and concise description of what the bug is.
Steps/Code to reproduce bug
compile and run the following
https://gist.github.com/henrylhtsang/fd34c103db83cba379980210bb88658b
Expected behavior
Figure out why we can't run it. Either we modify the config or the code.
Environment details (please complete the following information):
cuda 12.4
cutlass 3.8
Additional context
NA
The text was updated successfully, but these errors were encountered: