Skip to content

Commit 368a807

Browse files
committed
llama: Fix the KV cache quants q4_0 and q8_0 lead server abort in large context chat. ggml-org#8073
Credit : @mengkin
1 parent a3b0e8b commit 368a807

File tree

2 files changed

+339
-1
lines changed

2 files changed

+339
-1
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

+311-1
Original file line numberDiff line numberDiff line change
@@ -3358,6 +3358,308 @@ static void ggml_compute_forward_dup_same_cont(
33583358
}
33593359
}
33603360

3361+
static void ggml_compute_forward_dup_q4(
3362+
const struct ggml_compute_params * params,
3363+
struct ggml_tensor * dst) {
3364+
const struct ggml_tensor * src0 = dst->src[0];
3365+
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
3366+
GGML_TENSOR_UNARY_OP_LOCALS
3367+
const int ith = params->ith; // thread index
3368+
const int nth = params->nth; // number of threads
3369+
// parallelize by rows
3370+
const int nr = ne01;
3371+
// number of rows per thread
3372+
const int dr = (nr + nth - 1) / nth;
3373+
// row range for this thread
3374+
const int ir0 = dr * ith;
3375+
const int ir1 = MIN(ir0 + dr, nr);
3376+
if (src0->type == dst->type &&
3377+
ne00 == ne0 &&
3378+
nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
3379+
// copy by rows
3380+
const size_t rs = ne00 * nb00;
3381+
for (int64_t i03 = 0; i03 < ne03; i03++) {
3382+
for (int64_t i02 = 0; i02 < ne02; i02++) {
3383+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
3384+
memcpy(
3385+
((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3),
3386+
((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03),
3387+
rs);
3388+
}
3389+
}
3390+
}
3391+
return;
3392+
}
3393+
if (ggml_is_contiguous(dst)) {
3394+
if (nb00 == sizeof(block_q4_0)) {
3395+
const size_t rs = ne00 / 2; // QK4_0/2 bytes per row
3396+
if (dst->type == GGML_TYPE_F32) {
3397+
float * dst_ptr = (float *) dst->data;
3398+
for (int i03 = 0; i03 < ne03; i03++) {
3399+
for (int i02 = 0; i02 < ne02; i02++) {
3400+
size_t id = rs * ith;
3401+
for (int i01 = ir0; i01 < ir1; i01++) {
3402+
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3403+
dequantize_row_q4_0(src_ptr, dst_ptr + id, ne00);
3404+
id += rs;
3405+
}
3406+
id += rs * (ne01 - ir1);
3407+
}
3408+
}
3409+
} else if (dst->type == GGML_TYPE_F16) {
3410+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
3411+
for (int i03 = 0; i03 < ne03; i03++) {
3412+
for (int i02 = 0; i02 < ne02; i02++) {
3413+
size_t id = rs * ith;
3414+
for (int i01 = ir0; i01 < ir1; i01++) {
3415+
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3416+
float tmp[QK4_0];
3417+
dequantize_row_q4_0(src_ptr, tmp, ne00);
3418+
for (int i00 = 0; i00 < QK4_0; i00++) {
3419+
dst_ptr[id + i00] = GGML_FP32_TO_FP16(tmp[i00]);
3420+
}
3421+
id += rs;
3422+
}
3423+
id += rs * (ne01 - ir1);
3424+
}
3425+
}
3426+
} else if (dst->type == GGML_TYPE_BF16) {
3427+
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
3428+
for (int i03 = 0; i03 < ne03; i03++) {
3429+
for (int i02 = 0; i02 < ne02; i02++) {
3430+
size_t id = rs * ith;
3431+
for (int i01 = ir0; i01 < ir1; i01++) {
3432+
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3433+
float tmp[QK4_0];
3434+
dequantize_row_q4_0(src_ptr, tmp, ne00);
3435+
for (int i00 = 0; i00 < QK4_0; i00++) {
3436+
dst_ptr[id + i00] = GGML_FP32_TO_BF16(tmp[i00]);
3437+
}
3438+
id += rs;
3439+
}
3440+
id += rs * (ne01 - ir1);
3441+
}
3442+
}
3443+
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
3444+
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
3445+
float tmp[QK4_0];
3446+
for (int i03 = 0; i03 < ne03; i03++) {
3447+
for (int i02 = 0; i02 < ne02; i02++) {
3448+
size_t id = rs * ith;
3449+
for (int i01 = ir0; i01 < ir1; i01++) {
3450+
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3451+
dequantize_row_q4_0(src_ptr, tmp, ne00);
3452+
quantize_row_q(tmp, dst->data + id, ne00);
3453+
id += rs;
3454+
}
3455+
id += rs * (ne01 - ir1);
3456+
}
3457+
}
3458+
} else {
3459+
GGML_ABORT("fatal error"); // TODO: implement
3460+
}
3461+
}
3462+
} else {
3463+
if (dst->type == GGML_TYPE_F32) {
3464+
float * dst_ptr = (float *) dst->data;
3465+
for (int i03 = 0; i03 < ne03; i03++) {
3466+
for (int i02 = 0; i02 < ne02; i02++) {
3467+
size_t id = ith * QK4_0 / 2;
3468+
for (int i01 = ir0; i01 < ir1; i01++) {
3469+
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3470+
for (int i00 = 0; i00 < QK4_0 / 2; i00++) {
3471+
dst_ptr[id] = GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8);
3472+
dst_ptr[id + 1] = GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8);
3473+
id += 2;
3474+
}
3475+
}
3476+
}
3477+
}
3478+
} else if (dst->type == GGML_TYPE_F16) {
3479+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
3480+
for (int i03 = 0; i03 < ne03; i03++) {
3481+
for (int i02 = 0; i02 < ne02; i02++) {
3482+
size_t id = ith * QK4_0 / 2;
3483+
for (int i01 = ir0; i01 < ir1; i01++) {
3484+
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3485+
for (int i00 = 0; i00 < QK4_0 / 2; i00++) {
3486+
dst_ptr[id] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8));
3487+
dst_ptr[id + 1] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8));
3488+
id += 2;
3489+
}
3490+
}
3491+
}
3492+
}
3493+
} else if (dst->type == GGML_TYPE_BF16) {
3494+
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
3495+
for (int i03 = 0; i03 < ne03; i03++) {
3496+
for (int i02 = 0; i02 < ne02; i02++) {
3497+
size_t id = ith * QK4_0 / 2;
3498+
for (int i01 = ir0; i01 < ir1; i01++) {
3499+
const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3500+
for (int i00 = 0; i00 < QK4_0 / 2; i00++) {
3501+
dst_ptr[id] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8));
3502+
dst_ptr[id + 1] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8));
3503+
id += 2;
3504+
}
3505+
}
3506+
}
3507+
}
3508+
} else {
3509+
GGML_ABORT("fatal error"); // TODO: implement
3510+
}
3511+
}
3512+
return;
3513+
}
3514+
static void ggml_compute_forward_dup_q8(
3515+
const struct ggml_compute_params * params,
3516+
struct ggml_tensor * dst) {
3517+
const struct ggml_tensor * src0 = dst->src[0];
3518+
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
3519+
GGML_TENSOR_UNARY_OP_LOCALS
3520+
const int ith = params->ith; // thread index
3521+
const int nth = params->nth; // number of threads
3522+
// parallelize by rows
3523+
const int nr = ne01;
3524+
// number of rows per thread
3525+
const int dr = (nr + nth - 1) / nth;
3526+
// row range for this thread
3527+
const int ir0 = dr * ith;
3528+
const int ir1 = MIN(ir0 + dr, nr);
3529+
if (src0->type == dst->type &&
3530+
ne00 == ne0 &&
3531+
nb00 >= ggml_type_size(src0->type) && nb0 >= ggml_type_size(dst->type)) {
3532+
// copy by rows
3533+
const size_t rs = ne00 * nb00;
3534+
for (int64_t i03 = 0; i03 < ne03; i03++) {
3535+
for (int64_t i02 = 0; i02 < ne02; i02++) {
3536+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
3537+
memcpy(
3538+
((char * ) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3),
3539+
((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03),
3540+
rs);
3541+
}
3542+
}
3543+
}
3544+
return;
3545+
}
3546+
if (ggml_is_contiguous(dst)) {
3547+
const size_t rs = ne00 / QK8_0; // QK8_0 bytes per row
3548+
if (dst->type == GGML_TYPE_F32) {
3549+
float * dst_ptr = (float *) dst->data;
3550+
for (int64_t i03 = 0; i03 < ne03; i03++) {
3551+
for (int64_t i02 = 0; i02 < ne02; i02++) {
3552+
size_t id = rs * ith;
3553+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
3554+
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3555+
dequantize_row_q8_0(src_ptr, dst_ptr + id * QK8_0, ne00);
3556+
id += rs;
3557+
}
3558+
id += rs * (ne01 - ir1);
3559+
}
3560+
}
3561+
} else if (dst->type == GGML_TYPE_F16) {
3562+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
3563+
for (int64_t i03 = 0; i03 < ne03; i03++) {
3564+
for (int64_t i02 = 0; i02 < ne02; i02++) {
3565+
size_t id = rs * ith;
3566+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
3567+
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3568+
float tmp[QK8_0];
3569+
dequantize_row_q8_0(src_ptr, tmp, ne00);
3570+
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3571+
dst_ptr[id * QK8_0 + i00] = GGML_FP32_TO_FP16(tmp[i00]);
3572+
}
3573+
id += rs;
3574+
}
3575+
id += rs * (ne01 - ir1);
3576+
}
3577+
}
3578+
} else if (dst->type == GGML_TYPE_BF16) {
3579+
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
3580+
for (int64_t i03 = 0; i03 < ne03; i03++) {
3581+
for (int64_t i02 = 0; i02 < ne02; i02++) {
3582+
size_t id = rs * ith;
3583+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
3584+
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3585+
float tmp[QK8_0];
3586+
dequantize_row_q8_0(src_ptr, tmp, ne00);
3587+
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3588+
dst_ptr[id * QK8_0 + i00] = GGML_FP32_TO_BF16(tmp[i00]);
3589+
}
3590+
id += rs;
3591+
}
3592+
id += rs * (ne01 - ir1);
3593+
}
3594+
}
3595+
} else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
3596+
ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
3597+
float tmp[QK8_0];
3598+
for (int64_t i03 = 0; i03 < ne03; i03++) {
3599+
for (int64_t i02 = 0; i02 < ne02; i02++) {
3600+
size_t id = rs * ith;
3601+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
3602+
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3603+
dequantize_row_q8_0(src_ptr, tmp, ne00);
3604+
quantize_row_q(tmp, dst->data + id * QK8_0, ne00);
3605+
id += rs;
3606+
}
3607+
id += rs * (ne01 - ir1);
3608+
}
3609+
}
3610+
} else {
3611+
GGML_ABORT("fatal error"); // TODO: implement
3612+
}
3613+
} else {
3614+
if (dst->type == GGML_TYPE_F32) {
3615+
float * dst_ptr = (float *) dst->data;
3616+
for (int64_t i03 = 0; i03 < ne03; i03++) {
3617+
for (int64_t i02 = 0; i02 < ne02; i02++) {
3618+
size_t id = ith * QK8_0;
3619+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
3620+
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3621+
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3622+
dst_ptr[id] = GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00];
3623+
id += 1;
3624+
}
3625+
}
3626+
}
3627+
}
3628+
} else if (dst->type == GGML_TYPE_F16) {
3629+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
3630+
for (int64_t i03 = 0; i03 < ne03; i03++) {
3631+
for (int64_t i02 = 0; i02 < ne02; i02++) {
3632+
size_t id = ith * QK8_0;
3633+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
3634+
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3635+
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3636+
dst_ptr[id] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00]);
3637+
id += 1;
3638+
}
3639+
}
3640+
}
3641+
}
3642+
} else if (dst->type == GGML_TYPE_BF16) {
3643+
ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
3644+
for (int64_t i03 = 0; i03 < ne03; i03++) {
3645+
for (int64_t i02 = 0; i02 < ne02; i02++) {
3646+
size_t id = ith * QK8_0;
3647+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
3648+
const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3649+
for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3650+
dst_ptr[id] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00]);
3651+
id += 1;
3652+
}
3653+
}
3654+
}
3655+
}
3656+
} else {
3657+
GGML_ABORT("fatal error"); // TODO: implement
3658+
}
3659+
}
3660+
return;
3661+
}
3662+
33613663
static void ggml_compute_forward_dup_f16(
33623664
const struct ggml_compute_params * params,
33633665
struct ggml_tensor * dst) {
@@ -4457,9 +4759,17 @@ static void ggml_compute_forward_dup(
44574759
{
44584760
ggml_compute_forward_dup_f32(params, dst);
44594761
} break;
4762+
case GGML_TYPE_Q4_0:
4763+
{
4764+
ggml_compute_forward_dup_q4(params, dst);
4765+
} break;
4766+
case GGML_TYPE_Q8_0:
4767+
{
4768+
ggml_compute_forward_dup_q8(params, dst);
4769+
} break;
44604770
default:
44614771
{
4462-
GGML_ABORT("fatal error");
4772+
GGML_ABORT("fatal error, not support forward dup oper from %d ot %d", src0->type, dst->type);
44634773
}
44644774
}
44654775
}

ggml/src/ggml-cuda/cpy.cu

+28
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,20 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
131131
}
132132
}
133133

134+
static __device__ void cpy_blck_q4_0_f32(const char * cxi, char * cdsti) {
135+
const block_q4_0 * xi = (const block_q4_0 *) cxi;
136+
float * dsti = (float *) cdsti;
137+
138+
const float d = (float)xi->d;
139+
140+
for (int j = 0; j < QK4_0/2; ++j) {
141+
const float x0 = (xi->qs[j] & 0x0F) - 8;
142+
const float x1 = (xi->qs[j] >> 4) - 8;
143+
dsti[j + 0] = x0 * d;
144+
dsti[j + QK4_0/2] = x1 * d;
145+
}
146+
}
147+
134148
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
135149
const float * xi = (const float *) cxi;
136150
block_q4_1 * dsti = (block_q4_1 *) cdsti;
@@ -446,6 +460,16 @@ static void ggml_cpy_f32_q4_0_cuda(
446460
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
447461
}
448462

463+
static void ggml_cpy_q4_0_f32_cuda(
464+
const char * cx, char * cdst, const int ne,
465+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
466+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
467+
468+
const int num_blocks = ne;
469+
cpy_q_f32<cpy_blck_q4_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
470+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
471+
}
472+
449473
static void ggml_cpy_f32_q4_1_cuda(
450474
const char * cx, char * cdst, const int ne,
451475
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -556,6 +580,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
556580
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
557581
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
558582
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
583+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
584+
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
559585
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
560586
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
561587
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -598,6 +624,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
598624
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
599625
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
600626
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
627+
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
628+
return (void*) cpy_q_f32<cpy_blck_q4_0_f32, QK4_0>;
601629
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
602630
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
603631
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {

0 commit comments

Comments
 (0)