@@ -96,6 +96,85 @@ def test_batch_ragged_prefill_packed_input(
96
96
torch .testing .assert_close (o_packed , o_contiguous , rtol = 1e-3 , atol = 1e-3 )
97
97
98
98
99
+ @pytest .mark .parametrize ("batch_size" , [1 , 19 , 99 ])
100
+ @pytest .mark .parametrize ("page_size" , [1 , 5 ])
101
+ @pytest .mark .parametrize ("seq_len" , [1 , 7 , 127 , 257 ])
102
+ @pytest .mark .parametrize ("num_kv_heads" , [1 , 4 , 8 ])
103
+ @pytest .mark .parametrize ("num_qo_heads" , [4 , 8 ])
104
+ @pytest .mark .parametrize ("head_dim" , [64 , 128 , 256 ])
105
+ @pytest .mark .parametrize ("causal" , [True , False ])
106
+ def test_batch_paged_prefill_packed_input (
107
+ batch_size ,
108
+ page_size ,
109
+ seq_len ,
110
+ num_kv_heads ,
111
+ num_qo_heads ,
112
+ head_dim ,
113
+ causal ,
114
+ ):
115
+ if num_qo_heads % num_kv_heads != 0 :
116
+ pytest .skip ("num_qo_heads must be a multiple of num_kv_heads" )
117
+
118
+ nnz = batch_size * seq_len
119
+ num_pages_per_req = (seq_len + page_size - 1 ) // page_size
120
+ num_pages = batch_size * num_pages_per_req
121
+ last_page_len = (seq_len - 1 ) % page_size + 1
122
+ k_cache = torch .randn (
123
+ size = (num_pages , page_size , num_kv_heads , head_dim ),
124
+ dtype = torch .float16 ,
125
+ device = "cuda:0" ,
126
+ )
127
+ v_cache = torch .randn_like (k_cache )
128
+ paged_kv_cache = (k_cache , v_cache )
129
+ workspace_buffer = torch .empty (
130
+ (256 * 1024 * 1024 ,), dtype = torch .uint8 , device = "cuda:0"
131
+ )
132
+ qo_indptr = torch .tensor (
133
+ [i * seq_len for i in range (batch_size + 1 )], dtype = torch .int32 , device = "cuda:0"
134
+ )
135
+ paged_kv_indptr = torch .tensor (
136
+ [i * num_pages_per_req for i in range (batch_size + 1 )],
137
+ dtype = torch .int32 ,
138
+ device = "cuda:0" ,
139
+ )
140
+ paged_kv_indices = torch .tensor (
141
+ list (range (num_pages )), dtype = torch .int32 , device = "cuda:0"
142
+ )
143
+ paged_kv_last_page_len = torch .tensor (
144
+ [last_page_len for _ in range (batch_size )], dtype = torch .int32 , device = "cuda:0"
145
+ )
146
+ wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (workspace_buffer )
147
+ wrapper .plan (
148
+ qo_indptr = qo_indptr ,
149
+ paged_kv_indptr = paged_kv_indptr ,
150
+ paged_kv_indices = paged_kv_indices ,
151
+ paged_kv_last_page_len = paged_kv_last_page_len ,
152
+ num_qo_heads = num_qo_heads ,
153
+ num_kv_heads = num_kv_heads ,
154
+ head_dim = head_dim ,
155
+ page_size = page_size ,
156
+ causal = causal ,
157
+ )
158
+
159
+ qkv_packed = torch .randn (
160
+ size = (nnz , (num_qo_heads + 2 * num_kv_heads ) * head_dim ),
161
+ dtype = torch .float16 ,
162
+ device = "cuda:0" ,
163
+ )
164
+ qkv_split_idx = (
165
+ num_qo_heads * head_dim ,
166
+ num_kv_heads * head_dim ,
167
+ num_kv_heads * head_dim ,
168
+ )
169
+ q , _ , _ = qkv_packed .split (qkv_split_idx , dim = - 1 )
170
+ # pretend that we have already appended k/v to paged_kv table
171
+ q = q .view (- 1 , num_qo_heads , head_dim )
172
+ o_packed = wrapper .run (q , paged_kv_cache )
173
+ o_contiguous = wrapper .run (q .contiguous (), paged_kv_cache )
174
+ torch .testing .assert_close (o_packed , o_contiguous , rtol = 1e-3 , atol = 1e-3 )
175
+
176
+
99
177
if __name__ == "__main__" :
100
178
test_single_prefill_packed_input (127 , 4 , 4 , 64 , True )
101
179
test_batch_ragged_prefill_packed_input (37 , 127 , 4 , 4 , 64 , True )
180
+ test_batch_paged_prefill_packed_input (37 , 5 , 127 , 4 , 4 , 64 , True )
0 commit comments