Skip to content

Commit 5f93949

Browse files
authored
ggml : unit test for quantization functions (#953)
* Unit test for quantization functions Use the ggml_internal_get_quantize_fn function to loop through all quantization formats and run a sanity check on the result. Also add a microbenchmark that times these functions directly without running the rest of the GGML graph. * test-quantize-fns: CI fixes Fix issues uncovered in CI - need to use sizes divisible by 32*8 for loop unrolling - use intrinsic header that should work on Mac * test-quantize: remove Per PR comment, subsumed by test-quantize-fns * test-quantize: fix for q8_0 intermediates
1 parent 36b4f7e commit 5f93949

File tree

4 files changed

+466
-43
lines changed

4 files changed

+466
-43
lines changed

tests/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ function(llama_add_test source)
66
endfunction()
77

88
# llama_add_test(test-double-float.c) # SLOW
9-
llama_add_test(test-quantize.c)
9+
llama_add_test(test-quantize-fns.cpp)
10+
llama_add_test(test-quantize-perf.cpp)
1011
llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)

tests/test-quantize-fns.cpp

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Unit tests for quantization specific functions - quantize, dequantize and dot product
2+
3+
#include "ggml.h"
4+
5+
#undef NDEBUG
6+
#include <assert.h>
7+
#include <math.h>
8+
#include <stdio.h>
9+
#include <string>
10+
#include <vector>
11+
12+
13+
const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001;
14+
const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002;
15+
const float MAX_DOT_PRODUCT_ERROR = 0.02;
16+
17+
const char* RESULT_STR[] = {"ok", "FAILED"};
18+
19+
20+
// Generate synthetic data
21+
void generate_data(float offset, size_t n, float * dst) {
22+
for (size_t i = 0; i < n; i++) {
23+
dst[i] = 0.1 + 2*cosf(i + offset);
24+
}
25+
}
26+
27+
// Calculate RMSE between two float arrays
28+
float array_rmse(const float * a1, const float * a2, size_t n) {
29+
double sum = 0;
30+
for (size_t i = 0; i < n; i++) {
31+
double diff = a1[i] - a2[i];
32+
sum += diff * diff;
33+
}
34+
return sqrtf(sum) / n;
35+
}
36+
37+
// Total quantization error on test data
38+
float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
39+
std::vector<uint8_t> tmp_q(test_size);
40+
std::vector<float> tmp_out(test_size);
41+
42+
qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
43+
qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
44+
return array_rmse(test_data, tmp_out.data(), test_size);
45+
}
46+
47+
// Total quantization error on test data
48+
float reference_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
49+
std::vector<uint8_t> tmp_q(test_size);
50+
std::vector<float> tmp_out(test_size);
51+
std::vector<float> tmp_out_ref(test_size);
52+
53+
qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
54+
qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
55+
56+
qfns.quantize_row_q_reference(test_data, tmp_q.data(), test_size);
57+
qfns.dequantize_row_q(tmp_q.data(), tmp_out_ref.data(), test_size);
58+
59+
return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
60+
}
61+
62+
float dot_product(const float * a1, const float * a2, size_t test_size) {
63+
double sum = 0;
64+
for (size_t i = 0; i < test_size; i++) {
65+
sum += a1[i] * a2[i];
66+
}
67+
return sum;
68+
}
69+
70+
// Total dot product error
71+
float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
72+
std::vector<uint8_t> tmp_q1(test_size);
73+
std::vector<uint8_t> tmp_q2(test_size*2);
74+
75+
qfns.quantize_row_q(test_data1, tmp_q1.data(), test_size);
76+
qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size);
77+
78+
float result = INFINITY;
79+
qfns.vec_dot_q(test_size, &result, tmp_q1.data(), tmp_q2.data());
80+
81+
const float dot_ref = dot_product(test_data1, test_data2, test_size);
82+
83+
return fabsf(result - dot_ref) / test_size;
84+
}
85+
86+
int main(int argc, char * argv[]) {
87+
bool verbose = false;
88+
const size_t test_size = 32 * 128;
89+
90+
std::string arg;
91+
for (int i = 1; i < argc; i++) {
92+
arg = argv[i];
93+
94+
if (arg == "-v") {
95+
verbose = true;
96+
} else {
97+
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
98+
return 1;
99+
}
100+
}
101+
102+
std::vector<float> test_data(test_size);
103+
std::vector<float> test_data2(test_size);
104+
105+
generate_data(0.0, test_data.size(), test_data.data());
106+
generate_data(1.0, test_data2.size(), test_data2.data());
107+
108+
// Initialize GGML, ensures float conversion tables are initialized
109+
struct ggml_init_params ggml_params = {
110+
/* .mem_size = */ 1*1024,
111+
/* .mem_buffer = */ NULL,
112+
/* .no_alloc = */ true,
113+
};
114+
struct ggml_context * ctx = ggml_init(ggml_params);
115+
116+
int num_failed = 0;
117+
bool failed = false;
118+
119+
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
120+
ggml_type type = (ggml_type) i;
121+
quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
122+
123+
if (qfns.quantize_row_q) {
124+
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
125+
failed = !(total_error < MAX_QUANTIZATION_TOTAL_ERROR);
126+
num_failed += failed;
127+
if (failed || verbose) {
128+
printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
129+
}
130+
131+
const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
132+
failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
133+
num_failed += failed;
134+
if (failed || verbose) {
135+
printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
136+
}
137+
138+
const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
139+
failed = !(vec_dot_error < MAX_DOT_PRODUCT_ERROR);
140+
num_failed += failed;
141+
if (failed || verbose) {
142+
printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
143+
}
144+
}
145+
}
146+
147+
if (num_failed || verbose) {
148+
printf("%d tests failed\n", num_failed);
149+
}
150+
151+
ggml_free(ctx);
152+
153+
return num_failed > 0;
154+
}

0 commit comments

Comments
 (0)