Skip to content

Commit ef47ec1

Browse files
authored
ggml : add ggml_soft_max_ext (#4256)
* metal : implement soft_max_ext * cuda : implement soft_max_ext * ggml : implement soft_max_ext (CPU) * batched-bench : print threads ggml-ci * metal : simplify soft_max encoding ggml-ci * cuda : use 512 threads for soft_max instead of 32 * ggml : update soft max cpu * cuda : do warp-based block reduce * cuda : increase max block size to 1024 * cuda : fix warp reduction initialization of shared mem * metal : warp-based reduction for soft max kernel * metal : warp-based reduce for rms_norm * metal : simplify soft max kernel ggml-ci * alloc : fix build with debug
1 parent 1d14411 commit ef47ec1

File tree

8 files changed

+298
-183
lines changed

8 files changed

+298
-183
lines changed

examples/batched-bench/batched-bench.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
155155
}
156156

157157
LOG_TEE("\n");
158-
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
158+
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
159159
LOG_TEE("\n");
160160

161161
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");

ggml-alloc.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
137137

138138
#ifdef GGML_ALLOCATOR_DEBUG
139139
add_allocated_tensor(alloc, tensor);
140-
size_t cur_max = (char*)addr - (char*)alloc->data + size;
140+
size_t cur_max = (char*)addr - (char*)alloc->base + size;
141141
if (cur_max > alloc->max_size) {
142142
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
143143
for (int i = 0; i < 1024; i++) {

ggml-cuda.cu

+87-43
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
443443
#define CUDA_SCALE_BLOCK_SIZE 256
444444
#define CUDA_CLAMP_BLOCK_SIZE 256
445445
#define CUDA_ROPE_BLOCK_SIZE 256
446+
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
446447
#define CUDA_ALIBI_BLOCK_SIZE 32
447448
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
448449
#define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;
501502

502503
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
503504

505+
static __device__ __forceinline__ float warp_reduce_sum(float x) {
506+
#pragma unroll
507+
for (int mask = 16; mask > 0; mask >>= 1) {
508+
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
509+
}
510+
return x;
511+
}
512+
513+
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
514+
#pragma unroll
515+
for (int mask = 16; mask > 0; mask >>= 1) {
516+
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
517+
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
518+
}
519+
return a;
520+
}
521+
522+
static __device__ __forceinline__ float warp_reduce_max(float x) {
523+
#pragma unroll
524+
for (int mask = 16; mask > 0; mask >>= 1) {
525+
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
526+
}
527+
return x;
528+
}
529+
504530
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
505531
const int i = blockDim.x*blockIdx.x + threadIdx.x;
506532

@@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
577603
dst[i] = x[i] * x[i];
578604
}
579605

580-
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
581-
#pragma unroll
582-
for (int mask = 16; mask > 0; mask >>= 1) {
583-
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
584-
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
585-
}
586-
return a;
587-
}
588-
589606
template <int block_size>
590607
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
591608
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
624641
}
625642
}
626643

627-
static __device__ __forceinline__ float warp_reduce_sum(float x) {
628-
#pragma unroll
629-
for (int mask = 16; mask > 0; mask >>= 1) {
630-
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
631-
}
632-
return x;
633-
}
634-
635644
template <int block_size>
636645
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
637646
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
47174726
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
47184727
}
47194728

4720-
// the CUDA soft max implementation differs from the CPU implementation
4721-
// instead of doubles floats are used
4722-
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
4723-
const int row = blockDim.x*blockIdx.x + threadIdx.x;
4724-
const int block_size = blockDim.y;
4725-
const int tid = threadIdx.y;
4729+
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
4730+
const int tid = threadIdx.x;
4731+
const int rowx = blockIdx.x;
4732+
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
4733+
4734+
const int block_size = blockDim.x;
4735+
4736+
const int warp_id = threadIdx.x / WARP_SIZE;
4737+
const int lane_id = threadIdx.x % WARP_SIZE;
4738+
4739+
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
47264740

47274741
float max_val = -INFINITY;
47284742

47294743
for (int col = tid; col < ncols; col += block_size) {
4730-
const int i = row*ncols + col;
4731-
max_val = max(max_val, x[i]);
4744+
const int ix = rowx*ncols + col;
4745+
const int iy = rowy*ncols + col;
4746+
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
47324747
}
47334748

47344749
// find the max value in the block
4735-
#pragma unroll
4736-
for (int mask = 16; mask > 0; mask >>= 1) {
4737-
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
4750+
max_val = warp_reduce_max(max_val);
4751+
if (block_size > WARP_SIZE) {
4752+
if (warp_id == 0) {
4753+
buf[lane_id] = -INFINITY;
4754+
}
4755+
__syncthreads();
4756+
4757+
if (lane_id == 0) {
4758+
buf[warp_id] = max_val;
4759+
}
4760+
__syncthreads();
4761+
4762+
max_val = buf[lane_id];
4763+
max_val = warp_reduce_max(max_val);
47384764
}
47394765

47404766
float tmp = 0.f;
47414767

47424768
for (int col = tid; col < ncols; col += block_size) {
4743-
const int i = row*ncols + col;
4744-
const float val = expf(x[i] - max_val);
4769+
const int ix = rowx*ncols + col;
4770+
const int iy = rowy*ncols + col;
4771+
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
47454772
tmp += val;
4746-
dst[i] = val;
4773+
dst[ix] = val;
47474774
}
47484775

4749-
// sum up partial sums
4750-
#pragma unroll
4751-
for (int mask = 16; mask > 0; mask >>= 1) {
4752-
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
4776+
// find the sum of exps in the block
4777+
tmp = warp_reduce_sum(tmp);
4778+
if (block_size > WARP_SIZE) {
4779+
if (warp_id == 0) {
4780+
buf[lane_id] = 0.f;
4781+
}
4782+
__syncthreads();
4783+
4784+
if (lane_id == 0) {
4785+
buf[warp_id] = tmp;
4786+
}
4787+
__syncthreads();
4788+
4789+
tmp = buf[lane_id];
4790+
tmp = warp_reduce_sum(tmp);
47534791
}
47544792

47554793
const float inv_tmp = 1.f / tmp;
47564794

47574795
for (int col = tid; col < ncols; col += block_size) {
4758-
const int i = row*ncols + col;
4796+
const int i = rowx*ncols + col;
47594797
dst[i] *= inv_tmp;
47604798
}
47614799
}
@@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
57925830
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
57935831
}
57945832

5795-
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
5796-
const dim3 block_dims(1, WARP_SIZE, 1);
5833+
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
5834+
int nth = WARP_SIZE;
5835+
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
5836+
const dim3 block_dims(nth, 1, 1);
57975837
const dim3 block_nums(nrows_x, 1, 1);
5798-
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
5838+
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
57995839
}
58005840

58015841
static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
68466886
GGML_ASSERT(src0->type == GGML_TYPE_F32);
68476887
GGML_ASSERT( dst->type == GGML_TYPE_F32);
68486888

6889+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
6890+
68496891
const int64_t ne00 = src0->ne[0];
6850-
const int64_t nrows = ggml_nrows(src0);
6892+
const int64_t nrows_x = ggml_nrows(src0);
6893+
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
68516894

6852-
soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
6895+
float scale = 1.0f;
6896+
memcpy(&scale, dst->op_params, sizeof(float));
6897+
6898+
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
68536899

6854-
(void) src1;
68556900
(void) dst;
6856-
(void) src1_dd;
68576901
}
68586902

68596903
inline void ggml_cuda_op_scale(

ggml-metal.m

+27-16
Original file line numberDiff line numberDiff line change
@@ -1028,20 +1028,27 @@ void ggml_metal_graph_compute(
10281028
int nth = 32; // SIMD width
10291029

10301030
if (ne00%4 == 0) {
1031+
while (nth < ne00/4 && nth < 256) {
1032+
nth *= 2;
1033+
}
10311034
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
10321035
} else {
1033-
do {
1036+
while (nth < ne00 && nth < 1024) {
10341037
nth *= 2;
1035-
} while (nth <= ne00 && nth <= 1024);
1036-
nth /= 2;
1038+
}
10371039
[encoder setComputePipelineState:ctx->pipeline_soft_max];
10381040
}
1039-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1040-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1041-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1042-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1043-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1044-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1041+
1042+
const float scale = ((float *) dst->op_params)[0];
1043+
1044+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1045+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1046+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1047+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1048+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1049+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1050+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1051+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
10451052

10461053
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
10471054
} break;
@@ -1351,15 +1358,19 @@ void ggml_metal_graph_compute(
13511358
float eps;
13521359
memcpy(&eps, dst->op_params, sizeof(float));
13531360

1354-
const int nth = MIN(512, ne00);
1361+
int nth = 32; // SIMD width
1362+
1363+
while (nth < ne00/4 && nth < 1024) {
1364+
nth *= 2;
1365+
}
13551366

13561367
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
1357-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1358-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1359-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1360-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1361-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1362-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1368+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1369+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1370+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1371+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1372+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
1373+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
13631374

13641375
const int64_t nrows = ggml_nrows(src0);
13651376

0 commit comments

Comments
 (0)