Skip to content

Commit fa38b5e

Browse files
feat: add accept num, emit num metric for ChainSpeculativeSampling (#450)
1 parent 86c9e55 commit fa38b5e

File tree

5 files changed

+107
-25
lines changed

5 files changed

+107
-25
lines changed

include/flashinfer/sampling.cuh

+33-10
Original file line numberDiff line numberDiff line change
@@ -1154,8 +1154,10 @@ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
11541154
typename DType, typename IdType>
11551155
__global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
11561156
DType* uniform_samples, DType* target_probs,
1157-
IdType* output_token_ids, uint32_t num_speculative_tokens,
1158-
uint32_t d) {
1157+
IdType* output_token_ids,
1158+
IdType* output_accepted_token_num,
1159+
IdType* output_emitted_token_num,
1160+
uint32_t num_speculative_tokens, uint32_t d) {
11591161
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
11601162
const uint32_t row_idx = bx;
11611163

@@ -1165,20 +1167,38 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
11651167
auto& temp_storage = reinterpret_cast<
11661168
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
11671169

1168-
uint32_t pos = 0;
1169-
for (pos = 0; pos < num_speculative_tokens; ++pos) {
1170-
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos];
1171-
float q = target_probs[(row_idx * (num_speculative_tokens + 1) + pos) * d + draft_id],
1172-
p = draft_probs[(row_idx * num_speculative_tokens + pos) * d + draft_id];
1173-
DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + pos];
1170+
uint32_t pos = num_speculative_tokens;
1171+
for (uint32_t i = 0; i < num_speculative_tokens; ++i) {
1172+
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i];
1173+
float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id],
1174+
p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id];
1175+
DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i];
11741176
if (u * p < q) {
11751177
// accept the draft models output
1176-
output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = draft_id;
1178+
output_token_ids[row_idx * (num_speculative_tokens + 1) + i] = draft_id;
11771179
} else {
1180+
pos = i;
11781181
break;
11791182
}
11801183
}
11811184

1185+
uint32_t emitted_token_num = pos;
1186+
uint32_t accepted_token_num = pos;
1187+
for (uint32_t i = pos; i < num_speculative_tokens; ++i) {
1188+
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i];
1189+
float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id],
1190+
p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id];
1191+
DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i];
1192+
if (u * p < q) {
1193+
++accepted_token_num;
1194+
}
1195+
}
1196+
1197+
if (tx == 0) {
1198+
output_accepted_token_num[row_idx] += accepted_token_num;
1199+
output_emitted_token_num[row_idx] += emitted_token_num;
1200+
}
1201+
11821202
// sample from relu(target_probs - draft_probs)
11831203
DType sum_relu_q_minus_p(0);
11841204
vec_t<DType, VEC_SIZE> q_vec, p_vec;
@@ -1284,7 +1304,8 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o
12841304
template <typename DType, typename IdType>
12851305
cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
12861306
DType* uniform_samples, DType* target_probs,
1287-
IdType* output_token_ids, uint32_t batch_size,
1307+
IdType* output_token_ids, IdType* output_accepted_token_num,
1308+
IdType* output_emitted_token_num, uint32_t batch_size,
12881309
uint32_t num_speculative_tokens, uint32_t d,
12891310
bool deterministic, cudaStream_t stream = 0) {
12901311
constexpr uint32_t BLOCK_THREADS = 1024;
@@ -1299,6 +1320,8 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids
12991320
&uniform_samples,
13001321
&target_probs,
13011322
&output_token_ids,
1323+
&output_accepted_token_num,
1324+
&output_emitted_token_num,
13021325
&num_speculative_tokens,
13031326
&d};
13041327
DISPATCH_ALIGNED_VEC_SIZE(

python/csrc/flashinfer_ops.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor
6767
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
6868
unsigned int top_k_val, double eps);
6969

70-
torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
71-
torch::Tensor uniform_samples, torch::Tensor target_probs,
72-
bool deterministic);
70+
std::vector<torch::Tensor> chain_speculative_sampling(
71+
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
72+
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
73+
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic);
7374

7475
torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);
7576

python/csrc/sampling.cu

+21-6
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,10 @@ torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tenso
315315
return mask_logits;
316316
}
317317

318-
torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
319-
torch::Tensor uniform_samples, torch::Tensor target_probs,
320-
bool deterministic) {
318+
std::vector<torch::Tensor> chain_speculative_sampling(
319+
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
320+
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
321+
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic) {
321322
CHECK_INPUT(draft_probs);
322323
CHECK_INPUT(draft_token_ids);
323324
CHECK_INPUT(uniform_samples);
@@ -349,14 +350,28 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
349350
auto output_token_ids = torch::empty({batch_size, num_speculate_tokens + 1},
350351
torch::dtype(torch::kInt32).device(device));
351352

353+
bool has_output_accepted_token_num = maybe_output_accepted_token_num.has_value();
354+
bool has_output_emitted_token_num = maybe_output_emitted_token_num.has_value();
355+
auto output_accepted_token_num = maybe_output_accepted_token_num.value_or(
356+
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
357+
auto output_emitted_token_num = maybe_output_emitted_token_num.value_or(
358+
torch::zeros({batch_size}, torch::dtype(torch::kInt32).device(device)));
359+
if (has_output_accepted_token_num) {
360+
CHECK_EQ(has_output_emitted_token_num, true);
361+
CHECK_EQ(batch_size, output_accepted_token_num.size(0));
362+
CHECK_EQ(batch_size, output_emitted_token_num.size(0));
363+
}
364+
352365
cudaError_t status = sampling::ChainSpeculativeSampling<float, int>(
353366
static_cast<float*>(draft_probs.data_ptr()), static_cast<int*>(draft_token_ids.data_ptr()),
354367
static_cast<float*>(uniform_samples.data_ptr()), static_cast<float*>(target_probs.data_ptr()),
355-
static_cast<int*>(output_token_ids.data_ptr()), batch_size, num_speculate_tokens, vocab_size,
356-
deterministic, torch_current_stream);
368+
static_cast<int*>(output_token_ids.data_ptr()),
369+
static_cast<int*>(output_accepted_token_num.data_ptr()),
370+
static_cast<int*>(output_emitted_token_num.data_ptr()), batch_size, num_speculate_tokens,
371+
vocab_size, deterministic, torch_current_stream);
357372

358373
TORCH_CHECK(status == cudaSuccess, "ChainSpeculativeSampling failed with error code " +
359374
std::string(cudaGetErrorString(status)));
360375

361-
return output_token_ids;
376+
return {output_token_ids, output_accepted_token_num, output_emitted_token_num};
362377
}

python/flashinfer/sampling.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,8 @@ def chain_speculative_sampling(
592592
draft_token_ids,
593593
uniform_samples,
594594
target_probs,
595+
maybe_output_accepted_token_num: torch.Tensor = None,
596+
maybe_output_emitted_token_num: torch.Tensor = None,
595597
deterministic: bool = True,
596598
) -> torch.Tensor:
597599
r"""Fused-GPU kernel for speculative sampling for sequence generation (proposed in
@@ -614,6 +616,15 @@ def chain_speculative_sampling(
614616
Compared to input :attr:`draft_probs`, the target model's probability has an additional
615617
slot at the end because the target model will generate one more token than the draft model.
616618
Shape: ``(batch_size, num_speculate_tokens + 1, vocab_size)``
619+
maybe_output_accepted_token_num: torch.Tensor
620+
The number of tokens that can be accepted if each token is considered independently for each request.
621+
This metric does not consider the fact that rejection sampling will stop at the first token that does not
622+
satisfy the probablity requirement r < p/q.
623+
It only evaluates the alignment of draft model and target model.
624+
Shape: ``(batch_size)``
625+
maybe_output_emitted_token_num: torch.Tensor
626+
The number of tokens that are finally emitted/generated for each request.
627+
Shape: ``(batch_size)``
617628
deterministic: bool
618629
Whether to use deterministic kernel implementation, default is ``True``.
619630
@@ -628,5 +639,11 @@ def chain_speculative_sampling(
628639
Shape: (batch_size, num_specutate_tokens + 1)
629640
"""
630641
return _kernels.chain_speculative_sampling(
631-
draft_probs, draft_token_ids, uniform_samples, target_probs, deterministic
642+
draft_probs,
643+
draft_token_ids,
644+
uniform_samples,
645+
target_probs,
646+
maybe_output_accepted_token_num,
647+
maybe_output_emitted_token_num,
648+
deterministic,
632649
)

python/tests/test_sampling.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,17 @@ def test_chain_speculative_sampling(
339339
# NOTE(Zihao): this is a very simple test that only checks whether output is valid or not.
340340
for trials in range(10):
341341
uniform_samples.uniform_()
342-
output_token_ids = flashinfer.sampling.chain_speculative_sampling(
343-
normalized_draft_prob,
344-
draft_token_ids,
345-
uniform_samples,
346-
target_onehot_prob,
342+
accepted_num = torch.zeros(batch_size, dtype=torch.int32).to(0)
343+
emitted_num = torch.zeros(batch_size, dtype=torch.int32).to(0)
344+
output_token_ids, accepted_num, emitted_num = (
345+
flashinfer.sampling.chain_speculative_sampling(
346+
normalized_draft_prob,
347+
draft_token_ids,
348+
uniform_samples,
349+
target_onehot_prob,
350+
accepted_num,
351+
emitted_num,
352+
)
347353
)
348354
if onehot_target:
349355
assert torch.all(output_token_ids == target_token_ids)
@@ -359,6 +365,26 @@ def test_chain_speculative_sampling(
359365
# from the second mismatched token on, the output tokens should be -1
360366
assert torch.all(output_token_ids[row, mismatch_idx[0] + 1 :] == -1)
361367

368+
assert torch.all(emitted_num + 1 == (output_token_ids != -1).sum(dim=1))
369+
batch_indices = torch.arange(batch_size, device=normalized_draft_prob.device)[
370+
:, None
371+
]
372+
probs_indicies = torch.arange(
373+
num_speculate_tokens, device=normalized_draft_prob.device
374+
)
375+
selected_draft_probs = normalized_draft_prob[
376+
batch_indices, probs_indicies, draft_token_ids
377+
]
378+
selected_target_probs = target_onehot_prob[
379+
batch_indices, probs_indicies, draft_token_ids
380+
]
381+
capped_ratio = torch.minimum(
382+
selected_target_probs / selected_draft_probs,
383+
torch.full((1,), 1, device=normalized_draft_prob.device),
384+
)
385+
ref_accepted = (uniform_samples[:, :-1] < capped_ratio).sum(dim=1)
386+
assert torch.all(accepted_num == ref_accepted)
387+
362388

363389
if __name__ == "__main__":
364390
test_sampling(1, 111)

0 commit comments

Comments
 (0)