diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 65807b072bc80..a30eae0aba5c0 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -7,6 +7,39 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +namespace { + +constexpr std::string_view commonFunctions = R"ADDNL_FN( + fn DequantizedFrom4BitsTo8Bits(in: vec2) -> vec4 + { + var out = vec4(0); + var value_lower = vec4(unpack4xU8(in[0] & 0x0F0F0F0Fu)) - vec4(8); + var value_upper = vec4(unpack4xU8((in[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + out[0] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); + out[1] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); + value_lower = vec4(unpack4xU8(in[1] & 0x0F0F0F0Fu)) - vec4(8); + value_upper = vec4(unpack4xU8((in[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + out[2] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); + out[3] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); + return out; + } + + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; + } + )ADDNL_FN"; + +} // namespace Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); @@ -65,7 +98,8 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { // this shader require A to be int8 quantized with block size 64. B is regular // matmulnbits input with block size 32. - shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; + shader.AdditionalImplementation() << commonFunctions + << " const block_size = " << block_size_ << ";"; shader.AdditionalImplementation() << R"ADDNL_FN( const tile_size = 64; @@ -105,34 +139,13 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; - var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); - var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); - b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value); if (col == 0) { // kidx_v - each kidx_v covers 16 values of k scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; } } - - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t - { - var local_sum = dot4I8Packed(a1[0], b1[0]); - local_sum += dot4I8Packed(a1[1], b1[1]); - local_sum += dot4I8Packed(a1[2], b1[2]); - local_sum += dot4I8Packed(a1[3], b1[3]); - local_sum += dot4I8Packed(a2[0], b2[0]); - local_sum += dot4I8Packed(a2[1], b2[1]); - local_sum += dot4I8Packed(a2[2], b2[2]); - local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(local_sum) * scale; - } )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( @@ -249,11 +262,122 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +// scale_A components = 1, b components = 4, output components = 1 +Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform); + shader.AddInput("scales_a", ShaderUsage::UseUniform); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + // This algorithm works to compute dot product of k parallelly, by processing k at each step amongst tile_size_k_vec threads, + // and utilizing the remaining threads in the workgroup to process additional rows of b in parallel (such that the values in shared memory for A can be reused). + // For each load of k, the tile_size_k_vec threads also reload B tile_size/num_concurrent_b_rows times to compute partial dot products of other B rows + // in order to complete all tile_size b rows in this workgroup and also reusing the loaded in register values of a. + + // 1. Each workgroup handles tile_size_k_vec (16) * k_vectorization_in_b (32) columns (total 512) and num_concurrent_b_rows of matrix B at a time, + // iterating over the columns to compute a partial dot product. + // 2. Uses vec4 vectorization where each K represents 32 elements of matrix B + constexpr uint32_t tile_size_k_vec = 16; + + // 1. Workgroup Responsibility: + // - Processes one row of matrix A + // - Handles tile_size rows of matrix B + // + // 2. Computation Process: + // - Reads [tile_size][tile_size_k_vec] block of B data at a time + // - Each thread within workgroup computes dot products of 32 A*B elements since each K represents 32 elements of matrix B + // - Stores intermediate results in shared memory (inter_results) + // - Iterates through columns accumulating results in inter_results + // - Performs final reduction sum in inter_results for output + shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n" + << "const tile_size_k_vec = " << tile_size_k_vec << "u;\n" + // sub_tile_size is the number of concurrent b rows processed by the workgroup. + << "const sub_tile_size = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n"; + shader.AdditionalImplementation() << commonFunctions + << R"ADDNL_FN( + // Shared memory + // Need 2 * tile_size_k_vec (32) to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits. + var tile_A : array, 32>; + // Need 4 scales value since each tile_A includes 512 (4x4x32) scalars and the block_size is 128. + var scale_A : array; + var inter_results: array, tile_size>; + fn loadSHMA(a_global: u32, kidx_v: u32, col: u32) + { + let k_offset = kidx_v + col; + if (k_offset >= uniforms.K16) { + return; + } + + tile_A[col] = input_a[a_global*uniforms.K16+k_offset]; + if (col < 4) + { + // kidx_v - covers 16 values of k in input_a + scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col]; + } + } + )ADDNL_FN"; + + shader.MainFunctionBody() << R"MAIN_FN( + let a_global = u32(workgroup_idx / uniforms.num_N_tile); + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + // Handle each workgroup threads as a block of [sub_tile_size][tile_size_k_vec] + let local_col = local_idx % tile_size_k_vec; + let local_row = local_idx / tile_size_k_vec; + for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec) + { + // Load Phase: Populate shared memory for the workgroup. + if (local_idx < 32) + { + loadSHMA(a_global, kidx_v * 2, local_idx); + } + workgroupBarrier(); + var own_a: vec4 = tile_A[local_col * 2]; + var own_a1: vec4 = tile_A[local_col * 2 + 1]; + var own_scale_a = scale_A[local_col / 4]; + var own_b = vec4(0); + var own_b1 = vec4(0); + let k_offset = kidx_v + local_col; + // calculate intermediate results into inter_results. + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_size) { + let b_global = b_global_base + row_offset + local_row; + if (b_global < uniforms.N && k_offset < uniforms.K32) + { + let b_offset = b_global * uniforms.K32 + k_offset; + let b_value = input_b[b_offset]; + own_b = DequantizedFrom4BitsTo8Bits(b_value.xy); + own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw); + + // k_offset - covers 32 values of k in input_b + let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + k_offset * 32 / uniforms.block_size]; + inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); + } + } + workgroupBarrier(); + } + + if (local_idx < tile_size) { + // Do reduce sum to get final output. + var output_value = output_element_t(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + output_value += inter_results[local_idx][b]; + } + let b_global = b_global_base + local_idx; + let output_idx = a_global * uniforms.N + b_global; + if (b_global < uniforms.N) { + output[output_idx] = output_value; + } + } + )MAIN_FN"; + + return Status::OK(); +} + Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, uint32_t M, uint32_t N, uint32_t K, uint32_t block_size, + uint32_t min_M_for_tile_optimization, onnxruntime::webgpu::ComputeContext& context, Tensor* y) { constexpr uint32_t kVec4Components = 4; @@ -273,6 +397,21 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); + if (M < min_M_for_tile_optimization) { + constexpr uint32_t kTileSize = 32; + DP4AMatMulNBitsSmallMProgram mul_program{kTileSize}; + uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; + mul_program.SetWorkgroupSize(128); + mul_program.SetDispatchGroupSize(M * num_N_tile); + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1}, + {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components * kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) + .AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}); + return context.RunProgram(mul_program); + } + constexpr uint32_t kTileSize = 64; TensorShape reshaped_y_shape{1, M, N / kVec4Components}; DP4AMatMulNBitsProgram mul_program{block_size}; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 7e4a8f5d68437..7b2df9b3520d1 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -33,11 +33,29 @@ class DP4AMatMulNBitsProgram final : public Program { uint32_t block_size_; }; +class DP4AMatMulNBitsSmallMProgram final : public Program { + public: + DP4AMatMulNBitsSmallMProgram(uint32_t tile_size) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_(tile_size) {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K16", ProgramUniformVariableDataType::Uint32}, + {"K32", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}, + {"num_N_tile", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t tile_size_; +}; + Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, uint32_t M, uint32_t N, uint32_t K, uint32_t block_size, + uint32_t min_M_for_tile_optimization, onnxruntime::webgpu::ComputeContext& context, Tensor* y); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index cce10a59fbd4b..b4e47b9186265 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -574,9 +574,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } - if (M >= kMinMForTileOptimization && - CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { - return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, context, y); + // On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. + if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType()) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { + return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y); } // TODO: Support output_number > 1. Some cases are failed when output_number > 1.