Skip to content

Commit 23d6042

Browse files
committed
Apply feedback, add comments
1 parent 6de4d8a commit 23d6042

File tree

3 files changed

+32
-37
lines changed

3 files changed

+32
-37
lines changed

onnxruntime/core/providers/webgpu/math/matmul.cc

+26-28
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1717
1, 12,
1818
kWebGpuExecutionProvider,
1919
(*KernelDefBuilder::Create())
20-
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
20+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
2121
MatMul);
2222

2323
ONNX_OPERATOR_KERNEL_EX(
@@ -26,7 +26,7 @@ ONNX_OPERATOR_KERNEL_EX(
2626
13,
2727
kWebGpuExecutionProvider,
2828
(*KernelDefBuilder::Create())
29-
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
29+
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
3030
MatMul);
3131

3232
static std::string CalcResult(int64_t components, int64_t a_components, int64_t output_number) {
@@ -70,7 +70,7 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
7070
<< "let stride1 = uniforms.M / " << output_number_ << ";\n"
7171
<< "let row = (index1 % stride1) * " << output_number_ << ";\n"
7272
<< "let batch = index1 / stride1;\n";
73-
if (output_size_ != 2) {
73+
if (output_rank_ != 2) {
7474
shader.MainFunctionBody() << "let batch_indices = " << batch_dims.OffsetToIndices("batch") << ";\n";
7575
}
7676
shader.MainFunctionBody() << "var a_indices: a_indices_t;\n"
@@ -106,42 +106,40 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
106106

107107
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
108108
auto* output_tensor = context.Output(0, helper.OutputShape());
109+
bool has_bias = context.InputCount() > 2;
109110

110-
const uint32_t m = narrow<uint32_t>(helper.M());
111-
const uint32_t n = narrow<uint32_t>(helper.N());
112-
const uint32_t k = narrow<uint32_t>(helper.K());
111+
if (helper.N() < 8 && helper.K() < 8) { // call MatMulNaiveProgram
113112

114-
bool has_bias = context.InputCount() > 2;
113+
const uint32_t m = narrow<uint32_t>(helper.M()); // left matrix first dimension
114+
const uint32_t n = narrow<uint32_t>(helper.N()); // right matrix second dimension
115+
const uint32_t k = narrow<uint32_t>(helper.K()); // right matrix first dimension
115116

116-
if (n < 8 && k < 8) { // call MatMulNaiveProgram
117117
const auto components = GetMaxComponents(n);
118118
const auto a_components = GetMaxComponents(k);
119119

120120
const auto output_number = GetMaxComponents(m);
121-
uint32_t output_size = static_cast<uint32_t>(helper.OutputShape().Size() / components / output_number);
121+
uint32_t output_size = narrow<uint32_t>(helper.OutputShape().Size() / components / output_number);
122122

123123
const size_t output_rank = helper.OutputShape().NumDimensions();
124124
TensorShape outer_dims = output_rank > 2 ? helper.OutputShape().Slice(0, output_rank - 2) : TensorShape({});
125125
const int64_t batch_size = outer_dims.Size();
126126

127-
const int64_t m_val = a->Shape().NumDimensions() > 2
128-
? a->Shape()[a->Shape().NumDimensions() - 2]
129-
: helper.M();
130-
TensorShape output_shape_shader({batch_size, m_val, helper.N() / components});
127+
const int64_t a_rows = a->Shape().NumDimensions() > 1 ? a->Shape()[a->Shape().NumDimensions() - 2] : 1;
128+
TensorShape output_shape_shader({batch_size, a_rows, helper.N() / components});
131129

132-
MatMulNaiveProgram program{output_size, output_number, has_bias};
130+
MatMulNaiveProgram program{output_rank, output_number, has_bias};
133131

134132
program
135133
.CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number))
136-
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(a_components)},
137-
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
134+
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components},
135+
{b, ProgramTensorMetadataDependency::TypeAndRank, components}});
138136

139137
if (has_bias) {
140138
const auto* bias = context.Input(2);
141139
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1});
142140
}
143141
program
144-
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, output_shape_shader, static_cast<int>(components)}})
142+
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, output_shape_shader, components}})
145143
.SetDispatchGroupSize((output_size + 63) / 64) // Integer ceiling division
146144
.AddIndices(outer_dims)
147145
.AddUniformVariables({{output_size}, {m}, {n}, {k}});
@@ -156,9 +154,9 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
156154
TensorShape b_shape = b->Shape();
157155
TensorShape output_shape = helper.OutputShape();
158156

159-
const int64_t m_value = output_shape[output_shape.NumDimensions() - 2];
157+
const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2];
160158
// check if A is batch of vector (bach is not 1, M is 1) and B is a matrix (batch is 1)
161-
if (batchA != 1 && m_value == 1 && batchB == 1) {
159+
if (batchA != 1 && dim_output_outer == 1 && batchB == 1) {
162160
// optimization for batched vector matrix multiplication
163161
// dimensions of A: [1,`batchA`,K]
164162
TensorShapeVector dims_a = {1, batchA, helper.K()};
@@ -186,22 +184,22 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
186184
const int64_t batch_size = outer_dims.Size();
187185

188186
// Get dimensions for matrix multiplication from TensorShape
189-
const int32_t dim_a_outer = static_cast<int32_t>(a_shape[a_shape.NumDimensions() - 2]); // M dimension
190-
const int32_t dim_inner = static_cast<int32_t>(a_shape[a_shape.NumDimensions() - 1]); // K dimension
191-
const int32_t dim_b_outer = static_cast<int32_t>(b_shape[b_shape.NumDimensions() - 1]); // N dimension
187+
const int32_t dim_a_outer = narrow<int32_t>(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension
188+
const int32_t dim_inner = narrow<int32_t>(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension
189+
const int32_t dim_b_outer = narrow<int32_t>(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension
192190

193191
const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0;
194192

195193
InlinedVector<int64_t> elements_per_thread = dim_a_outer <= 8
196194
? InlinedVector<int64_t>({4, 1, 1})
197195
: InlinedVector<int64_t>({4, 4, 1});
198196

199-
const uint32_t dispatch_x = static_cast<uint32_t>((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
200-
(MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
201-
const uint32_t dispatch_y = static_cast<uint32_t>((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
202-
(MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
203-
const uint32_t dispatch_z = static_cast<uint32_t>((static_cast<uint32_t>(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
204-
(MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
197+
const uint32_t dispatch_x = narrow<uint32_t>((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
198+
(MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
199+
const uint32_t dispatch_y = narrow<uint32_t>((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
200+
(MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
201+
const uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
202+
(MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));
205203

206204
const int components = is_vec4 ? 4 : 1;
207205
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);

onnxruntime/core/providers/webgpu/math/matmul.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class MatMul final : public WebGpuKernel {
2626

2727
class MatMulNaiveProgram final : public Program<MatMulNaiveProgram> {
2828
public:
29-
MatMulNaiveProgram(const int64_t output_size, int64_t output_number, bool has_bias)
30-
: Program{"MatMulNaive"}, output_size_(output_size), output_number_(output_number), has_bias_{has_bias} {
29+
MatMulNaiveProgram(const size_t output_rank, int64_t output_number, bool has_bias)
30+
: Program{"MatMulNaive"}, output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias} {
3131
}
3232

3333
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -38,7 +38,7 @@ class MatMulNaiveProgram final : public Program<MatMulNaiveProgram> {
3838
{"K", ProgramUniformVariableDataType::Uint32});
3939

4040
private:
41-
const int64_t output_size_;
41+
const size_t output_rank_;
4242
const int64_t output_number_;
4343
const bool has_bias_;
4444
};

onnxruntime/test/providers/cpu/math/matmul_test.cc

+3-6
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ std::vector<MatMulTestData<T>> GenerateTestCases() {
158158
// clang-format on
159159
})});
160160

161-
#ifdef USE_WEBGPU
162161
test_cases.push_back(
163162
{"test 3D tensors with M = 1",
164163
{6, 1, 8},
@@ -263,7 +262,7 @@ std::vector<MatMulTestData<T>> GenerateTestCases() {
263262
{1, 2, 8, 1},
264263
{2, 2, 2, 1},
265264
real_expected_vals({140, 364, 364, 1100, 588, 812, 1836, 2572})});
266-
#endif
265+
267266
return test_cases;
268267
}
269268

@@ -295,9 +294,7 @@ void RunMatMulTest(int32_t opset_version, bool is_a_constant, bool is_b_constant
295294
excluded_providers.insert(kNnapiExecutionProvider);
296295
}
297296

298-
// WebGPU: test right 1D and left 1D
299-
// set of excluded test cases for WebGPU
300-
297+
// TODO:: Change MatMulNaive Shader to support these test cases webgpu
301298
std::unordered_set<std::string> webgpu_excluded_test_cases{
302299
"test left 1D",
303300
"test right 1D",
@@ -390,7 +387,7 @@ void RunMatMulZeroKTest() {
390387
// No special case is implemented.
391388
test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider,
392389
kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider,
393-
kOpenVINOExecutionProvider})
390+
kOpenVINOExecutionProvider, kWebGpuExecutionProvider})
394391
.Config(run_with_tunable_op)
395392
.RunWithConfig();
396393
}

0 commit comments

Comments
 (0)