Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable arbitrary head_dim when using gpu flash attention #1048

Merged
merged 9 commits into from
Mar 26, 2025
26 changes: 17 additions & 9 deletions axlearn/common/flash_attention/gpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _mha_forward_kernel(
dropout_rate: float,
block_q: int,
block_k: int,
head_dim: int,
):
"""Computes attention outputs for the given block.

Expand All @@ -139,6 +140,9 @@ def _mha_forward_kernel(
index_offset_size_ref: The number of valid blocks for each iteration.
o_ref: Output ref.
*residual_refs: Residual output refs, e.g. softmax statistics.
head_dim: Optional per_head_dim, necessary when per_head_dim cannot be
devided by the block size on the final dimension. When not provided, default to be
the final dimension of the q_ref.
**kwargs: See `flash_attention`.
"""
kv_seq_len = k_ref.shape[0]
Expand All @@ -156,12 +160,13 @@ def _mha_forward_kernel(
l_i = jnp.zeros(block_q, dtype=jnp.float32)
# acc is the buffer where we accumulate the output on sram.
o = jnp.zeros((block_q, block_d), dtype=jnp.float32)
d_mask = jnp.arange(block_d)[None] < head_dim

# Load q: it will stay in L1 throughout. Indices form a matrix because we
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
# q tile has shape [block_q, block_d], block_d == head_dim.
# q tile has shape [block_q, block_d], block_d >= head_dim and is a power of 2.
curr_q_slice = pl.dslice(start_q * block_q, block_q)
q = q_ref[...]
q = pl.load(q_ref, (slice(None), slice(None)), mask=d_mask, other=0)
q_segment_ids = None if s_ref is None else pl.load(s_ref, (curr_q_slice,))
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
Expand All @@ -176,7 +181,7 @@ def body(start_k, carry):
span_k = start_k * block_k + jnp.arange(block_k)
o_prev, m_prev, l_prev = carry
curr_k_slice = pl.dslice(start_k * block_k, block_k)
k = pl.load(k_ref, (curr_k_slice, slice(None)))
k = pl.load(k_ref, (curr_k_slice, slice(None)), mask=d_mask, other=0)
qk = pl.dot(q, k.T, precision=precision) # [block_q, block_k].
if softmax_scale != 1.0:
qk *= softmax_scale # [block_q, block_k].
Expand All @@ -203,7 +208,7 @@ def body(start_k, carry):
l_curr = s_curr.sum(axis=-1)
l_next = l_prev_corr + l_curr
o_prev_corr = correction[:, None] * o_prev
v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d)))
v = pl.load(v_ref, (curr_k_slice, slice(None)), mask=d_mask, other=jnp.nan)
if dropout_rate > 0:
dropout_mask = pl.load(dropout_mask_ref, (slice(None), curr_k_slice))
s_curr = jnp.where(dropout_mask, 0, s_curr / (1 - dropout_rate))
Expand All @@ -226,7 +231,7 @@ def body(start_k, carry):
lse_ref = residual_refs[0]
lse_ref[...] = m_i + jnp.log(l_i)
# Write output to dram.
o_ref[...] = o.astype(o_ref.dtype)
pl.store(o_ref, (slice(None), slice(None)), val=o.astype(o_ref.dtype), mask=d_mask)


# pylint: disable=unused-argument
Expand Down Expand Up @@ -307,6 +312,7 @@ def _flash_attention_impl(
kv_seq_len = key.shape[1]
block_q = min(block_q, q_seq_len)
block_k = min(block_k, kv_seq_len)
block_d = pl.next_power_of_2(head_dim)
assert q_seq_len % block_q == 0
assert kv_seq_len % block_k == 0
# Heuristics.
Expand All @@ -326,12 +332,13 @@ def _flash_attention_impl(
dropout_rate=dropout_rate,
block_q=block_q,
block_k=block_k,
head_dim=head_dim,
)
out_shape = jax.ShapeDtypeStruct(shape=query.shape, dtype=query.dtype) # out
in_specs = [
pl.BlockSpec((None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, block_q, None, block_d), lambda i, j, k: (j, i, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, block_d), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, block_d), lambda _, j, k: (j, 0, k, 0)),
]
if bias is not None:
assert bias.ndim == 4
Expand Down Expand Up @@ -379,7 +386,7 @@ def _flash_attention_impl(
)
in_specs.append(index_offset_spec)
in_specs.append(index_offset_size_spec)
out_specs = pl.BlockSpec((None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0))
out_specs = pl.BlockSpec((None, block_q, None, block_d), lambda i, j, k: (j, i, k, 0))
if output_activations:
out_specs = [out_specs, pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i))]
out_shape = [
Expand Down Expand Up @@ -412,6 +419,7 @@ def _mha_forward(*args: Any):
return _flash_attention_impl(*args, output_activations=True)


# TODO(lezhi): Add support arbitrary per-head-dim in backward pass
def _mha_backward_kernel_dkdv(
# Inputs.
q_ref,
Expand Down
1 change: 1 addition & 0 deletions axlearn/common/flash_attention/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
[
(1, 384, 1, 64),
(2, 384, 2, 64),
(2, 384, 2, 72),
(1, 384, 1, 128),
(2, 384, 2, 128),
(1, 384, 8, 128),
Expand Down
Loading