Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Support slide window in cutlass fused attention #24072

Merged
merged 7 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ struct GroupQueryAttentionParameters : AttentionParameters {
int seqlen_present_kv_cache; // sequence length of present kv tensor
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
int rotary_dim; // rotary embedding dimension
int local_window_size;
int num_splits; // number of splits for splitkv
int rotary_dim; // rotary embedding dimension
int local_window_size; // The window size excludes current token. It only includes tokens on the left side.
bool kv_share_buffer;
bool is_packed_qkv;
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ class GQAAttentionBase {
for (size_t seq = 0; seq < sequence_length; seq++) {
size_t seq_causal_length = past_seqlen + seq + 1;

const bool should_apply_local_window = local_window_size_ > 0 &&
// local_window_size does not include the current query token, while window_size includes it.
const bool should_apply_local_window = local_window_size_ >= 0 &&
seq_causal_length > static_cast<size_t>(local_window_size_) + 1;

const size_t start_offset = should_apply_local_window ? seq_causal_length - local_window_size_ - 1 : 0;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ Status EfficientAttention(
p.seqstart_q_ptr = nullptr;
p.seqstart_k_ptr = nullptr;
} else {
p.seqlen_k_ptr = const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index));
p.seqlen_k_ptr = reinterpret_cast<const int32_t*>(data.mask_index);
p.seqstart_q_ptr = p.seqlen_k_ptr + parameters.batch_size;
p.seqstart_k_ptr = p.seqlen_k_ptr + 2 * parameters.batch_size + 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
}

p.use_smooth_softmax = params.use_smooth_softmax;

// local_windows_size in GQA does not include current query token, while windows_size in this kernel includes it.
p.window_size = params.local_window_size + 1;
}

auto kernel_fn = attention_kernel_batched_impl<Attention>;
Expand Down
65 changes: 60 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,9 @@ struct AttentionKernel {
scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
int32_t* seqstart_q_ptr = nullptr;
int32_t* seqstart_k_ptr = nullptr;

int32_t* seqlen_k_ptr = nullptr;
const int32_t* seqstart_q_ptr = nullptr;
const int32_t* seqstart_k_ptr = nullptr;
const int32_t* seqlen_k_ptr = nullptr;
uint32_t causal_diagonal_offset = 0;

// Output tensors
Expand All @@ -187,6 +186,8 @@ struct AttentionKernel {
// [num_heads, num_queries] - can be null
lse_scalar_t* logsumexp_ptr = nullptr;

int32_t window_size = -1;

// Scale
accum_t scale = 0.0;

Expand Down Expand Up @@ -651,6 +652,12 @@ struct AttentionKernel {
XFORMERS_CHECK(
p.custom_mask_type < NumCustomMaskTypes,
"invalid value for `custom_mask_type`");
if (p.window_size > 0) {
XFORMERS_CHECK(
p.custom_mask_type == CausalFromTopLeft ||
p.custom_mask_type == CausalFromBottomRight,
"invalid value for custom_mask_type");
}
return true;
}

Expand Down Expand Up @@ -726,6 +733,13 @@ struct AttentionKernel {
// Iterate through keys
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
iter_key_start += kKeysPerBlock) {
if (p.window_size > 0) {
// don't compute anything if below attention band
if (iter_key_start + kKeysPerBlock <
static_cast<int32_t>(query_start + p.causal_diagonal_offset) - p.window_size) {
continue;
}
}
int32_t problem_size_0_m =
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
int32_t problem_size_0_n = cutlass::fast_min(
Expand Down Expand Up @@ -894,6 +908,38 @@ struct AttentionKernel {
},
[&](int accum_m) {});
}

// Mask out lower left corner of block if window_size > 0
// only required if current block intersects with the lower left corner
// block starts at x_lowerleft = iter_key_start // y = query_start +
// kQueriesPerBlock first non masked value at this y is : x_first =
// query_start + kQueriesPerBlock - window_size mask if x_fist >
// x_lowerleft

if (p.window_size > 0 &&
(query_start + p.causal_diagonal_offset +
cutlass::fast_min(
static_cast<int32_t>(kQueriesPerBlock), static_cast<int32_t>(p.num_queries)) -
p.window_size >=
iter_key_start)) {
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
my_lane_id, my_warp_id, iteratorC_tile_offset);
int32_t first_col;
const int32_t offset = query_start + p.causal_diagonal_offset -
p.window_size - iter_key_start;
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { first_col = accum_m + offset; },
[&](int accum_m, int accum_n, int idx) {
if (accum_n <= first_col) {
accum[idx] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
}

// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
Expand Down Expand Up @@ -1036,9 +1082,18 @@ struct AttentionKernel {
}

if (!kKeepOutputInRF) {
int first_key = 0;
if (p.window_size > 0) {
first_key = (cutlass::fast_max(
static_cast<int>(query_start + p.causal_diagonal_offset) -
p.window_size + 1,
0) /
kKeysPerBlock) *
kKeysPerBlock;
}
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
iter_key_start == first_key, kIsFirst, ([&] {
DISPATCH_BOOL(
(iter_key_start + kKeysPerBlock) >= p.num_keys,
kIsLast,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,39 @@ namespace cuda {
constexpr int kEfficientAttentionMaxHeadSize = 1024;

struct MemoryEfficientAttentionParams {
int32_t sm;
bool is_half;
int32_t sm = 50;
bool is_half = false;
bool is_kv_bsnh = true;
int32_t batch_size;
int32_t num_heads;
int32_t sequence_length;
int32_t kv_sequence_length;
int32_t max_sequence_length;
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
bool use_smooth_softmax;

float scale;
int32_t batch_size = 0;
int32_t num_heads = 0;
int32_t sequence_length = 0;
int32_t kv_sequence_length = 0;
int32_t max_sequence_length = 0;
int32_t qk_head_size = 0;
int32_t v_head_size = 0;
int32_t local_window_size = -1;
bool causal = false;
bool use_smooth_softmax = false;
bool broadcast_attn_bias_dim_0 = false;
bool broadcast_attn_bias_dim_1 = false;
bool has_custom_right_padding = false;
float scale = 1.0f;
float softcap = 0.0;

int32_t* seqstart_q_ptr;
int32_t* seqstart_k_ptr;
int32_t* seqlen_k_ptr;

const void* query; // [B, S, N, H]
const void* key; // [B, L, N, H], where L is kv_sequence_length
const void* value; // [B, L, N, H_v]
const void* attn_bias; // [B or 1, N or 1, S, L] or null
bool broadcast_attn_bias_dim_0;
bool broadcast_attn_bias_dim_1;

void* output; // [B, S, N, H_v]
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
cudaStream_t stream;
cudaStream_t stream = nullptr;
const int32_t* seqstart_q_ptr = nullptr; // [B + 1], cumulated sequence lengths of queries
const int32_t* seqstart_k_ptr = nullptr; // [B + 1], cumulated sequence lengths of keys
const int32_t* seqlen_k_ptr = nullptr; // [B], sequence lengths of keys
const void* query = nullptr; // [B, S, N, H]
const void* key = nullptr; // [B, L, N, H], where L is kv_sequence_length
const void* value = nullptr; // [B, L, N, H_v]
const void* attn_bias = nullptr; // [B or 1, N or 1, S, L] or null
void* workspace = nullptr; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
void* output = nullptr; // [B, S, N, H_v]

static bool need_workspace(size_t v_head_size, bool is_float) {
return (v_head_size > 128 && !is_float);
}

bool has_custom_right_padding = false;
};

void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params);
Expand Down
7 changes: 1 addition & 6 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool use_memory_efficient_attention =
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.head_size);
if (!use_flash_attention && !use_memory_efficient_attention && local_window_size_ != -1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Local attention UNSUPPORTED for sm < 80 on CUDA.");
}

// allocate buffers
size_t kv_buffer_bytes = 0;
// need a buffer if we must ungroup kv
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ Status EfficientAttention(
p.stream = stream;
p.has_custom_right_padding = true;
p.use_smooth_softmax = parameters.use_smooth_softmax;
p.local_window_size = parameters.local_window_size;
run_memory_efficient_attention(p);

DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,8 @@ Status FusedAttentionCutlass(
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
p.seqlen_k_ptr = nullptr;
p.seqstart_q_ptr = const_cast<int32_t*>(data.cumulative_sequence_length);
p.seqstart_k_ptr = const_cast<int32_t*>(data.cumulative_sequence_length);
p.seqstart_q_ptr = data.cumulative_sequence_length;
p.seqstart_k_ptr = data.cumulative_sequence_length;
p.query = data.no_qkv_workspace ? data.query : data.workspace;
p.key = data.no_qkv_workspace ? data.key : (data.workspace + elements_qk);
p.value = data.no_qkv_workspace ? data.value : (data.workspace + elements_qk + elements_qk);
Expand Down
Loading
Loading