@@ -177,12 +177,238 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.
177
177
return _kernels .merge_states (v , s )
178
178
179
179
180
+ class MultiLevelCascadeAttentionWrapper :
181
+ r"""Attention wrapper for memory efficient multi-level cascade inference, this API assumes all
182
+ levels KV-Cache are stored in a unified paged table.
183
+
184
+ Check :ref:`our tutorial<page-layout>` for page table layout, and
185
+ `Cascade Inference Query/Output Layout <cascade-qo-indptr-layout>` for query/output layout.
186
+
187
+ The idea of cascade inference is introduced in our `blog post <https://flashinfer.ai/2024/02/02/cascade-inference.html>`_.
188
+
189
+ Example
190
+ -------
191
+ >>> import torch
192
+ >>> import flashinfer
193
+ >>> num_layers = 32
194
+ >>> num_qo_heads = 64
195
+ >>> num_kv_heads = 8
196
+ >>> head_dim = 128
197
+ >>> page_size = 16
198
+ >>> # allocate 128MB workspace buffer
199
+ >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
200
+ >>> wrapper = flashinfer.MultiLevelCascadeAttentionWrapper(
201
+ ... 2, workspace_buffer, "NHD"
202
+ ... )
203
+ >>> batch_size = 7
204
+ >>> shared_kv_num_pages = 512
205
+ >>> unique_kv_num_pages = 128
206
+ >>> total_num_pages = shared_kv_num_pages + unique_kv_num_pages
207
+ >>> shared_kv_page_indices = torch.arange(shared_kv_num_pages).int().to("cuda:0")
208
+ >>> shared_kv_page_indptr = torch.tensor([0, shared_kv_num_pages], dtype=torch.int32, device="cuda:0")
209
+ >>> unique_kv_page_indices = torch.arange(shared_kv_num_pages, total_num_pages).int().to("cuda:0")
210
+ >>> unique_kv_page_indptr = torch.tensor(
211
+ ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
212
+ ... )
213
+ >>> shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device="cuda:0")
214
+ >>> # 1 <= kv_last_page_len <= page_size
215
+ >>> unique_kv_last_page_len = torch.tensor(
216
+ ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
217
+ ... )
218
+ >>> kv_cache_at_layer = [
219
+ ... torch.randn(
220
+ ... total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
221
+ ... ) for _ in range(num_layers)
222
+ ... ]
223
+ >>> qo_indptr_arr = [
224
+ ... torch.tensor([0, batch_size], dtype=torch.int32, device="cuda:0"), # top-level for shared KV-Cache
225
+ ... torch.arange(batch_size + 1, dtype=torch.int32, device="cuda:0") # bottom-level for unique KV-Cache
226
+ ... ]
227
+ >>> # create auxiliary data structures for batch decode attention
228
+ >>> wrapper.begin_forward(
229
+ ... qo_indptr_arr,
230
+ ... [shared_kv_page_indptr, unique_kv_page_indptr],
231
+ ... [shared_kv_page_indices, unique_kv_page_indices],
232
+ ... [shared_kv_last_page_len, unique_kv_last_page_len],
233
+ ... num_qo_heads,
234
+ ... num_kv_heads,
235
+ ... head_dim,
236
+ ... page_size,
237
+ ... )
238
+ >>> outputs = []
239
+ >>> for i in range(num_layers):
240
+ ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
241
+ ... # compute batch decode attention, reuse auxiliary data structures for all layers
242
+ ... o = wrapper.forward(q, kv_cache_at_layer[i])
243
+ ... outputs.append(o)
244
+ ...
245
+ >>> # clear auxiliary data structures
246
+ >>> wrapper.end_forward()
247
+ >>> outputs[0].shape
248
+ torch.Size([7, 64, 128])
249
+ """
250
+
251
+ def __init__ (
252
+ self , num_levels , float_workspace_buffer : torch .Tensor , kv_layout : str = "NHD"
253
+ ) -> None :
254
+ r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`.
255
+
256
+ Parameters
257
+ ----------
258
+ num_levels : int
259
+ The number of levels in the cascade attention.
260
+ float_workspace_buffer : torch.Tensor
261
+ The user reserved float workspace buffer used to store intermediate attention results
262
+ in the split-k algorithm. The recommended size is 128MB, the device of the workspace
263
+ buffer should be the same as the device of the input tensors.
264
+ kv_layout : str
265
+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
266
+ """
267
+ self ._batch_prefill_wrappers = [
268
+ BatchPrefillWithPagedKVCacheWrapper (float_workspace_buffer , kv_layout )
269
+ for _ in range (num_levels )
270
+ ]
271
+ self ._kv_layout = kv_layout
272
+
273
+ def reset_workspace_buffer (
274
+ self ,
275
+ float_workspace_buffer : torch .Tensor ,
276
+ int_workspace_buffers : list [torch .Tensor ],
277
+ ) -> None :
278
+ r"""Reset the workspace buffer.
279
+
280
+ Parameters
281
+ ----------
282
+ float_workspace_buffer : torch.Tensor
283
+ The new float workspace buffer, the device of the new float workspace buffer should
284
+ be the same as the device of the input tensors.
285
+
286
+ int_workspace_buffer : torch.Tensor
287
+ The new int workspace buffer, the device of the new int workspace buffer should
288
+ be the same as the device of the input tensors.
289
+ """
290
+ for wrapper , int_workspace_buffer in zip (
291
+ self ._batch_prefill_wrappers , int_workspace_buffers
292
+ ):
293
+ wrapper .reset_workspace_buffer (float_workspace_buffer , int_workspace_buffer )
294
+
295
+ def begin_forward (
296
+ self ,
297
+ qo_indptr_arr : list [torch .Tensor ],
298
+ paged_kv_indptr_arr : list [torch .Tensor ],
299
+ paged_kv_indices_arr : list [torch .Tensor ],
300
+ paged_kv_last_page_len : list [torch .Tensor ],
301
+ num_qo_heads : int ,
302
+ num_kv_heads : int ,
303
+ head_dim : int ,
304
+ page_size : int ,
305
+ ):
306
+ r"""Create auxiliary data structures for multi-level cascade attention for multiple
307
+ forward calls within the same decode step.
308
+
309
+ Parameters
310
+ ----------
311
+ qo_indptr_arr : list[torch.Tensor]
312
+ An array of qo indptr tensors for each level, the array length should be equal to
313
+ the number of levels. Check
314
+ `Cascade Inference Query/Output Layout <cascade-qo-indptr-layout>` for query/output layout.
315
+ The last element of each tensor should be the total number of queries/outputs.
316
+ paged_kv_indptr_arr : list[torch.Tensor]
317
+ An array of paged kv-cache indptr tensors for each level, the array length should be
318
+ equal to the number of levels.
319
+ paged_kv_indices_arr : list[torch.Tensor]
320
+ An array of paged kv-cache indices tensors for each level, the array length should be
321
+ equal to the number of levels.
322
+ paged_kv_last_page_len : list[torch.Tensor]
323
+ An array of paged kv-cache last page length tensors for each level, the array length
324
+ should be equal to the number of levels.
325
+ num_qo_heads : int
326
+ The number of query/output heads.
327
+ num_kv_heads : int
328
+ The number of key/value heads.
329
+ head_dim : int
330
+ The dimension of the heads.
331
+ page_size : int
332
+ The page size of the paged kv-cache.
333
+ """
334
+ for (
335
+ wrapper ,
336
+ qo_indptr ,
337
+ paged_kv_indptr ,
338
+ paged_kv_indices ,
339
+ paged_kv_last_page_len ,
340
+ ) in zip (
341
+ self ._batch_prefill_wrappers ,
342
+ qo_indptr_arr ,
343
+ paged_kv_indptr_arr ,
344
+ paged_kv_indices_arr ,
345
+ paged_kv_last_page_len ,
346
+ ):
347
+ wrapper .begin_forward (
348
+ qo_indptr ,
349
+ paged_kv_indptr ,
350
+ paged_kv_indices ,
351
+ paged_kv_last_page_len ,
352
+ num_qo_heads ,
353
+ num_kv_heads ,
354
+ head_dim ,
355
+ page_size ,
356
+ )
357
+
358
+ def end_forward (self ):
359
+ r"""Clear auxiliary data structures created by :meth:`begin_forward`."""
360
+ for wrapper in self ._batch_prefill_wrappers :
361
+ wrapper .end_forward ()
362
+
363
+ def forward (
364
+ self ,
365
+ q : torch .Tensor ,
366
+ paged_kv_cache : torch .Tensor ,
367
+ ** kwargs ,
368
+ ):
369
+ r"""Compute multi-level cascade attention.
370
+
371
+ Parameters
372
+ ----------
373
+ q : torch.Tensor
374
+ The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``.
375
+ paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
376
+ The paged KV-Cache stored as a tuple of tensors or a single tensor:
377
+
378
+ * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape:
379
+ ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``,
380
+ and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``.
381
+
382
+ * a single 5-D tensor with shape:
383
+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
384
+ :attr:`kv_layout` is ``NHD``, and
385
+ ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
386
+ :attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
387
+ ``paged_kv_cache[:, 1]`` is the value-cache.
388
+ """
389
+ out , lse = self ._batch_prefill_wrappers [- 1 ].forward_return_lse (
390
+ q , paged_kv_cache , ** kwargs
391
+ )
392
+ # NOTE(Zihao): causal mask should be False for all levels except the last level
393
+ kwargs ["causal" ] = False
394
+ for wrapper in self ._batch_prefill_wrappers [:- 1 ]:
395
+ out_i , lse_i = wrapper .forward_return_lse (q , paged_kv_cache , ** kwargs )
396
+ merge_state_in_place (out , lse , out_i , lse_i )
397
+
398
+ return out
399
+
400
+
180
401
class BatchDecodeWithSharedPrefixPagedKVCacheWrapper :
181
402
r"""Wrapper class for decode attention with shared-prefix paged kv-cache for batch
182
- of requests.
403
+ of requests. The shared-prefix KV-Cache was stored in a standalone tensors, and the
404
+ unique KV-Cache of each request was stored in a paged KV-Cache data stucture.
183
405
184
406
Check :ref:`our tutorial<page-layout>` for page table layout.
185
407
408
+ It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general
409
+ multi-level cascade inference, where the KV-Cache of each level is stored in a unified
410
+ page table. This API will be deprecated in the future.
411
+
186
412
Example
187
413
-------
188
414
>>> import torch
@@ -328,6 +554,11 @@ def begin_forward(
328
554
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
329
555
is not equal to ``num_kv_heads``, the function will use
330
556
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
557
+
558
+
559
+ See Also
560
+ --------
561
+ MultiLevelCascadeAttentionWrapper
331
562
"""
332
563
self ._batch_decode_wrapper .begin_forward (
333
564
unique_kv_indptr ,
@@ -433,6 +664,10 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
433
664
434
665
Check :ref:`our tutorial<page-layout>` for paged kv-cache layout.
435
666
667
+ It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general
668
+ multi-level cascade inference, where the KV-Cache of each level is stored in a unified
669
+ page table. This API will be deprecated in the future.
670
+
436
671
Example
437
672
-------
438
673
>>> import torch
@@ -533,7 +768,7 @@ def __init__(
533
768
self ._kv_layout = kv_layout
534
769
535
770
def reset_workspace_buffer (
536
- self , float_workspace_buffer : torch .Tensor , int_workspace_buffer
771
+ self , float_workspace_buffer : torch .Tensor , int_workspace_buffer : torch . Tensor
537
772
) -> None :
538
773
r"""Reset the workspace buffer.
539
774
@@ -671,6 +906,10 @@ def forward(
671
906
-------
672
907
V : torch.Tensor
673
908
The attention output, shape: ``[qo_indptr[-1], num_heads, head_dim]``.
909
+
910
+ See Also
911
+ --------
912
+ MultiLevelCascadeAttentionWrapper
674
913
"""
675
914
V_shared , S_shared = single_prefill_with_kv_cache_return_lse (
676
915
q ,
0 commit comments