Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 50cb666

Browse files
authoredApr 21, 2023
Improve cuBLAS performance by using a memory pool (#1094)
* Improve cuBLAS performance by using a memory pool * Move cuda specific definitions to ggml-cuda.h/cu * Add CXX flags to nvcc * Change memory pool synchronization mechanism to a spin lock General code cleanup
1 parent 25d7abb commit 50cb666

File tree

4 files changed

+168
-107
lines changed

4 files changed

+168
-107
lines changed
 

‎Makefile

+6-4
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,13 @@ ifdef LLAMA_OPENBLAS
101101
LDFLAGS += -lopenblas
102102
endif
103103
ifdef LLAMA_CUBLAS
104-
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
105-
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64
106-
OBJS += ggml-cuda.o
104+
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
105+
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64
106+
OBJS += ggml-cuda.o
107+
NVCC = nvcc
108+
NVCCFLAGS = --forward-unknown-to-host-linker -arch=native
107109
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
108-
nvcc -arch=native -c -o $@ $<
110+
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@
109111
endif
110112
ifdef LLAMA_GPROF
111113
CFLAGS += -pg

‎ggml-cuda.cu

+93-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <stdint.h>
2+
#include <stdio.h>
23
#include <cuda_fp16.h>
4+
#include <atomic>
35
#include "ggml-cuda.h"
46

57
typedef uint16_t ggml_fp16_t;
@@ -29,14 +31,12 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
2931

3032
#define QK4_3 16
3133
typedef struct {
32-
__half d; // delta
33-
__half m; // min
34-
uint8_t qs[QK4_3 / 2]; // nibbles / quants
34+
__half d; // delta
35+
__half m; // min
36+
uint8_t qs[QK4_3 / 2]; // nibbles / quants
3537
} block_q4_3;
3638
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
3739

38-
39-
4040
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
4141
const block_q4_0 * x = (const block_q4_0 *) vx;
4242

@@ -131,24 +131,98 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
131131
}
132132
}
133133

134-
extern "C" {
135-
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
136-
const int nb = k / QK4_0;
137-
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
138-
}
134+
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
135+
const int nb = k / QK4_0;
136+
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
137+
}
138+
139+
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
140+
const int nb = k / QK4_1;
141+
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
142+
}
143+
144+
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
145+
const int nb = k / QK4_2;
146+
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
147+
}
148+
149+
void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
150+
const int nb = k / QK4_3;
151+
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
152+
}
139153

140-
__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
141-
const int nb = k / QK4_1;
142-
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
154+
// buffer pool for cuda
155+
#define MAX_CUDA_BUFFERS 16
156+
157+
struct scoped_spin_lock {
158+
std::atomic_flag& lock;
159+
scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
160+
while (lock.test_and_set(std::memory_order_acquire)) {
161+
; // spin
162+
}
163+
}
164+
~scoped_spin_lock() {
165+
lock.clear(std::memory_order_release);
166+
}
167+
scoped_spin_lock(const scoped_spin_lock&) = delete;
168+
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
169+
};
170+
171+
struct cuda_buffer {
172+
void * ptr = nullptr;
173+
size_t size = 0;
174+
};
175+
176+
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
177+
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
178+
179+
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
180+
scoped_spin_lock lock(g_cuda_pool_lock);
181+
182+
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
183+
cuda_buffer& b = g_cuda_buffer_pool[i];
184+
if (b.size >= size && b.ptr != nullptr) {
185+
void * ptr = b.ptr;
186+
*actual_size = b.size;
187+
b.ptr = nullptr;
188+
b.size = 0;
189+
return ptr;
190+
}
143191
}
192+
void * ptr;
193+
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
194+
*actual_size = size;
195+
return ptr;
196+
}
197+
198+
void ggml_cuda_pool_free(void * ptr, size_t size) {
199+
scoped_spin_lock lock(g_cuda_pool_lock);
144200

145-
__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
146-
const int nb = k / QK4_2;
147-
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
201+
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
202+
cuda_buffer& b = g_cuda_buffer_pool[i];
203+
if (b.ptr == nullptr) {
204+
b.ptr = ptr;
205+
b.size = size;
206+
return;
207+
}
148208
}
209+
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
210+
CUDA_CHECK(cudaFree(ptr));
211+
}
212+
213+
cublasHandle_t g_cublasH = NULL;
214+
cudaStream_t g_cudaStream = NULL;
215+
216+
void ggml_init_cublas(void) {
217+
if (g_cublasH == NULL) {
218+
// create cublas handle, bind a stream
219+
CUBLAS_CHECK(cublasCreate(&g_cublasH));
220+
221+
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
222+
223+
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
149224

150-
__host__ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
151-
const int nb = k / QK4_3;
152-
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
225+
// configure logging to stdout
226+
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
153227
}
154228
}

‎ggml-cuda.h

+29
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,36 @@
1+
#include <cublas_v2.h>
2+
#include <cuda_runtime.h>
3+
14
#ifdef __cplusplus
25
extern "C" {
36
#endif
47

8+
#define CUDA_CHECK(err) \
9+
do { \
10+
cudaError_t err_ = (err); \
11+
if (err_ != cudaSuccess) { \
12+
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
13+
cudaGetErrorString(err_)); \
14+
exit(1); \
15+
} \
16+
} while (0)
17+
18+
#define CUBLAS_CHECK(err) \
19+
do { \
20+
cublasStatus_t err_ = (err); \
21+
if (err_ != CUBLAS_STATUS_SUCCESS) { \
22+
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
23+
exit(1); \
24+
} \
25+
} while (0)
26+
27+
extern cublasHandle_t g_cublasH;
28+
extern cudaStream_t g_cudaStream;
29+
30+
void ggml_init_cublas(void);
31+
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
32+
void ggml_cuda_pool_free(void * ptr, size_t size);
33+
534
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
635
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
736
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);

0 commit comments

Comments
 (0)
Please sign in to comment.