4
4
#define CUDA_Q8_0_NE_ALIGN 2048
5
5
6
6
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
7
- static __global__ void dequantize_block (const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
8
- const int i = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
7
+ static __global__ void dequantize_block (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
8
+ const int64_t i = 2 *(blockDim .x *blockIdx .x + threadIdx .x );
9
9
10
10
if (i >= k) {
11
11
return ;
12
12
}
13
13
14
- const int ib = i/qk; // block index
14
+ const int64_t ib = i/qk; // block index
15
15
const int iqs = (i%qk)/qr; // quant index
16
16
const int iybs = i - i%qk; // y block start index
17
17
const int y_offset = qr == 1 ? 1 : qk/2 ;
@@ -25,7 +25,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
25
25
}
26
26
27
27
template <bool need_check>
28
- static __global__ void dequantize_block_q8_0_f16 (const void * __restrict__ vx, half * __restrict__ y, const int k) {
28
+ static __global__ void dequantize_block_q8_0_f16 (const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
29
29
#if __CUDA_ARCH__ >= CC_PASCAL
30
30
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof (int ) + WARP_SIZE;
31
31
@@ -68,13 +68,13 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
68
68
template <typename dst_t >
69
69
static __global__ void dequantize_block_q4_0 (const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
70
70
71
- const int i = blockIdx .x ;
71
+ const int64_t i = blockIdx .x ;
72
72
73
73
// assume 32 threads
74
74
const int tid = threadIdx .x ;
75
75
const int il = tid/8 ;
76
76
const int ir = tid%8 ;
77
- const int ib = 8 *i + ir;
77
+ const int64_t ib = 8 *i + ir;
78
78
if (ib >= nb32) {
79
79
return ;
80
80
}
@@ -96,13 +96,13 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
96
96
template <typename dst_t >
97
97
static __global__ void dequantize_block_q4_1 (const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
98
98
99
- const int i = blockIdx .x ;
99
+ const int64_t i = blockIdx .x ;
100
100
101
101
// assume 32 threads
102
102
const int tid = threadIdx .x ;
103
103
const int il = tid/8 ;
104
104
const int ir = tid%8 ;
105
- const int ib = 8 *i + ir;
105
+ const int64_t ib = 8 *i + ir;
106
106
if (ib >= nb32) {
107
107
return ;
108
108
}
@@ -313,14 +313,14 @@ template<typename dst_t>
313
313
static __global__ void dequantize_block_q6_K (const void * __restrict__ vx, dst_t * __restrict__ yy) {
314
314
const block_q6_K * x = (const block_q6_K *) vx;
315
315
316
- const int i = blockIdx .x ;
316
+ const int64_t i = blockIdx .x ;
317
317
#if QK_K == 256
318
318
319
319
// assume 64 threads - this is very slightly better than the one below
320
- const int tid = threadIdx .x ;
321
- const int ip = tid/32 ; // ip is 0 or 1
322
- const int il = tid - 32 *ip; // 0...32
323
- const int is = 8 *ip + il/16 ;
320
+ const int64_t tid = threadIdx .x ;
321
+ const int64_t ip = tid/32 ; // ip is 0 or 1
322
+ const int64_t il = tid - 32 *ip; // 0...32
323
+ const int64_t is = 8 *ip + il/16 ;
324
324
325
325
dst_t * y = yy + i*QK_K + 128 *ip + il;
326
326
@@ -337,9 +337,9 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
337
337
#else
338
338
339
339
// assume 32 threads
340
- const int tid = threadIdx .x ;
341
- const int ip = tid/16 ; // 0 or 1
342
- const int il = tid - 16 *ip; // 0...15
340
+ const int64_t tid = threadIdx .x ;
341
+ const int64_t ip = tid/16 ; // 0 or 1
342
+ const int64_t il = tid - 16 *ip; // 0...15
343
343
344
344
dst_t * y = yy + i*QK_K + 16 *ip + il;
345
345
@@ -571,12 +571,12 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
571
571
#endif
572
572
573
573
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
574
- static void dequantize_block_cuda (const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
574
+ static void dequantize_block_cuda (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
575
575
const int num_blocks = (k + 2 *CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / (2 *CUDA_DEQUANTIZE_BLOCK_SIZE);
576
576
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
577
577
}
578
578
579
- static void dequantize_block_q8_0_f16_cuda (const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
579
+ static void dequantize_block_q8_0_f16_cuda (const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
580
580
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1 ) / CUDA_Q8_0_NE_ALIGN;
581
581
if (k % CUDA_Q8_0_NE_ALIGN == 0 ) {
582
582
const bool need_check = false ;
@@ -588,7 +588,7 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half *
588
588
}
589
589
590
590
template <typename dst_t >
591
- static void dequantize_row_q2_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
591
+ static void dequantize_row_q2_K_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
592
592
const int nb = k / QK_K;
593
593
#if QK_K == 256
594
594
dequantize_block_q2_K<<<nb, 64 , 0 , stream>>> (vx, y);
@@ -598,7 +598,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu
598
598
}
599
599
600
600
template <typename dst_t >
601
- static void dequantize_row_q3_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
601
+ static void dequantize_row_q3_K_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
602
602
const int nb = k / QK_K;
603
603
#if QK_K == 256
604
604
dequantize_block_q3_K<<<nb, 64 , 0 , stream>>> (vx, y);
@@ -608,27 +608,27 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
608
608
}
609
609
610
610
template <typename dst_t >
611
- static void dequantize_row_q4_0_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
611
+ static void dequantize_row_q4_0_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
612
612
const int nb32 = k / 32 ;
613
613
const int nb = (k + 255 ) / 256 ;
614
614
dequantize_block_q4_0<<<nb, 32 , 0 , stream>>> (vx, y, nb32);
615
615
}
616
616
617
617
template <typename dst_t >
618
- static void dequantize_row_q4_1_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
618
+ static void dequantize_row_q4_1_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
619
619
const int nb32 = k / 32 ;
620
620
const int nb = (k + 255 ) / 256 ;
621
621
dequantize_block_q4_1<<<nb, 32 , 0 , stream>>> (vx, y, nb32);
622
622
}
623
623
624
624
template <typename dst_t >
625
- static void dequantize_row_q4_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
625
+ static void dequantize_row_q4_K_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
626
626
const int nb = k / QK_K;
627
627
dequantize_block_q4_K<<<nb, 32 , 0 , stream>>> (vx, y);
628
628
}
629
629
630
630
template <typename dst_t >
631
- static void dequantize_row_q5_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
631
+ static void dequantize_row_q5_K_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
632
632
const int nb = k / QK_K;
633
633
#if QK_K == 256
634
634
dequantize_block_q5_K<<<nb, 64 , 0 , stream>>> (vx, y);
@@ -638,7 +638,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu
638
638
}
639
639
640
640
template <typename dst_t >
641
- static void dequantize_row_q6_K_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
641
+ static void dequantize_row_q6_K_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
642
642
const int nb = k / QK_K;
643
643
#if QK_K == 256
644
644
dequantize_block_q6_K<<<nb, 64 , 0 , stream>>> (vx, y);
@@ -648,55 +648,55 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
648
648
}
649
649
650
650
template <typename dst_t >
651
- static void dequantize_row_iq2_xxs_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
651
+ static void dequantize_row_iq2_xxs_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
652
652
const int nb = k / QK_K;
653
653
dequantize_block_iq2_xxs<<<nb, 32 , 0 , stream>>> (vx, y);
654
654
}
655
655
656
656
template <typename dst_t >
657
- static void dequantize_row_iq2_xs_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
657
+ static void dequantize_row_iq2_xs_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
658
658
const int nb = k / QK_K;
659
659
dequantize_block_iq2_xs<<<nb, 32 , 0 , stream>>> (vx, y);
660
660
}
661
661
662
662
template <typename dst_t >
663
- static void dequantize_row_iq2_s_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
663
+ static void dequantize_row_iq2_s_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
664
664
const int nb = k / QK_K;
665
665
dequantize_block_iq2_s<<<nb, 32 , 0 , stream>>> (vx, y);
666
666
}
667
667
668
668
template <typename dst_t >
669
- static void dequantize_row_iq3_xxs_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
669
+ static void dequantize_row_iq3_xxs_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
670
670
const int nb = k / QK_K;
671
671
dequantize_block_iq3_xxs<<<nb, 32 , 0 , stream>>> (vx, y);
672
672
}
673
673
674
674
template <typename dst_t >
675
- static void dequantize_row_iq3_s_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
675
+ static void dequantize_row_iq3_s_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
676
676
const int nb = k / QK_K;
677
677
dequantize_block_iq3_s<<<nb, 32 , 0 , stream>>> (vx, y);
678
678
}
679
679
680
680
template <typename dst_t >
681
- static void dequantize_row_iq1_s_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
681
+ static void dequantize_row_iq1_s_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
682
682
const int nb = k / QK_K;
683
683
dequantize_block_iq1_s<<<nb, 32 , 0 , stream>>> (vx, y);
684
684
}
685
685
686
686
template <typename dst_t >
687
- static void dequantize_row_iq4_nl_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
687
+ static void dequantize_row_iq4_nl_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
688
688
const int nb = (k + QK_K - 1 ) / QK_K;
689
689
dequantize_block_iq4_nl<<<nb, 32 , 0 , stream>>> (vx, y);
690
690
}
691
691
692
692
template <typename dst_t >
693
- static void dequantize_row_iq1_m_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
693
+ static void dequantize_row_iq1_m_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
694
694
const int nb = k / QK_K;
695
695
dequantize_block_iq1_m<<<nb, 32 , 0 , stream>>> (vx, y);
696
696
}
697
697
698
698
template <typename dst_t >
699
- static void dequantize_row_iq4_xs_cuda (const void * vx, dst_t * y, const int k, cudaStream_t stream) {
699
+ static void dequantize_row_iq4_xs_cuda (const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
700
700
const int nb = (k + QK_K - 1 ) / QK_K;
701
701
#if QK_K == 64
702
702
dequantize_block_iq4_nl<<<nb, 32 , 0 , stream>>> (vx, y);
@@ -706,8 +706,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
706
706
}
707
707
708
708
template <typename src_t , typename dst_t >
709
- static __global__ void convert_unary (const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
710
- const int i = blockDim .x *blockIdx .x + threadIdx .x ;
709
+ static __global__ void convert_unary (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
710
+ const int64_t i = ( int64_t ) blockDim .x *blockIdx .x + threadIdx .x ;
711
711
712
712
if (i >= k) {
713
713
return ;
@@ -719,7 +719,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
719
719
}
720
720
721
721
template <typename src_t , typename dst_t >
722
- static void convert_unary_cuda (const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
722
+ static void convert_unary_cuda (const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
723
723
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
724
724
convert_unary<src_t ><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
725
725
}
0 commit comments