Skip to content

Commit 41ebe6d

Browse files
authored
perf: improve plan performance by using non-blocking memcpy (#547)
cc @merrymercy
1 parent 021b585 commit 41ebe6d

File tree

7 files changed

+76
-56
lines changed

7 files changed

+76
-56
lines changed

flashinfer-aot/csrc_aot/batch_decode.cu

+2-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
4343
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
4444
auto device = float_workspace_buffer.device();
4545
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
46-
indptr = indptr.to(torch::kCPU);
46+
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");
4747

4848
DecodePlanInfo plan_info;
4949

@@ -150,8 +150,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCacheRun(
150150
paged_kv_t<DTypeKV, IdType> paged_kv(
151151
num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout,
152152
static_cast<DTypeKV*>(paged_k_cache.data_ptr()),
153-
static_cast<DTypeKV*>(paged_v_cache.data_ptr()),
154-
kv_cache_strides,
153+
static_cast<DTypeKV*>(paged_v_cache.data_ptr()), kv_cache_strides,
155154
static_cast<IdType*>(paged_kv_indices.data_ptr()),
156155
static_cast<IdType*>(paged_kv_indptr.data_ptr()),
157156
static_cast<IdType*>(paged_kv_last_page_len.data_ptr()));

flashinfer-aot/csrc_aot/batch_prefill.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
5151

5252
auto device = float_workspace_buffer.device();
5353
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
54-
qo_indptr = qo_indptr.to(torch::kCPU);
55-
kv_indptr = kv_indptr.to(torch::kCPU);
54+
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
55+
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
5656

5757
PrefillPlanInfo plan_info;
5858

python/flashinfer/decode.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def compile_single_decode_module(
5353
):
5454
uri, path = gen_single_decode_cu(*args)
5555
return load_cuda_ops(
56-
uri, [path],
56+
uri,
57+
[path],
5758
verbose=verbose,
5859
)
5960

@@ -64,7 +65,8 @@ def compile_batch_decode_module(
6465
):
6566
uri, path = gen_batch_decode_cu(*args)
6667
return load_cuda_ops(
67-
uri, [path],
68+
uri,
69+
[path],
6870
verbose=verbose,
6971
)
7072

@@ -114,6 +116,7 @@ def get_batch_decode_module(*args):
114116
_batch_decode_modules[args] = compile_batch_decode_module(*args)
115117
return _batch_decode_modules[args]
116118

119+
117120
def single_decode_with_kv_cache_with_jit_module(
118121
jit_module: Any,
119122
q: torch.Tensor,
@@ -123,8 +126,10 @@ def single_decode_with_kv_cache_with_jit_module(
123126
kv_layout: str = "NHD",
124127
window_left: int = -1,
125128
):
126-
tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
127-
return jit_module.run(q, k, v, tmp, TensorLayout[kv_layout].value, window_left, *args)
129+
tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
130+
return jit_module.run(
131+
q, k, v, tmp, TensorLayout[kv_layout].value, window_left, *args
132+
)
128133

129134

130135
def single_decode_with_kv_cache(
@@ -444,6 +449,7 @@ def __init__(
444449

445450
if use_tensor_cores:
446451
if use_cuda_graph:
452+
# NOTE(Zihao): if once created, no need to update it in plan/run
447453
self._qo_indptr_buf = torch.arange(
448454
self._fixed_batch_size + 1,
449455
dtype=torch.int32,
@@ -555,8 +561,7 @@ def plan(
555561
if logits_soft_cap is None:
556562
logits_soft_cap = 0.0
557563

558-
qo_indptr = _get_range_buf(batch_size + 1, indptr.device)
559-
564+
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
560565
if self.is_cuda_graph_enabled:
561566
if batch_size != self._fixed_batch_size:
562567
raise ValueError(
@@ -569,21 +574,18 @@ def plan(
569574
raise ValueError(
570575
"The size of indices should be less than or equal to the allocated buffer"
571576
)
572-
self._paged_kv_indptr_buf.copy_(indptr)
573-
self._paged_kv_indices_buf[: len(indices)] = indices
574-
self._paged_kv_last_page_len_buf.copy_(last_page_len)
575-
if self.use_tensor_cores:
576-
self._qo_indptr_buf.copy_(qo_indptr)
577+
self._paged_kv_indptr_buf.copy_(indptr, non_blocking=True)
578+
self._paged_kv_indices_buf[: len(indices)].copy_(indices, non_blocking=True)
579+
self._paged_kv_last_page_len_buf.copy_(last_page_len, non_blocking=True)
577580
else:
578-
self._paged_kv_indptr_buf = indptr.to(self.device)
579-
self._paged_kv_indices_buf = indices.to(self.device)
580-
self._paged_kv_last_page_len_buf = last_page_len.to(self.device)
581-
if self.use_tensor_cores:
582-
self._qo_indptr_buf = qo_indptr.to(self.device)
583-
584-
qo_indptr = qo_indptr.to("cpu", non_blocking=True)
585-
indptr = indptr.to("cpu", non_blocking=True)
581+
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
582+
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
583+
self._paged_kv_last_page_len_buf = last_page_len.to(
584+
self.device, non_blocking=True
585+
)
586+
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=True)
586587

588+
indptr_host = indptr.to("cpu", non_blocking=True)
587589
if data_type is not None:
588590
q_data_type = data_type
589591
kv_data_type = data_type
@@ -612,8 +614,8 @@ def plan(
612614
self._float_workspace_buffer,
613615
self._int_workspace_buffer,
614616
self._pin_memory_int_workspace_buffer,
615-
qo_indptr,
616-
indptr,
617+
qo_indptr_host,
618+
indptr_host,
617619
batch_size,
618620
num_qo_heads,
619621
num_kv_heads,
@@ -635,7 +637,7 @@ def plan(
635637
self._float_workspace_buffer,
636638
self._int_workspace_buffer,
637639
self._pin_memory_int_workspace_buffer,
638-
indptr,
640+
indptr_host,
639641
batch_size,
640642
num_qo_heads,
641643
num_kv_heads,

python/flashinfer/jit/batch_decode_templ.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
4343
auto device = float_workspace_buffer.device();
4444
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
45-
indptr = indptr.to(torch::kCPU);
45+
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");
4646
4747
DecodePlanInfo plan_info;
4848

python/flashinfer/jit/batch_prefill_templ.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
5050
auto device = float_workspace_buffer.device();
5151
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
52-
qo_indptr = qo_indptr.to(torch::kCPU);
53-
kv_indptr = kv_indptr.to(torch::kCPU);
52+
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
53+
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
5454
5555
PrefillPlanInfo plan_info;
5656

python/flashinfer/prefill.py

+38-19
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def compile_single_prefill_module(
5757
):
5858
uri, path = gen_single_prefill_cu(*args)
5959
return load_cuda_ops(
60-
uri, [path],
60+
uri,
61+
[path],
6162
verbose=verbose,
6263
)
6364

@@ -68,7 +69,8 @@ def compile_batch_prefill_module(
6869
):
6970
uri, path = gen_batch_prefill_cu(*args)
7071
return load_cuda_ops(
71-
uri, [path],
72+
uri,
73+
[path],
7274
verbose=verbose,
7375
)
7476

@@ -125,6 +127,7 @@ def get_batch_prefill_module(*args):
125127
_batch_prefill_modules[args] = compile_batch_prefill_module(*args)
126128
return _batch_prefill_modules[args]
127129

130+
128131
def single_prefill_with_kv_cache_with_jit_module(
129132
jit_module: Any,
130133
q: torch.Tensor,
@@ -137,7 +140,8 @@ def single_prefill_with_kv_cache_with_jit_module(
137140
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
138141
tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device)
139142
out = jit_module.run(
140-
q, k, v, tmp, TensorLayout[kv_layout].value, window_left, return_lse, *args)
143+
q, k, v, tmp, TensorLayout[kv_layout].value, window_left, return_lse, *args
144+
)
141145
return out if return_lse else out[0]
142146

143147

@@ -726,10 +730,14 @@ def plan(
726730
"The length of paged_kv_indices exceeds the allocated buffer size."
727731
)
728732

729-
self._qo_indptr_buf.copy_(qo_indptr)
730-
self._paged_kv_indptr_buf.copy_(paged_kv_indptr)
731-
self._paged_kv_indices_buf[: len(paged_kv_indices)] = paged_kv_indices
732-
self._paged_kv_last_page_len_buf.copy_(paged_kv_last_page_len)
733+
self._qo_indptr_buf.copy_(qo_indptr, non_blocking=True)
734+
self._paged_kv_indptr_buf.copy_(paged_kv_indptr, non_blocking=True)
735+
self._paged_kv_indices_buf[: len(paged_kv_indices)].copy_(
736+
paged_kv_indices, non_blocking=True
737+
)
738+
self._paged_kv_last_page_len_buf.copy_(
739+
paged_kv_last_page_len, non_blocking=True
740+
)
733741

734742
if packed_custom_mask is not None:
735743
if not torch.is_tensor(self._custom_mask_buf):
@@ -740,20 +748,31 @@ def plan(
740748
raise ValueError(
741749
"qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation."
742750
)
743-
self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask
751+
self._custom_mask_buf[: len(packed_custom_mask)].copy_(
752+
packed_custom_mask, non_blocking=True
753+
)
744754
# NOTE(Zihao): qk_indptr has the same length as qo_indptr
745-
self._qk_indptr_buf.copy_(qk_indptr)
755+
self._qk_indptr_buf.copy_(qk_indptr, non_blocking=True)
746756
else:
747-
self._qo_indptr_buf = qo_indptr.to(self.device)
748-
self._paged_kv_indptr_buf = paged_kv_indptr.to(self.device)
749-
self._paged_kv_indices_buf = paged_kv_indices.to(self.device)
750-
self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(self.device)
757+
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=True)
758+
self._paged_kv_indptr_buf = paged_kv_indptr.to(
759+
self.device, non_blocking=True
760+
)
761+
self._paged_kv_indices_buf = paged_kv_indices.to(
762+
self.device, non_blocking=True
763+
)
764+
self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to(
765+
self.device, non_blocking=True
766+
)
751767
if packed_custom_mask is not None:
752-
self._custom_mask_buf = packed_custom_mask.to(self.device)
753-
self._qk_indptr_buf = qk_indptr.to(self.device)
768+
self._custom_mask_buf = packed_custom_mask.to(
769+
self.device, non_blocking=True
770+
)
771+
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)
754772

755-
qo_indptr = qo_indptr.to("cpu", non_blocking=True)
756-
paged_kv_indptr = paged_kv_indptr.to("cpu", non_blocking=True)
773+
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
774+
qo_indptr_host = qo_indptr.to("cpu", non_blocking=True)
775+
paged_kv_indptr_host = paged_kv_indptr.to("cpu", non_blocking=True)
757776

758777
if packed_custom_mask is not None:
759778
mask_mode = MaskMode.CUSTOM.value
@@ -781,8 +800,8 @@ def plan(
781800
self._float_workspace_buffer,
782801
self._int_workspace_buffer,
783802
self._pin_memory_int_workspace_buffer,
784-
qo_indptr,
785-
paged_kv_indptr,
803+
qo_indptr_host,
804+
paged_kv_indptr_host,
786805
batch_size,
787806
num_qo_heads,
788807
num_kv_heads,

python/flashinfer/sparse.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def plan(
257257
num_blocks_row = len(indptr) - 1
258258
qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32)
259259
qo_indptr_host[-1] = M
260-
qo_indptr = qo_indptr_host.to(indptr.device)
260+
qo_indptr = qo_indptr_host.to(indptr.device, non_blocking=True)
261261
if indices.max().item() * C > N:
262262
raise ValueError("indices out of bound")
263263
last_block_len = torch.full(
@@ -279,13 +279,13 @@ def plan(
279279
mask.contiguous().view(-1), qk_indptr, bitorder="little"
280280
)
281281

282-
self._qo_indptr = qo_indptr.to(self.device)
283-
self._paged_kv_indptr_buf = indptr.to(self.device)
284-
self._paged_kv_indices_buf = indices.to(self.device)
285-
self._paged_kv_last_page_len = last_block_len.to(self.device)
282+
self._qo_indptr = qo_indptr.to(self.device, non_blocking=True)
283+
self._paged_kv_indptr_buf = indptr.to(self.device, non_blocking=True)
284+
self._paged_kv_indices_buf = indices.to(self.device, non_blocking=True)
285+
self._paged_kv_last_page_len = last_block_len.to(self.device, non_blocking=True)
286286
if packed_mask is not None:
287-
self._packed_mask_buf = packed_mask.to(self.device)
288-
self._qk_indptr_buf = qk_indptr.to(self.device)
287+
self._packed_mask_buf = packed_mask.to(self.device, non_blocking=True)
288+
self._qk_indptr_buf = qk_indptr.to(self.device, non_blocking=True)
289289
mask_mode = MaskMode.CUSTOM.value
290290
else:
291291
self._packed_mask_buf = None

0 commit comments

Comments
 (0)