Skip to content

Commit a00c424

Browse files
KerfuffleV2olexiyb
authored andcommitted
Add ReLU and SQR CUDA ops to (partially) fix Persimmon offloading (ggml-org#4041)
* Add ReLU and SQR CUDA ops to fix Persimmon offloading * Persimmon loader: More helpful error on CUDA/ROCM when offloading too many layers
1 parent 73d2aaa commit a00c424

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

Diff for: ggml-cuda.cu

+72
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
433433
#define CUDA_MUL_BLOCK_SIZE 256
434434
#define CUDA_GELU_BLOCK_SIZE 256
435435
#define CUDA_SILU_BLOCK_SIZE 256
436+
#define CUDA_RELU_BLOCK_SIZE 256
437+
#define CUDA_SQR_BLOCK_SIZE 256
436438
#define CUDA_CPY_BLOCK_SIZE 32
437439
#define CUDA_SCALE_BLOCK_SIZE 256
438440
#define CUDA_CLAMP_BLOCK_SIZE 256
@@ -553,6 +555,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
553555
dst[i] = x[i] / (1.0f + expf(-x[i]));
554556
}
555557

558+
static __global__ void relu_f32(const float * x, float * dst, const int k) {
559+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
560+
561+
if (i >= k) {
562+
return;
563+
}
564+
dst[i] = fmaxf(x[i], 0);
565+
}
566+
567+
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
568+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
569+
570+
if (i >= k) {
571+
return;
572+
}
573+
dst[i] = x[i] * x[i];
574+
}
575+
556576
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
557577
#pragma unroll
558578
for (int mask = 16; mask > 0; mask >>= 1) {
@@ -4759,6 +4779,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
47594779
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
47604780
}
47614781

4782+
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
4783+
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
4784+
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
4785+
}
4786+
4787+
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
4788+
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
4789+
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
4790+
}
4791+
47624792
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
47634793
GGML_ASSERT(ncols % WARP_SIZE == 0);
47644794
if (ncols < 1024) {
@@ -6128,6 +6158,34 @@ inline void ggml_cuda_op_silu(
61286158
(void) src1_dd;
61296159
}
61306160

6161+
inline void ggml_cuda_op_relu(
6162+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6163+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6164+
6165+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
6166+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
6167+
6168+
relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
6169+
6170+
(void) src1;
6171+
(void) dst;
6172+
(void) src1_dd;
6173+
}
6174+
6175+
inline void ggml_cuda_op_sqr(
6176+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6177+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6178+
6179+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
6180+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
6181+
6182+
sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
6183+
6184+
(void) src1;
6185+
(void) dst;
6186+
(void) src1_dd;
6187+
}
6188+
61316189
inline void ggml_cuda_op_norm(
61326190
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
61336191
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7160,6 +7218,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
71607218
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
71617219
}
71627220

7221+
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7222+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
7223+
}
7224+
7225+
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7226+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
7227+
}
7228+
71637229
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
71647230
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
71657231
}
@@ -7891,6 +7957,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
78917957
case GGML_UNARY_OP_SILU:
78927958
func = ggml_cuda_silu;
78937959
break;
7960+
case GGML_UNARY_OP_RELU:
7961+
func = ggml_cuda_relu;
7962+
break;
78947963
default:
78957964
return false;
78967965
} break;
@@ -7909,6 +7978,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
79097978
case GGML_OP_SCALE:
79107979
func = ggml_cuda_scale;
79117980
break;
7981+
case GGML_OP_SQR:
7982+
func = ggml_cuda_sqr;
7983+
break;
79127984
case GGML_OP_CLAMP:
79137985
if (!any_on_device) {
79147986
return false;

Diff for: llama.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -2877,6 +2877,13 @@ static void llm_load_tensors(
28772877
ggml_backend_type backend_output;
28782878

28792879
if (n_gpu_layers > int(n_layer)) {
2880+
#ifdef GGML_USE_CUBLAS
2881+
if (n_gpu_layers > int(n_layer + 1)) {
2882+
LLAMA_LOG_ERROR("%s: CUDA backend missing Persimmon CUDA ops, can offload at most %ld layers. See: https://github.com/ggerganov/llama.cpp/issues/4038\n",
2883+
__func__, n_layer + 1);
2884+
throw std::runtime_error("Persimmon CUDA offload failed");
2885+
}
2886+
#endif
28802887
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
28812888
// on Windows however this is detrimental unless everything is on the GPU
28822889
#ifndef _WIN32

0 commit comments

Comments
 (0)