|
18 | 18 | from typing import Optional, Union, Tuple
|
19 | 19 | import logging
|
20 | 20 | import torch
|
| 21 | +from .decode import get_batch_decode_module |
21 | 22 | from .prefill import _compute_page_qk_indptr, get_batch_prefill_module
|
22 | 23 | from .quantization import segment_packbits
|
23 | 24 | from .utils import (
|
@@ -299,31 +300,65 @@ def plan(
|
299 | 300 |
|
300 | 301 | kv_indptr_host = indptr.to("cpu", non_blocking=True)
|
301 | 302 |
|
302 |
| - self._cached_module = get_batch_prefill_module( |
303 |
| - q_data_type, |
304 |
| - kv_data_type, |
305 |
| - q_data_type, |
306 |
| - indptr.dtype, |
307 |
| - head_dim, |
308 |
| - PosEncodingMode[pos_encoding_mode].value, |
309 |
| - mask_mode, |
310 |
| - False, # use_sliding_window |
311 |
| - logits_soft_cap > 0, # use_logits_soft_cap |
312 |
| - allow_fp16_qk_reduction, |
313 |
| - ) |
| 303 | + # NOTE(Zihao): we haven't supported mask in cuda-core implementations but it should |
| 304 | + # be easy to add support for it if needed, leave it as a future work. |
| 305 | + # at this moment, when mask is provided, we use the tensor-core implementation |
| 306 | + if ( |
| 307 | + R * (num_qo_heads // num_kv_heads) < 4 |
| 308 | + and mask_mode == MaskMode.NON_CAUSAL.value |
| 309 | + ): |
| 310 | + # If the operation is not compute-bound, we use the cuda-core implementation |
| 311 | + self._use_tensor_cores = False |
| 312 | + self._cached_module = get_batch_decode_module( |
| 313 | + q_data_type, |
| 314 | + kv_data_type, |
| 315 | + q_data_type, |
| 316 | + indptr.dtype, |
| 317 | + head_dim, |
| 318 | + PosEncodingMode[pos_encoding_mode].value, |
| 319 | + False, # use_sliding_window |
| 320 | + logits_soft_cap > 0, # use_logits_soft_cap |
| 321 | + ) |
314 | 322 |
|
315 |
| - self._plan_info = self._cached_module.plan( |
316 |
| - self._float_workspace_buffer, |
317 |
| - self._int_workspace_buffer, |
318 |
| - self._pin_memory_int_workspace_buffer, |
319 |
| - qo_indptr_host, |
320 |
| - kv_indptr_host, |
321 |
| - num_blocks_row, |
322 |
| - num_qo_heads, |
323 |
| - num_kv_heads, |
324 |
| - C, |
325 |
| - False, # is_cuda_graph_enabled |
326 |
| - ) |
| 323 | + self._plan_info = self._cached_module.plan( |
| 324 | + self._float_workspace_buffer, |
| 325 | + self._int_workspace_buffer, |
| 326 | + self._pin_memory_int_workspace_buffer, |
| 327 | + kv_indptr_host, |
| 328 | + num_blocks_row, |
| 329 | + num_qo_heads, |
| 330 | + num_kv_heads, |
| 331 | + C, |
| 332 | + False, # is_cuda_graph_enabled |
| 333 | + ) |
| 334 | + else: |
| 335 | + # if the operation is compute-bound, we use the tensor-core implementation |
| 336 | + self._use_tensor_cores = True |
| 337 | + self._cached_module = get_batch_prefill_module( |
| 338 | + q_data_type, |
| 339 | + kv_data_type, |
| 340 | + q_data_type, |
| 341 | + indptr.dtype, |
| 342 | + head_dim, |
| 343 | + PosEncodingMode[pos_encoding_mode].value, |
| 344 | + mask_mode, |
| 345 | + False, # use_sliding_window |
| 346 | + logits_soft_cap > 0, # use_logits_soft_cap |
| 347 | + allow_fp16_qk_reduction, |
| 348 | + ) |
| 349 | + |
| 350 | + self._plan_info = self._cached_module.plan( |
| 351 | + self._float_workspace_buffer, |
| 352 | + self._int_workspace_buffer, |
| 353 | + self._pin_memory_int_workspace_buffer, |
| 354 | + qo_indptr_host, |
| 355 | + kv_indptr_host, |
| 356 | + num_blocks_row, |
| 357 | + num_qo_heads, |
| 358 | + num_kv_heads, |
| 359 | + C, |
| 360 | + False, # is_cuda_graph_enabled |
| 361 | + ) |
327 | 362 |
|
328 | 363 | self._pos_encoding_mode = pos_encoding_mode
|
329 | 364 | self._allow_fp16_qk_reduction = allow_fp16_qk_reduction
|
@@ -404,30 +439,57 @@ def run(
|
404 | 439 | k = k.reshape(-1, self.C, *k.shape[-2:]).contiguous()
|
405 | 440 | v = v.reshape(-1, self.C, *v.shape[-2:]).contiguous()
|
406 | 441 |
|
407 |
| - out = self._cached_module.paged_run( |
408 |
| - self._float_workspace_buffer, |
409 |
| - self._int_workspace_buffer, |
410 |
| - self._plan_info, |
411 |
| - q, |
412 |
| - k, |
413 |
| - v, |
414 |
| - self._packed_mask_buf, |
415 |
| - _get_cache_alibi_slopes_buf(q.shape[1], self.device), |
416 |
| - self._qo_indptr, |
417 |
| - self._paged_kv_indptr_buf, |
418 |
| - self._paged_kv_indices_buf, |
419 |
| - self._paged_kv_last_page_len, |
420 |
| - self._qk_indptr_buf, |
421 |
| - TensorLayout[self._kv_layout].value, |
422 |
| - -1, # window_left |
423 |
| - logits_soft_cap, |
424 |
| - sm_scale, |
425 |
| - rope_scale, |
426 |
| - rope_theta, |
427 |
| - return_lse, |
428 |
| - ) |
| 442 | + lse = None |
| 443 | + if return_lse: |
| 444 | + lse = torch.empty( |
| 445 | + (q.size(0), q.size(1)), dtype=torch.float32, device=q.device |
| 446 | + ) |
| 447 | + |
| 448 | + if self._use_tensor_cores: |
| 449 | + out = self._cached_module.paged_run( |
| 450 | + self._float_workspace_buffer, |
| 451 | + self._int_workspace_buffer, |
| 452 | + self._plan_info, |
| 453 | + q, |
| 454 | + k, |
| 455 | + v, |
| 456 | + self._packed_mask_buf, |
| 457 | + _get_cache_alibi_slopes_buf(q.shape[1], self.device), |
| 458 | + self._qo_indptr, |
| 459 | + self._paged_kv_indptr_buf, |
| 460 | + self._paged_kv_indices_buf, |
| 461 | + self._paged_kv_last_page_len, |
| 462 | + self._qk_indptr_buf, |
| 463 | + TensorLayout[self._kv_layout].value, |
| 464 | + -1, # window_left |
| 465 | + logits_soft_cap, |
| 466 | + sm_scale, |
| 467 | + rope_scale, |
| 468 | + rope_theta, |
| 469 | + lse, |
| 470 | + ) |
| 471 | + else: |
| 472 | + out = self._cached_module.run( |
| 473 | + self._float_workspace_buffer, |
| 474 | + self._int_workspace_buffer, |
| 475 | + self._plan_info, |
| 476 | + q, |
| 477 | + k, |
| 478 | + v, |
| 479 | + self._paged_kv_indptr_buf, |
| 480 | + self._paged_kv_indices_buf, |
| 481 | + self._paged_kv_last_page_len, |
| 482 | + _get_cache_alibi_slopes_buf(q.shape[1], self.device), |
| 483 | + TensorLayout[self._kv_layout].value, |
| 484 | + -1, # window_left |
| 485 | + logits_soft_cap, |
| 486 | + sm_scale, |
| 487 | + rope_scale, |
| 488 | + rope_theta, |
| 489 | + lse, |
| 490 | + ) |
429 | 491 |
|
430 |
| - return out if return_lse else out[0] |
| 492 | + return (out, lse) if return_lse else out |
431 | 493 |
|
432 | 494 | def end_forward(self) -> None:
|
433 | 495 | r"""Warning: This method is deprecated and has no effect."""
|
|
0 commit comments