Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k-quants with super-block size of 64 #2001

Merged
merged 35 commits into from
Jun 26, 2023
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d2f12ac
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
9fe2a2b
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
1f6195c
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
aebd547
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
2b2ab31
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
bcf8c5c
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
c6c3536
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 21, 2023
5aae4b8
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
41e46ec
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
460dd84
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
3bd9ae7
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
03f30c8
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
cda47a6
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
80c75fe
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
2b2a13c
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
9d27d8d
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
2ff543c
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 22, 2023
d92c5a9
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
fae24af
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
e1bbcfc
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
167a0bb
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
6081a65
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
ff83e32
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
285eeb1
k_quants: WIP super-blocks with 64 weights
Kawrakow Jun 23, 2023
8b98d01
k_quants: call them _K, not _k, also on Metal
Kawrakow Jun 23, 2023
558a194
k_quants: correctly define QK_K in llama.cpp
Kawrakow Jun 23, 2023
333ffcc
Fixed bug in q4_K quantization added with the 64-block addition
Kawrakow Jun 23, 2023
88412a1
Simplify via lambda
Kawrakow Jun 23, 2023
aeefd4e
k_quants: swicth Q3_K to 4-bit scales when QK_K = 64
Kawrakow Jun 24, 2023
ce19b96
k_quants: switch Q4_K to 4-bit scales when QK_K = 64
Kawrakow Jun 24, 2023
4f61506
k_quants: forgot to add the Metal changes in last commit
Kawrakow Jun 24, 2023
ccf4901
k_quants: change Q5_K to be type 0 when QK_K = 64
Kawrakow Jun 24, 2023
2da3a59
k_quants: AVX2 implementation for new 64-weight Q5_K
Kawrakow Jun 24, 2023
53e81ca
k_quants: 10% faster ARM_NEON Q5_K dot product
Kawrakow Jun 24, 2023
5fd8337
k_quants: fixed issue caused by merging with master
Kawrakow Jun 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
k_quants: WIP super-blocks with 64 weights
Kawrakow committed Jun 26, 2023
commit d2f12ac354552bcfba1dbc9c8593296d81b70452
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -75,6 +75,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" OFF)
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)

option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@@ -292,6 +293,9 @@ endif()
if (LLAMA_K_QUANTS)
set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h)
add_compile_definitions(GGML_USE_K_QUANTS)
if (LLAMA_QKK_64)
add_compile_definitions(GGML_QKK_64)
endif()
endif()

if (LLAMA_CLBLAST)
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -131,6 +131,10 @@ ifndef LLAMA_NO_K_QUANTS
CFLAGS += -DGGML_USE_K_QUANTS
CXXFLAGS += -DGGML_USE_K_QUANTS
OBJS += k_quants.o
ifdef LLAMA_QKK_64
CFLAGS += -DGGML_QKK_64
CXXFLAGS += -DGGML_QKK_64
endif
endif

ifndef LLAMA_NO_ACCELERATE
243 changes: 227 additions & 16 deletions k_quants.c
Original file line number Diff line number Diff line change
@@ -330,11 +330,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
}
}

#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
#else
for (int l = 0; l < 16; ++l) {
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
}
#endif

x += QK_K;

@@ -352,6 +358,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int

const uint8_t * q = x[i].qs;

#if QK_K == 256
int is = 0;
float dl, ml;
for (int n = 0; n < QK_K; n += 128) {
@@ -370,7 +377,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int
}
q += 32;
}

#else
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
for (int l = 0; l < 16; ++l) {
y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1;
y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2;
y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3;
y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4;
}
y += QK_K;
#endif
}
}

@@ -412,6 +431,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
}
}

#if QK_K == 256
memset(y[i].scales, 0, 12);
if (max_scale) {
float iscale = -32.f/max_scale;
@@ -445,9 +465,36 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
L[16*j + ii] = l + 4;
}
}
#else
if (max_scale) {
float iscale = -128.f/max_scale;
for (int j = 0; j < QK_K/16; ++j) {
int l = nearest_int(iscale*scales[j]);
l = MAX(-128, MIN(127, l));
y[i].scales[j] = l;
}
y[i].d = ggml_fp32_to_fp16(1/iscale);
} else {
for (int j = 0; j < QK_K/16; ++j) {
y[i].scales[j] = 0;
}
y[i].d = ggml_fp32_to_fp16(0.f);
}
for (int j = 0; j < QK_K/16; ++j) {
float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
if (!d) {
continue;
}
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int(x[16*j + ii]/d);
l = MAX(-4, MIN(3, l));
L[16*j + ii] = l + 4;
}
}
#endif

memset(y[i].hmask, 0, QK_K/8);
// We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc.
// We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
int m = 0;
uint8_t hm = 1;
for (int j = 0; j < QK_K; ++j) {
@@ -459,19 +506,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict
m = 0; hm <<= 1;
}
}
#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
#else
for (int l = 0; l < 16; ++l) {
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
}
#endif

x += QK_K;
}
}

#if QK_K == 256
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
assert(QK_K == 256);
const int nb = k / QK_K;

const uint32_t kmask1 = 0x03030303;
@@ -519,6 +572,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int

}
}
#else
void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
assert(QK_K == 64);
const int nb = k / QK_K;

for (int i = 0; i < nb; i++) {

const float d_all = ggml_fp16_to_fp32(x[i].d);

const uint8_t * restrict q = x[i].qs;
const uint8_t * restrict hm = x[i].hmask;

const float d1 = d_all * x[i].scales[0];
const float d2 = d_all * x[i].scales[1];
const float d3 = d_all * x[i].scales[2];
const float d4 = d_all * x[i].scales[3];

for (int l=0; l<8; ++l) {
uint8_t h = hm[l];
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
}
y += QK_K;
}
}
#endif

void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
quantize_row_q3_K_reference(x, vy, k);
@@ -544,11 +630,14 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
const int nb = k / QK_K;

uint8_t L[QK_K];
#if QK_K == 256
float mins[QK_K/32];
float scales[QK_K/32];
#endif

for (int i = 0; i < nb; i++) {

#if QK_K == 256
float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
@@ -594,9 +683,28 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
L[32*j + ii] = l;
}
}
#else
for (int j = 0; j < QK_K/32; ++j) {
float min;
float scale = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &min, 5);
y[i].d[2*j+0] = ggml_fp32_to_fp16(scale);
y[i].d[2*j+1] = ggml_fp32_to_fp16(min);
}

for (int j = 0; j < QK_K/32; ++j) {
const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]);
if (!d) continue;
const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]);
for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + dm)/d);
l = MAX(0, MIN(15, l));
L[32*j + ii] = l;
}
}
#endif
uint8_t * q = y[i].qs;
for (int j = 0; j < QK_K; j += 64) {
for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4);
for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
}

x += QK_K;
@@ -610,11 +718,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int

for (int i = 0; i < nb; i++) {

const float d = ggml_fp16_to_fp32(x[i].d);
const float min = ggml_fp16_to_fp32(x[i].dmin);

const uint8_t * q = x[i].qs;

#if QK_K == 256

const float d = ggml_fp16_to_fp32(x[i].d);
const float min = ggml_fp16_to_fp32(x[i].dmin);

int is = 0;
uint8_t sc, m;
for (int j = 0; j < QK_K; j += 64) {
@@ -626,6 +736,15 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
q += 32; is += 2;
}
#else
float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]);
float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]);
for (int l = 0; l < 32; ++l) {
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
y[l+32] = d2 * (q[l] >> 4) - m2;
}
y += QK_K;
#endif

}
}
@@ -654,11 +773,15 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
const int nb = k / QK_K;

uint8_t L[QK_K];
#if QK_K == 256
float mins[QK_K/32];
float scales[QK_K/32];
#endif

for (int i = 0; i < nb; i++) {

#if QK_K == 256

float max_scale = 0; // as we are deducting the min, scales are always positive
float max_min = 0;
for (int j = 0; j < QK_K/32; ++j) {
@@ -725,6 +848,42 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
m1 <<= 2; m2 <<= 2;
ql += 32;
}
#else
for (int j = 0; j < QK_K/32; ++j) {
float min;
float scale = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &min, 5);
y[i].d[2*j+0] = ggml_fp32_to_fp16(scale);
y[i].d[2*j+1] = ggml_fp32_to_fp16(min);
}
for (int j = 0; j < QK_K/32; ++j) {
const float d = ggml_fp16_to_fp32(y[i].d[2*j+0]);
if (!d) continue;
const float dm = ggml_fp16_to_fp32(y[i].d[2*j+1]);
for (int ii = 0; ii < 32; ++ii) {
int l = nearest_int((x[32*j + ii] + dm)/d);
l = MAX(0, MIN(31, l));
L[32*j + ii] = l;
}
}

uint8_t * restrict qh = y[i].qh;
uint8_t * restrict ql = y[i].qs;
memset(qh, 0, QK_K/8);

for (int j = 0; j < 32; ++j) {
int jm = j%8;
int is = j/8;
int l1 = L[j];
if (l1 > 15) {
l1 -= 16; qh[jm] |= (1 << is);
}
int l2 = L[j + 32];
if (l2 > 15) {
l2 -= 16; qh[jm] |= (1 << (4 + is));
}
ql[j] = l1 | (l2 << 4);
}
#endif

x += QK_K;

@@ -737,12 +896,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int

for (int i = 0; i < nb; i++) {

const float d = ggml_fp16_to_fp32(x[i].d);
const float min = ggml_fp16_to_fp32(x[i].dmin);

const uint8_t * ql = x[i].qs;
const uint8_t * qh = x[i].qh;

#if QK_K == 256

const float d = ggml_fp16_to_fp32(x[i].d);
const float min = ggml_fp16_to_fp32(x[i].dmin);

int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
@@ -756,6 +917,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
#else
float d1 = ggml_fp16_to_fp32(x[i].d[0]), m1 = ggml_fp16_to_fp32(x[i].d[1]);
float d2 = ggml_fp16_to_fp32(x[i].d[2]), m2 = ggml_fp16_to_fp32(x[i].d[3]);
for (int l = 0; l < 8; ++l) {
y[l+ 0] = d1 * ((ql[l+ 0] & 0xF) + (qh[l] & 0x01 ? 16 : 0)) - m1;
y[l+ 8] = d1 * ((ql[l+ 8] & 0xF) + (qh[l] & 0x02 ? 16 : 0)) - m1;
y[l+16] = d1 * ((ql[l+16] & 0xF) + (qh[l] & 0x04 ? 16 : 0)) - m1;
y[l+24] = d1 * ((ql[l+24] & 0xF) + (qh[l] & 0x08 ? 16 : 0)) - m1;
y[l+32] = d2 * ((ql[l+ 0] >> 4) + (qh[l] & 0x10 ? 16 : 0)) - m2;
y[l+40] = d2 * ((ql[l+ 8] >> 4) + (qh[l] & 0x20 ? 16 : 0)) - m2;
y[l+48] = d2 * ((ql[l+16] >> 4) + (qh[l] & 0x40 ? 16 : 0)) - m2;
y[l+56] = d2 * ((ql[l+24] >> 4) + (qh[l] & 0x80 ? 16 : 0)) - m2;
}
y += QK_K;
#endif
}
}

@@ -823,6 +999,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict

uint8_t * restrict ql = y[i].ql;
uint8_t * restrict qh = y[i].qh;
#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
const uint8_t q1 = L[j + l + 0] & 0xF;
@@ -836,6 +1013,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict
ql += 64;
qh += 32;
}
#else
for (int l = 0; l < 32; ++l) {
const uint8_t q1 = L[l + 0] & 0xF;
const uint8_t q2 = L[l + 32] & 0xF;
ql[l] = q1 | (q2 << 4);
}
for (int l = 0; l < 16; ++l) {
qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6);
}
#endif

x += QK_K;

@@ -854,6 +1041,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict sc = x[i].scales;

#if QK_K == 256
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
@@ -871,6 +1059,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int
qh += 32;
sc += 8;
}
#else
for (int l = 0; l < 16; ++l) {
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
y[l+ 0] = d * sc[0] * q1;
y[l+16] = d * sc[1] * q2;
y[l+32] = d * sc[2] * q3;
y[l+48] = d * sc[3] * q4;
}
y += 64;
#endif

}
}
@@ -1611,18 +1812,23 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri

for (int i = 0; i < nb; ++i) {

#if QK_K == 256
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);

const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;

memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
#else
// TODO
const float d = 0; const float dmin = 0;
#endif

const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;

const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));

@@ -1840,18 +2046,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri

for (int i = 0; i < nb; ++i) {

const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);

const uint8_t * restrict q5 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;

#if QK_K == 256
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);

memcpy(utmp, x[i].scales, 12);
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
const uint32_t uaux = utmp[1] & kmask1;
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[2] = uaux;
utmp[0] &= kmask1;
#else
// TODO
const float d = 0, dmin = 0;
#endif
Comment on lines +2622 to +2625
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Postponed for later or did you missed to implement this?


const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));

39 changes: 33 additions & 6 deletions k_quants.h
Original file line number Diff line number Diff line change
@@ -7,7 +7,13 @@
#include <stddef.h>

// Super-block size
#ifdef GGML_QKK_64
#define QK_K 64
#define K_SCALE_SIZE 4
#else
#define QK_K 256
#define K_SCALE_SIZE 12
#endif

//
// Super-block quantization structures
@@ -32,35 +38,56 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w
typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
#ifdef GGML_QKK_64
int8_t scales[K_SCALE_SIZE];
#else
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
#endif
ggml_fp16_t d; // super-block scale
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");

// 4-bit quantization
// 16 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2, "wrong q4_K block size/padding");
#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
#endif
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");

// 5-bit quantization
// 16 blocks of 32 elements each
// weight is represented as x = a * q + b
// Effectively 5.5 bits per weight
#ifdef GGML_QKK_64
typedef struct {
ggml_fp16_t d[2*QK_K/32]; // super-block scales/mins
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*QK_K/32*sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#else
typedef struct {
ggml_fp16_t d; // super-block scale for quantized scales
ggml_fp16_t dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#endif

// 6-bit quantization
// weight is represented as x = a * q