Skip to content

Commit b53a46f

Browse files
authored
misc: add device guard for kernels (#611)
## plan - [x] Check all kernels and add device guard - [x] Complete the tests FIX: #452
1 parent a3360ff commit b53a46f

19 files changed

+54
-5
lines changed

python/csrc/bmm_fp8.cu

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
5454
auto workspace_buffer = torch::empty(
5555
{32 * 1024 * 1024}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device()));
5656
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
57+
const at::cuda::OptionalCUDAGuard device_guard(A.device());
5758
auto stream = at::cuda::getCurrentCUDAStream();
5859

5960
// PyTorch is row major by default. cuBLASLt is column major by default.

python/csrc/cascade.cu

+5
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, tor
4242
unsigned int seq_len = v_a.size(0);
4343
unsigned int num_heads = v_a.size(1);
4444
unsigned int head_dim = v_a.size(2);
45+
46+
const at::cuda::OptionalCUDAGuard device_guard(device);
4547
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4648
auto v_merged = torch::empty_like(v_a, v_a.options());
4749
auto s_merged = torch::empty({seq_len, num_heads}, s_a.options());
@@ -91,6 +93,8 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe
9193
unsigned int seq_len = v.size(0);
9294
unsigned int num_heads = v.size(1);
9395
unsigned int head_dim = v.size(2);
96+
97+
const at::cuda::OptionalCUDAGuard device_guard(device);
9498
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
9599

96100
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] {
@@ -121,6 +125,7 @@ std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s) {
121125
unsigned int num_heads = v.size(2);
122126
unsigned int head_dim = v.size(3);
123127
s = s.to(torch::kFloat32);
128+
const at::cuda::OptionalCUDAGuard device_guard(device);
124129
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
125130
auto v_merged = torch::empty({seq_len, num_heads, head_dim}, v.options());
126131
auto s_merged = torch::empty({seq_len, num_heads}, s.options());

python/csrc/group_gemm.cu

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ void CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor all_proble
2525
torch::Tensor empty_x_data, bool weight_column_major) {
2626
unsigned int batch_size = x_ptr.size(0);
2727
auto device = workspace_buffer.device();
28+
29+
const at::cuda::OptionalCUDAGuard device_guard(device);
2830
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
2931

3032
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_x_data.scalar_type(), c_type, [&] {

python/csrc/group_gemm_sm90.cu

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ void CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer,
2727
bool weight_column_major) {
2828
unsigned int batch_size = x_ptr.size(0);
2929
auto device = float_workspace_buffer.device();
30+
31+
const at::cuda::OptionalCUDAGuard device_guard(device);
3032
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
3133

3234
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_x_data.scalar_type(), c_type, [&] {

python/csrc/norm.cu

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ void rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight,
3232
CHECK_EQ(output.size(0), batch_size);
3333
CHECK_EQ(output.size(1), hidden_size);
3434

35+
const at::cuda::OptionalCUDAGuard device_guard(device);
3536
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
3637
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
3738
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()),
@@ -61,6 +62,7 @@ void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Ten
6162
unsigned int batch_size = input.size(0);
6263
unsigned int hidden_size = input.size(1);
6364

65+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
6466
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
6567
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
6668
cudaError_t status = norm::FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()),
@@ -86,6 +88,7 @@ void gemma_rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& w
8688
CHECK_EQ(output.size(0), batch_size);
8789
CHECK_EQ(output.size(1), hidden_size);
8890

91+
const at::cuda::OptionalCUDAGuard device_guard(device);
8992
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
9093
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
9194
cudaError_t status = norm::GemmaRMSNorm(static_cast<c_type*>(input.data_ptr()),
@@ -115,6 +118,7 @@ void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torc
115118
unsigned int batch_size = input.size(0);
116119
unsigned int hidden_size = input.size(1);
117120

121+
const at::cuda::OptionalCUDAGuard device_guard(device);
118122
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
119123
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
120124
cudaError_t status = norm::GemmaFusedAddRMSNorm(

python/csrc/page.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
9292
CHECK_EQ(append_key.size(2), head_dim);
9393
CHECK_EQ(append_value.size(1), num_heads);
9494
CHECK_EQ(append_value.size(2), head_dim);
95-
95+
96+
const at::cuda::OptionalCUDAGuard device_guard(device);
9697
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
9798

9899
auto kv_scalar_dtype = paged_k_cache.scalar_type();

python/csrc/pytorch_extension_utils.h

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
#pragma once
1717
#include <c10/cuda/CUDAStream.h>
18+
#include <c10/cuda/CUDAGuard.h>
1819
#include <cuda_bf16.h>
1920
#include <cuda_fp16.h>
2021
#include <cuda_fp8.h>

python/csrc/quantization.cu

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ torch::Tensor packbits(torch::Tensor x, const std::string& bitorder) {
2424
auto device = x.device();
2525
TORCH_CHECK(bitorder == "big" || bitorder == "little", "bitorder must be 'big' or 'little'");
2626
x = x.to(torch::kBool);
27+
28+
const at::cuda::OptionalCUDAGuard device_guard(device);
2729
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
2830

2931
int64_t num_elements = x.numel();
@@ -57,6 +59,7 @@ torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
5759
int64_t output_nnz = output_indptr[batch_size].item<int64_t>();
5860
auto y = torch::empty({output_nnz}, x.options().dtype(torch::kUInt8));
5961

62+
const at::cuda::OptionalCUDAGuard device_guard(device);
6063
cudaError_t status = quantization::SegmentPackBits(
6164
static_cast<bool*>(x.data_ptr()), static_cast<uint8_t*>(y.data_ptr()),
6265
static_cast<int32_t*>(input_indptr.data_ptr()),

python/csrc/rope.cu

+6-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T
5050
size_t k_rope_stride_h = k_rope.stride(1);
5151
indptr = indptr.to(torch::kInt32);
5252
offsets = offsets.to(torch::kInt32);
53-
53+
54+
const at::cuda::OptionalCUDAGuard device_guard(device);
5455
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
5556
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
5657
cudaError_t status = BatchQKApplyRotary(
@@ -93,6 +94,7 @@ void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
9394
size_t k_rope_stride_h = k_rope.stride(1);
9495
pos_ids = pos_ids.to(torch::kInt32);
9596

97+
const at::cuda::OptionalCUDAGuard device_guard(device);
9698
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
9799
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
98100
cudaError_t status = BatchQKApplyRotaryPosIds(
@@ -145,6 +147,7 @@ void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::T
145147
size_t k_rope_stride_h = k_rope.stride(1);
146148
pos_ids = pos_ids.to(torch::kInt32);
147149

150+
const at::cuda::OptionalCUDAGuard device_guard(device);
148151
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
149152
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
150153
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
@@ -195,6 +198,7 @@ void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
195198
indptr = indptr.to(torch::kInt32);
196199
offsets = offsets.to(torch::kInt32);
197200

201+
const at::cuda::OptionalCUDAGuard device_guard(device);
198202
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
199203
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
200204
cudaError_t status = BatchQKApplyLlama31Rotary(
@@ -240,6 +244,7 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor
240244
size_t k_rope_stride_h = k_rope.stride(1);
241245
pos_ids = pos_ids.to(torch::kInt32);
242246

247+
const at::cuda::OptionalCUDAGuard device_guard(device);
243248
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
244249
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
245250
cudaError_t status = BatchQKApplyLlama31RotaryPosIds(

python/csrc/sampling.cu

+10-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam
3333
probs = probs.to(torch::kFloat32);
3434
uniform_samples = uniform_samples.to(torch::kFloat32);
3535

36+
const at::cuda::OptionalCUDAGuard device_guard(device);
3637
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
3738
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
3839

@@ -71,6 +72,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
7172
uniform_samples = uniform_samples.to(torch::kFloat32);
7273
top_p_arr = top_p_arr.to(torch::kFloat32);
7374

75+
const at::cuda::OptionalCUDAGuard device_guard(device);
7476
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
7577
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
7678
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
@@ -112,6 +114,7 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
112114
uniform_samples = uniform_samples.to(torch::kFloat32);
113115
top_k_arr = top_k_arr.to(torch::kInt32);
114116

117+
const at::cuda::OptionalCUDAGuard device_guard(device);
115118
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
116119
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
117120
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
@@ -153,6 +156,7 @@ std::vector<torch::Tensor> min_p_sampling_from_probs(torch::Tensor probs,
153156
probs = probs.to(torch::kFloat32);
154157
uniform_samples = uniform_samples.to(torch::kFloat32);
155158

159+
const at::cuda::OptionalCUDAGuard device_guard(device);
156160
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
157161
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
158162
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
@@ -203,6 +207,7 @@ std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
203207
probs = probs.to(torch::kFloat32);
204208
uniform_samples = uniform_samples.to(torch::kFloat32);
205209

210+
const at::cuda::OptionalCUDAGuard device_guard(device);
206211
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
207212
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
208213
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
@@ -236,7 +241,8 @@ torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
236241
}
237242
top_p_arr = top_p_arr.to(torch::kFloat32);
238243
probs = probs.to(torch::kFloat32);
239-
244+
245+
const at::cuda::OptionalCUDAGuard device_guard(device);
240246
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
241247
auto renorm_probs =
242248
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
@@ -268,6 +274,7 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
268274
top_k_arr = top_k_arr.to(torch::kInt32);
269275
probs = probs.to(torch::kFloat32);
270276

277+
const at::cuda::OptionalCUDAGuard device_guard(device);
271278
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
272279
auto renorm_probs =
273280
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
@@ -300,6 +307,7 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tenso
300307
top_k_arr = top_k_arr.to(torch::kInt32);
301308
logits = logits.to(torch::kFloat32);
302309

310+
const at::cuda::OptionalCUDAGuard device_guard(device);
303311
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
304312
auto mask_logits =
305313
torch::empty({batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
@@ -348,6 +356,7 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
348356
uniform_samples = uniform_samples.to(torch::kFloat32);
349357
target_probs = target_probs.to(torch::kFloat32);
350358

359+
const at::cuda::OptionalCUDAGuard device_guard(device);
351360
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
352361
auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1},
353362
torch::dtype(torch::kInt32).device(device));

python/csrc_aot/batch_decode.cu

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
4242
size_t int_workspace_size_in_bytes =
4343
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
4444
auto device = float_workspace_buffer.device();
45+
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
4546
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4647
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");
4748

@@ -112,6 +113,7 @@ torch::Tensor BatchDecodeWithPagedKVCacheRun(
112113
}
113114
uint32_t head_dim = q.size(2);
114115

116+
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
115117
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
116118
torch::Tensor o = torch::empty_like(q);
117119
if (maybe_lse) {

python/csrc_aot/batch_prefill.cu

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
5050
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
5151

5252
auto device = float_workspace_buffer.device();
53+
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
5354
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
5455
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
5556
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");

python/csrc_aot/single_decode.cu

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
6060
kv_len = k.size(1);
6161
}
6262
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
63+
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
6364
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
6465
auto o = torch::empty_like(q);
6566

python/csrc_aot/single_prefill.cu

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ torch::Tensor single_prefill_with_kv_cache(
5656
kv_stride_h = k.stride(0);
5757
kv_stride_n = k.stride(1);
5858
}
59+
const at::cuda::OptionalCUDAGuard device_guard(device_of(device));
5960
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
6061
auto o = torch::empty_like(q, q.options());
6162
if (maybe_lse) {

python/flashinfer/jit/batch_decode_mla_templ.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
size_t int_workspace_size_in_bytes =
4141
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
4242
auto device = float_workspace_buffer.device();
43+
const at::cuda::OptionalCUDAGuard device_guard(device);
4344
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4445
indptr = indptr.to(torch::kCPU);
4546
@@ -83,8 +84,9 @@
8384
auto device = q_nope.device();
8485
int64_t batch_size = q_nope.size(0);
8586
int64_t num_qo_heads = q_nope.size(1);
86-
int64_t page_size = paged_ckv_cache.size(1);;
87+
int64_t page_size = paged_ckv_cache.size(1);
8788
89+
const at::cuda::OptionalCUDAGuard device_guard(device);
8890
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
8991
torch::Tensor o = torch::empty_like(q_nope);
9092
torch::Tensor lse;

python/flashinfer/jit/batch_decode_templ.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
size_t int_workspace_size_in_bytes =
4242
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
4343
auto device = float_workspace_buffer.device();
44+
const at::cuda::OptionalCUDAGuard device_guard(device);
4445
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4546
TORCH_CHECK(indptr.device() == torch::kCPU, "indptr must be on CPU");
4647
@@ -93,7 +94,8 @@
9394
page_size = paged_k_cache.size(1);
9495
num_kv_heads = paged_k_cache.size(2);
9596
}
96-
97+
98+
const at::cuda::OptionalCUDAGuard device_guard(device);
9799
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
98100
torch::Tensor o = torch::empty_like(q);
99101
if (maybe_lse) {

python/flashinfer/jit/batch_prefill_templ.py

+3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
int_workspace_buffer.size(0) * int_workspace_buffer.element_size();
4646
4747
auto device = float_workspace_buffer.device();
48+
const at::cuda::OptionalCUDAGuard device_guard(device);
4849
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
4950
TORCH_CHECK(qo_indptr.device() == torch::kCPU, "qo_indptr must be on CPU");
5051
TORCH_CHECK(kv_indptr.device() == torch::kCPU, "kv_indptr must be on CPU");
@@ -92,6 +93,7 @@
9293
}
9394
9495
auto device = float_workspace_buffer.device();
96+
const at::cuda::OptionalCUDAGuard device_guard(device);
9597
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
9698
auto o = torch::empty_like(q, q.options());
9799
if (maybe_lse) {
@@ -187,6 +189,7 @@
187189
num_kv_heads = paged_k_cache.size(2);
188190
}
189191
192+
const at::cuda::OptionalCUDAGuard device_guard(device);
190193
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
191194
auto o = torch::empty_like(q, q.options());
192195
if (maybe_lse) {

python/flashinfer/jit/single_decode_templ.py

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
num_kv_heads = k.size(0);
107107
kv_len = k.size(1);
108108
}
109+
const at::cuda::OptionalCUDAGuard device_guard(device);
109110
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
110111
auto o = torch::empty_like(q);
111112
@@ -157,6 +158,7 @@
157158
num_kv_heads = k.size(0);
158159
kv_len = k.size(1);
159160
}
161+
const at::cuda::OptionalCUDAGuard device_guard(device);
160162
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
161163
auto o = torch::empty_like(q);
162164

python/flashinfer/jit/single_prefill_templ.py

+2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
kv_stride_h = k.stride(0);
103103
kv_stride_n = k.stride(1);
104104
}
105+
const at::cuda::OptionalCUDAGuard device_guard(device);
105106
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
106107
auto o = torch::empty_like(q, q.options());
107108
if (maybe_lse) {
@@ -177,6 +178,7 @@
177178
kv_stride_h = k.stride(0);
178179
kv_stride_n = k.stride(1);
179180
}
181+
const at::cuda::OptionalCUDAGuard device_guard(device);
180182
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
181183
auto o = torch::empty_like(q, q.options());
182184
if (maybe_lse) {

0 commit comments

Comments
 (0)