Skip to content

Commit 1fcdcc2

Browse files
cuda : performance optimizations (#1530)
* xor hack * block y dim * loop unrolling * Fixed cmake LLAMA_CUDA_BY option * Removed hipblas compatibility code * Define GGML_CUDA_DMMV_BLOCK_Y if not defined * Fewer iters, more ops per iter * Renamed DMMV X/Y compilation options
1 parent ac7876a commit 1fcdcc2

File tree

3 files changed

+111
-65
lines changed

3 files changed

+111
-65
lines changed

CMakeLists.txt

+29-25
Original file line numberDiff line numberDiff line change
@@ -37,42 +37,44 @@ endif()
3737
#
3838

3939
# general
40-
option(LLAMA_STATIC "llama: static link libraries" OFF)
41-
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
42-
option(LLAMA_LTO "llama: enable link time optimization" OFF)
40+
option(LLAMA_STATIC "llama: static link libraries" OFF)
41+
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
42+
option(LLAMA_LTO "llama: enable link time optimization" OFF)
4343

4444
# debug
45-
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
46-
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
47-
option(LLAMA_GPROF "llama: enable gprof" OFF)
45+
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
46+
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
47+
option(LLAMA_GPROF "llama: enable gprof" OFF)
4848

4949
# sanitizers
50-
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
51-
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
52-
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
50+
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
51+
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
52+
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
5353

5454
# instruction set specific
55-
option(LLAMA_AVX "llama: enable AVX" ON)
56-
option(LLAMA_AVX2 "llama: enable AVX2" ON)
57-
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
58-
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
59-
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
60-
option(LLAMA_FMA "llama: enable FMA" ON)
55+
option(LLAMA_AVX "llama: enable AVX" ON)
56+
option(LLAMA_AVX2 "llama: enable AVX2" ON)
57+
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
58+
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
59+
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
60+
option(LLAMA_FMA "llama: enable FMA" ON)
6161
# in MSVC F16C is implied with AVX2/AVX512
6262
if (NOT MSVC)
63-
option(LLAMA_F16C "llama: enable F16C" ON)
63+
option(LLAMA_F16C "llama: enable F16C" ON)
6464
endif()
6565

6666
# 3rd party libs
67-
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
68-
option(LLAMA_BLAS "llama: use BLAS" OFF)
69-
option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic)
70-
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
71-
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
72-
73-
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
74-
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
75-
option(LLAMA_BUILD_SERVER "llama: build server example" OFF)
67+
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
68+
option(LLAMA_BLAS "llama: use BLAS" OFF)
69+
option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic)
70+
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
71+
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
72+
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
73+
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
74+
75+
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
76+
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
77+
option(LLAMA_BUILD_SERVER "llama: build server example" OFF)
7678

7779
#
7880
# Build info header
@@ -184,6 +186,8 @@ if (LLAMA_CUBLAS)
184186
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
185187

186188
add_compile_definitions(GGML_USE_CUBLAS)
189+
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
190+
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
187191

188192
if (LLAMA_STATIC)
189193
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)

Makefile

+11-1
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,19 @@ ifdef LLAMA_CUBLAS
133133
OBJS += ggml-cuda.o
134134
NVCC = nvcc
135135
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
136+
ifdef LLAMA_CUDA_DMMV_X
137+
NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
138+
else
139+
NVCCFLAGS += -DGGML_CUDA_DMMV_X=32
140+
endif # LLAMA_CUDA_DMMV_X
141+
ifdef LLAMA_CUDA_DMMV_Y
142+
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y)
143+
else
144+
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
145+
endif # LLAMA_CUDA_DMMV_Y
136146
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
137147
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
138-
endif
148+
endif # LLAMA_CUBLAS
139149
ifdef LLAMA_CLBLAST
140150
CFLAGS += -DGGML_USE_CLBLAST
141151
CXXFLAGS += -DGGML_USE_CLBLAST

ggml-cuda.cu

+71-39
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,19 @@ typedef struct {
8383
} block_q8_0;
8484
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
8585

86+
#define WARP_SIZE 32
87+
8688
#define CUDA_MUL_BLOCK_SIZE 256
89+
8790
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
88-
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
91+
92+
// dmmv = dequantize_mul_mat_vec
93+
#ifndef GGML_CUDA_DMMV_X
94+
#define GGML_CUDA_DMMV_X 32
95+
#endif
96+
#ifndef GGML_CUDA_DMMV_Y
97+
#define GGML_CUDA_DMMV_Y 1
98+
#endif
8999

90100
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
91101
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
200210
dequantize_kernel(vx, ib, iqs, v0, v1);
201211
}
202212

203-
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
213+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
204214
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
205-
const int row = blockIdx.x;
215+
// qk = quantized weights per x block
216+
// qr = number of quantized weights per data value in x block
217+
const int row = blockIdx.x*blockDim.y + threadIdx.y;
206218
const int tid = threadIdx.x;
207219

220+
const int iter_stride = 2*GGML_CUDA_DMMV_X;
221+
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
208222
const int y_offset = qr == 1 ? 1 : qk/2;
209223

210-
__shared__ float tmp[block_size]; // separate sum for each thread
211-
tmp[tid] = 0;
224+
float tmp = 0; // partial sum for thread in warp
212225

213-
for (int i = 0; i < ncols/block_size; i += 2) {
214-
const int col = i*block_size + 2*tid;
215-
const int ib = (row*ncols + col)/qk; // block index
216-
const int iqs = (col%qk)/qr; // quant index
226+
for (int i = 0; i < ncols; i += iter_stride) {
227+
const int col = i + vals_per_iter*tid;
228+
const int ib = (row*ncols + col)/qk; // x block index
229+
const int iqs = (col%qk)/qr; // x quant index
217230
const int iybs = col - col%qk; // y block start index
218231

219-
// dequantize
220-
float v0, v1;
221-
dequantize_kernel(vx, ib, iqs, v0, v1);
232+
// processing >2 values per i iter is faster for fast GPUs
233+
#pragma unroll
234+
for (int j = 0; j < vals_per_iter; j += 2) {
235+
// process 2 vals per j iter
236+
237+
// dequantize
238+
float v0, v1;
239+
dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
240+
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
222241

223-
// matrix multiplication
224-
tmp[tid] += v0 * y[iybs + iqs + 0];
225-
tmp[tid] += v1 * y[iybs + iqs + y_offset];
242+
// matrix multiplication
243+
tmp += v0 * y[iybs + iqs + j/qr + 0];
244+
tmp += v1 * y[iybs + iqs + j/qr + y_offset];
245+
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
246+
}
226247
}
227248

228249
// sum up partial sums and write back result
229250
__syncthreads();
230-
for (int s=block_size/2; s>0; s>>=1) {
231-
if (tid < s) {
232-
tmp[tid] += tmp[tid + s];
233-
}
234-
__syncthreads();
251+
#pragma unroll
252+
for (int mask = 16; mask > 0; mask >>= 1) {
253+
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
235254
}
255+
236256
if (tid == 0) {
237-
dst[row] = tmp[0];
257+
dst[row] = tmp;
238258
}
239259
}
240260

@@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
269289
}
270290

271291
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
272-
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
273-
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
274-
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
292+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
293+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
294+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
295+
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
296+
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
275297
}
276298

277299
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
278-
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
279-
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
280-
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
300+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
301+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
302+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
303+
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
304+
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
281305
}
282306

283307
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
284-
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
285-
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
286-
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
308+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
309+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
310+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
311+
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
312+
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
287313
}
288314

289315
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
290-
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
291-
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
292-
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
316+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
317+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
318+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
319+
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
320+
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
293321
}
294322

295323
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
296-
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
297-
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
298-
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
324+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
325+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
326+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
327+
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
328+
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
299329
}
300330

301331
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
304334
}
305335

306336
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
307-
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
308-
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
309-
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
337+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
338+
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
339+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
340+
dequantize_mul_mat_vec<1, 1, convert_f16>
341+
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
310342
}
311343

312344
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {

0 commit comments

Comments
 (0)