@@ -3358,6 +3358,308 @@ static void ggml_compute_forward_dup_same_cont(
3358
3358
}
3359
3359
}
3360
3360
3361
+ static void ggml_compute_forward_dup_q4(
3362
+ const struct ggml_compute_params * params,
3363
+ struct ggml_tensor * dst) {
3364
+ const struct ggml_tensor * src0 = dst->src[0];
3365
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
3366
+ GGML_TENSOR_UNARY_OP_LOCALS
3367
+ const int ith = params->ith; // thread index
3368
+ const int nth = params->nth; // number of threads
3369
+ // parallelize by rows
3370
+ const int nr = ne01;
3371
+ // number of rows per thread
3372
+ const int dr = (nr + nth - 1) / nth;
3373
+ // row range for this thread
3374
+ const int ir0 = dr * ith;
3375
+ const int ir1 = MIN(ir0 + dr, nr);
3376
+ if (src0->type == dst->type &&
3377
+ ne00 == ne0 &&
3378
+ nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
3379
+ // copy by rows
3380
+ const size_t rs = ne00 * nb00;
3381
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
3382
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
3383
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
3384
+ memcpy(
3385
+ ((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3),
3386
+ ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03),
3387
+ rs);
3388
+ }
3389
+ }
3390
+ }
3391
+ return;
3392
+ }
3393
+ if (ggml_is_contiguous(dst)) {
3394
+ if (nb00 == sizeof(block_q4_0)) {
3395
+ const size_t rs = ne00 / 2; // QK4_0/2 bytes per row
3396
+ if (dst->type == GGML_TYPE_F32) {
3397
+ float * dst_ptr = (float *) dst->data;
3398
+ for (int i03 = 0; i03 < ne03; i03++) {
3399
+ for (int i02 = 0; i02 < ne02; i02++) {
3400
+ size_t id = rs * ith;
3401
+ for (int i01 = ir0; i01 < ir1; i01++) {
3402
+ const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3403
+ dequantize_row_q4_0(src_ptr, dst_ptr + id, ne00);
3404
+ id += rs;
3405
+ }
3406
+ id += rs * (ne01 - ir1);
3407
+ }
3408
+ }
3409
+ } else if (dst->type == GGML_TYPE_F16) {
3410
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
3411
+ for (int i03 = 0; i03 < ne03; i03++) {
3412
+ for (int i02 = 0; i02 < ne02; i02++) {
3413
+ size_t id = rs * ith;
3414
+ for (int i01 = ir0; i01 < ir1; i01++) {
3415
+ const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3416
+ float tmp[QK4_0];
3417
+ dequantize_row_q4_0(src_ptr, tmp, ne00);
3418
+ for (int i00 = 0; i00 < QK4_0; i00++) {
3419
+ dst_ptr[id + i00] = GGML_FP32_TO_FP16(tmp[i00]);
3420
+ }
3421
+ id += rs;
3422
+ }
3423
+ id += rs * (ne01 - ir1);
3424
+ }
3425
+ }
3426
+ } else if (dst->type == GGML_TYPE_BF16) {
3427
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
3428
+ for (int i03 = 0; i03 < ne03; i03++) {
3429
+ for (int i02 = 0; i02 < ne02; i02++) {
3430
+ size_t id = rs * ith;
3431
+ for (int i01 = ir0; i01 < ir1; i01++) {
3432
+ const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3433
+ float tmp[QK4_0];
3434
+ dequantize_row_q4_0(src_ptr, tmp, ne00);
3435
+ for (int i00 = 0; i00 < QK4_0; i00++) {
3436
+ dst_ptr[id + i00] = GGML_FP32_TO_BF16(tmp[i00]);
3437
+ }
3438
+ id += rs;
3439
+ }
3440
+ id += rs * (ne01 - ir1);
3441
+ }
3442
+ }
3443
+ } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
3444
+ ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
3445
+ float tmp[QK4_0];
3446
+ for (int i03 = 0; i03 < ne03; i03++) {
3447
+ for (int i02 = 0; i02 < ne02; i02++) {
3448
+ size_t id = rs * ith;
3449
+ for (int i01 = ir0; i01 < ir1; i01++) {
3450
+ const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3451
+ dequantize_row_q4_0(src_ptr, tmp, ne00);
3452
+ quantize_row_q(tmp, dst->data + id, ne00);
3453
+ id += rs;
3454
+ }
3455
+ id += rs * (ne01 - ir1);
3456
+ }
3457
+ }
3458
+ } else {
3459
+ GGML_ABORT("fatal error"); // TODO: implement
3460
+ }
3461
+ }
3462
+ } else {
3463
+ if (dst->type == GGML_TYPE_F32) {
3464
+ float * dst_ptr = (float *) dst->data;
3465
+ for (int i03 = 0; i03 < ne03; i03++) {
3466
+ for (int i02 = 0; i02 < ne02; i02++) {
3467
+ size_t id = ith * QK4_0 / 2;
3468
+ for (int i01 = ir0; i01 < ir1; i01++) {
3469
+ const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3470
+ for (int i00 = 0; i00 < QK4_0 / 2; i00++) {
3471
+ dst_ptr[id] = GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8);
3472
+ dst_ptr[id + 1] = GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8);
3473
+ id += 2;
3474
+ }
3475
+ }
3476
+ }
3477
+ }
3478
+ } else if (dst->type == GGML_TYPE_F16) {
3479
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
3480
+ for (int i03 = 0; i03 < ne03; i03++) {
3481
+ for (int i02 = 0; i02 < ne02; i02++) {
3482
+ size_t id = ith * QK4_0 / 2;
3483
+ for (int i01 = ir0; i01 < ir1; i01++) {
3484
+ const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3485
+ for (int i00 = 0; i00 < QK4_0 / 2; i00++) {
3486
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8));
3487
+ dst_ptr[id + 1] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8));
3488
+ id += 2;
3489
+ }
3490
+ }
3491
+ }
3492
+ }
3493
+ } else if (dst->type == GGML_TYPE_BF16) {
3494
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
3495
+ for (int i03 = 0; i03 < ne03; i03++) {
3496
+ for (int i02 = 0; i02 < ne02; i02++) {
3497
+ size_t id = ith * QK4_0 / 2;
3498
+ for (int i01 = ir0; i01 < ir1; i01++) {
3499
+ const block_q4_0 * src_ptr = (const block_q4_0 *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3500
+ for (int i00 = 0; i00 < QK4_0 / 2; i00++) {
3501
+ dst_ptr[id] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] & 0x0F) - 8));
3502
+ dst_ptr[id + 1] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * ((src_ptr->qs[i00] >> 4) - 8));
3503
+ id += 2;
3504
+ }
3505
+ }
3506
+ }
3507
+ }
3508
+ } else {
3509
+ GGML_ABORT("fatal error"); // TODO: implement
3510
+ }
3511
+ }
3512
+ return;
3513
+ }
3514
+ static void ggml_compute_forward_dup_q8(
3515
+ const struct ggml_compute_params * params,
3516
+ struct ggml_tensor * dst) {
3517
+ const struct ggml_tensor * src0 = dst->src[0];
3518
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
3519
+ GGML_TENSOR_UNARY_OP_LOCALS
3520
+ const int ith = params->ith; // thread index
3521
+ const int nth = params->nth; // number of threads
3522
+ // parallelize by rows
3523
+ const int nr = ne01;
3524
+ // number of rows per thread
3525
+ const int dr = (nr + nth - 1) / nth;
3526
+ // row range for this thread
3527
+ const int ir0 = dr * ith;
3528
+ const int ir1 = MIN(ir0 + dr, nr);
3529
+ if (src0->type == dst->type &&
3530
+ ne00 == ne0 &&
3531
+ nb00 >= ggml_type_size(src0->type) && nb0 >= ggml_type_size(dst->type)) {
3532
+ // copy by rows
3533
+ const size_t rs = ne00 * nb00;
3534
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
3535
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
3536
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
3537
+ memcpy(
3538
+ ((char * ) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3),
3539
+ ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03),
3540
+ rs);
3541
+ }
3542
+ }
3543
+ }
3544
+ return;
3545
+ }
3546
+ if (ggml_is_contiguous(dst)) {
3547
+ const size_t rs = ne00 / QK8_0; // QK8_0 bytes per row
3548
+ if (dst->type == GGML_TYPE_F32) {
3549
+ float * dst_ptr = (float *) dst->data;
3550
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
3551
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
3552
+ size_t id = rs * ith;
3553
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
3554
+ const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3555
+ dequantize_row_q8_0(src_ptr, dst_ptr + id * QK8_0, ne00);
3556
+ id += rs;
3557
+ }
3558
+ id += rs * (ne01 - ir1);
3559
+ }
3560
+ }
3561
+ } else if (dst->type == GGML_TYPE_F16) {
3562
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
3563
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
3564
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
3565
+ size_t id = rs * ith;
3566
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
3567
+ const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3568
+ float tmp[QK8_0];
3569
+ dequantize_row_q8_0(src_ptr, tmp, ne00);
3570
+ for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3571
+ dst_ptr[id * QK8_0 + i00] = GGML_FP32_TO_FP16(tmp[i00]);
3572
+ }
3573
+ id += rs;
3574
+ }
3575
+ id += rs * (ne01 - ir1);
3576
+ }
3577
+ }
3578
+ } else if (dst->type == GGML_TYPE_BF16) {
3579
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
3580
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
3581
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
3582
+ size_t id = rs * ith;
3583
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
3584
+ const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3585
+ float tmp[QK8_0];
3586
+ dequantize_row_q8_0(src_ptr, tmp, ne00);
3587
+ for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3588
+ dst_ptr[id * QK8_0 + i00] = GGML_FP32_TO_BF16(tmp[i00]);
3589
+ }
3590
+ id += rs;
3591
+ }
3592
+ id += rs * (ne01 - ir1);
3593
+ }
3594
+ }
3595
+ } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
3596
+ ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
3597
+ float tmp[QK8_0];
3598
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
3599
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
3600
+ size_t id = rs * ith;
3601
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
3602
+ const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3603
+ dequantize_row_q8_0(src_ptr, tmp, ne00);
3604
+ quantize_row_q(tmp, dst->data + id * QK8_0, ne00);
3605
+ id += rs;
3606
+ }
3607
+ id += rs * (ne01 - ir1);
3608
+ }
3609
+ }
3610
+ } else {
3611
+ GGML_ABORT("fatal error"); // TODO: implement
3612
+ }
3613
+ } else {
3614
+ if (dst->type == GGML_TYPE_F32) {
3615
+ float * dst_ptr = (float *) dst->data;
3616
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
3617
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
3618
+ size_t id = ith * QK8_0;
3619
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
3620
+ const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3621
+ for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3622
+ dst_ptr[id] = GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00];
3623
+ id += 1;
3624
+ }
3625
+ }
3626
+ }
3627
+ }
3628
+ } else if (dst->type == GGML_TYPE_F16) {
3629
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
3630
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
3631
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
3632
+ size_t id = ith * QK8_0;
3633
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
3634
+ const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3635
+ for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3636
+ dst_ptr[id] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00]);
3637
+ id += 1;
3638
+ }
3639
+ }
3640
+ }
3641
+ }
3642
+ } else if (dst->type == GGML_TYPE_BF16) {
3643
+ ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
3644
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
3645
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
3646
+ size_t id = ith * QK8_0;
3647
+ for (int64_t i01 = ir0; i01 < ir1; i01++) {
3648
+ const block_q8_0 * src_ptr = (const block_q8_0 *) ((char * ) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3649
+ for (int64_t i00 = 0; i00 < QK8_0; i00++) {
3650
+ dst_ptr[id] = GGML_FP32_TO_BF16(GGML_FP16_TO_FP32(src_ptr->d) * src_ptr->qs[i00]);
3651
+ id += 1;
3652
+ }
3653
+ }
3654
+ }
3655
+ }
3656
+ } else {
3657
+ GGML_ABORT("fatal error"); // TODO: implement
3658
+ }
3659
+ }
3660
+ return;
3661
+ }
3662
+
3361
3663
static void ggml_compute_forward_dup_f16(
3362
3664
const struct ggml_compute_params * params,
3363
3665
struct ggml_tensor * dst) {
@@ -4457,9 +4759,17 @@ static void ggml_compute_forward_dup(
4457
4759
{
4458
4760
ggml_compute_forward_dup_f32(params, dst);
4459
4761
} break;
4762
+ case GGML_TYPE_Q4_0:
4763
+ {
4764
+ ggml_compute_forward_dup_q4(params, dst);
4765
+ } break;
4766
+ case GGML_TYPE_Q8_0:
4767
+ {
4768
+ ggml_compute_forward_dup_q8(params, dst);
4769
+ } break;
4460
4770
default:
4461
4771
{
4462
- GGML_ABORT("fatal error" );
4772
+ GGML_ABORT("fatal error, not support forward dup oper from %d ot %d", src0->type, dst->type );
4463
4773
}
4464
4774
}
4465
4775
}
0 commit comments