@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443
443
#define CUDA_SCALE_BLOCK_SIZE 256
444
444
#define CUDA_CLAMP_BLOCK_SIZE 256
445
445
#define CUDA_ROPE_BLOCK_SIZE 256
446
+ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
446
447
#define CUDA_ALIBI_BLOCK_SIZE 32
447
448
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448
449
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;
501
502
502
503
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr };
503
504
505
+ static __device__ __forceinline__ float warp_reduce_sum (float x) {
506
+ #pragma unroll
507
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
508
+ x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
509
+ }
510
+ return x;
511
+ }
512
+
513
+ static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
514
+ #pragma unroll
515
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
516
+ a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
517
+ a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
518
+ }
519
+ return a;
520
+ }
521
+
522
+ static __device__ __forceinline__ float warp_reduce_max (float x) {
523
+ #pragma unroll
524
+ for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
525
+ x = fmaxf (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
526
+ }
527
+ return x;
528
+ }
529
+
504
530
static __global__ void add_f32 (const float * x, const float * y, float * dst, const int kx, const int ky) {
505
531
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
506
532
@@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
577
603
dst[i] = x[i] * x[i];
578
604
}
579
605
580
- static __device__ __forceinline__ float2 warp_reduce_sum (float2 a) {
581
- #pragma unroll
582
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
583
- a.x += __shfl_xor_sync (0xffffffff , a.x , mask, 32 );
584
- a.y += __shfl_xor_sync (0xffffffff , a.y , mask, 32 );
585
- }
586
- return a;
587
- }
588
-
589
606
template <int block_size>
590
607
static __global__ void norm_f32 (const float * x, float * dst, const int ncols) {
591
608
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
624
641
}
625
642
}
626
643
627
- static __device__ __forceinline__ float warp_reduce_sum (float x) {
628
- #pragma unroll
629
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
630
- x += __shfl_xor_sync (0xffffffff , x, mask, 32 );
631
- }
632
- return x;
633
- }
634
-
635
644
template <int block_size>
636
645
static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
637
646
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
@@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
4717
4726
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
4718
4727
}
4719
4728
4720
- // the CUDA soft max implementation differs from the CPU implementation
4721
- // instead of doubles floats are used
4722
- static __global__ void soft_max_f32 (const float * x, float * dst, const int ncols) {
4723
- const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4724
- const int block_size = blockDim .y ;
4725
- const int tid = threadIdx .y ;
4729
+ static __global__ void soft_max_f32 (const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4730
+ const int tid = threadIdx .x ;
4731
+ const int rowx = blockIdx .x ;
4732
+ const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4733
+
4734
+ const int block_size = blockDim .x ;
4735
+
4736
+ const int warp_id = threadIdx .x / WARP_SIZE;
4737
+ const int lane_id = threadIdx .x % WARP_SIZE;
4738
+
4739
+ __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
4726
4740
4727
4741
float max_val = -INFINITY;
4728
4742
4729
4743
for (int col = tid; col < ncols; col += block_size) {
4730
- const int i = row*ncols + col;
4731
- max_val = max (max_val, x[i]);
4744
+ const int ix = rowx*ncols + col;
4745
+ const int iy = rowy*ncols + col;
4746
+ max_val = max (max_val, x[ix]*scale + (y ? y[iy] : 0 .0f ));
4732
4747
}
4733
4748
4734
4749
// find the max value in the block
4735
- #pragma unroll
4736
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4737
- max_val = max (max_val, __shfl_xor_sync (0xffffffff , max_val, mask, 32 ));
4750
+ max_val = warp_reduce_max (max_val);
4751
+ if (block_size > WARP_SIZE) {
4752
+ if (warp_id == 0 ) {
4753
+ buf[lane_id] = -INFINITY;
4754
+ }
4755
+ __syncthreads ();
4756
+
4757
+ if (lane_id == 0 ) {
4758
+ buf[warp_id] = max_val;
4759
+ }
4760
+ __syncthreads ();
4761
+
4762
+ max_val = buf[lane_id];
4763
+ max_val = warp_reduce_max (max_val);
4738
4764
}
4739
4765
4740
4766
float tmp = 0 .f ;
4741
4767
4742
4768
for (int col = tid; col < ncols; col += block_size) {
4743
- const int i = row*ncols + col;
4744
- const float val = expf (x[i] - max_val);
4769
+ const int ix = rowx*ncols + col;
4770
+ const int iy = rowy*ncols + col;
4771
+ const float val = expf ((x[ix]*scale + (y ? y[iy] : 0 .0f )) - max_val);
4745
4772
tmp += val;
4746
- dst[i ] = val;
4773
+ dst[ix ] = val;
4747
4774
}
4748
4775
4749
- // sum up partial sums
4750
- #pragma unroll
4751
- for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
4752
- tmp += __shfl_xor_sync (0xffffffff , tmp, mask, 32 );
4776
+ // find the sum of exps in the block
4777
+ tmp = warp_reduce_sum (tmp);
4778
+ if (block_size > WARP_SIZE) {
4779
+ if (warp_id == 0 ) {
4780
+ buf[lane_id] = 0 .f ;
4781
+ }
4782
+ __syncthreads ();
4783
+
4784
+ if (lane_id == 0 ) {
4785
+ buf[warp_id] = tmp;
4786
+ }
4787
+ __syncthreads ();
4788
+
4789
+ tmp = buf[lane_id];
4790
+ tmp = warp_reduce_sum (tmp);
4753
4791
}
4754
4792
4755
4793
const float inv_tmp = 1 .f / tmp;
4756
4794
4757
4795
for (int col = tid; col < ncols; col += block_size) {
4758
- const int i = row *ncols + col;
4796
+ const int i = rowx *ncols + col;
4759
4797
dst[i] *= inv_tmp;
4760
4798
}
4761
4799
}
@@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
5792
5830
diag_mask_inf_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x, rows_per_channel, n_past);
5793
5831
}
5794
5832
5795
- static void soft_max_f32_cuda (const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5796
- const dim3 block_dims (1 , WARP_SIZE, 1 );
5833
+ static void soft_max_f32_cuda (const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5834
+ int nth = WARP_SIZE;
5835
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
5836
+ const dim3 block_dims (nth, 1 , 1 );
5797
5837
const dim3 block_nums (nrows_x, 1 , 1 );
5798
- soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
5838
+ soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale );
5799
5839
}
5800
5840
5801
5841
static void im2col_f32_f16_cuda (const float * x, half * dst,
@@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
6846
6886
GGML_ASSERT (src0->type == GGML_TYPE_F32);
6847
6887
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
6848
6888
6889
+ GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6890
+
6849
6891
const int64_t ne00 = src0->ne [0 ];
6850
- const int64_t nrows = ggml_nrows (src0);
6892
+ const int64_t nrows_x = ggml_nrows (src0);
6893
+ const int64_t nrows_y = src1 ? ggml_nrows (src1) : 1 ;
6851
6894
6852
- soft_max_f32_cuda (src0_dd, dst_dd, ne00, nrows, main_stream);
6895
+ float scale = 1 .0f ;
6896
+ memcpy (&scale, dst->op_params , sizeof (float ));
6897
+
6898
+ soft_max_f32_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
6853
6899
6854
- (void ) src1;
6855
6900
(void ) dst;
6856
- (void ) src1_dd;
6857
6901
}
6858
6902
6859
6903
inline void ggml_cuda_op_scale (
0 commit comments