Skip to content

Commit ce65e25

Browse files
authored
[Native WebGPU] Add Matmul (#24046)
### Description Add Native Matmul (`MatMulNaive`, `MatMulPacked` and `MatMulPackedVec4` ) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent afaf4a5 commit ce65e25

File tree

8 files changed

+812
-3
lines changed

8 files changed

+812
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/math/matmul.h"
5+
#include "core/common/inlined_containers.h"
6+
#include "core/providers/cpu/tensor/utils.h"
7+
#include "core/providers/webgpu/shader_helper.h"
8+
#include "core/providers/webgpu/webgpu_supported_types.h"
9+
10+
#include "core/providers/webgpu/data_transfer.h"
11+
namespace onnxruntime {
12+
namespace webgpu {
13+
14+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
15+
MatMul,
16+
kOnnxDomain,
17+
1, 12,
18+
kWebGpuExecutionProvider,
19+
(*KernelDefBuilder::Create())
20+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
21+
MatMul);
22+
23+
ONNX_OPERATOR_KERNEL_EX(
24+
MatMul,
25+
kOnnxDomain,
26+
13,
27+
kWebGpuExecutionProvider,
28+
(*KernelDefBuilder::Create())
29+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
30+
MatMul);
31+
32+
static std::string CalcResult(int64_t components, int64_t a_components, int64_t output_number) {
33+
std::ostringstream oss;
34+
oss << "var a_data: a_value_t;\n";
35+
for (int i = 0; i < a_components; ++i) {
36+
oss << "let b_data" << i << " = b[(b_offset + (k + " << i << ") * uniforms.N + col) / " << components << "];\n";
37+
}
38+
for (int i = 0; i < output_number; ++i) {
39+
oss << "a_data = a[(a_offset + (row + " << i << ") * uniforms.K + k) / " << a_components << "];\n";
40+
41+
for (int j = 0; j < a_components; j++) {
42+
oss << "values[" << i << "] = fma(b_value_t(a_data" << (a_components == 1 ? "" : "[" + std::to_string(j) + "]") << "), b_data" << j << ", values[" << i << "]);\n";
43+
}
44+
}
45+
return oss.str();
46+
}
47+
48+
Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
49+
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
50+
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
51+
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
52+
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
53+
54+
std::string process_bias;
55+
if (has_bias_) {
56+
shader.AddInput("bias", ShaderUsage::UseUniform);
57+
process_bias = "value += output_value_t(bias[row + i]);";
58+
}
59+
60+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform |
61+
ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
62+
const auto& batch_dims = shader.AddIndices("batch_dims");
63+
64+
int a_components = a.NumComponents();
65+
int components = b.NumComponents(); // components of N
66+
67+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
68+
<< "let col = (global_idx % (uniforms.N / " << components << ")) * " << components << ";\n"
69+
<< "var index1 = global_idx / (uniforms.N / " << components << ");\n"
70+
<< "let stride1 = uniforms.M / " << output_number_ << ";\n"
71+
<< "let row = (index1 % stride1) * " << output_number_ << ";\n"
72+
<< "let batch = index1 / stride1;\n";
73+
if (output_rank_ != 2) {
74+
shader.MainFunctionBody() << "let batch_indices = " << batch_dims.OffsetToIndices("batch") << ";\n";
75+
}
76+
shader.MainFunctionBody() << "var a_indices: a_indices_t;\n"
77+
<< ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims.Rank(), "batch_indices")
78+
<< a.IndicesSet("a_indices", a.Rank() - 2, 0) << "\n"
79+
<< a.IndicesSet("a_indices", a.Rank() - 1, 0) << "\n"
80+
<< "let a_offset = " << a.IndicesToOffset("a_indices") << "*" << a_components << ";\n"
81+
<< "var b_indices: b_indices_t;\n"
82+
<< ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims.Rank(), "batch_indices")
83+
<< b.IndicesSet("b_indices", b.Rank() - 2, 0) << "\n"
84+
<< b.IndicesSet("b_indices", b.Rank() - 1, 0) << "\n"
85+
<< "let b_offset = " << b.IndicesToOffset("b_indices") << " * " << components << ";\n"
86+
<< "var values: array<output_value_t, " << output_number_ << ">;\n"
87+
<< "for (var k: u32 = 0u; k < uniforms.K; k = k + " << a_components << ") {\n"
88+
<< CalcResult(components, a_components, output_number_) << "\n"
89+
<< "}\n"
90+
<< "for (var i = 0u; i < " << output_number_ << "u; i++) {\n"
91+
<< " var value = values[i];\n"
92+
<< process_bias << "\n"
93+
<< " let cur_indices = output_indices_t(batch, row + i, col/ " << components << ");\n"
94+
<< " let offset = " << output.IndicesToOffset("cur_indices") << ";\n"
95+
<< output.SetByOffset("offset", "value")
96+
<< "}\n";
97+
98+
return Status::OK();
99+
}
100+
101+
Status MatMul::ComputeInternal(ComputeContext& context) const {
102+
// calculate output shape
103+
MatMulComputeHelper helper;
104+
const auto* a = context.Input(0);
105+
const auto* b = context.Input(1);
106+
107+
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
108+
auto* output_tensor = context.Output(0, helper.OutputShape());
109+
bool has_bias = context.InputCount() > 2;
110+
111+
if (helper.N() < 8 && helper.K() < 8) { // call MatMulNaiveProgram
112+
113+
const uint32_t m = narrow<uint32_t>(helper.M()); // left matrix first dimension
114+
const uint32_t n = narrow<uint32_t>(helper.N()); // right matrix second dimension
115+
const uint32_t k = narrow<uint32_t>(helper.K()); // right matrix first dimension
116+
117+
const auto components = GetMaxComponents(n);
118+
const auto a_components = GetMaxComponents(k);
119+
120+
const auto output_number = GetMaxComponents(m);
121+
uint32_t output_size = narrow<uint32_t>(helper.OutputShape().Size() / components / output_number);
122+
123+
const size_t output_rank = helper.OutputShape().NumDimensions();
124+
TensorShape outer_dims = output_rank > 2 ? helper.OutputShape().Slice(0, output_rank - 2) : TensorShape({});
125+
const int64_t batch_size = outer_dims.Size();
126+
127+
const int64_t a_rows = a->Shape().NumDimensions() > 1 ? a->Shape()[a->Shape().NumDimensions() - 2] : 1;
128+
TensorShape output_shape_shader({batch_size, a_rows, helper.N() / components});
129+
130+
MatMulNaiveProgram program{output_rank, output_number, has_bias};
131+
132+
program
133+
.CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number))
134+
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components},
135+
{b, ProgramTensorMetadataDependency::TypeAndRank, components}});
136+
137+
if (has_bias) {
138+
const auto* bias = context.Input(2);
139+
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1});
140+
}
141+
program
142+
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, output_shape_shader, components}})
143+
.SetDispatchGroupSize((output_size + 63) / 64) // Integer ceiling division
144+
.AddIndices(outer_dims)
145+
.AddUniformVariables({{output_size}, {m}, {n}, {k}});
146+
147+
return context.RunProgram(program);
148+
}
149+
150+
int64_t batchA = a->Shape().SizeToDimension(a->Shape().NumDimensions() - 2);
151+
int64_t batchB = b->Shape().SizeToDimension(b->Shape().NumDimensions() - 2);
152+
153+
TensorShape a_shape = a->Shape();
154+
TensorShape b_shape = b->Shape();
155+
TensorShape output_shape = helper.OutputShape();
156+
157+
const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2];
158+
// check if A is batch of vector (bach is not 1, M is 1) and B is a matrix (batch is 1)
159+
if (batchA != 1 && dim_output_outer == 1 && batchB == 1) {
160+
// optimization for batched vector matrix multiplication
161+
// dimensions of A: [1,`batchA`,K]
162+
TensorShapeVector dims_a = {1, batchA, helper.K()};
163+
// dimensions of B: [1,K,N]
164+
TensorShapeVector dims_b = {1, helper.K(), helper.N()};
165+
166+
a_shape = TensorShape(dims_a);
167+
b_shape = TensorShape(dims_b);
168+
output_shape = {1, batchA, helper.N()};
169+
}
170+
171+
// helpful dimension variables
172+
TensorShape outer_dims_a = a_shape.NumDimensions() > 2
173+
? a_shape.Slice(0, a_shape.NumDimensions() - 2)
174+
: TensorShape({});
175+
176+
TensorShape outer_dims_b = b_shape.NumDimensions() > 2
177+
? b_shape.Slice(0, b_shape.NumDimensions() - 2)
178+
: TensorShape({});
179+
180+
TensorShape outer_dims = output_shape.NumDimensions() > 2
181+
? output_shape.Slice(0, output_shape.NumDimensions() - 2)
182+
: TensorShape({});
183+
184+
const int64_t batch_size = outer_dims.Size();
185+
186+
// Get dimensions for matrix multiplication from TensorShape
187+
const int32_t dim_a_outer = narrow<int32_t>(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension
188+
const int32_t dim_inner = narrow<int32_t>(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension
189+
const int32_t dim_b_outer = narrow<int32_t>(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension
190+
191+
const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0;
192+
193+
InlinedVector<int64_t> elements_per_thread = dim_a_outer <= 8
194+
? InlinedVector<int64_t>({4, 1, 1})
195+
: InlinedVector<int64_t>({4, 4, 1});
196+
197+
const uint32_t dispatch_x = narrow<uint32_t>((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
198+
(MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
199+
const uint32_t dispatch_y = narrow<uint32_t>((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
200+
(MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
201+
const uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
202+
(MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
203+
204+
const int components = is_vec4 ? 4 : 1;
205+
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);
206+
const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components);
207+
const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components});
208+
209+
MatMulProgram program{has_bias, is_vec4, elements_per_thread};
210+
program
211+
.CacheHint(absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4))
212+
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
213+
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
214+
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
215+
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
216+
.AddIndices(outer_dims)
217+
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
218+
.SetWorkgroupSize(MATMUL_PACKED_WORKGROUP_SIZE_X, MATMUL_PACKED_WORKGROUP_SIZE_Y, MATMUL_PACKED_WORKGROUP_SIZE_Z);
219+
220+
if (has_bias) {
221+
const auto* bias = context.Input(2);
222+
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1});
223+
}
224+
return context.RunProgram(program);
225+
}
226+
227+
} // namespace webgpu
228+
} // namespace onnxruntime
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
#include "core/providers/webgpu/program.h"
8+
#include "core/providers/cpu/math/matmul_helper.h"
9+
#include "core/providers/webgpu/math/matmul_utils.h"
10+
#include "core/providers/webgpu/math/matmul_packed.h"
11+
#include "core/providers/webgpu/webgpu_utils.h"
12+
13+
namespace onnxruntime {
14+
namespace webgpu {
15+
16+
class MatMul final : public WebGpuKernel {
17+
public:
18+
MatMul(const OpKernelInfo& info) : WebGpuKernel{info} {}
19+
20+
Status ComputeInternal(ComputeContext& context) const override;
21+
22+
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8;
23+
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8;
24+
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Z = 1;
25+
};
26+
27+
class MatMulNaiveProgram final : public Program<MatMulNaiveProgram> {
28+
public:
29+
MatMulNaiveProgram(const size_t output_rank, int64_t output_number, bool has_bias)
30+
: Program{"MatMulNaive"}, output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias} {
31+
}
32+
33+
Status GenerateShaderCode(ShaderHelper& sh) const override;
34+
35+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
36+
{"M", ProgramUniformVariableDataType::Uint32},
37+
{"N", ProgramUniformVariableDataType::Uint32},
38+
{"K", ProgramUniformVariableDataType::Uint32});
39+
40+
private:
41+
const size_t output_rank_;
42+
const int64_t output_number_;
43+
const bool has_bias_;
44+
};
45+
46+
} // namespace webgpu
47+
} // namespace onnxruntime

0 commit comments

Comments
 (0)