From d9430217f7fed2a9f20a4ece51397c1e372e16d6 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 17 Mar 2025 10:16:55 +0800 Subject: [PATCH 01/11] [webgpu] Apply dp4a for generation shader --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 128 ++++++++++++++++++ .../webgpu/quantization/dp4a_matmul_nbits.h | 13 ++ .../webgpu/quantization/matmul_nbits.cc | 5 +- 3 files changed, 143 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 05cbfb1f99c48..bca60c90ae58e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -258,11 +258,120 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +// tile_N size = 16, workgroup size = 64, scale_A components = 1, b components = 4, output components = 4 +Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform); + shader.AddInput("scales_a", ShaderUsage::UseUniform); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + shader.AdditionalImplementation() << R"ADDNL_FN( + const tile_size = 16u; // tile_size = tile_size_vec * output components + const tile_size_vec = 4u; + const tile_size_k_vec = 16u; // tile_size_vec * tile_size_k_vec = workgroup size + // Shared memory + var tile_A : array, 32>; // 512 scalars + var scale_A : array; // 4 + var inter_results: array, tile_size_k_vec>, tile_size_vec>; + fn loadSHMA(a_global:u32, kidx_v:u32, col: u32) + { + let k_offset = kidx_v + col; + if (k_offset >= uniforms.K16) { + return; + } + + tile_A[col] = input_a[a_global*uniforms.K16+k_offset]; + if (col < 4) + { + // kidx_v - covers 16 values of k + scale_A[col] = scales_a[a_global*(uniforms.K/128) + k_offset/8 + col]; + } + } + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; + } + )ADDNL_FN"; + + shader.MainFunctionBody() << R"MAIN_FN( + let a_global = workgroup_id.y; + let b_global_base = workgroup_id.x * tile_size; + let idx = local_idx % tile_size_k_vec; + let idy = local_idx / tile_size_k_vec; + for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v+=16) + { + // Load Phase: Populate shared memory for the workgroup. + if (local_idx < 32) + { + loadSHMA(a_global, kidx_v * 2, local_idx); + } + workgroupBarrier(); + var own_a: vec4 = tile_A[idx*2]; + var own_a1: vec4 = tile_A[idx*2 + 1]; + var own_scale_a: output_element_t = scale_A[idx / 4]; + var own_b = vec4(0); + var own_b1 = vec4(0); + let k_offset = kidx_v+idx; + for (var i = 0u; i < 4u; i++) { + let b_global = b_global_base + idy * 4 + i; + if (b_global < uniforms.N && k_offset < uniforms.K32) + { + let b_offset = b_global*uniforms.K32+k_offset; + let b_value = input_b[b_offset]; + var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + let own_scale_b = scales_b[b_offset]; + inter_results[idy][idx][i] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); + } + } + } + workgroupBarrier(); + if (local_idx < tile_size_vec) { + var output_value = vec4(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + output_value += inter_results[local_idx][b]; + } + let b_global = b_global_base + local_idx * 4; + let output_idx = (a_global * uniforms.N + b_global)/4; + if (b_global < uniforms.N) { + output[output_idx] = output_value; + } + } + )MAIN_FN"; + + return Status::OK(); +} + Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, uint32_t M, uint32_t N, uint32_t K, uint32_t block_size, + uint32_t min_M_for_tile_optimization, onnxruntime::webgpu::ComputeContext& context, Tensor* y) { constexpr uint32_t kVec4Components = 4; @@ -283,6 +392,25 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor .AddUniformVariable({static_cast(M * K / kVec4Components)}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); + if (M < min_M_for_tile_optimization) { + constexpr uint32_t kTileSize = 16; + DP4AMatMulNBitsSmallMProgram mul_program; + mul_program.SetWorkgroupSize(64); + mul_program.SetDispatchGroupSize( + (N + kTileSize - 1) / kTileSize, M, 1); + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components * kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}, + {static_cast(K / 16)}, + {static_cast(K / 32)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}); + return context.RunProgram(mul_program); + } + constexpr uint32_t kTileSize = 64; TensorShape reshaped_y_shape{1, M, N / kVec4Components}; DP4AMatMulNBitsProgram mul_program{block_size}; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 15b86d78301ad..a73a48368b4ae 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -34,11 +34,24 @@ class DP4AMatMulNBitsProgram final : public Program { uint32_t block_size_; }; +class DP4AMatMulNBitsSmallMProgram final : public Program { + public: + DP4AMatMulNBitsSmallMProgram() : Program{"DP4AMatMulNBitsSmallMProgram"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K16", ProgramUniformVariableDataType::Uint32}, + {"K32", ProgramUniformVariableDataType::Uint32}); +}; + Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, uint32_t M, uint32_t N, uint32_t K, uint32_t block_size, + uint32_t min_M_for_tile_optimization, onnxruntime::webgpu::ComputeContext& context, Tensor* y); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index cce10a59fbd4b..ca3825df20610 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -574,9 +574,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } - if (M >= kMinMForTileOptimization && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { - return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, context, y); + if (CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { + return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y); } // TODO: Support output_number > 1. Some cases are failed when output_number > 1. From 356410a5aa0504f5e1e1461ac310f288f9791638 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 17 Mar 2025 16:28:32 +0800 Subject: [PATCH 02/11] support any block_size % 32 = 0 --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 10 +++------- .../webgpu/quantization/dp4a_matmul_nbits.h | 3 ++- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index bca60c90ae58e..2dcc66758376d 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -344,7 +344,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co b_value_upper = vec4(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4(8); own_b1[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); own_b1[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - let own_scale_b = scales_b[b_offset]; + let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + k_offset * 32 / uniforms.block_size]; inter_results[idy][idx][i] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); } } @@ -402,12 +402,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) - .AddUniformVariables({{static_cast(M)}, - {static_cast(N)}, - {static_cast(K)}, - {static_cast(K / 16)}, - {static_cast(K / 32)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}); + .AddUniformVariables({M, N, K, K / 16, K / 32, block_size}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 4}); return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index a73a48368b4ae..56656c514891a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -43,7 +43,8 @@ class DP4AMatMulNBitsSmallMProgram final : public Program Date: Tue, 18 Mar 2025 09:51:28 +0800 Subject: [PATCH 03/11] apply it only for float type --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index ca3825df20610..e9bd06e8870f1 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -574,7 +574,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } - if (CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { + if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType()) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y); } From 4631638ee0b87d4c7aefbb333c3fe39c2f95083a Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 19 Mar 2025 10:56:59 +0800 Subject: [PATCH 04/11] use 1D dispatch group size --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 10 +++++----- .../webgpu/quantization/dp4a_matmul_nbits.h | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 71d98396ebe54..1c2eafc27752e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -295,8 +295,8 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( - let a_global = workgroup_id.y; - let b_global_base = workgroup_id.x * tile_size; + let a_global = u32(workgroup_idx / uniforms.num_N_tile); + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; let idx = local_idx % tile_size_k_vec; let idy = local_idx / tile_size_k_vec; for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v+=16) @@ -385,14 +385,14 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor if (M < min_M_for_tile_optimization) { constexpr uint32_t kTileSize = 16; DP4AMatMulNBitsSmallMProgram mul_program; + uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; mul_program.SetWorkgroupSize(64); - mul_program.SetDispatchGroupSize( - (N + kTileSize - 1) / kTileSize, M, 1); + mul_program.SetDispatchGroupSize(M * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) - .AddUniformVariables({M, N, K, K / 16, K / 32, block_size}) + .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 4}); return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index d6ca9167bb09d..34062083c5253 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -43,7 +43,8 @@ class DP4AMatMulNBitsSmallMProgram final : public Program Date: Wed, 19 Mar 2025 13:33:00 +0800 Subject: [PATCH 05/11] Adjust the code to make it more flexible --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 46 ++++++++++--------- .../webgpu/quantization/dp4a_matmul_nbits.h | 5 +- .../webgpu/quantization/matmul_nbits.cc | 1 + 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 1c2eafc27752e..24a414f6ac44e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -249,7 +249,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -// tile_N size = 16, workgroup size = 64, scale_A components = 1, b components = 4, output components = 4 +// scale_A components = 1, b components = 4, output components = 1 Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("input_a", ShaderUsage::UseUniform); shader.AddInput("scales_a", ShaderUsage::UseUniform); @@ -257,14 +257,15 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co shader.AddInput("scales_b", ShaderUsage::UseUniform); shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + constexpr uint32_t tile_size_k_vec = 16; // tile K in input_b. + shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n" + << "const tile_size_k_vec = " << tile_size_k_vec << "u;\n" + << "const sub_tile_size = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"; shader.AdditionalImplementation() << R"ADDNL_FN( - const tile_size = 16u; // tile_size = tile_size_vec * output components - const tile_size_vec = 4u; - const tile_size_k_vec = 16u; // tile_size_vec * tile_size_k_vec = workgroup size // Shared memory var tile_A : array, 32>; // 512 scalars var scale_A : array; // 4 - var inter_results: array, tile_size_k_vec>, tile_size_vec>; + var inter_results: array, tile_size>; fn loadSHMA(a_global:u32, kidx_v:u32, col: u32) { let k_offset = kidx_v + col; @@ -275,7 +276,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co tile_A[col] = input_a[a_global*uniforms.K16+k_offset]; if (col < 4) { - // kidx_v - covers 16 values of k + // kidx_v - covers 16 values of k in input_a scale_A[col] = scales_a[a_global*(uniforms.K/128) + k_offset/8 + col]; } } @@ -297,8 +298,8 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co shader.MainFunctionBody() << R"MAIN_FN( let a_global = u32(workgroup_idx / uniforms.num_N_tile); let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; - let idx = local_idx % tile_size_k_vec; - let idy = local_idx / tile_size_k_vec; + let local_col = local_idx % tile_size_k_vec; + let local_row = local_idx / tile_size_k_vec; for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v+=16) { // Load Phase: Populate shared memory for the workgroup. @@ -307,17 +308,17 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co loadSHMA(a_global, kidx_v * 2, local_idx); } workgroupBarrier(); - var own_a: vec4 = tile_A[idx*2]; - var own_a1: vec4 = tile_A[idx*2 + 1]; - var own_scale_a: output_element_t = scale_A[idx / 4]; + var own_a: vec4 = tile_A[local_col * 2]; + var own_a1: vec4 = tile_A[local_col * 2 + 1]; + var own_scale_a = scale_A[local_col / 4]; var own_b = vec4(0); var own_b1 = vec4(0); - let k_offset = kidx_v+idx; - for (var i = 0u; i < 4u; i++) { - let b_global = b_global_base + idy * 4 + i; + let k_offset = kidx_v + local_col; + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_size) { + let b_global = b_global_base + row_offset + local_row; if (b_global < uniforms.N && k_offset < uniforms.K32) { - let b_offset = b_global*uniforms.K32+k_offset; + let b_offset = b_global * uniforms.K32 + k_offset; let b_value = input_b[b_offset]; var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); @@ -335,19 +336,20 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co b_value_upper = vec4(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4(8); own_b1[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); own_b1[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + // k_offset - covers 32 values of k in input_b let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + k_offset * 32 / uniforms.block_size]; - inter_results[idy][idx][i] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); + inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); } } } workgroupBarrier(); - if (local_idx < tile_size_vec) { - var output_value = vec4(0); + if (local_idx < tile_size) { + var output_value = output_element_t(0); for (var b = 0u; b < tile_size_k_vec; b++) { output_value += inter_results[local_idx][b]; } - let b_global = b_global_base + local_idx * 4; - let output_idx = (a_global * uniforms.N + b_global)/4; + let b_global = b_global_base + local_idx; + let output_idx = a_global * uniforms.N + b_global; if (b_global < uniforms.N) { output[output_idx] = output_value; } @@ -384,7 +386,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor if (M < min_M_for_tile_optimization) { constexpr uint32_t kTileSize = 16; - DP4AMatMulNBitsSmallMProgram mul_program; + DP4AMatMulNBitsSmallMProgram mul_program{kTileSize}; uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; mul_program.SetWorkgroupSize(64); mul_program.SetDispatchGroupSize(M * num_N_tile); @@ -393,7 +395,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components * kU32Components)}, {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 4}); + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}); return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 34062083c5253..7b2df9b3520d1 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -35,7 +35,7 @@ class DP4AMatMulNBitsProgram final : public Program { class DP4AMatMulNBitsSmallMProgram final : public Program { public: - DP4AMatMulNBitsSmallMProgram() : Program{"DP4AMatMulNBitsSmallMProgram"} {} + DP4AMatMulNBitsSmallMProgram(uint32_t tile_size) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_(tile_size) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -45,6 +45,9 @@ class DP4AMatMulNBitsSmallMProgram final : public Program= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType()) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y); } From d96de51c37ca812147f18dec81ccedaefa87bbd5 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 19 Mar 2025 13:54:16 +0800 Subject: [PATCH 06/11] Use workgroup size = 128 --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 24a414f6ac44e..a833bc9c0c0ee 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -385,10 +385,10 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); if (M < min_M_for_tile_optimization) { - constexpr uint32_t kTileSize = 16; + constexpr uint32_t kTileSize = 32; DP4AMatMulNBitsSmallMProgram mul_program{kTileSize}; uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; - mul_program.SetWorkgroupSize(64); + mul_program.SetWorkgroupSize(128); mul_program.SetDispatchGroupSize(M * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, From 36db69d4286713d24e808906fea4e28e4de38f57 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 19 Mar 2025 14:31:03 +0800 Subject: [PATCH 07/11] Add more annotations --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index a833bc9c0c0ee..3420cf638d87b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -258,15 +258,21 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); constexpr uint32_t tile_size_k_vec = 16; // tile K in input_b. + // tile_size is the number of rows of b each workgroup process. + // tile_size_k_vec is the number of columns of b each workgroup process and each element is a vec4. + // In each workgroup, we read a block [tile_size][tile_size_k_vec] of b data and calculate the corresponding intermediate results of a * b. Then store them into inter_results. + // Finally, do a reduce sum in inter_results to get the final results. shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n" << "const tile_size_k_vec = " << tile_size_k_vec << "u;\n" << "const sub_tile_size = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"; shader.AdditionalImplementation() << R"ADDNL_FN( // Shared memory - var tile_A : array, 32>; // 512 scalars - var scale_A : array; // 4 + // Need 2 * tile_size_k_vec (32) to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits. + var tile_A : array, 32>; + // Need 4 scales value since each tile_A includes 512 (4x4x32) scalars and the block_size is 128. + var scale_A : array; var inter_results: array, tile_size>; - fn loadSHMA(a_global:u32, kidx_v:u32, col: u32) + fn loadSHMA(a_global: u32, kidx_v: u32, col: u32) { let k_offset = kidx_v + col; if (k_offset >= uniforms.K16) { @@ -298,9 +304,10 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co shader.MainFunctionBody() << R"MAIN_FN( let a_global = u32(workgroup_idx / uniforms.num_N_tile); let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + // Handle each workgroup threads as a block of [sub_tile_size][tile_size_k_vec] let local_col = local_idx % tile_size_k_vec; let local_row = local_idx / tile_size_k_vec; - for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v+=16) + for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += 16) { // Load Phase: Populate shared memory for the workgroup. if (local_idx < 32) @@ -314,6 +321,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co var own_b = vec4(0); var own_b1 = vec4(0); let k_offset = kidx_v + local_col; + // calculate intermediate results into inter_results. for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_size) { let b_global = b_global_base + row_offset + local_row; if (b_global < uniforms.N && k_offset < uniforms.K32) @@ -344,6 +352,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co } workgroupBarrier(); if (local_idx < tile_size) { + // Do reduce sum to get final output. var output_value = output_element_t(0); for (var b = 0u; b < tile_size_k_vec; b++) { output_value += inter_results[local_idx][b]; From 701acbd32e1a71a0fd31fcb434142416cdc8cf75 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 19 Mar 2025 14:58:59 +0800 Subject: [PATCH 08/11] fix error in scale_a --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 3420cf638d87b..0f92cc78f8ebf 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -283,7 +283,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co if (col < 4) { // kidx_v - covers 16 values of k in input_a - scale_A[col] = scales_a[a_global*(uniforms.K/128) + k_offset/8 + col]; + scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col]; } } // Scaled dot product of 8 packed unsigned integers. From e538dd571cd528a48a6bcd1a39246c0fb56e0293 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 19 Mar 2025 15:45:52 +0800 Subject: [PATCH 09/11] Extract common functions for code reuse --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 94 ++++++++----------- 1 file changed, 41 insertions(+), 53 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 0f92cc78f8ebf..2e835c5ad6591 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -7,6 +7,39 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +namespace { + +constexpr std::string_view commonFunctions = R"ADDNL_FN( + fn DequantizedFrom4BitsTo8Bits(in: vec2) -> vec4 + { + var out = vec4(0); + var value_lower = vec4(unpack4xU8(in[0] & 0x0F0F0F0Fu)) - vec4(8); + var value_upper = vec4(unpack4xU8((in[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + out[0] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); + out[1] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); + value_lower = vec4(unpack4xU8(in[1] & 0x0F0F0F0Fu)) - vec4(8); + value_upper = vec4(unpack4xU8((in[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + out[2] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); + out[3] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); + return out; + } + + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; + } + )ADDNL_FN"; + +} // namespace Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); @@ -65,7 +98,8 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { // this shader require A to be int8 quantized with block size 64. B is regular // matmulnbits input with block size 32. - shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; + shader.AdditionalImplementation() << commonFunctions + << " const block_size = " << block_size_ << ";"; shader.AdditionalImplementation() << R"ADDNL_FN( const tile_size = 64; @@ -105,34 +139,13 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; - var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); - var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); - b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value); if (col == 0) { // kidx_v - each kidx_v covers 16 values of k scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; } } - - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t - { - var local_sum = dot4I8Packed(a1[0], b1[0]); - local_sum += dot4I8Packed(a1[1], b1[1]); - local_sum += dot4I8Packed(a1[2], b1[2]); - local_sum += dot4I8Packed(a1[3], b1[3]); - local_sum += dot4I8Packed(a2[0], b2[0]); - local_sum += dot4I8Packed(a2[1], b2[1]); - local_sum += dot4I8Packed(a2[2], b2[2]); - local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(local_sum) * scale; - } )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( @@ -265,7 +278,8 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n" << "const tile_size_k_vec = " << tile_size_k_vec << "u;\n" << "const sub_tile_size = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"; - shader.AdditionalImplementation() << R"ADDNL_FN( + shader.AdditionalImplementation() << commonFunctions + << R"ADDNL_FN( // Shared memory // Need 2 * tile_size_k_vec (32) to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits. var tile_A : array, 32>; @@ -286,19 +300,6 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col]; } } - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t - { - var local_sum = dot4I8Packed(a1[0], b1[0]); - local_sum += dot4I8Packed(a1[1], b1[1]); - local_sum += dot4I8Packed(a1[2], b1[2]); - local_sum += dot4I8Packed(a1[3], b1[3]); - local_sum += dot4I8Packed(a2[0], b2[0]); - local_sum += dot4I8Packed(a2[1], b2[1]); - local_sum += dot4I8Packed(a2[2], b2[2]); - local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(local_sum) * scale; - } )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( @@ -328,22 +329,9 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co { let b_offset = b_global * uniforms.K32 + k_offset; let b_value = input_b[b_offset]; - var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); - var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - own_b[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - own_b[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); - b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - own_b[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - own_b[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - b_value_lower = vec4(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4(8); - b_value_upper = vec4(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - own_b1[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - own_b1[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - b_value_lower = vec4(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4(8); - b_value_upper = vec4(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - own_b1[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - own_b1[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + own_b = DequantizedFrom4BitsTo8Bits(b_value.xy); + own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw); + // k_offset - covers 32 values of k in input_b let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + k_offset * 32 / uniforms.block_size]; inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); From f3a93e74aa36b73c5e53fcb05c5684a90cbc42f2 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 20 Mar 2025 10:06:46 +0800 Subject: [PATCH 10/11] address comments --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 2e835c5ad6591..1e9d94aa0c672 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -270,11 +270,20 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co shader.AddInput("scales_b", ShaderUsage::UseUniform); shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - constexpr uint32_t tile_size_k_vec = 16; // tile K in input_b. - // tile_size is the number of rows of b each workgroup process. - // tile_size_k_vec is the number of columns of b each workgroup process and each element is a vec4. - // In each workgroup, we read a block [tile_size][tile_size_k_vec] of b data and calculate the corresponding intermediate results of a * b. Then store them into inter_results. - // Finally, do a reduce sum in inter_results to get the final results. + // 1. Each workgroup handles tile_size_k_vec (16) columns of matrix B at a time, iterating over the columns to compute a partial dot product. + // 2. Uses vec4 vectorization where each K represents 32 elements of matrix B + constexpr uint32_t tile_size_k_vec = 16; + + // 1. Workgroup Responsibility: + // - Processes one row of matrix A + // - Handles tile_size rows of matrix B + // + // 2. Computation Process: + // - Reads [tile_size][tile_size_k_vec] block of B data at a time + // - Each thread within workgroup computes dot products of 32 A*B elements since each K represents 32 elements of matrix B + // - Stores intermediate results in shared memory (inter_results) + // - Iterates through columns accumulating results in inter_results + // - Performs final reduction sum in inter_results for output shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n" << "const tile_size_k_vec = " << tile_size_k_vec << "u;\n" << "const sub_tile_size = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"; @@ -337,8 +346,9 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); } } + workgroupBarrier(); } - workgroupBarrier(); + if (local_idx < tile_size) { // Do reduce sum to get final output. var output_value = output_element_t(0); From f9ac9ab15252fae5d4bf5b87137db7fb9b3f6021 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 20 Mar 2025 13:21:05 +0800 Subject: [PATCH 11/11] address comments --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 1e9d94aa0c672..a30eae0aba5c0 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -269,8 +269,13 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co shader.AddInput("input_b", ShaderUsage::UseUniform); shader.AddInput("scales_b", ShaderUsage::UseUniform); shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + // This algorithm works to compute dot product of k parallelly, by processing k at each step amongst tile_size_k_vec threads, + // and utilizing the remaining threads in the workgroup to process additional rows of b in parallel (such that the values in shared memory for A can be reused). + // For each load of k, the tile_size_k_vec threads also reload B tile_size/num_concurrent_b_rows times to compute partial dot products of other B rows + // in order to complete all tile_size b rows in this workgroup and also reusing the loaded in register values of a. - // 1. Each workgroup handles tile_size_k_vec (16) columns of matrix B at a time, iterating over the columns to compute a partial dot product. + // 1. Each workgroup handles tile_size_k_vec (16) * k_vectorization_in_b (32) columns (total 512) and num_concurrent_b_rows of matrix B at a time, + // iterating over the columns to compute a partial dot product. // 2. Uses vec4 vectorization where each K represents 32 elements of matrix B constexpr uint32_t tile_size_k_vec = 16; @@ -286,6 +291,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co // - Performs final reduction sum in inter_results for output shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n" << "const tile_size_k_vec = " << tile_size_k_vec << "u;\n" + // sub_tile_size is the number of concurrent b rows processed by the workgroup. << "const sub_tile_size = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"; shader.AdditionalImplementation() << commonFunctions << R"ADDNL_FN( @@ -317,7 +323,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co // Handle each workgroup threads as a block of [sub_tile_size][tile_size_k_vec] let local_col = local_idx % tile_size_k_vec; let local_row = local_idx / tile_size_k_vec; - for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += 16) + for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec) { // Load Phase: Populate shared memory for the workgroup. if (local_idx < 32)