@@ -281,10 +281,22 @@ class MultiLevelCascadeAttentionWrapper:
281
281
...
282
282
>>> outputs[0].shape
283
283
torch.Size([7, 64, 128])
284
+
285
+ See Also
286
+ --------
287
+ BatchPrefillWithPagedKVCacheWrapper
284
288
"""
285
289
286
290
def __init__ (
287
- self , num_levels , float_workspace_buffer : torch .Tensor , kv_layout : str = "NHD"
291
+ self ,
292
+ num_levels ,
293
+ float_workspace_buffer : torch .Tensor ,
294
+ kv_layout : str = "NHD" ,
295
+ use_cuda_graph : bool = False ,
296
+ qo_indptr_buf_arr : Optional [list [torch .Tensor ]] = None ,
297
+ paged_kv_indptr_buf_arr : Optional [list [torch .Tensor ]] = None ,
298
+ paged_kv_indices_buf_arr : Optional [list [torch .Tensor ]] = None ,
299
+ paged_kv_last_page_len_buf_arr : Optional [list [torch .Tensor ]] = None ,
288
300
) -> None :
289
301
r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`.
290
302
@@ -298,14 +310,59 @@ def __init__(
298
310
buffer should be the same as the device of the input tensors.
299
311
kv_layout : str
300
312
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
313
+ use_cuda_graph : bool
314
+ Whether to use CUDA graph to capture the kernels, if enabled, the auxiliary data structures
315
+ will be stored in provided buffers.
316
+ qo_indptr_buf_arr : Optional[List[torch.Tensor]]
317
+ An array of qo indptr buffers for each level, the array length should be equal to
318
+ the number of levels.
319
+ The last element of each tensor should be the total number of queries/outputs.
320
+ paged_kv_indptr_buf_arr : Optional[List[torch.Tensor]]
321
+ An array of paged kv-cache indptr buffers for each level, the array length should be
322
+ equal to the number of levels.
323
+ paged_kv_indices_buf_arr : Optional[List[torch.Tensor]]
324
+ An array of paged kv-cache indices buffers for each level, the array length should be
325
+ equal to the number of levels.
326
+ paged_kv_last_page_len_buf_arr : Optional[List[torch.Tensor]]
327
+ An array of paged kv-cache last page length buffers for each level, the array length
328
+ should be equal to the number of levels.
301
329
"""
302
- self ._batch_prefill_wrappers = [
303
- BatchPrefillWithPagedKVCacheWrapper (float_workspace_buffer , kv_layout )
304
- for _ in range (num_levels )
305
- ]
330
+ self ._use_cuda_graph = use_cuda_graph
331
+ if use_cuda_graph :
332
+ self ._batch_prefill_wrappers = [
333
+ BatchPrefillWithPagedKVCacheWrapper (
334
+ float_workspace_buffer ,
335
+ kv_layout ,
336
+ use_cuda_graph = True ,
337
+ qo_indptr_buf = qo_indptr_buf ,
338
+ paged_kv_indptr_buf = paged_kv_indptr_buf ,
339
+ paged_kv_indices_buf = paged_kv_indices_buf ,
340
+ paged_kv_last_page_len_buf = paged_kv_last_page_len_buf ,
341
+ )
342
+ for (
343
+ qo_indptr_buf ,
344
+ paged_kv_indptr_buf ,
345
+ paged_kv_indices_buf ,
346
+ paged_kv_last_page_len_buf ,
347
+ ) in zip (
348
+ qo_indptr_buf_arr ,
349
+ paged_kv_indptr_buf_arr ,
350
+ paged_kv_indices_buf_arr ,
351
+ paged_kv_last_page_len_buf_arr ,
352
+ )
353
+ ]
354
+ else :
355
+ self ._batch_prefill_wrappers = [
356
+ BatchPrefillWithPagedKVCacheWrapper (float_workspace_buffer , kv_layout )
357
+ for _ in range (num_levels )
358
+ ]
306
359
self ._num_levels = num_levels
307
360
self ._kv_layout = kv_layout
308
361
362
+ @property
363
+ def is_cuda_graph_enabled (self ) -> bool :
364
+ return self ._use_cuda_graph
365
+
309
366
def reset_workspace_buffer (
310
367
self ,
311
368
float_workspace_buffer : torch .Tensor ,
@@ -912,7 +969,7 @@ def forward(
912
969
k_shared : torch .Tensor ,
913
970
v_shared : torch .Tensor ,
914
971
unique_kv_cache : torch .Tensor ,
915
- causal : bool = True ,
972
+ causal : bool = False ,
916
973
allow_fp16_qk_reduction : bool = False ,
917
974
sm_scale : Optional [float ] = None ,
918
975
rope_scale : Optional [float ] = None ,
0 commit comments