Skip to content

gpu: nvidia: matmul: fix issues with scaling #2564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,6 @@ enum {
key_matmul_wei_trans,
key_matmul_dst_trans,
key_matmul_dst_cast_acc,
key_matmul_lt_src_scale,
key_matmul_lt_wei_scale,
key_matmul_sparse_tmp_ptr,
key_pool_dst_bf16cvt,
key_pool_dst_plain2blocked_cvt,
Expand Down
5 changes: 4 additions & 1 deletion src/gpu/generic/sycl/ref_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ struct ref_matmul_t : public gpu::generic::sycl::primitive_t {
const auto &scales = attr()->scales_;
bool dt_ok = true;
for (auto arg : supported_args) {
dt_ok = dt_ok && is_supported_type(scales.get_data_type(arg));
if (!scales.get(arg).has_default_values()) {
dt_ok = dt_ok
&& is_supported_type(scales.get_data_type(arg));
}
}
return dt_ok && attr_scales_ok(supported_args);
}
Expand Down
56 changes: 0 additions & 56 deletions src/gpu/nvidia/cudnn_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,66 +66,10 @@ status_t cudnn_matmul_lt_t::execute(const exec_ctx_t &ctx) const {
nvidia::stream_t *cuda_stream
= utils::downcast<nvidia::stream_t *>(ctx.stream());

const bool has_src_scales
= ctx.args().find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC)
!= ctx.args().end();
const bool has_wei_scales
= ctx.args().find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS)
!= ctx.args().end();
const bool has_dst_scales
= ctx.args().find(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)
!= ctx.args().end();

if (has_src_scales
&& (pd()->params_->multi_src_scale_
|| pd()->params_->acc_type_ == CUDA_R_32I)) {
// src scale sycl binary
exec_args_t src_scale_binary_args;
src_scale_binary_args[DNNL_ARG_SRC_0]
= memory_arg_t {ctx.args().at(DNNL_ARG_SRC).mem, true};
src_scale_binary_args[DNNL_ARG_SRC_1] = memory_arg_t {
ctx.args().at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC).mem, true};

std::unique_ptr<memory_t, memory_deleter_t> scratch_mem;
auto scratchpad_storage
= ctx.get_scratchpad_grantor().get_memory_storage(
memory_tracking::names::key_matmul_lt_src_scale);
safe_ptr_assign(scratch_mem,
new memory_t(ctx.stream()->engine(), pd()->src_md(),
std::move(scratchpad_storage)));
src_scale_binary_args[DNNL_ARG_DST]
= memory_arg_t {scratch_mem.get(), false};

exec_ctx_t binary_ctx(ctx, std::move(src_scale_binary_args));

CHECK(src_scale_binary_->execute(binary_ctx));
}
if (has_wei_scales
&& (pd()->params_->multi_wei_scale_
|| pd()->params_->acc_type_ == CUDA_R_32I)) {
// wei scale sycl binary
exec_args_t wei_scale_binary_args;
wei_scale_binary_args[DNNL_ARG_SRC_0]
= memory_arg_t {ctx.args().at(DNNL_ARG_WEIGHTS).mem, true};
wei_scale_binary_args[DNNL_ARG_SRC_1] = memory_arg_t {
ctx.args().at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS).mem,
true};

std::unique_ptr<memory_t, memory_deleter_t> scratch_mem;
auto scratchpad_storage
= ctx.get_scratchpad_grantor().get_memory_storage(
memory_tracking::names::key_matmul_lt_wei_scale);
safe_ptr_assign(scratch_mem,
new memory_t(ctx.stream()->engine(), pd()->weights_md(0),
std::move(scratchpad_storage)));
wei_scale_binary_args[DNNL_ARG_DST]
= memory_arg_t {scratch_mem.get(), false};

exec_ctx_t binary_ctx(ctx, std::move(wei_scale_binary_args));

CHECK(wei_scale_binary_->execute(binary_ctx));
}

CHECK(executor_->execute(ctx, ctx.stream()->engine(), matmul_impl_,
pd()->params_, src_d, weights_d, dst_d));

Expand Down
71 changes: 9 additions & 62 deletions src/gpu/nvidia/cudnn_matmul_executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,11 @@ struct cudnn_matmul_lt_base_exec_t {
xpu::sycl::interop_memory_arg_t<scratch_m> arg_block_a_scratch,
xpu::sycl::interop_memory_arg_t<scratch_m> arg_block_b_scratch,
xpu::sycl::interop_memory_arg_t<scratch_m> arg_block_c_scratch,
xpu::sycl::interop_memory_arg_t<scratch_m> scaled_arg_src,
xpu::sycl::interop_memory_arg_t<scratch_m> scaled_arg_wt,
xpu::sycl::interop_memory_arg_t<::sycl::access::mode::read>
arg_src_scale,
xpu::sycl::interop_memory_arg_t<::sycl::access::mode::read>
arg_wei_scale,
xpu::sycl::interop_memory_arg_t<::sycl::access::mode::read>
arg_dst_scale,
uint8_t *algo_scratch_ptr, uint8_t *bias_scratch_ptr,
uint8_t *block_a_scratch_ptr, uint8_t *block_b_scratch_ptr,
uint8_t *block_c_scratch_ptr, uint8_t *src_scale_scratch_ptr,
uint8_t *wei_scale_scratch_ptr) {
uint8_t *block_c_scratch_ptr) {

compat::host_task(cgh,
[= WA_THIS_COPY_CAPTURE](const compat::interop_handle &ih) {
Expand All @@ -282,29 +275,22 @@ struct cudnn_matmul_lt_base_exec_t {
void *block_c_scratch
= arg_block_c_scratch.get_native_pointer(ih);

void *scaled_src = scaled_arg_src.get_native_pointer(ih);
void *scaled_wt = scaled_arg_wt.get_native_pointer(ih);

void *bias = arg_bias.get_native_pointer(ih);
void *weights = arg_weights.get_native_pointer(ih);
void *src = arg_src.get_native_pointer(ih);
void *dst = arg_dst.get_native_pointer(ih);

void *src_scale = arg_src_scale.get_native_pointer(ih);
void *wei_scale = arg_wei_scale.get_native_pointer(ih);
void *dst_scale = arg_dst_scale.get_native_pointer(ih);

matmul_impl_->execute(cublas_handle, params, weights, src,
dst, bias, algo_scratch, reorder_scratch,
block_a_scratch, block_b_scratch, block_c_scratch,
scaled_src, scaled_wt, src_scale, wei_scale,
dst_scale);
nullptr, nullptr, dst_scale);

free_runtime_scratch(params->has_runtime_params_,
cublas_handle, cuda_stream, algo_scratch_ptr,
bias_scratch_ptr, block_a_scratch_ptr,
block_b_scratch_ptr, block_c_scratch_ptr,
src_scale_scratch_ptr, wei_scale_scratch_ptr);
block_b_scratch_ptr, block_c_scratch_ptr);
if (params->has_runtime_params_) { params->rt_cleanup(); }
});
}
Expand All @@ -314,8 +300,7 @@ struct cudnn_matmul_lt_base_exec_t {
cublasHandle_t cublas_handle, nvidia::stream_t *cuda_stream,
uint8_t *algo_scratch_ptr, uint8_t *bias_scratch_ptr,
uint8_t *block_a_scratch_ptr, uint8_t *block_b_scratch_ptr,
uint8_t *block_c_scratch_ptr, uint8_t *src_scale_scratch_ptr,
uint8_t *wei_scale_scratch_ptr) {
uint8_t *block_c_scratch_ptr) {
if (has_runtime_params || bias_scratch_ptr) {
cudaStream_t streamId;
cublasGetStream(cublas_handle, &streamId);
Expand All @@ -335,12 +320,6 @@ struct cudnn_matmul_lt_base_exec_t {
if (block_c_scratch_ptr) {
::sycl::free(block_c_scratch_ptr, cuda_stream->queue());
}
if (src_scale_scratch_ptr) {
::sycl::free(src_scale_scratch_ptr, cuda_stream->queue());
}
if (wei_scale_scratch_ptr) {
::sycl::free(wei_scale_scratch_ptr, cuda_stream->queue());
}
}
}

Expand Down Expand Up @@ -375,11 +354,6 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
auto arg_bias = CTX_IN_SYCL_MEMORY(DNNL_ARG_BIAS);
auto arg_dst = CTX_OUT_SYCL_MEMORY(DNNL_ARG_DST);

auto arg_src_scale
= CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);

auto arg_wei_scale = CTX_IN_SYCL_MEMORY(
DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
auto arg_dst_scale
= CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
auto arg_algo_scratch = params->algo_scratch_size_ != 0
Expand Down Expand Up @@ -407,23 +381,12 @@ struct cudnn_matmul_lt_exec_t final : public cudnn_matmul_lt_base_exec_t {
memory_tracking::names::key_matmul_lt_block_c)
: xpu::sycl::interop_memory_arg_t<
::sycl::access::mode::read_write>();
auto scaled_arg_src = params->src_scale_size_ != 0
? CTX_SCRATCH_SYCL_MEMORY(
memory_tracking::names::key_matmul_lt_src_scale)
: xpu::sycl::interop_memory_arg_t<
::sycl::access::mode::read_write>();
auto scaled_arg_wt = params->wei_scale_size_ != 0
? CTX_SCRATCH_SYCL_MEMORY(
memory_tracking::names::key_matmul_lt_wei_scale)
: xpu::sycl::interop_memory_arg_t<
::sycl::access::mode::read_write>();

interop_task(matmul_impl_, params, engine, cgh, cuda_stream, arg_wt,
arg_src, arg_dst, arg_bias, arg_algo_scratch,
arg_bias_scratch, arg_block_a_scratch, arg_block_b_scratch,
arg_block_c_scratch, scaled_arg_src, scaled_arg_wt,
arg_src_scale, arg_wei_scale, arg_dst_scale, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr);
arg_block_c_scratch, arg_dst_scale, nullptr, nullptr,
nullptr, nullptr, nullptr);
});
}

Expand Down Expand Up @@ -465,12 +428,6 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
uint8_t *block_c_scratch_ptr
= alloc_ptr(matmul_params->dest_size_, cuda_stream->queue());

uint8_t *src_scale_scratch_ptr = alloc_ptr(
matmul_params->src_scale_size_, cuda_stream->queue());

uint8_t *wei_scale_scratch_ptr = alloc_ptr(
matmul_params->wei_scale_size_, cuda_stream->queue());

return cuda_stream->interop_task([= WA_THIS_COPY_CAPTURE](
::sycl::handler &cgh) {
auto arg_src = CTX_IN_SYCL_MEMORY(DNNL_ARG_SRC);
Expand All @@ -488,26 +445,16 @@ struct cudnn_matmul_lt_runtime_args_exec_t final
matmul_params->weight_size_, block_b_scratch_ptr);
auto arg_block_c_scratch = init_scratch_from_ptr(
matmul_params->dest_size_, block_c_scratch_ptr);
auto scaled_arg_src = init_scratch_from_ptr(
matmul_params->src_scale_size_, src_scale_scratch_ptr);
auto scaled_arg_wt = init_scratch_from_ptr(
matmul_params->wei_scale_size_, wei_scale_scratch_ptr);

auto arg_src_scale
= CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
auto arg_wei_scale = CTX_IN_SYCL_MEMORY(
DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
auto arg_dst_scale
= CTX_IN_SYCL_MEMORY(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);

interop_task(matmul_impl_, matmul_params, engine, cgh, cuda_stream,
arg_wt, arg_src, arg_dst, arg_bias, arg_algo_scratch,
arg_bias_scratch, arg_block_a_scratch, arg_block_b_scratch,
arg_block_c_scratch, scaled_arg_src, scaled_arg_wt,
arg_src_scale, arg_wei_scale, arg_dst_scale,
algo_scratch_ptr, bias_scratch_ptr, block_a_scratch_ptr,
block_b_scratch_ptr, block_c_scratch_ptr,
src_scale_scratch_ptr, wei_scale_scratch_ptr);
arg_block_c_scratch, arg_dst_scale, algo_scratch_ptr,
bias_scratch_ptr, block_a_scratch_ptr, block_b_scratch_ptr,
block_c_scratch_ptr);
});
}

Expand Down
Loading
Loading