Skip to content

Commit 1e37989

Browse files
authored
feat: add MultiLevelCascadeAttentionWrapper API (#462)
Our existing cascade inference APIs all assumes shared prefix kv-cache are standalone tensors which is not the case for real-world llm serving. This PR adds a more general `MultiLevelCascadeAttentionWrapper` API which not only supports multi-level cascade inference, and the kv-cache of all levels are stored in the unified paged kv-cache, which can seamlessly integrate with existing LLM serving frameworks. Tutorials, tests and examples are updated correspondingly. The old `BatchDecodeWithSharedPrefixPagedKVCacheWrapper` and `BatchPrefillWithSharedPrefixPagedKVCacheWrapper` should be deprecated, starting from 0.2.0.
1 parent c1f576a commit 1e37989

File tree

6 files changed

+306
-55
lines changed

6 files changed

+306
-55
lines changed

docs/api/python/cascade.rst

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Cascade Attention
2525
Cascade Attention Wrapper Classes
2626
---------------------------------
2727

28+
.. autoclass:: MultiLevelCascadeAttentionWrapper
29+
:members:
30+
31+
2832
.. autoclass:: BatchDecodeWithSharedPrefixPagedKVCacheWrapper
2933
:members:
3034

docs/conf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
author = "FlashInfer Contributors"
1919
copyright = "2023-2024, {}".format(author)
2020

21-
version = "0.1.4"
22-
release = "0.1.4"
21+
version = "0.1.5"
22+
release = "0.1.5"
2323

2424
# -- General configuration ---------------------------------------------------
2525
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

docs/tutorials/kv_layout.rst

+18
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,24 @@ shape ``(indptr[-1], num_heads, head_dim)`` when the layout is ``NHD``.
4141

4242
We can use ``data[indptr[i]:indptr[i+1]]`` to slice the keys (or values) of request ``i``.
4343

44+
.. _cascade-qo-indptr-layout:
45+
46+
Multi-level Cascade Inference Query/Output Layout
47+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
48+
49+
When using multi-level `cascade inference <https://flashinfer.ai/2024/02/02/cascade-inference.html>`_,
50+
the query and output of each level are stored in ragged tensors, each level's ``qo_indptr`` array stores
51+
the interval information of each node in the cascade tree at that level, the figure below shows the
52+
``qo_indptr`` for each level in cascade inference:
53+
54+
.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/cascade_qo_indptr.png
55+
:width: 800
56+
:align: center
57+
:alt: The ``qo_indptr`` for each level in cascade inference.
58+
59+
Note that each level's ``qo_indptr`` array should start from 0, and the last element of the ``qo_indptr`` array
60+
should be equal to the sum of length for all query/output tensors.
61+
4462
FlashInfer APIs
4563
~~~~~~~~~~~~~~~
4664

python/flashinfer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
from .cascade import (
18+
MultiLevelCascadeAttentionWrapper,
1819
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
1920
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
2021
merge_state,

python/flashinfer/cascade.py

+241-2
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,238 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.
177177
return _kernels.merge_states(v, s)
178178

179179

180+
class MultiLevelCascadeAttentionWrapper:
181+
r"""Attention wrapper for memory efficient multi-level cascade inference, this API assumes all
182+
levels KV-Cache are stored in a unified paged table.
183+
184+
Check :ref:`our tutorial<page-layout>` for page table layout, and
185+
`Cascade Inference Query/Output Layout <cascade-qo-indptr-layout>` for query/output layout.
186+
187+
The idea of cascade inference is introduced in our `blog post <https://flashinfer.ai/2024/02/02/cascade-inference.html>`_.
188+
189+
Example
190+
-------
191+
>>> import torch
192+
>>> import flashinfer
193+
>>> num_layers = 32
194+
>>> num_qo_heads = 64
195+
>>> num_kv_heads = 8
196+
>>> head_dim = 128
197+
>>> page_size = 16
198+
>>> # allocate 128MB workspace buffer
199+
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
200+
>>> wrapper = flashinfer.MultiLevelCascadeAttentionWrapper(
201+
... 2, workspace_buffer, "NHD"
202+
... )
203+
>>> batch_size = 7
204+
>>> shared_kv_num_pages = 512
205+
>>> unique_kv_num_pages = 128
206+
>>> total_num_pages = shared_kv_num_pages + unique_kv_num_pages
207+
>>> shared_kv_page_indices = torch.arange(shared_kv_num_pages).int().to("cuda:0")
208+
>>> shared_kv_page_indptr = torch.tensor([0, shared_kv_num_pages], dtype=torch.int32, device="cuda:0")
209+
>>> unique_kv_page_indices = torch.arange(shared_kv_num_pages, total_num_pages).int().to("cuda:0")
210+
>>> unique_kv_page_indptr = torch.tensor(
211+
... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
212+
... )
213+
>>> shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device="cuda:0")
214+
>>> # 1 <= kv_last_page_len <= page_size
215+
>>> unique_kv_last_page_len = torch.tensor(
216+
... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
217+
... )
218+
>>> kv_cache_at_layer = [
219+
... torch.randn(
220+
... total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
221+
... ) for _ in range(num_layers)
222+
... ]
223+
>>> qo_indptr_arr = [
224+
... torch.tensor([0, batch_size], dtype=torch.int32, device="cuda:0"), # top-level for shared KV-Cache
225+
... torch.arange(batch_size + 1, dtype=torch.int32, device="cuda:0") # bottom-level for unique KV-Cache
226+
... ]
227+
>>> # create auxiliary data structures for batch decode attention
228+
>>> wrapper.begin_forward(
229+
... qo_indptr_arr,
230+
... [shared_kv_page_indptr, unique_kv_page_indptr],
231+
... [shared_kv_page_indices, unique_kv_page_indices],
232+
... [shared_kv_last_page_len, unique_kv_last_page_len],
233+
... num_qo_heads,
234+
... num_kv_heads,
235+
... head_dim,
236+
... page_size,
237+
... )
238+
>>> outputs = []
239+
>>> for i in range(num_layers):
240+
... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
241+
... # compute batch decode attention, reuse auxiliary data structures for all layers
242+
... o = wrapper.forward(q, kv_cache_at_layer[i])
243+
... outputs.append(o)
244+
...
245+
>>> # clear auxiliary data structures
246+
>>> wrapper.end_forward()
247+
>>> outputs[0].shape
248+
torch.Size([7, 64, 128])
249+
"""
250+
251+
def __init__(
252+
self, num_levels, float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD"
253+
) -> None:
254+
r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`.
255+
256+
Parameters
257+
----------
258+
num_levels : int
259+
The number of levels in the cascade attention.
260+
float_workspace_buffer : torch.Tensor
261+
The user reserved float workspace buffer used to store intermediate attention results
262+
in the split-k algorithm. The recommended size is 128MB, the device of the workspace
263+
buffer should be the same as the device of the input tensors.
264+
kv_layout : str
265+
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
266+
"""
267+
self._batch_prefill_wrappers = [
268+
BatchPrefillWithPagedKVCacheWrapper(float_workspace_buffer, kv_layout)
269+
for _ in range(num_levels)
270+
]
271+
self._kv_layout = kv_layout
272+
273+
def reset_workspace_buffer(
274+
self,
275+
float_workspace_buffer: torch.Tensor,
276+
int_workspace_buffers: list[torch.Tensor],
277+
) -> None:
278+
r"""Reset the workspace buffer.
279+
280+
Parameters
281+
----------
282+
float_workspace_buffer : torch.Tensor
283+
The new float workspace buffer, the device of the new float workspace buffer should
284+
be the same as the device of the input tensors.
285+
286+
int_workspace_buffer : torch.Tensor
287+
The new int workspace buffer, the device of the new int workspace buffer should
288+
be the same as the device of the input tensors.
289+
"""
290+
for wrapper, int_workspace_buffer in zip(
291+
self._batch_prefill_wrappers, int_workspace_buffers
292+
):
293+
wrapper.reset_workspace_buffer(float_workspace_buffer, int_workspace_buffer)
294+
295+
def begin_forward(
296+
self,
297+
qo_indptr_arr: list[torch.Tensor],
298+
paged_kv_indptr_arr: list[torch.Tensor],
299+
paged_kv_indices_arr: list[torch.Tensor],
300+
paged_kv_last_page_len: list[torch.Tensor],
301+
num_qo_heads: int,
302+
num_kv_heads: int,
303+
head_dim: int,
304+
page_size: int,
305+
):
306+
r"""Create auxiliary data structures for multi-level cascade attention for multiple
307+
forward calls within the same decode step.
308+
309+
Parameters
310+
----------
311+
qo_indptr_arr : list[torch.Tensor]
312+
An array of qo indptr tensors for each level, the array length should be equal to
313+
the number of levels. Check
314+
`Cascade Inference Query/Output Layout <cascade-qo-indptr-layout>` for query/output layout.
315+
The last element of each tensor should be the total number of queries/outputs.
316+
paged_kv_indptr_arr : list[torch.Tensor]
317+
An array of paged kv-cache indptr tensors for each level, the array length should be
318+
equal to the number of levels.
319+
paged_kv_indices_arr : list[torch.Tensor]
320+
An array of paged kv-cache indices tensors for each level, the array length should be
321+
equal to the number of levels.
322+
paged_kv_last_page_len : list[torch.Tensor]
323+
An array of paged kv-cache last page length tensors for each level, the array length
324+
should be equal to the number of levels.
325+
num_qo_heads : int
326+
The number of query/output heads.
327+
num_kv_heads : int
328+
The number of key/value heads.
329+
head_dim : int
330+
The dimension of the heads.
331+
page_size : int
332+
The page size of the paged kv-cache.
333+
"""
334+
for (
335+
wrapper,
336+
qo_indptr,
337+
paged_kv_indptr,
338+
paged_kv_indices,
339+
paged_kv_last_page_len,
340+
) in zip(
341+
self._batch_prefill_wrappers,
342+
qo_indptr_arr,
343+
paged_kv_indptr_arr,
344+
paged_kv_indices_arr,
345+
paged_kv_last_page_len,
346+
):
347+
wrapper.begin_forward(
348+
qo_indptr,
349+
paged_kv_indptr,
350+
paged_kv_indices,
351+
paged_kv_last_page_len,
352+
num_qo_heads,
353+
num_kv_heads,
354+
head_dim,
355+
page_size,
356+
)
357+
358+
def end_forward(self):
359+
r"""Clear auxiliary data structures created by :meth:`begin_forward`."""
360+
for wrapper in self._batch_prefill_wrappers:
361+
wrapper.end_forward()
362+
363+
def forward(
364+
self,
365+
q: torch.Tensor,
366+
paged_kv_cache: torch.Tensor,
367+
**kwargs,
368+
):
369+
r"""Compute multi-level cascade attention.
370+
371+
Parameters
372+
----------
373+
q : torch.Tensor
374+
The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``.
375+
paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
376+
The paged KV-Cache stored as a tuple of tensors or a single tensor:
377+
378+
* a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape:
379+
``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``,
380+
and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``.
381+
382+
* a single 5-D tensor with shape:
383+
``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
384+
:attr:`kv_layout` is ``NHD``, and
385+
``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
386+
:attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
387+
``paged_kv_cache[:, 1]`` is the value-cache.
388+
"""
389+
out, lse = self._batch_prefill_wrappers[-1].forward_return_lse(
390+
q, paged_kv_cache, **kwargs
391+
)
392+
# NOTE(Zihao): causal mask should be False for all levels except the last level
393+
kwargs["causal"] = False
394+
for wrapper in self._batch_prefill_wrappers[:-1]:
395+
out_i, lse_i = wrapper.forward_return_lse(q, paged_kv_cache, **kwargs)
396+
merge_state_in_place(out, lse, out_i, lse_i)
397+
398+
return out
399+
400+
180401
class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
181402
r"""Wrapper class for decode attention with shared-prefix paged kv-cache for batch
182-
of requests.
403+
of requests. The shared-prefix KV-Cache was stored in a standalone tensors, and the
404+
unique KV-Cache of each request was stored in a paged KV-Cache data stucture.
183405
184406
Check :ref:`our tutorial<page-layout>` for page table layout.
185407
408+
It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general
409+
multi-level cascade inference, where the KV-Cache of each level is stored in a unified
410+
page table. This API will be deprecated in the future.
411+
186412
Example
187413
-------
188414
>>> import torch
@@ -328,6 +554,11 @@ def begin_forward(
328554
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
329555
is not equal to ``num_kv_heads``, the function will use
330556
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
557+
558+
559+
See Also
560+
--------
561+
MultiLevelCascadeAttentionWrapper
331562
"""
332563
self._batch_decode_wrapper.begin_forward(
333564
unique_kv_indptr,
@@ -433,6 +664,10 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
433664
434665
Check :ref:`our tutorial<page-layout>` for paged kv-cache layout.
435666
667+
It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general
668+
multi-level cascade inference, where the KV-Cache of each level is stored in a unified
669+
page table. This API will be deprecated in the future.
670+
436671
Example
437672
-------
438673
>>> import torch
@@ -533,7 +768,7 @@ def __init__(
533768
self._kv_layout = kv_layout
534769

535770
def reset_workspace_buffer(
536-
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer
771+
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor
537772
) -> None:
538773
r"""Reset the workspace buffer.
539774
@@ -671,6 +906,10 @@ def forward(
671906
-------
672907
V : torch.Tensor
673908
The attention output, shape: ``[qo_indptr[-1], num_heads, head_dim]``.
909+
910+
See Also
911+
--------
912+
MultiLevelCascadeAttentionWrapper
674913
"""
675914
V_shared, S_shared = single_prefill_with_kv_cache_return_lse(
676915
q,

0 commit comments

Comments
 (0)