Skip to content

Commit e00b4a8

Browse files
authored
Fix more int overflow during quant (PPL/CUDA). (#6563)
* Fix more int overflow during quant. * Fix some more int overflow in softmax. * Revert back to int64_t.
1 parent 7bb36cc commit e00b4a8

File tree

2 files changed

+88
-88
lines changed

2 files changed

+88
-88
lines changed

ggml-cuda/convert.cu

+84-84
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55

66
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
77
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
8-
const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
8+
const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
99

1010
if (i >= k) {
1111
return;
1212
}
1313

1414
const int64_t ib = i/qk; // block index
15-
const int iqs = (i%qk)/qr; // quant index
16-
const int iybs = i - i%qk; // y block start index
17-
const int y_offset = qr == 1 ? 1 : qk/2;
15+
const int64_t iqs = (i%qk)/qr; // quant index
16+
const int64_t iybs = i - i%qk; // y block start index
17+
const int64_t y_offset = qr == 1 ? 1 : qk/2;
1818

1919
// dequantize
2020
dfloat2 v;
@@ -29,7 +29,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
2929
#if __CUDA_ARCH__ >= CC_PASCAL
3030
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
3131

32-
const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
32+
const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
3333
const int * x0 = ((int *) vx) + blockIdx.x * nint;
3434
half2 * y2 = (half2 *) (y + i0);
3535

@@ -73,9 +73,9 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
7373
const int64_t i = blockIdx.x;
7474

7575
// assume 32 threads
76-
const int tid = threadIdx.x;
77-
const int il = tid/8;
78-
const int ir = tid%8;
76+
const int64_t tid = threadIdx.x;
77+
const int64_t il = tid/8;
78+
const int64_t ir = tid%8;
7979
const int64_t ib = 8*i + ir;
8080
if (ib >= nb32) {
8181
return;
@@ -101,9 +101,9 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
101101
const int64_t i = blockIdx.x;
102102

103103
// assume 32 threads
104-
const int tid = threadIdx.x;
105-
const int il = tid/8;
106-
const int ir = tid%8;
104+
const int64_t tid = threadIdx.x;
105+
const int64_t il = tid/8;
106+
const int64_t ir = tid%8;
107107
const int64_t ib = 8*i + ir;
108108
if (ib >= nb32) {
109109
return;
@@ -127,14 +127,14 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
127127
template<typename dst_t>
128128
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
129129

130-
const int i = blockIdx.x;
130+
const int64_t i = blockIdx.x;
131131
const block_q2_K * x = (const block_q2_K *) vx;
132132

133-
const int tid = threadIdx.x;
133+
const int64_t tid = threadIdx.x;
134134
#if QK_K == 256
135-
const int n = tid/32;
136-
const int l = tid - 32*n;
137-
const int is = 8*n + l/16;
135+
const int64_t n = tid/32;
136+
const int64_t l = tid - 32*n;
137+
const int64_t is = 8*n + l/16;
138138

139139
const uint8_t q = x[i].qs[32*n + l];
140140
dst_t * y = yy + i*QK_K + 128*n;
@@ -146,8 +146,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
146146
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
147147
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
148148
#else
149-
const int is = tid/16; // 0 or 1
150-
const int il = tid%16; // 0...15
149+
const int64_t is = tid/16; // 0 or 1
150+
const int64_t il = tid%16; // 0...15
151151
const uint8_t q = x[i].qs[il] >> (2*is);
152152
dst_t * y = yy + i*QK_K + 16*is + il;
153153
float dall = __low2half(x[i].dm);
@@ -161,19 +161,19 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
161161
template<typename dst_t>
162162
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
163163

164-
const int i = blockIdx.x;
164+
const int64_t i = blockIdx.x;
165165
const block_q3_K * x = (const block_q3_K *) vx;
166166

167167
#if QK_K == 256
168-
const int r = threadIdx.x/4;
169-
const int tid = r/2;
170-
const int is0 = r%2;
171-
const int l0 = 16*is0 + 4*(threadIdx.x%4);
172-
const int n = tid / 4;
173-
const int j = tid - 4*n;
168+
const int64_t r = threadIdx.x/4;
169+
const int64_t tid = r/2;
170+
const int64_t is0 = r%2;
171+
const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
172+
const int64_t n = tid / 4;
173+
const int64_t j = tid - 4*n;
174174

175175
uint8_t m = 1 << (4*n + j);
176-
int is = 8*n + 2*j + is0;
176+
int64_t is = 8*n + 2*j + is0;
177177
int shift = 2*j;
178178

179179
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
@@ -189,11 +189,11 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
189189

190190
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
191191
#else
192-
const int tid = threadIdx.x;
193-
const int is = tid/16; // 0 or 1
194-
const int il = tid%16; // 0...15
195-
const int im = il/8; // 0...1
196-
const int in = il%8; // 0...7
192+
const int64_t tid = threadIdx.x;
193+
const int64_t is = tid/16; // 0 or 1
194+
const int64_t il = tid%16; // 0...15
195+
const int64_t im = il/8; // 0...1
196+
const int64_t in = il%8; // 0...7
197197

198198
dst_t * y = yy + i*QK_K + 16*is + il;
199199

@@ -227,15 +227,15 @@ template<typename dst_t>
227227
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
228228
const block_q4_K * x = (const block_q4_K *) vx;
229229

230-
const int i = blockIdx.x;
230+
const int64_t i = blockIdx.x;
231231

232232
#if QK_K == 256
233233
// assume 32 threads
234-
const int tid = threadIdx.x;
235-
const int il = tid/8;
236-
const int ir = tid%8;
237-
const int is = 2*il;
238-
const int n = 4;
234+
const int64_t tid = threadIdx.x;
235+
const int64_t il = tid/8;
236+
const int64_t ir = tid%8;
237+
const int64_t is = 2*il;
238+
const int64_t n = 4;
239239

240240
dst_t * y = yy + i*QK_K + 64*il + n*ir;
241241

@@ -254,7 +254,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
254254
y[l +32] = d2 * (q[l] >> 4) - m2;
255255
}
256256
#else
257-
const int tid = threadIdx.x;
257+
const int64_t tid = threadIdx.x;
258258
const uint8_t * q = x[i].qs;
259259
dst_t * y = yy + i*QK_K;
260260
const float d = (float)x[i].dm[0];
@@ -268,14 +268,14 @@ template<typename dst_t>
268268
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
269269
const block_q5_K * x = (const block_q5_K *) vx;
270270

271-
const int i = blockIdx.x;
271+
const int64_t i = blockIdx.x;
272272

273273
#if QK_K == 256
274274
// assume 64 threads - this is very slightly better than the one below
275-
const int tid = threadIdx.x;
276-
const int il = tid/16; // il is in 0...3
277-
const int ir = tid%16; // ir is in 0...15
278-
const int is = 2*il; // is is in 0...6
275+
const int64_t tid = threadIdx.x;
276+
const int64_t il = tid/16; // il is in 0...3
277+
const int64_t ir = tid%16; // ir is in 0...15
278+
const int64_t is = 2*il; // is is in 0...6
279279

280280
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
281281

@@ -298,11 +298,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
298298
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
299299
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
300300
#else
301-
const int tid = threadIdx.x;
301+
const int64_t tid = threadIdx.x;
302302
const uint8_t q = x[i].qs[tid];
303-
const int im = tid/8; // 0...3
304-
const int in = tid%8; // 0...7
305-
const int is = tid/16; // 0 or 1
303+
const int64_t im = tid/8; // 0...3
304+
const int64_t in = tid%8; // 0...7
305+
const int64_t is = tid/16; // 0 or 1
306306
const uint8_t h = x[i].qh[in] >> im;
307307
const float d = x[i].d;
308308
dst_t * y = yy + i*QK_K + tid;
@@ -359,13 +359,13 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
359359
template<typename dst_t>
360360
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
361361

362-
const int i = blockIdx.x;
362+
const int64_t i = blockIdx.x;
363363
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
364364

365-
const int tid = threadIdx.x;
365+
const int64_t tid = threadIdx.x;
366366
#if QK_K == 256
367-
const int il = tid/8; // 0...3
368-
const int ib = tid%8; // 0...7
367+
const int64_t il = tid/8; // 0...3
368+
const int64_t ib = tid%8; // 0...7
369369
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
370370
const uint16_t * q2 = x[i].qs + 4*ib;
371371
const uint8_t * aux8 = (const uint8_t *)q2;
@@ -383,13 +383,13 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
383383
template<typename dst_t>
384384
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
385385

386-
const int i = blockIdx.x;
386+
const int64_t i = blockIdx.x;
387387
const block_iq2_xs * x = (const block_iq2_xs *) vx;
388388

389-
const int tid = threadIdx.x;
389+
const int64_t tid = threadIdx.x;
390390
#if QK_K == 256
391-
const int il = tid/8; // 0...3
392-
const int ib = tid%8; // 0...7
391+
const int64_t il = tid/8; // 0...3
392+
const int64_t ib = tid%8; // 0...7
393393
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
394394
const uint16_t * q2 = x[i].qs + 4*ib;
395395
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
@@ -405,13 +405,13 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
405405
template<typename dst_t>
406406
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
407407

408-
const int i = blockIdx.x;
408+
const int64_t i = blockIdx.x;
409409
const block_iq2_s * x = (const block_iq2_s *) vx;
410410

411-
const int tid = threadIdx.x;
411+
const int64_t tid = threadIdx.x;
412412
#if QK_K == 256
413-
const int il = tid/8; // 0...3
414-
const int ib = tid%8; // 0...7
413+
const int64_t il = tid/8; // 0...3
414+
const int64_t ib = tid%8; // 0...7
415415
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
416416
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
417417
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
@@ -426,13 +426,13 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
426426
template<typename dst_t>
427427
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
428428

429-
const int i = blockIdx.x;
429+
const int64_t i = blockIdx.x;
430430
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
431431

432-
const int tid = threadIdx.x;
432+
const int64_t tid = threadIdx.x;
433433
#if QK_K == 256
434-
const int il = tid/8; // 0...3
435-
const int ib = tid%8; // 0...7
434+
const int64_t il = tid/8; // 0...3
435+
const int64_t ib = tid%8; // 0...7
436436
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
437437
const uint8_t * q3 = x[i].qs + 8*ib;
438438
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
@@ -454,13 +454,13 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
454454
template<typename dst_t>
455455
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
456456

457-
const int i = blockIdx.x;
457+
const int64_t i = blockIdx.x;
458458
const block_iq3_s * x = (const block_iq3_s *) vx;
459459

460-
const int tid = threadIdx.x;
460+
const int64_t tid = threadIdx.x;
461461
#if QK_K == 256
462-
const int il = tid/8; // 0...3
463-
const int ib = tid%8; // 0...7
462+
const int64_t il = tid/8; // 0...3
463+
const int64_t ib = tid%8; // 0...7
464464
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
465465
const uint8_t * qs = x[i].qs + 8*ib;
466466
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
@@ -480,13 +480,13 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
480480
template<typename dst_t>
481481
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
482482

483-
const int i = blockIdx.x;
483+
const int64_t i = blockIdx.x;
484484
const block_iq1_s * x = (const block_iq1_s *) vx;
485485

486-
const int tid = threadIdx.x;
486+
const int64_t tid = threadIdx.x;
487487
#if QK_K == 256
488-
const int il = tid/8; // 0...3
489-
const int ib = tid%8; // 0...7
488+
const int64_t il = tid/8; // 0...3
489+
const int64_t ib = tid%8; // 0...7
490490
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
491491
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
492492
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
@@ -506,18 +506,18 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
506506
template<typename dst_t>
507507
static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
508508

509-
const int i = blockIdx.x;
509+
const int64_t i = blockIdx.x;
510510
const block_iq1_m * x = (const block_iq1_m *) vx;
511511

512-
const int tid = threadIdx.x;
512+
const int64_t tid = threadIdx.x;
513513
#if QK_K == 256
514-
const int il = tid/8; // 0...3
515-
const int ib = tid%8; // 0...7
514+
const int64_t il = tid/8; // 0...3
515+
const int64_t ib = tid%8; // 0...7
516516
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
517517
const uint16_t * sc = (const uint16_t *)x[i].scales;
518518
iq1m_scale_t scale;
519519
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
520-
const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
520+
const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
521521
const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
522522
const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
523523
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
@@ -537,12 +537,12 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
537537
template<typename dst_t>
538538
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
539539

540-
const int i = blockIdx.x;
540+
const int64_t i = blockIdx.x;
541541
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
542542

543-
const int tid = threadIdx.x;
544-
const int il = tid/8; // 0...3
545-
const int ib = tid%8; // 0...7
543+
const int64_t tid = threadIdx.x;
544+
const int64_t il = tid/8; // 0...3
545+
const int64_t ib = tid%8; // 0...7
546546
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
547547
const uint8_t * q4 = x[ib].qs + 4*il;
548548
const float d = (float)x[ib].d;
@@ -556,12 +556,12 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
556556
#if QK_K != 64
557557
template<typename dst_t>
558558
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
559-
const int i = blockIdx.x;
559+
const int64_t i = blockIdx.x;
560560
const block_iq4_xs * x = (const block_iq4_xs *)vx;
561561

562-
const int tid = threadIdx.x;
563-
const int il = tid/8; // 0...3
564-
const int ib = tid%8; // 0...7
562+
const int64_t tid = threadIdx.x;
563+
const int64_t il = tid/8; // 0...3
564+
const int64_t ib = tid%8; // 0...7
565565
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
566566
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
567567
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);

ggml-cuda/softmax.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
2828
extern __shared__ float data_soft_max_f32[];
2929
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
3030
// shared memory buffer to cache values between iterations:
31-
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
31+
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
3232

3333
float max_val = -INFINITY;
3434

@@ -40,8 +40,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
4040
break;
4141
}
4242

43-
const int ix = rowx*ncols + col;
44-
const int iy = rowy*ncols + col;
43+
const int64_t ix = (int64_t)rowx*ncols + col;
44+
const int64_t iy = (int64_t)rowy*ncols + col;
4545

4646
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
4747

@@ -109,7 +109,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f
109109
return;
110110
}
111111

112-
const int idst = rowx*ncols + col;
112+
const int64_t idst = (int64_t)rowx*ncols + col;
113113
dst[idst] = vals[col] * inv_sum;
114114
}
115115
}

0 commit comments

Comments
 (0)