Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Apply dp4a for generation shader #24064

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
185 changes: 162 additions & 23 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,39 @@
namespace onnxruntime {
namespace contrib {
namespace webgpu {
namespace {

constexpr std::string_view commonFunctions = R"ADDNL_FN(
fn DequantizedFrom4BitsTo8Bits(in: vec2<u32>) -> vec4<u32>
{
var out = vec4<u32>(0);
var value_lower = vec4<i32>(unpack4xU8(in[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
var value_upper = vec4<i32>(unpack4xU8((in[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
out[0] = pack4xI8(vec4<i32>(value_lower[0], value_upper[0], value_lower[1], value_upper[1]));
out[1] = pack4xI8(vec4<i32>(value_lower[2], value_upper[2], value_lower[3], value_upper[3]));
value_lower = vec4<i32>(unpack4xU8(in[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
value_upper = vec4<i32>(unpack4xU8((in[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
out[2] = pack4xI8(vec4<i32>(value_lower[0], value_upper[0], value_lower[1], value_upper[1]));
out[3] = pack4xI8(vec4<i32>(value_lower[2], value_upper[2], value_lower[3], value_upper[3]));
return out;
}

// Scaled dot product of 8 packed unsigned integers.
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
{
var local_sum = dot4I8Packed(a1[0], b1[0]);
local_sum += dot4I8Packed(a1[1], b1[1]);
local_sum += dot4I8Packed(a1[2], b1[2]);
local_sum += dot4I8Packed(a1[3], b1[3]);
local_sum += dot4I8Packed(a2[0], b2[0]);
local_sum += dot4I8Packed(a2[1], b2[1]);
local_sum += dot4I8Packed(a2[2], b2[2]);
local_sum += dot4I8Packed(a2[3], b2[3]);
return output_element_t(local_sum) * scale;
}
)ADDNL_FN";

} // namespace

Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
Expand Down Expand Up @@ -65,7 +98,8 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
// this shader require A to be int8 quantized with block size 64. B is regular
// matmulnbits input with block size 32.

shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";";
shader.AdditionalImplementation() << commonFunctions
<< " const block_size = " << block_size_ << ";";

shader.AdditionalImplementation() << R"ADDNL_FN(
const tile_size = 64;
Expand Down Expand Up @@ -105,34 +139,13 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
var b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
var b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[col][row][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[col][row][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[col][row][2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[col][row][3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value);
if (col == 0)
{
// kidx_v - each kidx_v covers 16 values of k
scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)];
}
}

// Scaled dot product of 8 packed unsigned integers.
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
{
var local_sum = dot4I8Packed(a1[0], b1[0]);
local_sum += dot4I8Packed(a1[1], b1[1]);
local_sum += dot4I8Packed(a1[2], b1[2]);
local_sum += dot4I8Packed(a1[3], b1[3]);
local_sum += dot4I8Packed(a2[0], b2[0]);
local_sum += dot4I8Packed(a2[1], b2[1]);
local_sum += dot4I8Packed(a2[2], b2[2]);
local_sum += dot4I8Packed(a2[3], b2[3]);
return output_element_t(local_sum) * scale;
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
Expand Down Expand Up @@ -249,11 +262,122 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

// scale_A components = 1, b components = 4, output components = 1
Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform);
shader.AddInput("scales_a", ShaderUsage::UseUniform);
shader.AddInput("input_b", ShaderUsage::UseUniform);
shader.AddInput("scales_b", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
// This algorithm works to compute dot product of k parallelly, by processing k at each step amongst tile_size_k_vec threads,
// and utilizing the remaining threads in the workgroup to process additional rows of b in parallel (such that the values in shared memory for A can be reused).
// For each load of k, the tile_size_k_vec threads also reload B tile_size/num_concurrent_b_rows times to compute partial dot products of other B rows
// in order to complete all tile_size b rows in this workgroup and also reusing the loaded in register values of a.

// 1. Each workgroup handles tile_size_k_vec (16) * k_vectorization_in_b (32) columns (total 512) and num_concurrent_b_rows of matrix B at a time,
// iterating over the columns to compute a partial dot product.
// 2. Uses vec4 vectorization where each K represents 32 elements of matrix B
constexpr uint32_t tile_size_k_vec = 16;

// 1. Workgroup Responsibility:
// - Processes one row of matrix A
// - Handles tile_size rows of matrix B
//
// 2. Computation Process:
// - Reads [tile_size][tile_size_k_vec] block of B data at a time
// - Each thread within workgroup computes dot products of 32 A*B elements since each K represents 32 elements of matrix B
// - Stores intermediate results in shared memory (inter_results)
// - Iterates through columns accumulating results in inter_results
// - Performs final reduction sum in inter_results for output
shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n"
<< "const tile_size_k_vec = " << tile_size_k_vec << "u;\n"
// sub_tile_size is the number of concurrent b rows processed by the workgroup.
<< "const sub_tile_size = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n";
shader.AdditionalImplementation() << commonFunctions
<< R"ADDNL_FN(
// Shared memory
// Need 2 * tile_size_k_vec (32) to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits.
var<workgroup> tile_A : array<vec4<u32>, 32>;
// Need 4 scales value since each tile_A includes 512 (4x4x32) scalars and the block_size is 128.
var<workgroup> scale_A : array<output_element_t, 4>;
var<workgroup> inter_results: array<array<output_element_t, tile_size_k_vec>, tile_size>;
fn loadSHMA(a_global: u32, kidx_v: u32, col: u32)
{
let k_offset = kidx_v + col;
if (k_offset >= uniforms.K16) {
return;
}

tile_A[col] = input_a[a_global*uniforms.K16+k_offset];
if (col < 4)
{
// kidx_v - covers 16 values of k in input_a
scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col];
}
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
let a_global = u32(workgroup_idx / uniforms.num_N_tile);
let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
// Handle each workgroup threads as a block of [sub_tile_size][tile_size_k_vec]
let local_col = local_idx % tile_size_k_vec;
let local_row = local_idx / tile_size_k_vec;
for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec)
{
// Load Phase: Populate shared memory for the workgroup.
if (local_idx < 32)
{
loadSHMA(a_global, kidx_v * 2, local_idx);
}
workgroupBarrier();
var own_a: vec4<u32> = tile_A[local_col * 2];
var own_a1: vec4<u32> = tile_A[local_col * 2 + 1];
var own_scale_a = scale_A[local_col / 4];
var own_b = vec4<u32>(0);
var own_b1 = vec4<u32>(0);
let k_offset = kidx_v + local_col;
// calculate intermediate results into inter_results.
for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_size) {
let b_global = b_global_base + row_offset + local_row;
if (b_global < uniforms.N && k_offset < uniforms.K32)
{
let b_offset = b_global * uniforms.K32 + k_offset;
let b_value = input_b[b_offset];
own_b = DequantizedFrom4BitsTo8Bits(b_value.xy);
own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw);

// k_offset - covers 32 values of k in input_b
let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + k_offset * 32 / uniforms.block_size];
inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);
}
}
workgroupBarrier();
}

if (local_idx < tile_size) {
// Do reduce sum to get final output.
var output_value = output_element_t(0);
for (var b = 0u; b < tile_size_k_vec; b++) {
output_value += inter_results[local_idx][b];
}
let b_global = b_global_base + local_idx;
let output_idx = a_global * uniforms.N + b_global;
if (b_global < uniforms.N) {
output[output_idx] = output_value;
}
}
)MAIN_FN";

return Status::OK();
}

Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
uint32_t M,
uint32_t N,
uint32_t K,
uint32_t block_size,
uint32_t min_M_for_tile_optimization,
onnxruntime::webgpu::ComputeContext& context,
Tensor* y) {
constexpr uint32_t kVec4Components = 4;
Expand All @@ -273,6 +397,21 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}});
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));

if (M < min_M_for_tile_optimization) {
constexpr uint32_t kTileSize = 32;
DP4AMatMulNBitsSmallMProgram mul_program{kTileSize};
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
mul_program.SetWorkgroupSize(128);
mul_program.SetDispatchGroupSize(M * num_N_tile);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1},
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components * kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1});
return context.RunProgram(mul_program);
}

constexpr uint32_t kTileSize = 64;
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
DP4AMatMulNBitsProgram mul_program{block_size};
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,29 @@ class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
uint32_t block_size_;
};

class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMProgram> {
public:
DP4AMatMulNBitsSmallMProgram(uint32_t tile_size) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_(tile_size) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K16", ProgramUniformVariableDataType::Uint32},
{"K32", ProgramUniformVariableDataType::Uint32},
{"block_size", ProgramUniformVariableDataType::Uint32},
{"num_N_tile", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_;
};

Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
uint32_t M,
uint32_t N,
uint32_t K,
uint32_t block_size,
uint32_t min_M_for_tile_optimization,
onnxruntime::webgpu::ComputeContext& context,
Tensor* y);

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y);
}

if (M >= kMinMForTileOptimization &&
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, context, y);
// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>()) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y);
}

// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
Expand Down
Loading