@@ -57,7 +57,8 @@ def compile_single_prefill_module(
57
57
):
58
58
uri , path = gen_single_prefill_cu (* args )
59
59
return load_cuda_ops (
60
- uri , [path ],
60
+ uri ,
61
+ [path ],
61
62
verbose = verbose ,
62
63
)
63
64
@@ -68,7 +69,8 @@ def compile_batch_prefill_module(
68
69
):
69
70
uri , path = gen_batch_prefill_cu (* args )
70
71
return load_cuda_ops (
71
- uri , [path ],
72
+ uri ,
73
+ [path ],
72
74
verbose = verbose ,
73
75
)
74
76
@@ -125,6 +127,7 @@ def get_batch_prefill_module(*args):
125
127
_batch_prefill_modules [args ] = compile_batch_prefill_module (* args )
126
128
return _batch_prefill_modules [args ]
127
129
130
+
128
131
def single_prefill_with_kv_cache_with_jit_module (
129
132
jit_module : Any ,
130
133
q : torch .Tensor ,
@@ -137,7 +140,8 @@ def single_prefill_with_kv_cache_with_jit_module(
137
140
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
138
141
tmp = _get_cache_buf ("single_prefill_with_kv_cache_tmp" , 32 * 1024 * 1024 , q .device )
139
142
out = jit_module .run (
140
- q , k , v , tmp , TensorLayout [kv_layout ].value , window_left , return_lse , * args )
143
+ q , k , v , tmp , TensorLayout [kv_layout ].value , window_left , return_lse , * args
144
+ )
141
145
return out if return_lse else out [0 ]
142
146
143
147
@@ -726,10 +730,14 @@ def plan(
726
730
"The length of paged_kv_indices exceeds the allocated buffer size."
727
731
)
728
732
729
- self ._qo_indptr_buf .copy_ (qo_indptr )
730
- self ._paged_kv_indptr_buf .copy_ (paged_kv_indptr )
731
- self ._paged_kv_indices_buf [: len (paged_kv_indices )] = paged_kv_indices
732
- self ._paged_kv_last_page_len_buf .copy_ (paged_kv_last_page_len )
733
+ self ._qo_indptr_buf .copy_ (qo_indptr , non_blocking = True )
734
+ self ._paged_kv_indptr_buf .copy_ (paged_kv_indptr , non_blocking = True )
735
+ self ._paged_kv_indices_buf [: len (paged_kv_indices )].copy_ (
736
+ paged_kv_indices , non_blocking = True
737
+ )
738
+ self ._paged_kv_last_page_len_buf .copy_ (
739
+ paged_kv_last_page_len , non_blocking = True
740
+ )
733
741
734
742
if packed_custom_mask is not None :
735
743
if not torch .is_tensor (self ._custom_mask_buf ):
@@ -740,20 +748,31 @@ def plan(
740
748
raise ValueError (
741
749
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
742
750
)
743
- self ._custom_mask_buf [: len (packed_custom_mask )] = packed_custom_mask
751
+ self ._custom_mask_buf [: len (packed_custom_mask )].copy_ (
752
+ packed_custom_mask , non_blocking = True
753
+ )
744
754
# NOTE(Zihao): qk_indptr has the same length as qo_indptr
745
- self ._qk_indptr_buf .copy_ (qk_indptr )
755
+ self ._qk_indptr_buf .copy_ (qk_indptr , non_blocking = True )
746
756
else :
747
- self ._qo_indptr_buf = qo_indptr .to (self .device )
748
- self ._paged_kv_indptr_buf = paged_kv_indptr .to (self .device )
749
- self ._paged_kv_indices_buf = paged_kv_indices .to (self .device )
750
- self ._paged_kv_last_page_len_buf = paged_kv_last_page_len .to (self .device )
757
+ self ._qo_indptr_buf = qo_indptr .to (self .device , non_blocking = True )
758
+ self ._paged_kv_indptr_buf = paged_kv_indptr .to (
759
+ self .device , non_blocking = True
760
+ )
761
+ self ._paged_kv_indices_buf = paged_kv_indices .to (
762
+ self .device , non_blocking = True
763
+ )
764
+ self ._paged_kv_last_page_len_buf = paged_kv_last_page_len .to (
765
+ self .device , non_blocking = True
766
+ )
751
767
if packed_custom_mask is not None :
752
- self ._custom_mask_buf = packed_custom_mask .to (self .device )
753
- self ._qk_indptr_buf = qk_indptr .to (self .device )
768
+ self ._custom_mask_buf = packed_custom_mask .to (
769
+ self .device , non_blocking = True
770
+ )
771
+ self ._qk_indptr_buf = qk_indptr .to (self .device , non_blocking = True )
754
772
755
- qo_indptr = qo_indptr .to ("cpu" , non_blocking = True )
756
- paged_kv_indptr = paged_kv_indptr .to ("cpu" , non_blocking = True )
773
+ # NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
774
+ qo_indptr_host = qo_indptr .to ("cpu" , non_blocking = True )
775
+ paged_kv_indptr_host = paged_kv_indptr .to ("cpu" , non_blocking = True )
757
776
758
777
if packed_custom_mask is not None :
759
778
mask_mode = MaskMode .CUSTOM .value
@@ -781,8 +800,8 @@ def plan(
781
800
self ._float_workspace_buffer ,
782
801
self ._int_workspace_buffer ,
783
802
self ._pin_memory_int_workspace_buffer ,
784
- qo_indptr ,
785
- paged_kv_indptr ,
803
+ qo_indptr_host ,
804
+ paged_kv_indptr_host ,
786
805
batch_size ,
787
806
num_qo_heads ,
788
807
num_kv_heads ,
0 commit comments