Skip to content

Commit 560af6f

Browse files
authored
feat: add an option non_blocking to plan function (#622)
Use non-blocking memcpy only in plan functions when this option is turned on.
1 parent f236f70 commit 560af6f

File tree

4 files changed

+57
-26
lines changed

4 files changed

+57
-26
lines changed

python/flashinfer/decode.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ def plan(
646646
sm_scale: Optional[float] = None,
647647
rope_scale: Optional[float] = None,
648648
rope_theta: Optional[float] = None,
649+
non_blocking: bool = False,
649650
) -> None:
650651
r"""Plan batch decode for given problem specification.
651652
@@ -687,6 +688,10 @@ def plan(
687688
data_type: Optional[Union[str, torch.dtype]]
688689
The data type of both the query and key/value tensors. Defaults to torch.float16.
689690
data_type is deprecated, please use q_data_type and kv_data_type instead.
691+
non_blocking : bool
692+
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
693+
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
694+
690695
691696
Note
692697
----
@@ -717,16 +722,26 @@ def plan(
717722
raise ValueError(
718723
"The size of indices should be less than or equal to the allocated buffer"
719724
)
720-
self._paged_kv_indptr_buf.copy_(indptr, non_blocking=True)
721-
self._paged_kv_indices_buf[: len(indices)].copy_(indices, non_blocking=True)
722-
self._paged_kv_last_page_len_buf.copy_(last_page_len, non_blocking=True)
725+
self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking)
726+
self._paged_kv_indices_buf[: len(indices)].copy_(
727+
indices, non_blocking=non_blocking
728+
)
729+
self._paged_kv_last_page_len_buf.copy_(
730+
last_page_len, non_blocking=non_blocking
731+
)
723732
else:
724-
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
725-
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
733+
self._paged_kv_indptr_buf = indptr.to(
734+
self.device, non_blocking=non_blocking
735+
)
736+
self._paged_kv_indices_buf = indices.to(
737+
self.device, non_blocking=non_blocking
738+
)
726739
self._paged_kv_last_page_len_buf = last_page_len.to(
727-
self.device, non_blocking=True
740+
self.device, non_blocking=non_blocking
741+
)
742+
self._qo_indptr_buf = qo_indptr_host.to(
743+
self.device, non_blocking=non_blocking
728744
)
729-
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=True)
730745

731746
indptr_host = indptr.to("cpu")
732747
if data_type is not None:

python/flashinfer/prefill.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,7 @@ def plan(
883883
rope_theta: Optional[float] = None,
884884
q_data_type: Union[str, torch.dtype] = "float16",
885885
kv_data_type: Optional[Union[str, torch.dtype]] = None,
886+
non_blocking: bool = False,
886887
) -> None:
887888
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
888889
@@ -952,6 +953,9 @@ def plan(
952953
The data type of the query tensor, defaults torch.float16.
953954
kv_data_type : Optional[Union[str, torch.dtype]]
954955
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
956+
non_blocking : bool
957+
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
958+
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
955959
956960
Note
957961
----
@@ -1003,13 +1007,13 @@ def plan(
10031007
"The length of paged_kv_indices exceeds the allocated buffer size."
10041008
)
10051009

1006-
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=True)
1007-
self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=True)
1010+
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=non_blocking)
1011+
self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=non_blocking)
10081012
self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_(
1009-
paged_kv_indices, non_blocking=True
1013+
paged_kv_indices, non_blocking=non_blocking
10101014
)
10111015
self._paged_kv_last_page_len_buf.copy_(
1012-
paged_kv_last_page_len, non_blocking=True
1016+
paged_kv_last_page_len, non_blocking=non_blocking
10131017
)
10141018

10151019
if packed_custom_mask is not None:
@@ -1022,26 +1026,28 @@ def plan(
10221026
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
10231027
)
10241028
self._custom_mask_buf[: len(packed_custom_mask)].copy_(
1025-
packed_custom_mask, non_blocking=True
1029+
packed_custom_mask, non_blocking=non_blocking
10261030
)
10271031
# NOTE(Zihao): qk_indptr has the same length as qo_indptr
1028-
self._qk_indptr_buf.copy_(qk_indptr, non_blocking=True)
1032+
self._qk_indptr_buf.copy_(qk_indptr, non_blocking=non_blocking)
10291033
else:
1030-
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=True)
1034+
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking)
10311035
self._paged_kv_indptr_buf = paged_kv_indptr.to(
1032-
self.device, non_blocking=True
1036+
self.device, non_blocking=non_blocking
10331037
)
10341038
self._paged_kv_indices_buf = paged_kv_indices.to(
1035-
self.device, non_blocking=True
1039+
self.device, non_blocking=non_blocking
10361040
)
10371041
self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(
1038-
self.device, non_blocking=True
1042+
self.device, non_blocking=non_blocking
10391043
)
10401044
if packed_custom_mask is not None:
10411045
self._custom_mask_buf = packed_custom_mask.to(
1042-
self.device, non_blocking=True
1046+
self.device, non_blocking=non_blocking
1047+
)
1048+
self._qk_indptr_buf = qk_indptr.to(
1049+
self.device, non_blocking=non_blocking
10431050
)
1044-
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)
10451051

10461052
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
10471053
qo_indptr_host = qo_indptr.to("cpu")

python/flashinfer/sparse.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def plan(
184184
rope_theta: Optional[float] = None,
185185
q_data_type: Union[str, torch.dtype] = "float16",
186186
kv_data_type: Optional[Union[str, torch.dtype]] = None,
187+
non_blocking: bool = False,
187188
) -> None:
188189
r"""Create auxiliary data structures for block sparse attention.
189190
@@ -241,6 +242,10 @@ def plan(
241242
The data type of the query tensor.
242243
kv_data_type : Optional[Union[str, torch.dtype]]
243244
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
245+
non_blocking : bool
246+
Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
247+
If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
248+
244249
245250
The :meth:`plan` method should be called before any :meth:`run` or
246251
:meth:`run_return_lse` calls, auxiliary data structures will be created
@@ -261,7 +266,7 @@ def plan(
261266
num_blocks_row = len(indptr) - 1
262267
qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32)
263268
qo_indptr_host[-1] = M
264-
qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=True)
269+
qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=non_blocking)
265270
if indices.max().item() * C > N:
266271
raise ValueError("indices out of bound")
267272
last_block_len = torch.full(
@@ -283,13 +288,17 @@ def plan(
283288
mask.contiguous().view(-1), qk_indptr, bitorder="little"
284289
)
285290

286-
self._qo_indptr = qo_indptr.to(self.device, non_blocking=True)
287-
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
288-
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
289-
self._paged_kv_last_page_len = last_block_len.to(self.device, non_blocking=True)
291+
self._qo_indptr = qo_indptr.to(self.device, non_blocking=non_blocking)
292+
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=non_blocking)
293+
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=non_blocking)
294+
self._paged_kv_last_page_len = last_block_len.to(
295+
self.device, non_blocking=non_blocking
296+
)
290297
if packed_mask is not None:
291-
self._packed_mask_buf = packed_mask.to(self.device, non_blocking=True)
292-
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)
298+
self._packed_mask_buf = packed_mask.to(
299+
self.device, non_blocking=non_blocking
300+
)
301+
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=non_blocking)
293302
mask_mode = MaskMode.CUSTOM.value
294303
else:
295304
self._packed_mask_buf = None

tests/test_batch_decode_kernels.py

+1
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def test_batch_decode_with_tuple_paged_kv_cache(
294294
@pytest.mark.parametrize(
295295
"kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
296296
)
297+
@pytest.mark.parametrize("contiguous_kv", [True, False])
297298
def test_cuda_graph_batch_decode_with_paged_kv_cache(
298299
batch_size,
299300
kv_len,

0 commit comments

Comments
 (0)