Skip to content

Commit 5796fc5

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
Support of INT4 KV (#3878)
Summary: Pull Request resolved: #3878 X-link: facebookresearch/FBGEMM#968 Enabling int4 KV for LLama4 numeric evals Changes: 1) k_norm 2) zero init dequantization. 3) Add NoPE for int4 Reviewed By: summerdengfb Differential Revision: D70508737
1 parent eeee38e commit 5796fc5

File tree

2 files changed

+152
-63
lines changed

2 files changed

+152
-63
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ at::Tensor nope_qkv_varseq_prefill(
3737
std::optional<at::Tensor> block_tables,
3838
int64_t page_size,
3939
std::optional<at::Tensor> varseq_cache_seqpos,
40+
int64_t cache_logical_dtype_int,
41+
std::optional<int64_t> num_groups,
4042
std::optional<at::Tensor> qparam_k,
4143
std::optional<at::Tensor> qparam_v,
4244
bool k_norm);
@@ -53,6 +55,8 @@ at::Tensor nope_qkv_decoding(
5355
std::optional<at::Tensor> actual_batch_size,
5456
std::optional<at::Tensor> batch,
5557
std::optional<at::Tensor> cache_seqpos,
58+
int64_t cache_logical_dtype_int,
59+
std::optional<int64_t> num_groups,
5660
std::optional<at::Tensor> qparam_k,
5761
std::optional<at::Tensor> qparam_v,
5862
bool k_norm);
@@ -187,9 +191,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
187191
m.def("rope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
188192
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor");
189193
m.def("nope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING(
190-
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor");
194+
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor");
191195
m.def("nope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING(
192-
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor");
196+
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor");
193197
m.def("xpos_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
194198
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
195199
m.def("xpos_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
@@ -288,6 +292,8 @@ at::Tensor nope_qkv_varseq_prefill_meta(
288292
std::optional<at::Tensor> /* block_tables */,
289293
int64_t /* page_size */,
290294
std::optional<at::Tensor> /* varseq_cache_seqpos */,
295+
int64_t /* cache_logical_dtype_int */,
296+
std::optional<int64_t> /* num_groups */,
291297
std::optional<at::Tensor> /* qparam_k */,
292298
std::optional<at::Tensor> /* qparam_v */,
293299
bool /* k_norm */
@@ -307,6 +313,8 @@ at::Tensor nope_qkv_decoding_meta(
307313
std::optional<at::Tensor> /* actual_batch_size */,
308314
std::optional<at::Tensor> /* batch */,
309315
std::optional<at::Tensor> /* cache_seqpos */,
316+
int64_t /* cache_logical_dtype_int */,
317+
std::optional<int64_t> /* num_groups */,
310318
std::optional<at::Tensor> /* qparam_k */,
311319
std::optional<at::Tensor> /* qparam_v */,
312320
bool /* k_norm */

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

+142-61
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache(
168168
auto D_H = (D_HQ - int4_qparam_offset) * 2;
169169

170170
auto cache_K_dq =
171-
at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
171+
at::zeros({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
172172
auto cache_V_dq =
173-
at::empty({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
173+
at::zeros({B, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
174174

175175
if (B == 0) {
176176
return {cache_K_dq, cache_V_dq};
@@ -625,7 +625,14 @@ DEVICE_INLINE fx4 rope_xpos(
625625
}
626626

627627
template <int KVQuantNumGroups = 1>
628-
DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q) {
628+
DEVICE_INLINE void quantize_int4_kv(fx4 dst, uint8_t* dst_row_q, bool do_norm) {
629+
if (do_norm) {
630+
float sum = fx4_dot(dst, dst);
631+
// Warp reduce sum
632+
sum = warpReduceSum(sum);
633+
float rsqr = rsqrtf(sum / D_H);
634+
dst = fx4_scale(dst, rsqr);
635+
}
629636
auto thread_min = fminf(fminf(fminf(dst.x, dst.y), dst.z), dst.w);
630637
auto thread_max = fmaxf(fmaxf(fmaxf(dst.x, dst.y), dst.z), dst.w);
631638

@@ -961,7 +968,8 @@ __global__ void rope_xpos_qkv_varseq_prefill_kernel_quantized(
961968
quantize_fp8_kv(dst, dst_row_q, qparam_row, (qkv == QKV::K && k_norm));
962969
} else if (kCacheDtype == CacheLogicalDtype::INT4) {
963970
CUDA_KERNEL_ASSERT(D_H_q - D_H / 2 == 4 * KVQuantNumGroups);
964-
quantize_int4_kv<KVQuantNumGroups>(dst, dst_row_q);
971+
quantize_int4_kv<KVQuantNumGroups>(
972+
dst, dst_row_q, (qkv == QKV::K && k_norm));
965973
}
966974
}
967975
}
@@ -977,6 +985,8 @@ at::Tensor nope_qkv_varseq_prefill(
977985
std::optional<at::Tensor> block_tables,
978986
int64_t page_size,
979987
std::optional<at::Tensor> varseq_cache_seqpos,
988+
int64_t cache_logical_dtype_int,
989+
std::optional<int64_t> num_groups,
980990
std::optional<at::Tensor> qparam_k = std::nullopt,
981991
std::optional<at::Tensor> qparam_v = std::nullopt,
982992
bool k_norm = false) {
@@ -1005,7 +1015,8 @@ at::Tensor nope_qkv_varseq_prefill(
10051015
block_tables_ptr = static_cast<int32_t*>(block_tables.value().data_ptr());
10061016
block_tables_b_stride = block_tables.value().stride(0);
10071017
}
1008-
1018+
CacheLogicalDtype cache_logical_dtype =
1019+
static_cast<CacheLogicalDtype>(cache_logical_dtype_int);
10091020
if (cache_K.dtype() == at::kBFloat16) {
10101021
nope_qkv_varseq_prefill_kernel<<<
10111022
blocks,
@@ -1029,7 +1040,7 @@ at::Tensor nope_qkv_varseq_prefill(
10291040
C10_CUDA_KERNEL_LAUNCH_CHECK();
10301041
return XQ_O;
10311042
} else {
1032-
// TODO: Pass Logical datatype to differentiate INT4 and FP8
1043+
auto num_groups_ = num_groups ? num_groups.value() : 1;
10331044
int32_t* qparam_k_ptr = nullptr;
10341045
int32_t* qparam_v_ptr = nullptr;
10351046
if (qparam_k.has_value()) {
@@ -1039,33 +1050,66 @@ at::Tensor nope_qkv_varseq_prefill(
10391050
auto varseq_batch_ = varseq_batch.data_ptr<int32_t>();
10401051
auto varseq_seqpos_ =
10411052
varseq_seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>();
1042-
1043-
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL(
1044-
1,
1045-
CacheLogicalDtype::FP8,
1046-
PositionEmbeddingMode::NOPE,
1047-
varseq_batch_,
1048-
varseq_seqpos_,
1049-
0,
1050-
0,
1051-
0,
1052-
0,
1053-
block_tables_ptr,
1054-
page_size,
1055-
block_tables_b_stride,
1056-
(varseq_cache_seqpos_
1057-
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1058-
nullptr,
1059-
false,
1060-
0,
1061-
0,
1062-
0,
1063-
0,
1064-
false,
1065-
k_norm);
1066-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1067-
return XQ_O;
1053+
if (cache_logical_dtype == CacheLogicalDtype::FP8) {
1054+
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
1055+
(defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
1056+
CUDA_KERNEL_ASSERT(num_groups_ == 1);
1057+
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL(
1058+
1,
1059+
CacheLogicalDtype::FP8,
1060+
PositionEmbeddingMode::NOPE,
1061+
varseq_batch_,
1062+
varseq_seqpos_,
1063+
0,
1064+
0,
1065+
0,
1066+
0,
1067+
block_tables_ptr,
1068+
page_size,
1069+
block_tables_b_stride,
1070+
(varseq_cache_seqpos_
1071+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1072+
nullptr,
1073+
false,
1074+
0,
1075+
0,
1076+
0,
1077+
0,
1078+
false,
1079+
k_norm);
1080+
C10_CUDA_KERNEL_LAUNCH_CHECK();
1081+
#else
1082+
throw std::runtime_error("CUDA version is older than 12.0");
1083+
#endif
1084+
} else {
1085+
CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK(
1086+
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL,
1087+
num_groups_,
1088+
CacheLogicalDtype::INT4,
1089+
PositionEmbeddingMode::NOPE,
1090+
varseq_batch_,
1091+
varseq_seqpos_,
1092+
0,
1093+
0,
1094+
0,
1095+
0,
1096+
block_tables_ptr,
1097+
page_size,
1098+
block_tables_b_stride,
1099+
(varseq_cache_seqpos_
1100+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1101+
nullptr,
1102+
false,
1103+
0,
1104+
0,
1105+
0,
1106+
0,
1107+
false,
1108+
k_norm);
1109+
C10_CUDA_KERNEL_LAUNCH_CHECK();
1110+
}
10681111
}
1112+
return XQ_O;
10691113
}
10701114
10711115
at::Tensor nope_qkv_decoding(
@@ -1080,6 +1124,8 @@ at::Tensor nope_qkv_decoding(
10801124
std::optional<at::Tensor> actual_batch_size,
10811125
std::optional<at::Tensor> batch,
10821126
std::optional<at::Tensor> cache_seqpos,
1127+
int64_t cache_logical_dtype_int,
1128+
std::optional<int64_t> num_groups,
10831129
std::optional<at::Tensor> qparam_k = std::nullopt,
10841130
std::optional<at::Tensor> qparam_v = std::nullopt,
10851131
bool k_norm = false) {
@@ -1107,7 +1153,8 @@ at::Tensor nope_qkv_decoding(
11071153
static_cast<int64_t*>(actual_batch_size.value().data_ptr());
11081154
}
11091155
auto cache_seqpos_ = cache_seqpos.value_or(seqpos);
1110-
1156+
CacheLogicalDtype cache_logical_dtype =
1157+
static_cast<CacheLogicalDtype>(cache_logical_dtype_int);
11111158
if (cache_K.dtype() == at::kBFloat16) {
11121159
nope_qkv_varseq_prefill_kernel<<<
11131160
blocks,
@@ -1129,9 +1176,8 @@ at::Tensor nope_qkv_decoding(
11291176
actual_batch_size_ptr);
11301177
11311178
C10_CUDA_KERNEL_LAUNCH_CHECK();
1132-
return XQ_O;
11331179
} else {
1134-
// TODO: Pass KV logical Dtype
1180+
auto num_groups_ = num_groups ? num_groups.value() : 1;
11351181
int32_t* qparam_k_ptr = nullptr;
11361182
int32_t* qparam_v_ptr = nullptr;
11371183
if (qparam_k.has_value()) {
@@ -1142,32 +1188,67 @@ at::Tensor nope_qkv_decoding(
11421188
batch.has_value() ? batch.value().data_ptr<int32_t>() : nullptr;
11431189
auto seqpos_ =
11441190
seqpos.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>();
1145-
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL(
1146-
1,
1147-
CacheLogicalDtype::FP8,
1148-
PositionEmbeddingMode::NOPE,
1149-
batch_,
1150-
seqpos_,
1151-
0,
1152-
0,
1153-
0,
1154-
0,
1155-
block_tables_ptr,
1156-
page_size,
1157-
block_tables_b_stride,
1158-
(cache_seqpos_.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1159-
actual_batch_size_ptr,
1160-
false,
1161-
0,
1162-
0,
1163-
0,
1164-
0,
1165-
false,
1166-
k_norm);
1191+
if (cache_logical_dtype == CacheLogicalDtype::FP8) {
1192+
#if (defined(USE_ROCM) && ROCM_VERSION >= 60200) || \
1193+
(defined(CUDA_VERSION) && CUDA_VERSION >= 12000)
1194+
CUDA_KERNEL_ASSERT(num_groups_ == 1);
1195+
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL(
1196+
1,
1197+
CacheLogicalDtype::FP8,
1198+
PositionEmbeddingMode::NOPE,
1199+
batch_,
1200+
seqpos_,
1201+
0,
1202+
0,
1203+
0,
1204+
0,
1205+
block_tables_ptr,
1206+
page_size,
1207+
block_tables_b_stride,
1208+
(cache_seqpos_
1209+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1210+
actual_batch_size_ptr,
1211+
false,
1212+
0,
1213+
0,
1214+
0,
1215+
0,
1216+
false,
1217+
k_norm);
11671218
1168-
C10_CUDA_KERNEL_LAUNCH_CHECK();
1169-
return XQ_O;
1219+
C10_CUDA_KERNEL_LAUNCH_CHECK();
1220+
#else
1221+
throw std::runtime_error("CUDA version is older than 12.0");
1222+
#endif
1223+
} else {
1224+
CALL_INT4_KERNEL_WITH_KV_GROUPWISE_QUANT_CHECK(
1225+
CALL_ROPE_XPOS_QKV_VARSEQ_PREFILL_GROUPWISE_KERNEL,
1226+
num_groups_,
1227+
CacheLogicalDtype::INT4,
1228+
PositionEmbeddingMode::NOPE,
1229+
batch_,
1230+
seqpos_,
1231+
0,
1232+
0,
1233+
0,
1234+
0,
1235+
block_tables_ptr,
1236+
page_size,
1237+
block_tables_b_stride,
1238+
(cache_seqpos_
1239+
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1240+
actual_batch_size_ptr,
1241+
false,
1242+
0,
1243+
0,
1244+
0,
1245+
0,
1246+
false,
1247+
k_norm);
1248+
C10_CUDA_KERNEL_LAUNCH_CHECK();
1249+
}
11701250
}
1251+
return XQ_O;
11711252
}
11721253
11731254
at::Tensor rope_qkv_varseq_prefill(
@@ -1316,7 +1397,7 @@ at::Tensor rope_qkv_varseq_prefill(
13161397
lo_freq_factor,
13171398
hi_freq_factor,
13181399
write_k_back,
1319-
false);
1400+
k_norm);
13201401
13211402
C10_CUDA_KERNEL_LAUNCH_CHECK();
13221403
}
@@ -1621,7 +1702,7 @@ at::Tensor rope_qkv_decoding(
16211702
lo_freq_factor,
16221703
hi_freq_factor,
16231704
false,
1624-
false);
1705+
k_norm);
16251706
16261707
C10_CUDA_KERNEL_LAUNCH_CHECK();
16271708
}

0 commit comments

Comments
 (0)