Skip to content

Commit d606905

Browse files
young-developerOleksii Maryshchenko
and
Oleksii Maryshchenko
authored
cuda : use CUDA memory pool with async memory allocation/deallocation when available (ggml-org#3903)
* Using cuda memory pools for async alloc/dealloc. * If cuda device doesnt support memory pool than use old implementation. * Removed redundant cublasSetStream --------- Co-authored-by: Oleksii Maryshchenko <[email protected]>
1 parent 4ff1046 commit d606905

File tree

1 file changed

+77
-51
lines changed

1 file changed

+77
-51
lines changed

ggml-cuda.cu

+77-51
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
181181
do { \
182182
cudaError_t err_ = (err); \
183183
if (err_ != cudaSuccess) { \
184-
int id; \
185-
cudaGetDevice(&id); \
184+
int dev_id; \
185+
cudaGetDevice(&dev_id); \
186186
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
187187
cudaGetErrorString(err_)); \
188-
fprintf(stderr, "current device: %d\n", id); \
188+
fprintf(stderr, "current device: %d\n", dev_id); \
189189
exit(1); \
190190
} \
191191
} while (0)
@@ -195,11 +195,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
195195
do { \
196196
cublasStatus_t err_ = (err); \
197197
if (err_ != CUBLAS_STATUS_SUCCESS) { \
198-
int id; \
199-
cudaGetDevice(&id); \
198+
int dev_id; \
199+
cudaGetDevice(&dev_id); \
200200
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
201201
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
202-
fprintf(stderr, "current device: %d\n", id); \
202+
fprintf(stderr, "current device: %d\n", dev_id); \
203203
exit(1); \
204204
} \
205205
} while (0)
@@ -465,6 +465,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
465465

466466
#define MAX_STREAMS 8
467467
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
468+
static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
468469

469470
struct ggml_tensor_extra_gpu {
470471
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -5772,6 +5773,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
57725773
return ptr;
57735774
}
57745775

5776+
static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
5777+
if (g_cudaMemPools[id] == nullptr) {
5778+
return ggml_cuda_pool_malloc(size, actual_size);
5779+
}
5780+
void *ptr;
5781+
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
5782+
*actual_size = size;
5783+
return ptr;
5784+
}
5785+
57755786
static void ggml_cuda_pool_free(void * ptr, size_t size) {
57765787
scoped_spin_lock lock(g_cuda_pool_lock);
57775788
int id;
@@ -5790,6 +5801,13 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
57905801
}
57915802

57925803

5804+
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
5805+
if (g_cudaMemPools[id] == nullptr) {
5806+
return ggml_cuda_pool_free(ptr, actual_size);
5807+
}
5808+
CUDA_CHECK(cudaFreeAsync(ptr, stream));
5809+
}
5810+
57935811
void ggml_init_cublas() {
57945812
static bool initialized = false;
57955813

@@ -5844,6 +5862,13 @@ void ggml_init_cublas() {
58445862
// create cublas handle
58455863
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
58465864
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
5865+
5866+
// configure memory pool
5867+
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
5868+
if (err == cudaSuccess) {
5869+
size_t treshold = UINT64_MAX;
5870+
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
5871+
}
58475872
}
58485873

58495874
// configure logging to stdout
@@ -6437,7 +6462,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
64376462
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
64386463
GGML_ASSERT(to_fp16_cuda != nullptr);
64396464
size_t ne = row_diff*ne00;
6440-
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
6465+
src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
64416466
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
64426467
}
64436468
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6448,13 +6473,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
64486473
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
64496474
GGML_ASSERT(to_fp16_cuda != nullptr);
64506475
size_t ne = src1_ncols*ne10;
6451-
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
6476+
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
64526477
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
64536478
}
64546479
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6455-
6456-
size_t dst_as = 0;
6457-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
6480+
size_t dst_f16_as = 0;
6481+
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
64586482

64596483
const half alpha_f16 = 1.0f;
64606484
const half beta_f16 = 0.0f;
@@ -6472,14 +6496,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
64726496
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
64736497
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
64746498

6475-
ggml_cuda_pool_free(dst_f16, dst_as);
6499+
if (dst_f16_as != 0) {
6500+
ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
6501+
}
64766502

64776503
if (src0_as != 0) {
6478-
ggml_cuda_pool_free(src0_as_f16, src0_as);
6504+
ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
64796505
}
6480-
64816506
if (src1_as != 0) {
6482-
ggml_cuda_pool_free(src1_as_f16, src1_as);
6507+
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
64836508
}
64846509
}
64856510
else {
@@ -6489,7 +6514,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
64896514
if (src0->type != GGML_TYPE_F32) {
64906515
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
64916516
GGML_ASSERT(to_fp32_cuda != nullptr);
6492-
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
6517+
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
64936518
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
64946519
}
64956520
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6506,7 +6531,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
65066531
&beta, dst_dd_i, ldc));
65076532

65086533
if (src0_as != 0) {
6509-
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
6534+
ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
65106535
}
65116536
}
65126537

@@ -6929,29 +6954,30 @@ static void ggml_cuda_op_mul_mat(
69296954
src0_dd[id] = (char *) src0_extra->data_device[id];
69306955
} else {
69316956
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
6932-
src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
6957+
src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
69336958
}
69346959

69356960
if (src1_on_device && src1_is_contiguous) {
69366961
src1_ddf[id] = (float *) src1_extra->data_device[id];
69376962
} else {
6938-
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
6963+
src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
69396964
}
69406965

69416966
if (convert_src1_to_q8_1) {
6942-
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
6967+
const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
6968+
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
69436969

69446970
if (src1_on_device && src1_is_contiguous) {
69456971
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
6946-
CUDA_CHECK(cudaGetLastError());
6972+
// CUDA_CHECK(cudaGetLastError());
69476973
}
69486974
}
69496975

69506976
if (dst_on_device) {
69516977
dst_dd[id] = (float *) dst_extra->data_device[id];
69526978
} else {
69536979
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
6954-
dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
6980+
dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream);
69556981
}
69566982
}
69576983

@@ -7077,24 +7103,6 @@ static void ggml_cuda_op_mul_mat(
70777103
}
70787104
}
70797105

7080-
for (int64_t id = 0; id < g_device_count; ++id) {
7081-
CUDA_CHECK(ggml_cuda_set_device(id));
7082-
7083-
// free buffers again when done
7084-
if (src0_as[id] > 0) {
7085-
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
7086-
}
7087-
if (src1_asf[id] > 0) {
7088-
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
7089-
}
7090-
if (src1_asq[id] > 0) {
7091-
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
7092-
}
7093-
if (dst_as[id] > 0) {
7094-
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
7095-
}
7096-
}
7097-
70987106
// main device waits for all other devices to be finished
70997107
if (split && g_device_count > 1) {
71007108
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@@ -7112,6 +7120,21 @@ static void ggml_cuda_op_mul_mat(
71127120
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
71137121
CUDA_CHECK(cudaDeviceSynchronize());
71147122
}
7123+
7124+
for (int64_t id = 0; id < g_device_count; ++id) {
7125+
if (src0_as[id] > 0) {
7126+
ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
7127+
}
7128+
if (src1_asf[id] > 0) {
7129+
ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
7130+
}
7131+
if (src1_asq[id] > 0) {
7132+
ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
7133+
}
7134+
if (dst_as[id] > 0) {
7135+
ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
7136+
}
7137+
}
71157138
}
71167139

71177140
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7298,11 +7321,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72987321
GGML_ASSERT(to_fp16_cuda != nullptr);
72997322

73007323
size_t src1_as = 0;
7301-
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
7324+
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
73027325
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
73037326

73047327
size_t dst_as = 0;
7305-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
7328+
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
73067329

73077330
GGML_ASSERT(ne12 % ne02 == 0);
73087331
GGML_ASSERT(ne13 % ne03 == 0);
@@ -7349,10 +7372,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73497372
} else {
73507373
// use cublasGemmBatchedEx
73517374
const int ne23 = ne12*ne13;
7352-
7353-
void ** ptrs_as = nullptr;
7375+
// allocate device memory for pointers
73547376
size_t ptrs_s = 0;
7355-
ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7377+
void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream);
73567378

73577379
dim3 block_dims(ne13, ne12);
73587380
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
@@ -7365,7 +7387,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73657387
dst->nb[2], dst->nb[3],
73667388
r2, r3);
73677389
CUDA_CHECK(cudaGetLastError());
7368-
73697390
CUBLAS_CHECK(
73707391
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
73717392
ne01, ne11, ne10,
@@ -7375,16 +7396,21 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73757396
ne23,
73767397
CUBLAS_COMPUTE_16F,
73777398
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7378-
7379-
ggml_cuda_pool_free(ptrs_as, ptrs_s);
7399+
// free device memory for pointers
7400+
if (ptrs_s != 0) {
7401+
ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream);
7402+
}
73807403
}
73817404
#endif
73827405

73837406
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
73847407
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7385-
7386-
ggml_cuda_pool_free(src1_as_f16, src1_as);
7387-
ggml_cuda_pool_free(dst_f16, dst_as);
7408+
if (src1_as != 0) {
7409+
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
7410+
}
7411+
if (dst_as != 0) {
7412+
ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
7413+
}
73887414
}
73897415

73907416
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

0 commit comments

Comments
 (0)