@@ -1154,8 +1154,10 @@ template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
1154
1154
typename DType, typename IdType>
1155
1155
__global__ void ChainSpeculativeSampling (DType* draft_probs, IdType* draft_token_ids,
1156
1156
DType* uniform_samples, DType* target_probs,
1157
- IdType* output_token_ids, uint32_t num_speculative_tokens,
1158
- uint32_t d) {
1157
+ IdType* output_token_ids,
1158
+ IdType* output_accepted_token_num,
1159
+ IdType* output_emitted_token_num,
1160
+ uint32_t num_speculative_tokens, uint32_t d) {
1159
1161
const uint32_t bx = blockIdx .x , tx = threadIdx .x ;
1160
1162
const uint32_t row_idx = bx;
1161
1163
@@ -1165,20 +1167,38 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
1165
1167
auto & temp_storage = reinterpret_cast <
1166
1168
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
1167
1169
1168
- uint32_t pos = 0 ;
1169
- for (pos = 0 ; pos < num_speculative_tokens; ++pos ) {
1170
- IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos ];
1171
- float q = target_probs[(row_idx * (num_speculative_tokens + 1 ) + pos ) * d + draft_id],
1172
- p = draft_probs[(row_idx * num_speculative_tokens + pos ) * d + draft_id];
1173
- DType u = uniform_samples[row_idx * (num_speculative_tokens + 1 ) + pos ];
1170
+ uint32_t pos = num_speculative_tokens ;
1171
+ for (uint32_t i = 0 ; i < num_speculative_tokens; ++i ) {
1172
+ IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i ];
1173
+ float q = target_probs[(row_idx * (num_speculative_tokens + 1 ) + i ) * d + draft_id],
1174
+ p = draft_probs[(row_idx * num_speculative_tokens + i ) * d + draft_id];
1175
+ DType u = uniform_samples[row_idx * (num_speculative_tokens + 1 ) + i ];
1174
1176
if (u * p < q) {
1175
1177
// accept the draft models output
1176
- output_token_ids[row_idx * (num_speculative_tokens + 1 ) + pos ] = draft_id;
1178
+ output_token_ids[row_idx * (num_speculative_tokens + 1 ) + i ] = draft_id;
1177
1179
} else {
1180
+ pos = i;
1178
1181
break ;
1179
1182
}
1180
1183
}
1181
1184
1185
+ uint32_t emitted_token_num = pos;
1186
+ uint32_t accepted_token_num = pos;
1187
+ for (uint32_t i = pos; i < num_speculative_tokens; ++i) {
1188
+ IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i];
1189
+ float q = target_probs[(row_idx * (num_speculative_tokens + 1 ) + i) * d + draft_id],
1190
+ p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id];
1191
+ DType u = uniform_samples[row_idx * (num_speculative_tokens + 1 ) + i];
1192
+ if (u * p < q) {
1193
+ ++accepted_token_num;
1194
+ }
1195
+ }
1196
+
1197
+ if (tx == 0 ) {
1198
+ output_accepted_token_num[row_idx] += accepted_token_num;
1199
+ output_emitted_token_num[row_idx] += emitted_token_num;
1200
+ }
1201
+
1182
1202
// sample from relu(target_probs - draft_probs)
1183
1203
DType sum_relu_q_minus_p (0 );
1184
1204
vec_t <DType, VEC_SIZE> q_vec, p_vec;
@@ -1284,7 +1304,8 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o
1284
1304
template <typename DType, typename IdType>
1285
1305
cudaError_t ChainSpeculativeSampling (DType* draft_probs, IdType* draft_token_ids,
1286
1306
DType* uniform_samples, DType* target_probs,
1287
- IdType* output_token_ids, uint32_t batch_size,
1307
+ IdType* output_token_ids, IdType* output_accepted_token_num,
1308
+ IdType* output_emitted_token_num, uint32_t batch_size,
1288
1309
uint32_t num_speculative_tokens, uint32_t d,
1289
1310
bool deterministic, cudaStream_t stream = 0 ) {
1290
1311
constexpr uint32_t BLOCK_THREADS = 1024 ;
@@ -1299,6 +1320,8 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids
1299
1320
&uniform_samples,
1300
1321
&target_probs,
1301
1322
&output_token_ids,
1323
+ &output_accepted_token_num,
1324
+ &output_emitted_token_num,
1302
1325
&num_speculative_tokens,
1303
1326
&d};
1304
1327
DISPATCH_ALIGNED_VEC_SIZE (
0 commit comments