Skip to content

Commit 51236c9

Browse files
authored
perf: Dense and sparse customizable flashattention-3 template (#667)
This PR adds flashattention-3 template for improving prefill performance on hopper. Block/Vector-sparse support in FlashInfer early version are ported to FA-3 template with CustomStride abstraction in CuTE so that we can support PageAttention with any page size. The programming interface for FA3 template is exactly the same as our previous FA2 template while we add an argument `backend` to allow user to select their own backend. Functionalities that are missing in current template include custom mask and we plan to support it using JIT instead of AOT. H100 Reference performance on variable-length dense and sparse attention kernels (exposed through [BatchPrefillWithRaggedKVCacheWrapper](https://docs.flashinfer.ai/api/prefill.html#flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper) and [BatchDecodeWithPagedKVCacheWrapper](https://docs.flashinfer.ai/api/decode.html#flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper) API correspondingly, for sparse attention workload, we use PageAttention with `page_size=1`. ![image](https://github.com/user-attachments/assets/7e989f8c-8b0f-4c99-ad11-6102c2dc5090) FlashInfer's vector sparse (page_size=1) attention implementation can get 90% percent of the dense equivalent, reference benchmark: https://github.com/flashinfer-ai/flashinfer/blob/04ee9bceb5ab0a66c612c1abaee8fa28de2b2349/benchmarks/bench_hopper_attention . JIT support is left to the next PR because this PR is already heavy. For fp8 support, we will incorporate SageAttention-2 algorithm for numerical stability, and it's left to v0.2.1. Currently there is some discrepancy in attention variant interface for our FA2 and FA3 template and we will gradually fix the gap. cc @merrymercy @zhyncs @youkaichao @WoosukKwon @jason-huang03
1 parent d9d8eb1 commit 51236c9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+5730
-132
lines changed

LICENSE

+22
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,25 @@
199199
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200200
See the License for the specific language governing permissions and
201201
limitations under the License.
202+
203+
-------------------------------------------------------------------------------------------------
204+
Some of the code in this project are adapted from other open-source projects with different
205+
licenses. This product also bundles some third-party components under other open source licenses.
206+
This section summarizes those components and their licenses.
207+
See licenses/ for text of these licenses.
208+
209+
BSD 3-Clause License
210+
--------------------
211+
212+
include/flashinfer/attention/hopper/epilogue.cuh
213+
include/flashinfer/attention/hopper/mainloop.cuh
214+
include/flashinfer/attention/hopper/kernel_traits.cuh
215+
include/flashinfer/attention/hopper/named_barrier.cuh
216+
include/flashinfer/attention/hopper/tile_scheduler.cuh
217+
include/flashinfer/attention/hopper/utils.cuh
218+
219+
BSD 3-Clause "New" License
220+
--------------------------
221+
222+
3rdparty/cutlass
223+
include/flashinfer/attention/hopper/block_sparse_gather.cuh
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import re
18+
import sys
19+
from pathlib import Path
20+
21+
from .literal_map import (
22+
dtype_literal,
23+
idtype_literal,
24+
mask_mode_literal,
25+
pos_encoding_mode_literal,
26+
)
27+
28+
29+
def get_cu_file_str(
30+
head_dim,
31+
pos_encoding_mode,
32+
allow_fp16_qk_reduction,
33+
mask_mode,
34+
dtype_q,
35+
dtype_kv,
36+
dtype_out,
37+
idtype,
38+
):
39+
def get_insts(attention_variant):
40+
return "\n".join(
41+
[
42+
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
43+
Params& params,
44+
cudaStream_t stream);
45+
46+
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
47+
Params& params,
48+
cudaStream_t stream);
49+
""".format(
50+
head_dim=head_dim,
51+
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
52+
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
53+
mask_mode=mask_mode_literal[int(mask_mode)],
54+
attention_variant=attention_variant,
55+
)
56+
]
57+
)
58+
59+
dtype_q = dtype_literal[dtype_q]
60+
dtype_kv = dtype_literal[dtype_kv]
61+
dtype_out = dtype_literal[dtype_out]
62+
idtype = idtype_literal[idtype]
63+
64+
content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
65+
#include <flashinfer/attention/hopper/variants.cuh>
66+
#include <flashinfer/cutlass_utils.cuh>
67+
68+
69+
namespace flashinfer {{
70+
71+
using DTypeQ = cutlass_dtype_t<{dtype_q}>;
72+
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
73+
using DTypeO = cutlass_dtype_t<{dtype_out}>;
74+
75+
using Params = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
76+
77+
{get_insts("LogitsSoftCap")}
78+
79+
{get_insts("StandardAttention")}
80+
81+
}}"""
82+
return content
83+
84+
85+
if __name__ == "__main__":
86+
pattern = (
87+
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
88+
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
89+
)
90+
compiled_pattern = re.compile(pattern)
91+
path = Path(sys.argv[1])
92+
fname = path.name
93+
match = compiled_pattern.match(fname)
94+
95+
with open(path, "w") as f:
96+
f.write(get_cu_file_str(*match.groups()))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import re
18+
import sys
19+
from pathlib import Path
20+
21+
from .literal_map import (
22+
dtype_literal,
23+
idtype_literal,
24+
mask_mode_literal,
25+
pos_encoding_mode_literal,
26+
)
27+
28+
29+
def get_cu_file_str(
30+
head_dim,
31+
pos_encoding_mode,
32+
allow_fp16_qk_reduction,
33+
mask_mode,
34+
dtype_q,
35+
dtype_kv,
36+
dtype_out,
37+
idtype,
38+
):
39+
40+
def get_insts(attention_variant):
41+
return "\n".join(
42+
[
43+
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
44+
Params& params,
45+
cudaStream_t stream);
46+
47+
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
48+
Params& params,
49+
cudaStream_t stream);
50+
""".format(
51+
head_dim=head_dim,
52+
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
53+
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
54+
mask_mode=mask_mode_literal[int(mask_mode)],
55+
attention_variant=attention_variant,
56+
)
57+
]
58+
)
59+
60+
dtype_q = dtype_literal[dtype_q]
61+
dtype_kv = dtype_literal[dtype_kv]
62+
dtype_out = dtype_literal[dtype_out]
63+
idtype = idtype_literal[idtype]
64+
65+
content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
66+
#include <flashinfer/attention/hopper/variants.cuh>
67+
#include <flashinfer/cutlass_utils.cuh>
68+
69+
70+
namespace flashinfer {{
71+
72+
using DTypeQ = cutlass_dtype_t<{dtype_q}>;
73+
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
74+
using DTypeO = cutlass_dtype_t<{dtype_out}>;
75+
76+
using Params = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
77+
78+
{get_insts("LogitsSoftCap")}
79+
80+
{get_insts("StandardAttention")}
81+
82+
}}
83+
"""
84+
return content
85+
86+
87+
if __name__ == "__main__":
88+
pattern = (
89+
r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
90+
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
91+
)
92+
compiled_pattern = re.compile(pattern)
93+
path = Path(sys.argv[1])
94+
fname = path.name
95+
match = compiled_pattern.match(fname)
96+
with open(path, "w") as f:
97+
f.write(get_cu_file_str(*match.groups()))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import re
18+
import sys
19+
from pathlib import Path
20+
21+
from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal
22+
23+
24+
def get_cu_file_str(
25+
head_dim,
26+
pos_encoding_mode,
27+
allow_fp16_qk_reduction,
28+
mask_mode,
29+
dtype_q,
30+
dtype_kv,
31+
dtype_out,
32+
):
33+
content = """#include <flashinfer/attention/hopper/prefill_sm90.cuh>
34+
#include <flashinfer/attention/hopper/variants.cuh>
35+
#include <flashinfer/cutlass_utils.cuh>
36+
37+
namespace flashinfer {{
38+
39+
using DTypeQ = cutlass_dtype_t<{dtype_q}>;
40+
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
41+
using DTypeO = cutlass_dtype_t<{dtype_out}>;
42+
43+
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
44+
45+
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>(
46+
Params& params,
47+
cudaStream_t stream);
48+
49+
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>(
50+
Params& params,
51+
cudaStream_t stream);
52+
53+
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>(
54+
Params& params,
55+
cudaStream_t stream);
56+
57+
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>(
58+
Params& params,
59+
cudaStream_t stream);
60+
}}
61+
""".format(
62+
head_dim=head_dim,
63+
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
64+
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
65+
mask_mode=mask_mode_literal[int(mask_mode)],
66+
dtype_q=dtype_literal[dtype_q],
67+
dtype_kv=dtype_literal[dtype_kv],
68+
dtype_out=dtype_literal[dtype_out],
69+
use_custom_mask="true" if int(mask_mode) == 2 else "false",
70+
)
71+
return content
72+
73+
74+
if __name__ == "__main__":
75+
pattern = (
76+
r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_"
77+
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_sm90\.cu"
78+
)
79+
80+
compiled_pattern = re.compile(pattern)
81+
path = Path(sys.argv[1])
82+
fname = path.name
83+
match = compiled_pattern.match(fname)
84+
with open(path, "w") as f:
85+
f.write(get_cu_file_str(*match.groups()))

0 commit comments

Comments
 (0)