Skip to content

Commit fd70e7a

Browse files
committed
test
1 parent 580fe20 commit fd70e7a

File tree

1 file changed

+69
-5
lines changed

1 file changed

+69
-5
lines changed

ggml-cuda.cu

+69-5
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443443
#define CUDA_SCALE_BLOCK_SIZE 256
444444
#define CUDA_CLAMP_BLOCK_SIZE 256
445445
#define CUDA_ROPE_BLOCK_SIZE 256
446+
#define CUDA_SOFT_MAX_BLOCK_SIZE 256
446447
#define CUDA_ALIBI_BLOCK_SIZE 32
447448
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448449
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -4719,11 +4720,12 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47194720

47204721
// the CUDA soft max implementation differs from the CPU implementation
47214722
// instead of doubles floats are used
4722-
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4723-
const int rowx = blockDim.x*blockIdx.x + threadIdx.x;
4723+
static __global__ void soft_max_f32_warp(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4724+
const int tid = threadIdx.x;
4725+
const int rowx = blockIdx.x;
47244726
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4725-
const int block_size = blockDim.y;
4726-
const int tid = threadIdx.y;
4727+
4728+
const int block_size = blockDim.x;
47274729

47284730
float max_val = -INFINITY;
47294731

@@ -4763,6 +4765,66 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
47634765
}
47644766
}
47654767

4768+
// use shared memory to reduce the number of global memory reads
4769+
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4770+
const int tid = threadIdx.x;
4771+
const int rowx = blockIdx.x;
4772+
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4773+
4774+
const int block_size = blockDim.x;
4775+
4776+
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE];
4777+
4778+
buf[tid] = -INFINITY;
4779+
4780+
for (int col = tid; col < ncols; col += block_size) {
4781+
const int ix = rowx*ncols + col;
4782+
const int iy = rowy*ncols + col;
4783+
buf[tid] = max(buf[tid], x[ix]*scale + (y ? y[iy] : 0.0f));
4784+
}
4785+
4786+
__syncthreads();
4787+
4788+
// find the max value in the block
4789+
for (int i = block_size/2; i > 0; i >>= 1) {
4790+
if (tid < i) {
4791+
buf[tid] = max(buf[tid], buf[tid + i]);
4792+
}
4793+
__syncthreads();
4794+
}
4795+
4796+
float tmp = 0.f;
4797+
4798+
for (int col = tid; col < ncols; col += block_size) {
4799+
const int ix = rowx*ncols + col;
4800+
const int iy = rowy*ncols + col;
4801+
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - buf[0]);
4802+
tmp += val;
4803+
dst[ix] = val;
4804+
}
4805+
4806+
__syncthreads();
4807+
4808+
buf[tid] = tmp;
4809+
4810+
__syncthreads();
4811+
4812+
// sum up partial sums
4813+
for (int i = block_size/2; i > 0; i >>= 1) {
4814+
if (tid < i) {
4815+
buf[tid] += buf[tid + i];
4816+
}
4817+
__syncthreads();
4818+
}
4819+
4820+
const float inv_tmp = 1.f / buf[0];
4821+
4822+
for (int col = tid; col < ncols; col += block_size) {
4823+
const int i = rowx*ncols + col;
4824+
dst[i] *= inv_tmp;
4825+
}
4826+
}
4827+
47664828
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
47674829
const int i = blockDim.x*blockIdx.x + threadIdx.x;
47684830

@@ -5796,7 +5858,9 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
57965858
}
57975859

57985860
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5799-
const dim3 block_dims(1, WARP_SIZE, 1);
5861+
int nth = WARP_SIZE;
5862+
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
5863+
const dim3 block_dims(nth , 1, 1);
58005864
const dim3 block_nums(nrows_x, 1, 1);
58015865
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
58025866
}

0 commit comments

Comments
 (0)