@@ -883,6 +883,7 @@ def plan(
883
883
rope_theta : Optional [float ] = None ,
884
884
q_data_type : Union [str , torch .dtype ] = "float16" ,
885
885
kv_data_type : Optional [Union [str , torch .dtype ]] = None ,
886
+ non_blocking : bool = False ,
886
887
) -> None :
887
888
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.
888
889
@@ -952,6 +953,9 @@ def plan(
952
953
The data type of the query tensor, defaults torch.float16.
953
954
kv_data_type : Optional[Union[str, torch.dtype]]
954
955
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
956
+ non_blocking : bool
957
+ Whether to copy the input tensors to the device asynchronously, defaults to ``False``.
958
+ If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay.
955
959
956
960
Note
957
961
----
@@ -1003,13 +1007,13 @@ def plan(
1003
1007
"The length of paged_kv_indices exceeds the allocated buffer size."
1004
1008
)
1005
1009
1006
- self ._qo_indptr_buf .copy_ (qo_indptr , non_blocking = True )
1007
- self ._paged_kv_indptr_buf .copy_ (paged_kv_indptr , non_blocking = True )
1010
+ self ._qo_indptr_buf .copy_ (qo_indptr , non_blocking = non_blocking )
1011
+ self ._paged_kv_indptr_buf .copy_ (paged_kv_indptr , non_blocking = non_blocking )
1008
1012
self ._paged_kv_indices_buf [: len (paged_kv_indices )].copy_ (
1009
- paged_kv_indices , non_blocking = True
1013
+ paged_kv_indices , non_blocking = non_blocking
1010
1014
)
1011
1015
self ._paged_kv_last_page_len_buf .copy_ (
1012
- paged_kv_last_page_len , non_blocking = True
1016
+ paged_kv_last_page_len , non_blocking = non_blocking
1013
1017
)
1014
1018
1015
1019
if packed_custom_mask is not None :
@@ -1022,26 +1026,28 @@ def plan(
1022
1026
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
1023
1027
)
1024
1028
self ._custom_mask_buf [: len (packed_custom_mask )].copy_ (
1025
- packed_custom_mask , non_blocking = True
1029
+ packed_custom_mask , non_blocking = non_blocking
1026
1030
)
1027
1031
# NOTE(Zihao): qk_indptr has the same length as qo_indptr
1028
- self ._qk_indptr_buf .copy_ (qk_indptr , non_blocking = True )
1032
+ self ._qk_indptr_buf .copy_ (qk_indptr , non_blocking = non_blocking )
1029
1033
else :
1030
- self ._qo_indptr_buf = qo_indptr .to (self .device , non_blocking = True )
1034
+ self ._qo_indptr_buf = qo_indptr .to (self .device , non_blocking = non_blocking )
1031
1035
self ._paged_kv_indptr_buf = paged_kv_indptr .to (
1032
- self .device , non_blocking = True
1036
+ self .device , non_blocking = non_blocking
1033
1037
)
1034
1038
self ._paged_kv_indices_buf = paged_kv_indices .to (
1035
- self .device , non_blocking = True
1039
+ self .device , non_blocking = non_blocking
1036
1040
)
1037
1041
self ._paged_kv_last_page_len_buf = paged_kv_last_page_len .to (
1038
- self .device , non_blocking = True
1042
+ self .device , non_blocking = non_blocking
1039
1043
)
1040
1044
if packed_custom_mask is not None :
1041
1045
self ._custom_mask_buf = packed_custom_mask .to (
1042
- self .device , non_blocking = True
1046
+ self .device , non_blocking = non_blocking
1047
+ )
1048
+ self ._qk_indptr_buf = qk_indptr .to (
1049
+ self .device , non_blocking = non_blocking
1043
1050
)
1044
- self ._qk_indptr_buf = qk_indptr .to (self .device , non_blocking = True )
1045
1051
1046
1052
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1047
1053
qo_indptr_host = qo_indptr .to ("cpu" )
0 commit comments