@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443
443
#define CUDA_SCALE_BLOCK_SIZE 256
444
444
#define CUDA_CLAMP_BLOCK_SIZE 256
445
445
#define CUDA_ROPE_BLOCK_SIZE 256
446
+ #define CUDA_SOFT_MAX_BLOCK_SIZE 256
446
447
#define CUDA_ALIBI_BLOCK_SIZE 32
447
448
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448
449
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -4719,11 +4720,12 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
4719
4720
4720
4721
// the CUDA soft max implementation differs from the CPU implementation
4721
4722
// instead of doubles floats are used
4722
- static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4723
- const int rowx = blockDim .x *blockIdx .x + threadIdx .x ;
4723
+ static __global__ void soft_max_f32_warp (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4724
+ const int tid = threadIdx .x ;
4725
+ const int rowx = blockIdx .x ;
4724
4726
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4725
- const int block_size = blockDim . y ;
4726
- const int tid = threadIdx . y ;
4727
+
4728
+ const int block_size = blockDim . x ;
4727
4729
4728
4730
float max_val = -INFINITY;
4729
4731
@@ -4763,6 +4765,66 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
4763
4765
}
4764
4766
}
4765
4767
4768
+ // use shared memory to reduce the number of global memory reads
4769
+ static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4770
+ const int tid = threadIdx .x ;
4771
+ const int rowx = blockIdx .x ;
4772
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4773
+
4774
+ const int block_size = blockDim .x ;
4775
+
4776
+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE];
4777
+
4778
+ buf[tid] = -INFINITY;
4779
+
4780
+ for (int col = tid; col < ncols; col += block_size) {
4781
+ const int ix = rowx*ncols + col;
4782
+ const int iy = rowy*ncols + col;
4783
+ buf[tid] = max (buf[tid], x[ix]*scale + (y ? y[iy] : 0 .0f ));
4784
+ }
4785
+
4786
+ __syncthreads ();
4787
+
4788
+ // find the max value in the block
4789
+ for (int i = block_size/2 ; i > 0 ; i >>= 1 ) {
4790
+ if (tid < i) {
4791
+ buf[tid] = max (buf[tid], buf[tid + i]);
4792
+ }
4793
+ __syncthreads ();
4794
+ }
4795
+
4796
+ float tmp = 0 .f ;
4797
+
4798
+ for (int col = tid; col < ncols; col += block_size) {
4799
+ const int ix = rowx*ncols + col;
4800
+ const int iy = rowy*ncols + col;
4801
+ const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - buf[0 ]);
4802
+ tmp += val;
4803
+ dst[ix] = val;
4804
+ }
4805
+
4806
+ __syncthreads ();
4807
+
4808
+ buf[tid] = tmp;
4809
+
4810
+ __syncthreads ();
4811
+
4812
+ // sum up partial sums
4813
+ for (int i = block_size/2 ; i > 0 ; i >>= 1 ) {
4814
+ if (tid < i) {
4815
+ buf[tid] += buf[tid + i];
4816
+ }
4817
+ __syncthreads ();
4818
+ }
4819
+
4820
+ const float inv_tmp = 1 .f / buf[0 ];
4821
+
4822
+ for (int col = tid; col < ncols; col += block_size) {
4823
+ const int i = rowx*ncols + col;
4824
+ dst[i] *= inv_tmp;
4825
+ }
4826
+ }
4827
+
4766
4828
static __global__ void scale_f32 (const float * x, float * dst, const float scale, const int k) {
4767
4829
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
4768
4830
@@ -5796,7 +5858,9 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
5796
5858
}
5797
5859
5798
5860
static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5799
- const dim3 block_dims (1 , WARP_SIZE, 1 );
5861
+ int nth = WARP_SIZE;
5862
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
5863
+ const dim3 block_dims (nth , 1 , 1 );
5800
5864
const dim3 block_nums (nrows_x, 1 , 1 );
5801
5865
soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale);
5802
5866
}
0 commit comments