Skip to content

Commit 8cc91dc

Browse files
authored
ggml : add llamafile sgemm (ggml-org#6414)
This change upstreams llamafile's cpu matrix multiplication kernels which improve image and prompt evaluation speed. For starters, Q4_0 and Q8_0 weights should go ~40% faster on CPU. The biggest benefits are with data types like f16 / f32, which process prompts 2x faster thus making them faster than quantized data types for prompt evals. This change also introduces bona fide AVX512 support since tinyBLAS is able to exploit the larger register file. For example, on my CPU llama.cpp llava-cli processes an image prompt at 305 tokens/second, using the Q4_K and Q4_0 types, which has always been faster than if we used f16 LLaVA weights, which at HEAD go 188 tokens/second. With this change, f16 LLaVA performance leap frogs to 464 tokens/second. On Intel Core i9-14900K this change improves F16 prompt perf by 5x. For example, using llama.cpp at HEAD with Mistral 7b f16 to process a 215 token prompt will go 13 tok/sec. This change has fixes making it go 52 tok/sec. It's mostly thanks to my vectorized outer product kernels but also because I added support for correctly counting the number of cores on Alderlake, so the default thread count discounts Intel's new efficiency cores. Only Linux right now can count cores. This work was sponsored by Mozilla who's given permission to change the license of this code from Apache 2.0 to MIT. To read more about what's improved, and how it works, see: https://justine.lol/matmul/
1 parent dbceec8 commit 8cc91dc

12 files changed

+1312
-12
lines changed

CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,8 @@ add_library(ggml OBJECT
11511151
ggml-backend.h
11521152
ggml-quants.c
11531153
ggml-quants.h
1154+
sgemm.cpp
1155+
sgemm.h
11541156
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
11551157
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
11561158
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}

Makefile

+9-1
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ ifdef LLAMA_DISABLE_LOGS
219219
MK_CPPFLAGS += -DLOG_DISABLE_LOGS
220220
endif # LLAMA_DISABLE_LOGS
221221

222+
# disable ggml.c's use of sgemm.cpp
223+
ifdef LLAMA_NO_LLAMAFILE
224+
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=0
225+
endif
226+
222227
# warnings
223228
WARN_FLAGS = -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
224229
MK_CFLAGS += $(WARN_FLAGS) -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int \
@@ -676,13 +681,16 @@ ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
676681
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h ggml-common.h
677682
$(CC) $(CFLAGS) -c $< -o $@
678683

684+
sgemm.o: sgemm.cpp sgemm.h ggml.h
685+
$(CXX) $(CXXFLAGS) -c $< -o $@
686+
679687
unicode.o: unicode.cpp unicode.h
680688
$(CXX) $(CXXFLAGS) -c $< -o $@
681689

682690
unicode-data.o: unicode-data.cpp unicode-data.h
683691
$(CXX) $(CXXFLAGS) -c $< -o $@
684692

685-
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o
693+
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o unicode.o unicode-data.o sgemm.o
686694

687695
llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
688696
$(CXX) $(CXXFLAGS) -c $< -o $@

Package.swift

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import PackageDescription
44

55
var sources = [
66
"ggml.c",
7+
"sgemm.cpp",
78
"llama.cpp",
89
"unicode.cpp",
910
"unicode-data.cpp",

build.zig

+8-7
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ pub fn build(b: *std.build.Builder) !void {
112112
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
113113

114114
const ggml = make.obj("ggml", "ggml.c");
115+
const sgemm = make.obj("sgemm", "sgemm.cpp");
115116
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
116117
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
117118
const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
@@ -128,14 +129,14 @@ pub fn build(b: *std.build.Builder) !void {
128129
const clip = make.obj("clip", "examples/llava/clip.cpp");
129130
const llava = make.obj("llava", "examples/llava/llava.cpp");
130131

131-
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
132-
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
133-
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
134-
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
135-
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
136-
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
132+
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
133+
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
134+
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
135+
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
136+
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
137+
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
137138

138-
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava });
139+
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava });
139140
if (server.target.isWindows()) {
140141
server.linkSystemLibrary("ws2_32");
141142
}

common/common.cpp

+73
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,79 @@ int32_t get_num_physical_cores() {
108108
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
109109
}
110110

111+
#if defined(__x86_64__) && defined(__linux__)
112+
#include <pthread.h>
113+
114+
static void cpuid(unsigned leaf, unsigned subleaf,
115+
unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) {
116+
__asm__("movq\t%%rbx,%%rsi\n\t"
117+
"cpuid\n\t"
118+
"xchgq\t%%rbx,%%rsi"
119+
: "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx)
120+
: "0"(leaf), "2"(subleaf));
121+
}
122+
123+
static int pin_cpu(int cpu) {
124+
cpu_set_t mask;
125+
CPU_ZERO(&mask);
126+
CPU_SET(cpu, &mask);
127+
return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask);
128+
}
129+
130+
static bool is_hybrid_cpu(void) {
131+
unsigned eax, ebx, ecx, edx;
132+
cpuid(7, 0, &eax, &ebx, &ecx, &edx);
133+
return !!(edx & (1u << 15));
134+
}
135+
136+
static bool is_running_on_efficiency_core(void) {
137+
unsigned eax, ebx, ecx, edx;
138+
cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx);
139+
int intel_atom = 0x20;
140+
int core_type = (eax & 0xff000000u) >> 24;
141+
return core_type == intel_atom;
142+
}
143+
144+
static int count_math_cpus(int cpu_count) {
145+
int result = 0;
146+
for (int cpu = 0; cpu < cpu_count; ++cpu) {
147+
if (pin_cpu(cpu)) {
148+
return -1;
149+
}
150+
if (is_running_on_efficiency_core()) {
151+
continue; // efficiency cores harm lockstep threading
152+
}
153+
++cpu; // hyperthreading isn't useful for linear algebra
154+
++result;
155+
}
156+
return result;
157+
}
158+
159+
#endif // __x86_64__ && __linux__
160+
161+
/**
162+
* Returns number of CPUs on system that are useful for math.
163+
*/
164+
int get_math_cpu_count() {
165+
#if defined(__x86_64__) && defined(__linux__)
166+
int cpu_count = sysconf(_SC_NPROCESSORS_ONLN);
167+
if (cpu_count < 1) {
168+
return get_num_physical_cores();
169+
}
170+
if (is_hybrid_cpu()) {
171+
cpu_set_t affinity;
172+
if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) {
173+
int result = count_math_cpus(cpu_count);
174+
pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity);
175+
if (result > 0) {
176+
return result;
177+
}
178+
}
179+
}
180+
#endif
181+
return get_num_physical_cores();
182+
}
183+
111184
void process_escapes(std::string & input) {
112185
std::size_t input_len = input.length();
113186
std::size_t output_idx = 0;

common/common.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ extern char const *LLAMA_BUILD_TARGET;
3939

4040
struct llama_control_vector_load_info;
4141

42+
int get_math_cpu_count();
4243
int32_t get_num_physical_cores();
4344

4445
//
@@ -48,7 +49,7 @@ int32_t get_num_physical_cores();
4849
struct gpt_params {
4950
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
5051

51-
int32_t n_threads = get_num_physical_cores();
52+
int32_t n_threads = get_math_cpu_count();
5253
int32_t n_threads_draft = -1;
5354
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
5455
int32_t n_threads_batch_draft = -1;

examples/llama-bench/llama-bench.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ static const cmd_params cmd_params_defaults = {
190190
/* n_ubatch */ {512},
191191
/* type_k */ {GGML_TYPE_F16},
192192
/* type_v */ {GGML_TYPE_F16},
193-
/* n_threads */ {get_num_physical_cores()},
193+
/* n_threads */ {get_math_cpu_count()},
194194
/* n_gpu_layers */ {99},
195195
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
196196
/* main_gpu */ {0},

ggml-impl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ typedef uint16_t ggml_fp16_internal_t;
8888
#if defined(_MSC_VER) || defined(__MINGW32__)
8989
#include <intrin.h>
9090
#else
91-
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
91+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
9292
#if !defined(__riscv)
9393
#include <immintrin.h>
9494
#endif

ggml-quants.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
132132
}
133133

134134
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
135-
#if defined(__AVXVNNI__) || defined(__AVX512VNNI__)
135+
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
136136
const __m256i zero = _mm256_setzero_si256();
137137
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
138138
return _mm256_cvtepi32_ps(summed_pairs);

ggml.c

+54
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "ggml-impl.h"
55
#include "ggml-quants.h"
66
#include "ggml.h"
7+
#include "sgemm.h"
78

89
#if defined(_MSC_VER) || defined(__MINGW32__)
910
#include <malloc.h> // using malloc.h with MSC/MINGW
@@ -32,6 +33,14 @@
3233
#include <unistd.h>
3334
#endif
3435

36+
#ifndef GGML_USE_LLAMAFILE
37+
#ifdef __ARM_FEATURE_MATMUL_INT8
38+
#define GGML_USE_LLAMAFILE 0
39+
#else
40+
#define GGML_USE_LLAMAFILE 1
41+
#endif
42+
#endif
43+
3544
#if defined(_MSC_VER)
3645
// disable "possible loss of data" to avoid hundreds of casts
3746
// we should just be careful :)
@@ -10810,6 +10819,28 @@ static void ggml_compute_forward_mul_mat(
1081010819
}
1081110820
#endif
1081210821

10822+
#if GGML_USE_LLAMAFILE
10823+
if (nb10 == ggml_type_size(src1->type)) {
10824+
for (int64_t i13 = 0; i13 < ne13; i13++)
10825+
for (int64_t i12 = 0; i12 < ne12; i12++)
10826+
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
10827+
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
10828+
nb01/ggml_type_size(src0->type),
10829+
(const char *)src1->data + i12*nb12 + i13*nb13,
10830+
nb11/ggml_type_size(src1->type),
10831+
(char *)dst->data + i12*nb2 + i13*nb3,
10832+
nb1/ggml_type_size(dst->type),
10833+
ith, nth,
10834+
params->type,
10835+
src0->type,
10836+
src1->type,
10837+
dst->type))
10838+
goto UseGgmlGemm1;
10839+
return;
10840+
}
10841+
UseGgmlGemm1:;
10842+
#endif
10843+
1081310844
if (params->type == GGML_TASK_TYPE_INIT) {
1081410845
if (ith != 0) {
1081510846
return;
@@ -10841,6 +10872,29 @@ static void ggml_compute_forward_mul_mat(
1084110872
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
1084210873
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
1084310874

10875+
#if GGML_USE_LLAMAFILE
10876+
if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
10877+
for (int64_t i13 = 0; i13 < ne13; i13++)
10878+
for (int64_t i12 = 0; i12 < ne12; i12++)
10879+
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
10880+
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
10881+
nb01/ggml_type_size(src0->type),
10882+
(const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 +
10883+
nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13),
10884+
row_size/ggml_type_size(vec_dot_type),
10885+
(char *)dst->data + i12*nb2 + i13*nb3,
10886+
nb1/ggml_type_size(dst->type),
10887+
ith, nth,
10888+
params->type,
10889+
src0->type,
10890+
vec_dot_type,
10891+
dst->type))
10892+
goto UseGgmlGemm2;
10893+
return;
10894+
}
10895+
UseGgmlGemm2:;
10896+
#endif
10897+
1084410898
const int64_t nr0 = ne01; // src0 rows
1084510899
const int64_t nr1 = ne1*ne12*ne13; // src1 rows
1084610900

0 commit comments

Comments
 (0)