Skip to content

Commit 5dc9dd7

Browse files
RefractAISSslarenggerganov
authored
llama : add Command R Plus support (ggml-org#6491)
* Add Command R Plus GGUF * Add Command R Plus GGUF * Loading works up to LayerNorm2D * Export new tensors in 1D so they are not quantized. * Fix embedding layer based on Noeda's example * Whitespace * Add line * Fix unexpected tokens on MPS. Re-add F16 fix. ((Noeda) * dranger003: Fix block index overflow in CUDA dequantizing. * Reverted blocked multiplication code as it still has issues and could affect other Llama arches * export norms as f32 * fix overflow issues during quant and other cleanup * Type convention Co-authored-by: Georgi Gerganov <[email protected]> * dranger003: Fix more int overflow during quant. --------- Co-authored-by: S <[email protected]> Co-authored-by: S <[email protected]> Co-authored-by: slaren <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent e11a899 commit 5dc9dd7

16 files changed

+366
-326
lines changed

convert-hf-to-gguf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def write_tensors(self):
160160
data = data.astype(np.float32)
161161

162162
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
163-
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
163+
if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")):
164164
data = data.astype(np.float32)
165165

166166
# if f16 desired, convert any float32 2-dim weight tensors to float16

ggml-cuda.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12251225

12261226
// the main device has a larger memory buffer to hold the results from all GPUs
12271227
// ldc == nrows of the matrix that cuBLAS writes into
1228-
int ldc = id == ctx.device ? ne0 : row_diff;
1228+
int64_t ldc = id == ctx.device ? ne0 : row_diff;
12291229

12301230
const int compute_capability = ggml_cuda_info().devices[id].cc;
12311231

@@ -1377,8 +1377,8 @@ static void ggml_cuda_op_mul_mat(
13771377
const int64_t ne0 = dst->ne[0];
13781378
const int64_t ne1 = dst->ne[1];
13791379

1380-
const int nb2 = dst->nb[2];
1381-
const int nb3 = dst->nb[3];
1380+
const int64_t nb2 = dst->nb[2];
1381+
const int64_t nb3 = dst->nb[3];
13821382

13831383
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
13841384
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));

ggml-cuda/common.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
394394
// TODO: move to ggml-common.h
395395
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
396396

397-
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
397+
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
398398

399399

400400
//////////////////////

ggml-cuda/convert.cu

+37-37
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
#define CUDA_Q8_0_NE_ALIGN 2048
55

66
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
7-
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
8-
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
7+
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
8+
const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
99

1010
if (i >= k) {
1111
return;
1212
}
1313

14-
const int ib = i/qk; // block index
14+
const int64_t ib = i/qk; // block index
1515
const int iqs = (i%qk)/qr; // quant index
1616
const int iybs = i - i%qk; // y block start index
1717
const int y_offset = qr == 1 ? 1 : qk/2;
@@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
2525
}
2626

2727
template <bool need_check>
28-
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
28+
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
2929
#if __CUDA_ARCH__ >= CC_PASCAL
3030
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
3131

@@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
6868
template<typename dst_t>
6969
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
7070

71-
const int i = blockIdx.x;
71+
const int64_t i = blockIdx.x;
7272

7373
// assume 32 threads
7474
const int tid = threadIdx.x;
7575
const int il = tid/8;
7676
const int ir = tid%8;
77-
const int ib = 8*i + ir;
77+
const int64_t ib = 8*i + ir;
7878
if (ib >= nb32) {
7979
return;
8080
}
@@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
9696
template<typename dst_t>
9797
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
9898

99-
const int i = blockIdx.x;
99+
const int64_t i = blockIdx.x;
100100

101101
// assume 32 threads
102102
const int tid = threadIdx.x;
103103
const int il = tid/8;
104104
const int ir = tid%8;
105-
const int ib = 8*i + ir;
105+
const int64_t ib = 8*i + ir;
106106
if (ib >= nb32) {
107107
return;
108108
}
@@ -313,14 +313,14 @@ template<typename dst_t>
313313
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
314314
const block_q6_K * x = (const block_q6_K *) vx;
315315

316-
const int i = blockIdx.x;
316+
const int64_t i = blockIdx.x;
317317
#if QK_K == 256
318318

319319
// assume 64 threads - this is very slightly better than the one below
320-
const int tid = threadIdx.x;
321-
const int ip = tid/32; // ip is 0 or 1
322-
const int il = tid - 32*ip; // 0...32
323-
const int is = 8*ip + il/16;
320+
const int64_t tid = threadIdx.x;
321+
const int64_t ip = tid/32; // ip is 0 or 1
322+
const int64_t il = tid - 32*ip; // 0...32
323+
const int64_t is = 8*ip + il/16;
324324

325325
dst_t * y = yy + i*QK_K + 128*ip + il;
326326

@@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
337337
#else
338338

339339
// assume 32 threads
340-
const int tid = threadIdx.x;
341-
const int ip = tid/16; // 0 or 1
342-
const int il = tid - 16*ip; // 0...15
340+
const int64_t tid = threadIdx.x;
341+
const int64_t ip = tid/16; // 0 or 1
342+
const int64_t il = tid - 16*ip; // 0...15
343343

344344
dst_t * y = yy + i*QK_K + 16*ip + il;
345345

@@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
571571
#endif
572572

573573
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
574-
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
574+
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
575575
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
576576
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
577577
}
578578

579-
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
579+
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
580580
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
581581
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
582582
const bool need_check = false;
@@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
588588
}
589589

590590
template<typename dst_t>
591-
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
591+
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
592592
const int nb = k / QK_K;
593593
#if QK_K == 256
594594
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu
598598
}
599599

600600
template<typename dst_t>
601-
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
601+
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
602602
const int nb = k / QK_K;
603603
#if QK_K == 256
604604
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
608608
}
609609

610610
template<typename dst_t>
611-
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
611+
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
612612
const int nb32 = k / 32;
613613
const int nb = (k + 255) / 256;
614614
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
615615
}
616616

617617
template<typename dst_t>
618-
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
618+
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
619619
const int nb32 = k / 32;
620620
const int nb = (k + 255) / 256;
621621
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
622622
}
623623

624624
template<typename dst_t>
625-
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
625+
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
626626
const int nb = k / QK_K;
627627
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
628628
}
629629

630630
template<typename dst_t>
631-
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
631+
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
632632
const int nb = k / QK_K;
633633
#if QK_K == 256
634634
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu
638638
}
639639

640640
template<typename dst_t>
641-
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
641+
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
642642
const int nb = k / QK_K;
643643
#if QK_K == 256
644644
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
@@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
648648
}
649649

650650
template<typename dst_t>
651-
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
651+
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
652652
const int nb = k / QK_K;
653653
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
654654
}
655655

656656
template<typename dst_t>
657-
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
657+
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
658658
const int nb = k / QK_K;
659659
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
660660
}
661661

662662
template<typename dst_t>
663-
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
663+
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
664664
const int nb = k / QK_K;
665665
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
666666
}
667667

668668
template<typename dst_t>
669-
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
669+
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
670670
const int nb = k / QK_K;
671671
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
672672
}
673673

674674
template<typename dst_t>
675-
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
675+
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
676676
const int nb = k / QK_K;
677677
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
678678
}
679679

680680
template<typename dst_t>
681-
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
681+
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
682682
const int nb = k / QK_K;
683683
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
684684
}
685685

686686
template<typename dst_t>
687-
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
687+
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
688688
const int nb = (k + QK_K - 1) / QK_K;
689689
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
690690
}
691691

692692
template<typename dst_t>
693-
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
693+
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
694694
const int nb = k / QK_K;
695695
dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
696696
}
697697

698698
template<typename dst_t>
699-
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
699+
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
700700
const int nb = (k + QK_K - 1) / QK_K;
701701
#if QK_K == 64
702702
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
@@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
706706
}
707707

708708
template <typename src_t, typename dst_t>
709-
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
710-
const int i = blockDim.x*blockIdx.x + threadIdx.x;
709+
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
710+
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
711711

712712
if (i >= k) {
713713
return;
@@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
719719
}
720720

721721
template <typename src_t, typename dst_t>
722-
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
722+
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
723723
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
724724
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
725725
}

ggml-cuda/convert.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
44

55
template<typename T>
6-
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
6+
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
77

88
typedef to_t_cuda_t<float> to_fp32_cuda_t;
99
typedef to_t_cuda_t<half> to_fp16_cuda_t;

ggml-cuda/dequantize.cuh

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "common.cuh"
22

3-
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
3+
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
44
const block_q4_0 * x = (const block_q4_0 *) vx;
55

66
const dfloat d = x[ib].d;
@@ -19,7 +19,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
1919
#endif // GGML_CUDA_F16
2020
}
2121

22-
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
22+
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
2323
const block_q4_1 * x = (const block_q4_1 *) vx;
2424

2525
const dfloat d = __low2half(x[ib].dm);
@@ -39,7 +39,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in
3939
#endif // GGML_CUDA_F16
4040
}
4141

42-
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
42+
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
4343
const block_q5_0 * x = (const block_q5_0 *) vx;
4444

4545
const dfloat d = x[ib].d;
@@ -62,7 +62,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
6262
#endif // GGML_CUDA_F16
6363
}
6464

65-
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
65+
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
6666
const block_q5_1 * x = (const block_q5_1 *) vx;
6767

6868
const dfloat d = __low2half(x[ib].dm);
@@ -86,7 +86,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
8686
#endif // GGML_CUDA_F16
8787
}
8888

89-
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
89+
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
9090
const block_q8_0 * x = (const block_q8_0 *) vx;
9191

9292
const dfloat d = x[ib].d;

ggml-cuda/dmmv.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
565565
}
566566
}
567567

568-
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
568+
static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
569569
const half * x = (const half *) vx;
570570

571571
// automatic half -> float type cast if dfloat == float
@@ -577,7 +577,7 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
577577
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
578578
// qk = quantized weights per x block
579579
// qr = number of quantized weights per data value in x block
580-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
580+
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
581581

582582
if (row >= nrows) {
583583
return;
@@ -598,7 +598,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
598598

599599
for (int i = 0; i < ncols; i += iter_stride) {
600600
const int col = i + vals_per_iter*tid;
601-
const int ib = (row*ncols + col)/qk; // x block index
601+
const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
602602
const int iqs = (col%qk)/qr; // x quant index
603603
const int iybs = col - col%qk; // y block start index
604604

0 commit comments

Comments
 (0)