Skip to content

Commit 2332e8a

Browse files
authored
feat: CUDAGraph compatibility of multi-level cascade inference APIs (#586)
This PR add support for CUDAGraph compatibility for `MultiLevelCascadeAttentionWrapper`. cc @raywanb @pavanimajety @comaniac
1 parent 83e541d commit 2332e8a

File tree

3 files changed

+70
-13
lines changed

3 files changed

+70
-13
lines changed

python/flashinfer/cascade.py

+63-6
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,22 @@ class MultiLevelCascadeAttentionWrapper:
281281
...
282282
>>> outputs[0].shape
283283
torch.Size([7, 64, 128])
284+
285+
See Also
286+
--------
287+
BatchPrefillWithPagedKVCacheWrapper
284288
"""
285289

286290
def __init__(
287-
self, num_levels, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
291+
self,
292+
num_levels,
293+
float_workspace_buffer: torch.Tensor,
294+
kv_layout: str = "NHD",
295+
use_cuda_graph: bool = False,
296+
qo_indptr_buf_arr: Optional[list[torch.Tensor]] = None,
297+
paged_kv_indptr_buf_arr: Optional[list[torch.Tensor]] = None,
298+
paged_kv_indices_buf_arr: Optional[list[torch.Tensor]] = None,
299+
paged_kv_last_page_len_buf_arr: Optional[list[torch.Tensor]] = None,
288300
) -> None:
289301
r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`.
290302
@@ -298,14 +310,59 @@ def __init__(
298310
buffer should be the same as the device of the input tensors.
299311
kv_layout : str
300312
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
313+
use_cuda_graph : bool
314+
Whether to use CUDA graph to capture the kernels, if enabled, the auxiliary data structures
315+
will be stored in provided buffers.
316+
qo_indptr_buf_arr : Optional[List[torch.Tensor]]
317+
An array of qo indptr buffers for each level, the array length should be equal to
318+
the number of levels.
319+
The last element of each tensor should be the total number of queries/outputs.
320+
paged_kv_indptr_buf_arr : Optional[List[torch.Tensor]]
321+
An array of paged kv-cache indptr buffers for each level, the array length should be
322+
equal to the number of levels.
323+
paged_kv_indices_buf_arr : Optional[List[torch.Tensor]]
324+
An array of paged kv-cache indices buffers for each level, the array length should be
325+
equal to the number of levels.
326+
paged_kv_last_page_len_buf_arr : Optional[List[torch.Tensor]]
327+
An array of paged kv-cache last page length buffers for each level, the array length
328+
should be equal to the number of levels.
301329
"""
302-
self._batch_prefill_wrappers = [
303-
BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout)
304-
for _ in range(num_levels)
305-
]
330+
self._use_cuda_graph = use_cuda_graph
331+
if use_cuda_graph:
332+
self._batch_prefill_wrappers = [
333+
BatchPrefillWithPagedKVCacheWrapper(
334+
float_workspace_buffer,
335+
kv_layout,
336+
use_cuda_graph=True,
337+
qo_indptr_buf=qo_indptr_buf,
338+
paged_kv_indptr_buf=paged_kv_indptr_buf,
339+
paged_kv_indices_buf=paged_kv_indices_buf,
340+
paged_kv_last_page_len_buf=paged_kv_last_page_len_buf,
341+
)
342+
for (
343+
qo_indptr_buf,
344+
paged_kv_indptr_buf,
345+
paged_kv_indices_buf,
346+
paged_kv_last_page_len_buf,
347+
) in zip(
348+
qo_indptr_buf_arr,
349+
paged_kv_indptr_buf_arr,
350+
paged_kv_indices_buf_arr,
351+
paged_kv_last_page_len_buf_arr,
352+
)
353+
]
354+
else:
355+
self._batch_prefill_wrappers = [
356+
BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout)
357+
for _ in range(num_levels)
358+
]
306359
self._num_levels = num_levels
307360
self._kv_layout = kv_layout
308361

362+
@property
363+
def is_cuda_graph_enabled(self) -> bool:
364+
return self._use_cuda_graph
365+
309366
def reset_workspace_buffer(
310367
self,
311368
float_workspace_buffer: torch.Tensor,
@@ -912,7 +969,7 @@ def forward(
912969
k_shared: torch.Tensor,
913970
v_shared: torch.Tensor,
914971
unique_kv_cache: torch.Tensor,
915-
causal: bool = True,
972+
causal: bool = False,
916973
allow_fp16_qk_reduction: bool = False,
917974
sm_scale: Optional[float] = None,
918975
rope_scale: Optional[float] = None,

python/flashinfer/prefill.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,7 @@ def __init__(
747747
748748
use_cuda_graph : bool
749749
Whether to enable CUDA graph capture for the prefill kernels, if enabled, the
750-
auxiliary data structures will be stored as provided buffers. The ``batch_size``
750+
auxiliary data structures will be stored in provided buffers. The ``batch_size``
751751
cannot change during the lifecycle of this wrapper when CUDAGraph is enabled.
752752
753753
qo_indptr_buf : Optional[torch.Tensor]
@@ -1095,7 +1095,7 @@ def forward(
10951095
self,
10961096
q: torch.Tensor,
10971097
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
1098-
causal: bool = True,
1098+
causal: bool = False,
10991099
pos_encoding_mode: str = "NONE",
11001100
allow_fp16_qk_reduction: bool = False,
11011101
k_scale: Optional[float] = None,
@@ -1240,7 +1240,7 @@ def forward_return_lse(
12401240
self,
12411241
q: torch.Tensor,
12421242
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
1243-
causal: bool = True,
1243+
causal: bool = False,
12441244
pos_encoding_mode: str = "NONE",
12451245
allow_fp16_qk_reduction: bool = False,
12461246
k_scale: Optional[float] = None,
@@ -1491,7 +1491,7 @@ def plan(
14911491
head_dim: int,
14921492
custom_mask: Optional[torch.Tensor] = None,
14931493
packed_custom_mask: Optional[torch.Tensor] = None,
1494-
causal: bool = True,
1494+
causal: bool = False,
14951495
pos_encoding_mode: str = "NONE",
14961496
allow_fp16_qk_reduction: bool = False,
14971497
window_left: int = -1,
@@ -1683,7 +1683,7 @@ def forward(
16831683
q: torch.Tensor,
16841684
k: torch.Tensor,
16851685
v: torch.Tensor,
1686-
causal: bool = True,
1686+
causal: bool = False,
16871687
pos_encoding_mode: str = "NONE",
16881688
allow_fp16_qk_reduction: bool = False,
16891689
window_left: int = -1,
@@ -1812,7 +1812,7 @@ def forward_return_lse(
18121812
q: torch.Tensor,
18131813
k: torch.Tensor,
18141814
v: torch.Tensor,
1815-
causal: bool = True,
1815+
causal: bool = False,
18161816
pos_encoding_mode: str = "NONE",
18171817
allow_fp16_qk_reduction: bool = False,
18181818
window_left: int = -1,

tests/test_shared_prefix_kernels.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def ceil_div(a, b):
2929
@pytest.mark.parametrize("unique_kv_len", [37, 17])
3030
@pytest.mark.parametrize("shared_kv_len", [128, 512, 2048])
3131
@pytest.mark.parametrize("num_heads", [8, 16])
32-
@pytest.mark.parametrize("causal", [False, True])
32+
@pytest.mark.parametrize("causal", [False])
3333
@pytest.mark.parametrize("head_dim", [128, 256])
3434
@pytest.mark.parametrize("page_size", [1, 16])
3535
def test_batch_attention_with_shared_prefix_paged_kv_cache(

0 commit comments

Comments
 (0)