Skip to content

Commit 3fbf028

Browse files
authored
perf: use cuda-core implemention for io-bound block-sparse attention (#560)
When operational intensity is low, select cuda-core implementations for block-sparse attention.
1 parent ea86f81 commit 3fbf028

File tree

1 file changed

+109
-47
lines changed

1 file changed

+109
-47
lines changed

python/flashinfer/sparse.py

+109-47
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Optional, Union, Tuple
1919
import logging
2020
import torch
21+
from .decode import get_batch_decode_module
2122
from .prefill import _compute_page_qk_indptr, get_batch_prefill_module
2223
from .quantization import segment_packbits
2324
from .utils import (
@@ -299,31 +300,65 @@ def plan(
299300

300301
kv_indptr_host = indptr.to("cpu", non_blocking=True)
301302

302-
self._cached_module = get_batch_prefill_module(
303-
q_data_type,
304-
kv_data_type,
305-
q_data_type,
306-
indptr.dtype,
307-
head_dim,
308-
PosEncodingMode[pos_encoding_mode].value,
309-
mask_mode,
310-
False, # use_sliding_window
311-
logits_soft_cap > 0, # use_logits_soft_cap
312-
allow_fp16_qk_reduction,
313-
)
303+
# NOTE(Zihao): we haven't supported mask in cuda-core implementations but it should
304+
# be easy to add support for it if needed, leave it as a future work.
305+
# at this moment, when mask is provided, we use the tensor-core implementation
306+
if (
307+
R * (num_qo_heads // num_kv_heads) < 4
308+
and mask_mode == MaskMode.NON_CAUSAL.value
309+
):
310+
# If the operation is not compute-bound, we use the cuda-core implementation
311+
self._use_tensor_cores = False
312+
self._cached_module = get_batch_decode_module(
313+
q_data_type,
314+
kv_data_type,
315+
q_data_type,
316+
indptr.dtype,
317+
head_dim,
318+
PosEncodingMode[pos_encoding_mode].value,
319+
False, # use_sliding_window
320+
logits_soft_cap > 0, # use_logits_soft_cap
321+
)
314322

315-
self._plan_info = self._cached_module.plan(
316-
self._float_workspace_buffer,
317-
self._int_workspace_buffer,
318-
self._pin_memory_int_workspace_buffer,
319-
qo_indptr_host,
320-
kv_indptr_host,
321-
num_blocks_row,
322-
num_qo_heads,
323-
num_kv_heads,
324-
C,
325-
False, # is_cuda_graph_enabled
326-
)
323+
self._plan_info = self._cached_module.plan(
324+
self._float_workspace_buffer,
325+
self._int_workspace_buffer,
326+
self._pin_memory_int_workspace_buffer,
327+
kv_indptr_host,
328+
num_blocks_row,
329+
num_qo_heads,
330+
num_kv_heads,
331+
C,
332+
False, # is_cuda_graph_enabled
333+
)
334+
else:
335+
# if the operation is compute-bound, we use the tensor-core implementation
336+
self._use_tensor_cores = True
337+
self._cached_module = get_batch_prefill_module(
338+
q_data_type,
339+
kv_data_type,
340+
q_data_type,
341+
indptr.dtype,
342+
head_dim,
343+
PosEncodingMode[pos_encoding_mode].value,
344+
mask_mode,
345+
False, # use_sliding_window
346+
logits_soft_cap > 0, # use_logits_soft_cap
347+
allow_fp16_qk_reduction,
348+
)
349+
350+
self._plan_info = self._cached_module.plan(
351+
self._float_workspace_buffer,
352+
self._int_workspace_buffer,
353+
self._pin_memory_int_workspace_buffer,
354+
qo_indptr_host,
355+
kv_indptr_host,
356+
num_blocks_row,
357+
num_qo_heads,
358+
num_kv_heads,
359+
C,
360+
False, # is_cuda_graph_enabled
361+
)
327362

328363
self._pos_encoding_mode = pos_encoding_mode
329364
self._allow_fp16_qk_reduction = allow_fp16_qk_reduction
@@ -404,30 +439,57 @@ def run(
404439
k = k.reshape(-1, self.C, *k.shape[-2:]).contiguous()
405440
v = v.reshape(-1, self.C, *v.shape[-2:]).contiguous()
406441

407-
out = self._cached_module.paged_run(
408-
self._float_workspace_buffer,
409-
self._int_workspace_buffer,
410-
self._plan_info,
411-
q,
412-
k,
413-
v,
414-
self._packed_mask_buf,
415-
_get_cache_alibi_slopes_buf(q.shape[1], self.device),
416-
self._qo_indptr,
417-
self._paged_kv_indptr_buf,
418-
self._paged_kv_indices_buf,
419-
self._paged_kv_last_page_len,
420-
self._qk_indptr_buf,
421-
TensorLayout[self._kv_layout].value,
422-
-1, # window_left
423-
logits_soft_cap,
424-
sm_scale,
425-
rope_scale,
426-
rope_theta,
427-
return_lse,
428-
)
442+
lse = None
443+
if return_lse:
444+
lse = torch.empty(
445+
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
446+
)
447+
448+
if self._use_tensor_cores:
449+
out = self._cached_module.paged_run(
450+
self._float_workspace_buffer,
451+
self._int_workspace_buffer,
452+
self._plan_info,
453+
q,
454+
k,
455+
v,
456+
self._packed_mask_buf,
457+
_get_cache_alibi_slopes_buf(q.shape[1], self.device),
458+
self._qo_indptr,
459+
self._paged_kv_indptr_buf,
460+
self._paged_kv_indices_buf,
461+
self._paged_kv_last_page_len,
462+
self._qk_indptr_buf,
463+
TensorLayout[self._kv_layout].value,
464+
-1, # window_left
465+
logits_soft_cap,
466+
sm_scale,
467+
rope_scale,
468+
rope_theta,
469+
lse,
470+
)
471+
else:
472+
out = self._cached_module.run(
473+
self._float_workspace_buffer,
474+
self._int_workspace_buffer,
475+
self._plan_info,
476+
q,
477+
k,
478+
v,
479+
self._paged_kv_indptr_buf,
480+
self._paged_kv_indices_buf,
481+
self._paged_kv_last_page_len,
482+
_get_cache_alibi_slopes_buf(q.shape[1], self.device),
483+
TensorLayout[self._kv_layout].value,
484+
-1, # window_left
485+
logits_soft_cap,
486+
sm_scale,
487+
rope_scale,
488+
rope_theta,
489+
lse,
490+
)
429491

430-
return out if return_lse else out[0]
492+
return (out, lse) if return_lse else out
431493

432494
def end_forward(self) -> None:
433495
r"""Warning: This method is deprecated and has no effect."""

0 commit comments

Comments
 (0)