@@ -168,9 +168,9 @@ std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache(
168
168
auto D_H = (D_HQ - int4_qparam_offset) * 2 ;
169
169
170
170
auto cache_K_dq =
171
- at::empty ({B, MAX_T, N_KVH, D_H}, cache_K.options ().dtype (at::kBFloat16 ));
171
+ at::zeros ({B, MAX_T, N_KVH, D_H}, cache_K.options ().dtype (at::kBFloat16 ));
172
172
auto cache_V_dq =
173
- at::empty ({B, MAX_T, N_KVH, D_H}, cache_K.options ().dtype (at::kBFloat16 ));
173
+ at::zeros ({B, MAX_T, N_KVH, D_H}, cache_K.options ().dtype (at::kBFloat16 ));
174
174
175
175
if (B == 0 ) {
176
176
return {cache_K_dq, cache_V_dq};
@@ -625,7 +625,14 @@ DEVICE_INLINE fx4 rope_xpos(
625
625
}
626
626
627
627
template <int KVQuantNumGroups = 1 >
628
- DEVICE_INLINE void quantize_int4_kv (fx4 dst, uint8_t * dst_row_q) {
628
+ DEVICE_INLINE void quantize_int4_kv (fx4 dst, uint8_t * dst_row_q, bool do_norm) {
629
+ if (do_norm) {
630
+ float sum = fx4_dot (dst, dst);
631
+ // Warp reduce sum
632
+ sum = warpReduceSum (sum);
633
+ float rsqr = rsqrtf (sum / D_H);
634
+ dst = fx4_scale (dst, rsqr);
635
+ }
629
636
auto thread_min = fminf (fminf (fminf (dst.x , dst.y ), dst.z ), dst.w );
630
637
auto thread_max = fmaxf (fmaxf (fmaxf (dst.x , dst.y ), dst.z ), dst.w );
631
638
@@ -961,7 +968,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized(
961
968
quantize_fp8_kv (dst, dst_row_q, qparam_row, (qkv == QKV::K && k_norm));
962
969
} else if (kCacheDtype == CacheLogicalDtype::INT4) {
963
970
CUDA_KERNEL_ASSERT (D_H_q - D_H / 2 == 4 * KVQuantNumGroups);
964
- quantize_int4_kv<KVQuantNumGroups>(dst, dst_row_q);
971
+ quantize_int4_kv<KVQuantNumGroups>(
972
+ dst, dst_row_q, (qkv == QKV::K && k_norm));
965
973
}
966
974
}
967
975
}
@@ -977,6 +985,8 @@ at::Tensor nope_qkv_varseq_prefill(
977
985
std::optional<at::Tensor> block_tables,
978
986
int64_t page_size,
979
987
std::optional<at::Tensor> varseq_cache_seqpos,
988
+ int64_t cache_logical_dtype_int,
989
+ std::optional<int64_t > num_groups,
980
990
std::optional<at::Tensor> qparam_k = std::nullopt,
981
991
std::optional<at::Tensor> qparam_v = std::nullopt,
982
992
bool k_norm = false ) {
@@ -1005,7 +1015,8 @@ at::Tensor nope_qkv_varseq_prefill(
1005
1015
block_tables_ptr = static_cast <int32_t *>(block_tables.value ().data_ptr ());
1006
1016
block_tables_b_stride = block_tables.value ().stride (0 );
1007
1017
}
1008
-
1018
+ CacheLogicalDtype cache_logical_dtype =
1019
+ static_cast <CacheLogicalDtype>(cache_logical_dtype_int);
1009
1020
if (cache_K.dtype () == at::kBFloat16 ) {
1010
1021
nope_qkv_varseq_prefill_kernel<<<
1011
1022
blocks,
@@ -1029,7 +1040,7 @@ at::Tensor nope_qkv_varseq_prefill(
1029
1040
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1030
1041
return XQ_O;
1031
1042
} else {
1032
- // TODO: Pass Logical datatype to differentiate INT4 and FP8
1043
+ auto num_groups_ = num_groups ? num_groups. value () : 1 ;
1033
1044
int32_t * qparam_k_ptr = nullptr ;
1034
1045
int32_t * qparam_v_ptr = nullptr ;
1035
1046
if (qparam_k.has_value ()) {
@@ -1039,33 +1050,66 @@ at::Tensor nope_qkv_varseq_prefill(
1039
1050
auto varseq_batch_ = varseq_batch.data_ptr <int32_t >();
1040
1051
auto varseq_seqpos_ =
1041
1052
varseq_seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>();
1042
-
1043
- CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL (
1044
- 1 ,
1045
- CacheLogicalDtype::FP8,
1046
- PositionEmbeddingMode::NOPE,
1047
- varseq_batch_,
1048
- varseq_seqpos_,
1049
- 0 ,
1050
- 0 ,
1051
- 0 ,
1052
- 0 ,
1053
- block_tables_ptr,
1054
- page_size,
1055
- block_tables_b_stride,
1056
- (varseq_cache_seqpos_
1057
- .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1058
- nullptr ,
1059
- false ,
1060
- 0 ,
1061
- 0 ,
1062
- 0 ,
1063
- 0 ,
1064
- false ,
1065
- k_norm);
1066
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1067
- return XQ_O;
1053
+ if (cache_logical_dtype == CacheLogicalDtype::FP8) {
1054
+ #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
1055
+ (defined (CUDA_VERSION) && CUDA_VERSION >= 12000 )
1056
+ CUDA_KERNEL_ASSERT (num_groups_ == 1 );
1057
+ CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL (
1058
+ 1 ,
1059
+ CacheLogicalDtype::FP8,
1060
+ PositionEmbeddingMode::NOPE,
1061
+ varseq_batch_,
1062
+ varseq_seqpos_,
1063
+ 0 ,
1064
+ 0 ,
1065
+ 0 ,
1066
+ 0 ,
1067
+ block_tables_ptr,
1068
+ page_size,
1069
+ block_tables_b_stride,
1070
+ (varseq_cache_seqpos_
1071
+ .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1072
+ nullptr ,
1073
+ false ,
1074
+ 0 ,
1075
+ 0 ,
1076
+ 0 ,
1077
+ 0 ,
1078
+ false ,
1079
+ k_norm);
1080
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
1081
+ #else
1082
+ throw std::runtime_error (" CUDA version is older than 12.0" );
1083
+ #endif
1084
+ } else {
1085
+ CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK (
1086
+ CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL,
1087
+ num_groups_,
1088
+ CacheLogicalDtype::INT4,
1089
+ PositionEmbeddingMode::NOPE,
1090
+ varseq_batch_,
1091
+ varseq_seqpos_,
1092
+ 0 ,
1093
+ 0 ,
1094
+ 0 ,
1095
+ 0 ,
1096
+ block_tables_ptr,
1097
+ page_size,
1098
+ block_tables_b_stride,
1099
+ (varseq_cache_seqpos_
1100
+ .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1101
+ nullptr ,
1102
+ false ,
1103
+ 0 ,
1104
+ 0 ,
1105
+ 0 ,
1106
+ 0 ,
1107
+ false ,
1108
+ k_norm);
1109
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
1110
+ }
1068
1111
}
1112
+ return XQ_O;
1069
1113
}
1070
1114
1071
1115
at::Tensor nope_qkv_decoding (
@@ -1080,6 +1124,8 @@ at::Tensor nope_qkv_decoding(
1080
1124
std::optional<at::Tensor> actual_batch_size,
1081
1125
std::optional<at::Tensor> batch,
1082
1126
std::optional<at::Tensor> cache_seqpos,
1127
+ int64_t cache_logical_dtype_int,
1128
+ std::optional<int64_t > num_groups,
1083
1129
std::optional<at::Tensor> qparam_k = std::nullopt,
1084
1130
std::optional<at::Tensor> qparam_v = std::nullopt,
1085
1131
bool k_norm = false ) {
@@ -1107,7 +1153,8 @@ at::Tensor nope_qkv_decoding(
1107
1153
static_cast <int64_t *>(actual_batch_size.value ().data_ptr ());
1108
1154
}
1109
1155
auto cache_seqpos_ = cache_seqpos.value_or (seqpos);
1110
-
1156
+ CacheLogicalDtype cache_logical_dtype =
1157
+ static_cast <CacheLogicalDtype>(cache_logical_dtype_int);
1111
1158
if (cache_K.dtype () == at::kBFloat16 ) {
1112
1159
nope_qkv_varseq_prefill_kernel<<<
1113
1160
blocks,
@@ -1129,9 +1176,8 @@ at::Tensor nope_qkv_decoding(
1129
1176
actual_batch_size_ptr);
1130
1177
1131
1178
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1132
- return XQ_O;
1133
1179
} else {
1134
- // TODO: Pass KV logical Dtype
1180
+ auto num_groups_ = num_groups ? num_groups. value () : 1 ;
1135
1181
int32_t * qparam_k_ptr = nullptr ;
1136
1182
int32_t * qparam_v_ptr = nullptr ;
1137
1183
if (qparam_k.has_value ()) {
@@ -1142,32 +1188,67 @@ at::Tensor nope_qkv_decoding(
1142
1188
batch.has_value () ? batch.value ().data_ptr <int32_t >() : nullptr ;
1143
1189
auto seqpos_ =
1144
1190
seqpos.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>();
1145
- CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL (
1146
- 1 ,
1147
- CacheLogicalDtype::FP8,
1148
- PositionEmbeddingMode::NOPE,
1149
- batch_,
1150
- seqpos_,
1151
- 0 ,
1152
- 0 ,
1153
- 0 ,
1154
- 0 ,
1155
- block_tables_ptr,
1156
- page_size,
1157
- block_tables_b_stride,
1158
- (cache_seqpos_.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1159
- actual_batch_size_ptr,
1160
- false ,
1161
- 0 ,
1162
- 0 ,
1163
- 0 ,
1164
- 0 ,
1165
- false ,
1166
- k_norm);
1191
+ if (cache_logical_dtype == CacheLogicalDtype::FP8) {
1192
+ #if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
1193
+ (defined (CUDA_VERSION) && CUDA_VERSION >= 12000 )
1194
+ CUDA_KERNEL_ASSERT (num_groups_ == 1 );
1195
+ CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL (
1196
+ 1 ,
1197
+ CacheLogicalDtype::FP8,
1198
+ PositionEmbeddingMode::NOPE,
1199
+ batch_,
1200
+ seqpos_,
1201
+ 0 ,
1202
+ 0 ,
1203
+ 0 ,
1204
+ 0 ,
1205
+ block_tables_ptr,
1206
+ page_size,
1207
+ block_tables_b_stride,
1208
+ (cache_seqpos_
1209
+ .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1210
+ actual_batch_size_ptr,
1211
+ false ,
1212
+ 0 ,
1213
+ 0 ,
1214
+ 0 ,
1215
+ 0 ,
1216
+ false ,
1217
+ k_norm);
1167
1218
1168
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
1169
- return XQ_O;
1219
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
1220
+ #else
1221
+ throw std::runtime_error (" CUDA version is older than 12.0" );
1222
+ #endif
1223
+ } else {
1224
+ CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK (
1225
+ CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL,
1226
+ num_groups_,
1227
+ CacheLogicalDtype::INT4,
1228
+ PositionEmbeddingMode::NOPE,
1229
+ batch_,
1230
+ seqpos_,
1231
+ 0 ,
1232
+ 0 ,
1233
+ 0 ,
1234
+ 0 ,
1235
+ block_tables_ptr,
1236
+ page_size,
1237
+ block_tables_b_stride,
1238
+ (cache_seqpos_
1239
+ .packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1240
+ actual_batch_size_ptr,
1241
+ false ,
1242
+ 0 ,
1243
+ 0 ,
1244
+ 0 ,
1245
+ 0 ,
1246
+ false ,
1247
+ k_norm);
1248
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
1249
+ }
1170
1250
}
1251
+ return XQ_O;
1171
1252
}
1172
1253
1173
1254
at::Tensor rope_qkv_varseq_prefill (
@@ -1316,7 +1397,7 @@ at::Tensor rope_qkv_varseq_prefill(
1316
1397
lo_freq_factor,
1317
1398
hi_freq_factor,
1318
1399
write_k_back,
1319
- false );
1400
+ k_norm );
1320
1401
1321
1402
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1322
1403
}
@@ -1621,7 +1702,7 @@ at::Tensor rope_qkv_decoding(
1621
1702
lo_freq_factor,
1622
1703
hi_freq_factor,
1623
1704
false ,
1624
- false );
1705
+ k_norm );
1625
1706
1626
1707
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1627
1708
}
0 commit comments