@@ -17,7 +17,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
17
17
1 , 12 ,
18
18
kWebGpuExecutionProvider ,
19
19
(*KernelDefBuilder::Create ())
20
- .TypeConstraint(" T" , WebGpuSupportedNumberTypes ()),
20
+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes ()),
21
21
MatMul);
22
22
23
23
ONNX_OPERATOR_KERNEL_EX (
@@ -26,7 +26,7 @@ ONNX_OPERATOR_KERNEL_EX(
26
26
13 ,
27
27
kWebGpuExecutionProvider ,
28
28
(*KernelDefBuilder::Create ())
29
- .TypeConstraint(" T" , WebGpuSupportedNumberTypes ()),
29
+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes ()),
30
30
MatMul);
31
31
32
32
static std::string CalcResult (int64_t components, int64_t a_components, int64_t output_number) {
@@ -70,7 +70,7 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
70
70
<< " let stride1 = uniforms.M / " << output_number_ << " ;\n "
71
71
<< " let row = (index1 % stride1) * " << output_number_ << " ;\n "
72
72
<< " let batch = index1 / stride1;\n " ;
73
- if (output_size_ != 2 ) {
73
+ if (output_rank_ != 2 ) {
74
74
shader.MainFunctionBody () << " let batch_indices = " << batch_dims.OffsetToIndices (" batch" ) << " ;\n " ;
75
75
}
76
76
shader.MainFunctionBody () << " var a_indices: a_indices_t;\n "
@@ -106,42 +106,40 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
106
106
107
107
ORT_RETURN_IF_ERROR (helper.Compute (a->Shape (), b->Shape ()));
108
108
auto * output_tensor = context.Output (0 , helper.OutputShape ());
109
+ bool has_bias = context.InputCount () > 2 ;
109
110
110
- const uint32_t m = narrow<uint32_t >(helper.M ());
111
- const uint32_t n = narrow<uint32_t >(helper.N ());
112
- const uint32_t k = narrow<uint32_t >(helper.K ());
111
+ if (helper.N () < 8 && helper.K () < 8 ) { // call MatMulNaiveProgram
113
112
114
- bool has_bias = context.InputCount () > 2 ;
113
+ const uint32_t m = narrow<uint32_t >(helper.M ()); // left matrix first dimension
114
+ const uint32_t n = narrow<uint32_t >(helper.N ()); // right matrix second dimension
115
+ const uint32_t k = narrow<uint32_t >(helper.K ()); // right matrix first dimension
115
116
116
- if (n < 8 && k < 8 ) { // call MatMulNaiveProgram
117
117
const auto components = GetMaxComponents (n);
118
118
const auto a_components = GetMaxComponents (k);
119
119
120
120
const auto output_number = GetMaxComponents (m);
121
- uint32_t output_size = static_cast <uint32_t >(helper.OutputShape ().Size () / components / output_number);
121
+ uint32_t output_size = narrow <uint32_t >(helper.OutputShape ().Size () / components / output_number);
122
122
123
123
const size_t output_rank = helper.OutputShape ().NumDimensions ();
124
124
TensorShape outer_dims = output_rank > 2 ? helper.OutputShape ().Slice (0 , output_rank - 2 ) : TensorShape ({});
125
125
const int64_t batch_size = outer_dims.Size ();
126
126
127
- const int64_t m_val = a->Shape ().NumDimensions () > 2
128
- ? a->Shape ()[a->Shape ().NumDimensions () - 2 ]
129
- : helper.M ();
130
- TensorShape output_shape_shader ({batch_size, m_val, helper.N () / components});
127
+ const int64_t a_rows = a->Shape ().NumDimensions () > 1 ? a->Shape ()[a->Shape ().NumDimensions () - 2 ] : 1 ;
128
+ TensorShape output_shape_shader ({batch_size, a_rows, helper.N () / components});
131
129
132
- MatMulNaiveProgram program{output_size , output_number, has_bias};
130
+ MatMulNaiveProgram program{output_rank , output_number, has_bias};
133
131
134
132
program
135
133
.CacheHint (std::to_string (components), std::to_string (a_components), std::to_string (output_number))
136
- .AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast < int >( a_components) },
137
- {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast < int >( components) }});
134
+ .AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components},
135
+ {b, ProgramTensorMetadataDependency::TypeAndRank, components}});
138
136
139
137
if (has_bias) {
140
138
const auto * bias = context.Input (2 );
141
139
program.AddInput ({bias, ProgramTensorMetadataDependency::Rank, 1 });
142
140
}
143
141
program
144
- .AddOutputs ({{output_tensor, ProgramTensorMetadataDependency::None, output_shape_shader, static_cast < int >( components) }})
142
+ .AddOutputs ({{output_tensor, ProgramTensorMetadataDependency::None, output_shape_shader, components}})
145
143
.SetDispatchGroupSize ((output_size + 63 ) / 64 ) // Integer ceiling division
146
144
.AddIndices (outer_dims)
147
145
.AddUniformVariables ({{output_size}, {m}, {n}, {k}});
@@ -156,9 +154,9 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
156
154
TensorShape b_shape = b->Shape ();
157
155
TensorShape output_shape = helper.OutputShape ();
158
156
159
- const int64_t m_value = output_shape[output_shape.NumDimensions () - 2 ];
157
+ const int64_t dim_output_outer = output_shape[output_shape.NumDimensions () - 2 ];
160
158
// check if A is batch of vector (bach is not 1, M is 1) and B is a matrix (batch is 1)
161
- if (batchA != 1 && m_value == 1 && batchB == 1 ) {
159
+ if (batchA != 1 && dim_output_outer == 1 && batchB == 1 ) {
162
160
// optimization for batched vector matrix multiplication
163
161
// dimensions of A: [1,`batchA`,K]
164
162
TensorShapeVector dims_a = {1 , batchA, helper.K ()};
@@ -186,22 +184,22 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
186
184
const int64_t batch_size = outer_dims.Size ();
187
185
188
186
// Get dimensions for matrix multiplication from TensorShape
189
- const int32_t dim_a_outer = static_cast <int32_t >(a_shape[a_shape.NumDimensions () - 2 ]); // M dimension
190
- const int32_t dim_inner = static_cast <int32_t >(a_shape[a_shape.NumDimensions () - 1 ]); // K dimension
191
- const int32_t dim_b_outer = static_cast <int32_t >(b_shape[b_shape.NumDimensions () - 1 ]); // N dimension
187
+ const int32_t dim_a_outer = narrow <int32_t >(a_shape[a_shape.NumDimensions () - 2 ]); // left matrix second dimension
188
+ const int32_t dim_inner = narrow <int32_t >(a_shape[a_shape.NumDimensions () - 1 ]); // left matrix first dimension
189
+ const int32_t dim_b_outer = narrow <int32_t >(b_shape[b_shape.NumDimensions () - 1 ]); // right matrix first dimension
192
190
193
191
const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0 ;
194
192
195
193
InlinedVector<int64_t > elements_per_thread = dim_a_outer <= 8
196
194
? InlinedVector<int64_t >({4 , 1 , 1 })
197
195
: InlinedVector<int64_t >({4 , 4 , 1 });
198
196
199
- const uint32_t dispatch_x = static_cast <uint32_t >((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0 ] - 1 ) /
200
- (MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0 ]));
201
- const uint32_t dispatch_y = static_cast <uint32_t >((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1 ] - 1 ) /
202
- (MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1 ]));
203
- const uint32_t dispatch_z = static_cast <uint32_t >((static_cast <uint32_t >(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2 ] - 1 ) /
204
- (MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2 ]));
197
+ const uint32_t dispatch_x = narrow <uint32_t >((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0 ] - 1 ) /
198
+ (MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0 ]));
199
+ const uint32_t dispatch_y = narrow <uint32_t >((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1 ] - 1 ) /
200
+ (MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1 ]));
201
+ const uint32_t dispatch_z = narrow <uint32_t >((static_cast <uint32_t >(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2 ] - 1 ) /
202
+ (MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2 ]));
205
203
206
204
const int components = is_vec4 ? 4 : 1 ;
207
205
const TensorShape a_shape_temp = CreateMatMulIntermediateShape (outer_dims_a, dim_a_outer, dim_inner, components);
0 commit comments