Skip to content

Commit 5c13776

Browse files
coconutrubenfacebook-github-bot
authored andcommitted
adjust interface (pytorch#747)
Summary: X-link: pytorch#3669 Pull Request resolved: facebookresearch/FBGEMM#747 # Why 1. we're extracting the size wrong after the latest changes, rather than extracing it from w1 we need to get it from w2 as w1 is treated at 2x intermediate size on `gate_only=False` 2. we're hard-coding the weights dtype when we should be extracting it 3. we're using the default stream instead of the current stream # What 1. get intermediate size from w2 2. do not hard-code the `topk_weights` dtype 3. use current stream Reviewed By: sijiac Differential Revision: D69341443 fbshipit-source-id: cf7a908c6a78d3ecb9d030491722967fbe0d097b
1 parent 37ea0d2 commit 5c13776

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip

+12-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include <ATen/ATen.h>
66
#include <torch/library.h>
77

8+
#include <c10/hip/HIPStream.h>
9+
810
#include <atomic>
911
#include <cassert>
1012
#include <cmath>
@@ -40,7 +42,10 @@ at::Tensor fused_moe_impl(
4042
auto tokens = input.size(0);
4143
auto hidden_size = input.size(1);
4244
auto experts = gate_up_weight.size(0);
43-
auto intermediate_size = gate_up_weight.size(1);
45+
// Interface requires that you pass intermediate size. On |gate_only| = False,
46+
// |gate_up_weight| might be 2 * intermediate size, so extract the size from
47+
// |down_weight|
48+
auto intermediate_size = down_weight.size(2);
4449
auto topk = topk_ids.size(1);
4550
auto stride = input.stride(0);
4651

@@ -70,6 +75,7 @@ at::Tensor fused_moe_impl(
7075
auto prec_i = get_prec_str(input);
7176
auto prec_w = get_prec_str(gate_up_weight);
7277
auto prec_o = get_prec_str(output);
78+
auto prec_tkw = get_prec_str(topk_weights);
7379

7480
// Set up traits structure
7581
fused_moe_traits traits{
@@ -79,8 +85,9 @@ at::Tensor fused_moe_impl(
7985
"fp32", // prec_st (token scale)
8086
"fp32", // prec_sw (weight scale)
8187
"fp32", // prec_sq (smooth quant)
82-
"fp32", // prec_kw (topk weight)
88+
prec_tkw, // prec_kw (topk weight)
8389
static_cast<int>(block_m),
90+
1,
8491
static_cast<int>(gate_only),
8592
static_cast<int>(fused_quant)};
8693

@@ -109,10 +116,9 @@ at::Tensor fused_moe_impl(
109116
static_cast<int>(stride)};
110117

111118
// Call kernel with default stream config
112-
ck_tile::stream_config stream_cfg{nullptr, true, 0, 0, 1};
113-
float time_ms = fused_moe(traits, args, stream_cfg);
114-
115-
TORCH_CHECK(time_ms >= 0, "Fused MoE kernel execution failed");
119+
auto stream = at::cuda::getCurrentHIPStream().stream();
120+
ck_tile::stream_config stream_cfg{stream};
121+
fused_moe(traits, args, stream_cfg);
116122

117123
return output;
118124
}

0 commit comments

Comments
 (0)