Skip to content

Commit 05ef24d

Browse files
Quantized dot products for CUDA mul mat vec
1 parent 698efad commit 05ef24d

File tree

4 files changed

+414
-91
lines changed

4 files changed

+414
-91
lines changed

CMakeLists.txt

+8-3
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework
6868
option(LLAMA_BLAS "llama: use BLAS" OFF)
6969
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
7070
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
71+
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
7172
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
72-
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
73+
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
7374
option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
7475
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
7576
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
@@ -246,8 +247,12 @@ if (LLAMA_CUBLAS)
246247
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
247248

248249
add_compile_definitions(GGML_USE_CUBLAS)
250+
add_compile_definitions(GGML_CUDA_FORCE_DMMV=${LLAMA_CUDA_FORCE_DMMV})
249251
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
250-
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
252+
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
253+
if (DEFINED LLAMA_CUDA_DMMV_Y)
254+
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_DMMV_Y}) # for backwards compatibility
255+
endif()
251256
if (LLAMA_CUDA_DMMV_F16)
252257
add_compile_definitions(GGML_CUDA_DMMV_F16)
253258
endif()
@@ -263,7 +268,7 @@ if (LLAMA_CUBLAS)
263268
if (LLAMA_CUDA_DMMV_F16)
264269
set(CMAKE_CUDA_ARCHITECTURES "61") # needed for f16 CUDA intrinsics
265270
else()
266-
set(CMAKE_CUDA_ARCHITECTURES "52") # lowest CUDA 12 standard
271+
set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics
267272
endif()
268273
endif()
269274
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")

Makefile

+9-4
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,21 @@ ifdef LLAMA_CUBLAS
166166
OBJS += ggml-cuda.o
167167
NVCC = nvcc
168168
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
169+
ifdef LLAMA_CUDA_FORCE_DMMV
170+
NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV=$(LLAMA_CUDA_FORCE_DMMV)
171+
endif # LLAMA_CUDA_FORCE_DMMV
169172
ifdef LLAMA_CUDA_DMMV_X
170173
NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
171174
else
172175
NVCCFLAGS += -DGGML_CUDA_DMMV_X=32
173176
endif # LLAMA_CUDA_DMMV_X
174-
ifdef LLAMA_CUDA_DMMV_Y
175-
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y)
177+
ifdef LLAMA_CUDA_MMV_Y
178+
NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
179+
else ifdef LLAMA_CUDA_DMMV_Y
180+
NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_DMMV_Y) # for backwards compatibility
176181
else
177-
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
178-
endif # LLAMA_CUDA_DMMV_Y
182+
NVCCFLAGS += -DGGML_CUDA_MMV_Y=1
183+
endif # LLAMA_CUDA_MMV_Y
179184
ifdef LLAMA_CUDA_DMMV_F16
180185
NVCCFLAGS += -DGGML_CUDA_DMMV_F16
181186
endif # LLAMA_CUDA_DMMV_F16

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,9 @@ Building the program with BLAS support may lead to some performance improvements
344344
345345
| Option | Legal values | Default | Description |
346346
|-------------------------|------------------------|---------|-------------|
347+
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 7.0/Turing/RTX 2000 or higher). Does not affect k-quants. |
347348
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
348-
| LLAMA_CUDA_DMMV_Y | Positive integer | 1 | Block size in y direction for the CUDA dequantization + mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
349+
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
349350
| LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |
350351
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
351352

0 commit comments

Comments
 (0)