Skip to content

Commit 89f2c4a

Browse files
authored
feat: non-contiguous query with paged kv cache (#553)
## Motivation Previously, only ragged version of prefill kernel supported non-contiguous query tensor (#404). But with paged kv cache, you have to make query tensor contiguous. Libraries like vLLM or SGLang must make query tensor contiguous before calling flashinfer kernels ([vLLM call of flashinfer](https://github.com/vllm-project/vllm/blob/b7df53cd42f3eab007b4f287c151960858e949df/vllm/attention/backends/flashinfer.py#L839), [SGLang call of flashinfer](https://github.com/sgl-project/sglang/blob/87a7cfa080cec3f123618c1429b5f998bf5d99cb/python/sglang/srt/layers/attention/flashinfer_backend.py#L236)). This PR solves it, ensuring that prefill/decode kernels with paged kv cache support non-contiguous query tensor. ## Main Changes 1. Add strides of query tensor in `BatchPrefillPagedParams` and `BatchDecodeParams`. 2. Set stride parameters before calling those kernels. 3. Modify JIT compiling templates to support new kernel parameters. 4. Add some tests. The Python interfaces remain the same. Nothing changes except it accepts non-contiguous query tensors now! --------- Signed-off-by: LinHeLurking <[email protected]>
1 parent f6e0010 commit 89f2c4a

10 files changed

+196
-17
lines changed

flashinfer-aot/csrc_aot/batch_decode.cu

+6-2
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
128128
auto q_scalar_type = q.scalar_type();
129129
auto kv_scalar_type = paged_k_cache.scalar_type();
130130

131+
// get q_stride_n and q_stride_h
132+
const auto q_stride_n = q.stride(0);
133+
const auto q_stride_h = q.stride(1);
134+
131135
// get kv_cache_strides
132136
const int64_t* kv_cache_strides = nullptr;
133137
auto k_strides = paged_k_cache.strides();
@@ -157,8 +161,8 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
157161
ParamsT params(static_cast<DTypeQ*>(q.data_ptr()),
158162
/*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()),
159163
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
160-
/*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap,
161-
sm_scale, rope_scale, rope_theta);
164+
/*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left,
165+
logits_soft_cap, sm_scale, rope_scale, rope_theta);
162166

163167
DTypeO* tmp_v = nullptr;
164168
float* tmp_s = nullptr;

flashinfer-aot/csrc_aot/batch_prefill.cu

+7-5
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
237237
auto q_scalar_type = q.scalar_type();
238238
auto kv_scalar_type = paged_k_cache.scalar_type();
239239

240+
// get q_stride_n and q_stride_h
241+
const auto q_stride_n = q.stride(0);
242+
const auto q_stride_h = q.stride(1);
243+
240244
// get kv_cache_strides
241245
const int64_t* kv_cache_strides = nullptr;
242246
auto k_strides = paged_k_cache.strides();
@@ -254,8 +258,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
254258
paged_kv_t<DTypeKV, IdType> paged_kv(
255259
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
256260
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
257-
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
258-
kv_cache_strides,
261+
static_cast<DTypeKV*>(paged_v_cache.data_ptr()), kv_cache_strides,
259262
static_cast<IdType*>(paged_kv_indices.data_ptr()),
260263
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
261264
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));
@@ -266,7 +269,6 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
266269
get_variant_code(/*use_custom_mask=*/MASK_MODE == MaskMode::kCustom,
267270
/*use_sliding_window=*/true, USE_LOGITS_SOFT_CAP,
268271
/*use_alibi_slopes=*/false)>;
269-
270272
PagedParamsT params(
271273
static_cast<DTypeQ*>(q.data_ptr()), paged_kv,
272274
maybe_custom_mask.has_value() ? static_cast<uint8_t*>(maybe_custom_mask->data_ptr())
@@ -276,8 +278,8 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCacheRun(
276278
: nullptr,
277279
/*q_offset=*/nullptr, static_cast<DTypeO*>(o.data_ptr()),
278280
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
279-
/*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap, sm_scale,
280-
rope_scale, rope_theta);
281+
/*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left,
282+
logits_soft_cap, sm_scale, rope_scale, rope_theta);
281283

282284
DTypeO* tmp_v = nullptr;
283285
float* tmp_s = nullptr;

include/flashinfer/attention/decode.cuh

+4-2
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
439439
vec_t<float, vec_size> q_vec;
440440
vec_t<float, vec_size> freq;
441441
int32_t q_offset_val = q_offset == nullptr ? (kv_len - 1) : q_offset[batch_idx];
442+
const uint32_t q_stride_n = params.q_stride_n;
443+
const uint32_t q_stride_h = params.q_stride_h;
442444
if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) {
443445
const float rope_rcp_scale = params.rope_rcp_scale;
444446
const float rope_rcp_theta = params.rope_rcp_theta;
@@ -450,10 +452,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__
450452
}
451453
// apply rotary embedding to q matrix
452454
q_vec = vec_apply_llama_rope<vec_size, bdx>(
453-
q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, q_offset_val);
455+
q + batch_idx * q_stride_n + qo_head_idx * q_stride_h, freq, q_offset_val);
454456
} else {
455457
// do not apply rotary embedding to q matrix
456-
q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
458+
q_vec.cast_load(q + batch_idx * q_stride_n + qo_head_idx * q_stride_h + tx * vec_size);
457459
}
458460
#pragma unroll
459461
for (uint32_t i = 0; i < vec_size; ++i) {

include/flashinfer/attention/decode_params.cuh

+7-2
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ struct BatchDecodeParams {
119119
float* alibi_slopes;
120120
uint32_t padded_batch_size;
121121
uint32_t num_qo_heads;
122+
IdType q_stride_n;
123+
IdType q_stride_h;
122124
int32_t window_left;
123125
float logits_soft_cap;
124126
float sm_scale;
@@ -135,8 +137,9 @@ struct BatchDecodeParams {
135137
__device__ __host__ BatchDecodeParams(DTypeQ* q, IdType* q_offset,
136138
paged_kv_t<DTypeKV, IdType> paged_kv, DTypeO* o, float* lse,
137139
float* alibi_slopes, uint32_t num_qo_heads,
138-
int32_t window_left, float logits_soft_cap, float sm_scale,
139-
float rope_scale, float rope_theta)
140+
IdType q_stride_n, IdType q_stride_h, int32_t window_left,
141+
float logits_soft_cap, float sm_scale, float rope_scale,
142+
float rope_theta)
140143
: q(q),
141144
q_offset(q_offset),
142145
paged_kv(paged_kv),
@@ -145,6 +148,8 @@ struct BatchDecodeParams {
145148
alibi_slopes(alibi_slopes),
146149
padded_batch_size(0),
147150
num_qo_heads(num_qo_heads),
151+
q_stride_n(q_stride_n),
152+
q_stride_h(q_stride_h),
148153
window_left(window_left),
149154
logits_soft_cap(logits_soft_cap),
150155
sm_scale(sm_scale),

include/flashinfer/attention/prefill.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1867,7 +1867,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag
18671867
const uint32_t qo_packed_idx_base =
18681868
(qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q<NUM_WARPS_Q, NUM_WARPS_KV>()) * NUM_FRAGS_Q *
18691869
16;
1870-
const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim;
1870+
const uint32_t q_stride_n = params.q_stride_n, q_stride_h = params.q_stride_h;
18711871
constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B;
18721872
smem_t<swizzle_mode_q> qo_smem(smem);
18731873
DTypeQ* q_ptr_base = q + get_elem_offset_impl(q_indptr[request_idx], kv_head_idx * group_size,

include/flashinfer/attention/prefill_params.cuh

+7-3
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ struct BatchPrefillPagedParams {
212212
float* lse;
213213
float* alibi_slopes;
214214
uint32_t num_qo_heads;
215+
IdType q_stride_n;
216+
IdType q_stride_h;
215217
int32_t window_left;
216218
float logits_soft_cap;
217219
float sm_scale;
@@ -232,9 +234,9 @@ struct BatchPrefillPagedParams {
232234
__host__ BatchPrefillPagedParams(DTypeQ* q, paged_kv_t<DTypeKV, IdType> paged_kv,
233235
uint8_t* custom_mask, IdType* q_indptr, IdType* qk_indptr,
234236
IdType* q_offset, DTypeO* o, float* lse, float* alibi_slopes,
235-
uint32_t num_qo_heads, int32_t window_left,
236-
float logits_soft_cap, float sm_scale, float rope_scale,
237-
float rope_theta)
237+
uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h,
238+
int32_t window_left, float logits_soft_cap, float sm_scale,
239+
float rope_scale, float rope_theta)
238240
: q(q),
239241
paged_kv(paged_kv),
240242
custom_mask(custom_mask),
@@ -245,6 +247,8 @@ struct BatchPrefillPagedParams {
245247
lse(lse),
246248
alibi_slopes(alibi_slopes),
247249
num_qo_heads(num_qo_heads),
250+
q_stride_n(q_stride_n),
251+
q_stride_h(q_stride_h),
248252
window_left(window_left),
249253
logits_soft_cap(logits_soft_cap),
250254
sm_scale(sm_scale),

python/flashinfer/jit/batch_decode_templ.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@
100100
101101
void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr());
102102
void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr());
103+
104+
const auto q_stride_n = q.stride(0);
105+
const auto q_stride_h = q.stride(1);
103106
104107
const int64_t* kv_cache_strides = nullptr;
105108
auto k_strides = paged_k_cache.strides();
@@ -121,7 +124,7 @@
121124
/*q_offset=*/nullptr, paged_kv, static_cast<{{ dtype_o }}*>(o.data_ptr()),
122125
/*lse=*/(return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr),
123126
{% if use_alibi == "true" %}static_cast<float*>(alibi_slopes->data_ptr()){% else %}nullptr{% endif %},
124-
num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);
127+
num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);
125128
126129
{{ dtype_o }}* tmp_v = nullptr;
127130
float* tmp_s = nullptr;

python/flashinfer/jit/batch_prefill_templ.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@
195195
196196
void* float_buffer_ptr = static_cast<void*>(float_workspace_buffer.data_ptr());
197197
void* int_buffer_ptr = static_cast<void*>(int_workspace_buffer.data_ptr());
198+
199+
const auto q_stride_n = q.stride(0);
200+
const auto q_stride_h = q.stride(1);
198201
199202
const int64_t* kv_cache_strides = nullptr;
200203
auto k_strides = paged_k_cache.strides();
@@ -221,7 +224,7 @@
221224
static_cast<{{ dtype_o }}*>(o.data_ptr()),
222225
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
223226
{% if use_alibi == "true" %}static_cast<float*>(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %},
224-
num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);
227+
num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta);
225228
226229
{{ dtype_o }}* tmp_v = nullptr;
227230
float* tmp_s = nullptr;

tests/test_non_contiguous_decode.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import torch
2+
import pytest
3+
import flashinfer
4+
5+
6+
@pytest.mark.parametrize("batch_size", [1, 19, 99])
7+
@pytest.mark.parametrize("page_size", [1, 5])
8+
@pytest.mark.parametrize("seq_len", [1])
9+
@pytest.mark.parametrize("num_kv_heads", [1, 4, 8])
10+
@pytest.mark.parametrize("num_qo_heads", [4, 8])
11+
@pytest.mark.parametrize("head_dim", [64, 128, 256])
12+
def test_batch_paged_decode_packed_input(
13+
batch_size,
14+
page_size,
15+
seq_len,
16+
num_kv_heads,
17+
num_qo_heads,
18+
head_dim,
19+
):
20+
if num_qo_heads % num_kv_heads != 0:
21+
pytest.skip("num_qo_heads must be a multiple of num_kv_heads")
22+
nnz = batch_size * seq_len
23+
num_pages_per_req = (seq_len + page_size - 1) // page_size
24+
num_pages = batch_size * num_pages_per_req
25+
last_page_len = (seq_len - 1) % page_size + 1
26+
k_cache = torch.randn(
27+
size=(num_pages, page_size, num_kv_heads, head_dim),
28+
dtype=torch.float16,
29+
device="cuda:0",
30+
)
31+
v_cache = torch.randn_like(k_cache)
32+
paged_kv_cache = (k_cache, v_cache)
33+
workspace_buffer = torch.empty(
34+
(256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0"
35+
)
36+
paged_kv_indptr = torch.tensor(
37+
[i * num_pages_per_req for i in range(batch_size + 1)],
38+
dtype=torch.int32,
39+
device="cuda:0",
40+
)
41+
paged_kv_indices = torch.tensor(
42+
list(range(num_pages)), dtype=torch.int32, device="cuda:0"
43+
)
44+
paged_kv_last_page_len = torch.tensor(
45+
[last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0"
46+
)
47+
48+
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer)
49+
wrapper.plan(
50+
indptr=paged_kv_indptr,
51+
indices=paged_kv_indices,
52+
last_page_len=paged_kv_last_page_len,
53+
num_qo_heads=num_qo_heads,
54+
num_kv_heads=num_kv_heads,
55+
head_dim=head_dim,
56+
page_size=page_size,
57+
)
58+
59+
qkv_packed = torch.randn(
60+
size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim),
61+
dtype=torch.float16,
62+
device="cuda:0",
63+
)
64+
qkv_split_idx = (
65+
num_qo_heads * head_dim,
66+
num_kv_heads * head_dim,
67+
num_kv_heads * head_dim,
68+
)
69+
q, _, _ = qkv_packed.split(qkv_split_idx, dim=-1)
70+
q = q.view(-1, num_qo_heads, head_dim)
71+
o_packed = wrapper.run(q, paged_kv_cache)
72+
o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache)
73+
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3)
74+
75+
76+
if __name__ == "__main__":
77+
test_batch_paged_decode_packed_input(37, 127, 1, 4, 64, 128)

tests/test_non_contiguous_prefill.py

+79
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,85 @@ def test_batch_ragged_prefill_packed_input(
9696
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3)
9797

9898

99+
@pytest.mark.parametrize("batch_size", [1, 19, 99])
100+
@pytest.mark.parametrize("page_size", [1, 5])
101+
@pytest.mark.parametrize("seq_len", [1, 7, 127, 257])
102+
@pytest.mark.parametrize("num_kv_heads", [1, 4, 8])
103+
@pytest.mark.parametrize("num_qo_heads", [4, 8])
104+
@pytest.mark.parametrize("head_dim", [64, 128, 256])
105+
@pytest.mark.parametrize("causal", [True, False])
106+
def test_batch_paged_prefill_packed_input(
107+
batch_size,
108+
page_size,
109+
seq_len,
110+
num_kv_heads,
111+
num_qo_heads,
112+
head_dim,
113+
causal,
114+
):
115+
if num_qo_heads % num_kv_heads != 0:
116+
pytest.skip("num_qo_heads must be a multiple of num_kv_heads")
117+
118+
nnz = batch_size * seq_len
119+
num_pages_per_req = (seq_len + page_size - 1) // page_size
120+
num_pages = batch_size * num_pages_per_req
121+
last_page_len = (seq_len - 1) % page_size + 1
122+
k_cache = torch.randn(
123+
size=(num_pages, page_size, num_kv_heads, head_dim),
124+
dtype=torch.float16,
125+
device="cuda:0",
126+
)
127+
v_cache = torch.randn_like(k_cache)
128+
paged_kv_cache = (k_cache, v_cache)
129+
workspace_buffer = torch.empty(
130+
(256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0"
131+
)
132+
qo_indptr = torch.tensor(
133+
[i * seq_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
134+
)
135+
paged_kv_indptr = torch.tensor(
136+
[i * num_pages_per_req for i in range(batch_size + 1)],
137+
dtype=torch.int32,
138+
device="cuda:0",
139+
)
140+
paged_kv_indices = torch.tensor(
141+
list(range(num_pages)), dtype=torch.int32, device="cuda:0"
142+
)
143+
paged_kv_last_page_len = torch.tensor(
144+
[last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0"
145+
)
146+
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer)
147+
wrapper.plan(
148+
qo_indptr=qo_indptr,
149+
paged_kv_indptr=paged_kv_indptr,
150+
paged_kv_indices=paged_kv_indices,
151+
paged_kv_last_page_len=paged_kv_last_page_len,
152+
num_qo_heads=num_qo_heads,
153+
num_kv_heads=num_kv_heads,
154+
head_dim=head_dim,
155+
page_size=page_size,
156+
causal=causal,
157+
)
158+
159+
qkv_packed = torch.randn(
160+
size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim),
161+
dtype=torch.float16,
162+
device="cuda:0",
163+
)
164+
qkv_split_idx = (
165+
num_qo_heads * head_dim,
166+
num_kv_heads * head_dim,
167+
num_kv_heads * head_dim,
168+
)
169+
q, _, _ = qkv_packed.split(qkv_split_idx, dim=-1)
170+
# pretend that we have already appended k/v to paged_kv table
171+
q = q.view(-1, num_qo_heads, head_dim)
172+
o_packed = wrapper.run(q, paged_kv_cache)
173+
o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache)
174+
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3)
175+
176+
99177
if __name__ == "__main__":
100178
test_single_prefill_packed_input(127, 4, 4, 64, True)
101179
test_batch_ragged_prefill_packed_input(37, 127, 4, 4, 64, True)
180+
test_batch_paged_prefill_packed_input(37, 5, 127, 4, 4, 64, True)

0 commit comments

Comments
 (0)