@@ -33,6 +33,7 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam
33
33
probs = probs.to (torch::kFloat32 );
34
34
uniform_samples = uniform_samples.to (torch::kFloat32 );
35
35
36
+ const at::cuda::OptionalCUDAGuard device_guard (device);
36
37
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
37
38
auto samples = torch::empty ({batch_size}, torch::dtype (torch::kInt32 ).device (device));
38
39
@@ -71,6 +72,7 @@ std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
71
72
uniform_samples = uniform_samples.to (torch::kFloat32 );
72
73
top_p_arr = top_p_arr.to (torch::kFloat32 );
73
74
75
+ const at::cuda::OptionalCUDAGuard device_guard (device);
74
76
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
75
77
auto samples = torch::empty ({batch_size}, torch::dtype (torch::kInt32 ).device (device));
76
78
auto success = torch::empty ({batch_size}, torch::dtype (torch::kBool ).device (device));
@@ -112,6 +114,7 @@ std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
112
114
uniform_samples = uniform_samples.to (torch::kFloat32 );
113
115
top_k_arr = top_k_arr.to (torch::kInt32 );
114
116
117
+ const at::cuda::OptionalCUDAGuard device_guard (device);
115
118
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
116
119
auto samples = torch::empty ({batch_size}, torch::dtype (torch::kInt32 ).device (device));
117
120
auto success = torch::empty ({batch_size}, torch::dtype (torch::kBool ).device (device));
@@ -153,6 +156,7 @@ std::vector<torch::Tensor> min_p_sampling_from_probs(torch::Tensor probs,
153
156
probs = probs.to (torch::kFloat32 );
154
157
uniform_samples = uniform_samples.to (torch::kFloat32 );
155
158
159
+ const at::cuda::OptionalCUDAGuard device_guard (device);
156
160
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
157
161
auto samples = torch::empty ({batch_size}, torch::dtype (torch::kInt32 ).device (device));
158
162
auto success = torch::empty ({batch_size}, torch::dtype (torch::kBool ).device (device));
@@ -203,6 +207,7 @@ std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
203
207
probs = probs.to (torch::kFloat32 );
204
208
uniform_samples = uniform_samples.to (torch::kFloat32 );
205
209
210
+ const at::cuda::OptionalCUDAGuard device_guard (device);
206
211
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
207
212
auto samples = torch::empty ({batch_size}, torch::dtype (torch::kInt32 ).device (device));
208
213
auto success = torch::empty ({batch_size}, torch::dtype (torch::kBool ).device (device));
@@ -236,7 +241,8 @@ torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
236
241
}
237
242
top_p_arr = top_p_arr.to (torch::kFloat32 );
238
243
probs = probs.to (torch::kFloat32 );
239
-
244
+
245
+ const at::cuda::OptionalCUDAGuard device_guard (device);
240
246
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
241
247
auto renorm_probs =
242
248
torch::empty ({batch_size, vocab_size}, torch::dtype (torch::kFloat32 ).device (device));
@@ -268,6 +274,7 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
268
274
top_k_arr = top_k_arr.to (torch::kInt32 );
269
275
probs = probs.to (torch::kFloat32 );
270
276
277
+ const at::cuda::OptionalCUDAGuard device_guard (device);
271
278
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
272
279
auto renorm_probs =
273
280
torch::empty ({batch_size, vocab_size}, torch::dtype (torch::kFloat32 ).device (device));
@@ -300,6 +307,7 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tenso
300
307
top_k_arr = top_k_arr.to (torch::kInt32 );
301
308
logits = logits.to (torch::kFloat32 );
302
309
310
+ const at::cuda::OptionalCUDAGuard device_guard (device);
303
311
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
304
312
auto mask_logits =
305
313
torch::empty ({batch_size, vocab_size}, torch::dtype (torch::kFloat32 ).device (device));
@@ -348,6 +356,7 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
348
356
uniform_samples = uniform_samples.to (torch::kFloat32 );
349
357
target_probs = target_probs.to (torch::kFloat32 );
350
358
359
+ const at::cuda::OptionalCUDAGuard device_guard (device);
351
360
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream (device.index ());
352
361
auto output_token_ids = torch::empty ({batch_size, num_speculate_tokens + 1 },
353
362
torch::dtype (torch::kInt32 ).device (device));
0 commit comments