Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Apply dp4a for generation shader #24064

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

qjia7
Copy link
Contributor

@qjia7 qjia7 commented Mar 17, 2025

This pr applies DP4A to generation shader. And also support any block_size % 32 = 0.

@qjia7 qjia7 requested a review from sushraja-msft March 17, 2025 08:54
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Mar 17, 2025
@qjia7 qjia7 requested a review from guschmue March 18, 2025 01:54
@guschmue
Copy link
Contributor

CI is unhappy because of unrelated error, PR on the way to fix this:
1: [ FAILED ] ReductionOpTest.ReduceMaxAxesInitializerOpset18
1: [ FAILED ] ReductionOpTest.ReduceMinAxesInitializerOpset18

shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);

shader.AdditionalImplementation() << R"ADDNL_FN(
const tile_size = 16u; // tile_size = tile_size_vec * output components
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave comment as to why, // tile_size = tile_size_vec * output components should be true.
May be explaining the approach of this shader as a comment would be good as well.

Copy link
Contributor

@sushraja-msft sushraja-msft Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the comment please explain what tile_size is. tile_size is the number of columns of b (or rows because matmulnbits stores B Transposed) each workgroup process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the code to make it more flexible and easy to read.
Now you can adjust the workgroup size and tile size in API side very flexibly and don't need to change any code in shader.
With the latest changes, I tested it on NV RTX 2000 Ada, it's faster (45 tokens/s -> 47 tokens/s) than non-dp4a path for f16 type. Please help check the impact on FP32 only GPUs.

const tile_size_vec = 4u;
const tile_size_k_vec = 16u; // tile_size_vec * tile_size_k_vec = workgroup size
// Shared memory
var<workgroup> tile_A : array<vec4<u32>, 32>; // 512 scalars
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

512 bytes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

saying how many bytes is useful to understand if we ever get into situation where we are out of workgroup shared memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I want to say tile_A includes 512 scalars which is calculated by 4x4x32. And this needs 4 scales since the block_size is 128 (4 = 512 / 128). The code has been updated.

if (col < 4)
{
// kidx_v - covers 16 values of k
scale_A[col] = scales_a[a_global*(uniforms.K/128) + k_offset/8 + col];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scales_a[a_global*(uniforms.K/128) + k_offset/8 + col];

I think this should be

scales_a[a_global*(uniforms.K/128) + k_offset/8];

because k_offset already contains col.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col];. I need 4 scale a's data.

@qjia7 qjia7 marked this pull request as draft March 19, 2025 06:51
@qjia7 qjia7 marked this pull request as ready for review March 19, 2025 08:57
@qjia7 qjia7 requested a review from sushraja-msft March 19, 2025 08:57
let local_row = local_idx / tile_size_k_vec;
for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += 16)
{
// Load Phase: Populate shared memory for the workgroup.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workgroupBarrier()

because we are changing shared memory and the previous iteration needs to have completed for all threads.


constexpr uint32_t tile_size_k_vec = 16; // tile K in input_b.
// tile_size is the number of rows of b each workgroup process.
// tile_size_k_vec is the number of columns of b each workgroup process and each element is a vec4<u32>.
Copy link
Contributor

@sushraja-msft sushraja-msft Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Each workgroup processes tile_size_k_vec columns of b at a time, iterating over the columns to compute the full dot product. Vectorization here is vec4 or each K is 32 items of B.

constexpr uint32_t tile_size_k_vec = 16; // tile K in input_b.
// tile_size is the number of rows of b each workgroup process.
// tile_size_k_vec is the number of columns of b each workgroup process and each element is a vec4<u32>.
// In each workgroup, we read a block [tile_size][tile_size_k_vec] of b data and calculate the corresponding intermediate results of a * b. Then store them into inter_results.
Copy link
Contributor

@sushraja-msft sushraja-msft Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Each workgroup is responsible for a single row of A and tile_size rows of B.
// Each XYZ set of threads within the workgroup is responsible for within a row of A and row of B, performing a dot product of tile_size_k_vec (vectorization of 16, therefore 512 items of A & B)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants