-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[webgpu] Apply dp4a for generation shader #24064
base: main
Are you sure you want to change the base?
Conversation
CI is unhappy because of unrelated error, PR on the way to fix this: |
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Outdated
Show resolved
Hide resolved
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); | ||
|
||
shader.AdditionalImplementation() << R"ADDNL_FN( | ||
const tile_size = 16u; // tile_size = tile_size_vec * output components |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave comment as to why, // tile_size = tile_size_vec * output components should be true.
May be explaining the approach of this shader as a comment would be good as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the comment please explain what tile_size is. tile_size is the number of columns of b (or rows because matmulnbits stores B Transposed) each workgroup process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the code to make it more flexible and easy to read.
Now you can adjust the workgroup size and tile size in API side very flexibly and don't need to change any code in shader.
With the latest changes, I tested it on NV RTX 2000 Ada, it's faster (45 tokens/s -> 47 tokens/s) than non-dp4a path for f16 type. Please help check the impact on FP32 only GPUs.
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Outdated
Show resolved
Hide resolved
const tile_size_vec = 4u; | ||
const tile_size_k_vec = 16u; // tile_size_vec * tile_size_k_vec = workgroup size | ||
// Shared memory | ||
var<workgroup> tile_A : array<vec4<u32>, 32>; // 512 scalars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
512 bytes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
saying how many bytes is useful to understand if we ever get into situation where we are out of workgroup shared memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I want to say tile_A includes 512 scalars which is calculated by 4x4x32. And this needs 4 scales since the block_size is 128 (4 = 512 / 128). The code has been updated.
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Outdated
Show resolved
Hide resolved
if (col < 4) | ||
{ | ||
// kidx_v - covers 16 values of k | ||
scale_A[col] = scales_a[a_global*(uniforms.K/128) + k_offset/8 + col]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scales_a[a_global*(uniforms.K/128) + k_offset/8 + col];
I think this should be
scales_a[a_global*(uniforms.K/128) + k_offset/8];
because k_offset already contains col.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col];
. I need 4 scale a's data.
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Outdated
Show resolved
Hide resolved
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Outdated
Show resolved
Hide resolved
let local_row = local_idx / tile_size_k_vec; | ||
for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += 16) | ||
{ | ||
// Load Phase: Populate shared memory for the workgroup. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
workgroupBarrier()
because we are changing shared memory and the previous iteration needs to have completed for all threads.
|
||
constexpr uint32_t tile_size_k_vec = 16; // tile K in input_b. | ||
// tile_size is the number of rows of b each workgroup process. | ||
// tile_size_k_vec is the number of columns of b each workgroup process and each element is a vec4<u32>. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Each workgroup processes tile_size_k_vec columns of b at a time, iterating over the columns to compute the full dot product. Vectorization here is vec4 or each K is 32 items of B.
constexpr uint32_t tile_size_k_vec = 16; // tile K in input_b. | ||
// tile_size is the number of rows of b each workgroup process. | ||
// tile_size_k_vec is the number of columns of b each workgroup process and each element is a vec4<u32>. | ||
// In each workgroup, we read a block [tile_size][tile_size_k_vec] of b data and calculate the corresponding intermediate results of a * b. Then store them into inter_results. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Each workgroup is responsible for a single row of A and tile_size rows of B.
// Each XYZ set of threads within the workgroup is responsible for within a row of A and row of B, performing a dot product of tile_size_k_vec (vectorization of 16, therefore 512 items of A & B)
This pr applies DP4A to generation shader. And also support any block_size % 32 = 0.