Skip to content

Commit 3385dd4

Browse files
committed
bring ggml v2_cuda up to date with AMD changes
1 parent f915a46 commit 3385dd4

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

otherarch/ggml_v2-cuda-legacy.cu

+6-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
#include <hip/hip_runtime.h>
99
#include <hipblas/hipblas.h>
1010
#include <hip/hip_fp16.h>
11+
#ifdef __HIP_PLATFORM_AMD__
12+
// for rocblas_initialize()
13+
#include "rocblas/rocblas.h"
14+
#endif
1115
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
1216
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
1317
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
@@ -32,6 +36,7 @@
3236
#define cudaEventDisableTiming hipEventDisableTiming
3337
#define cudaEventRecord hipEventRecord
3438
#define cudaEvent_t hipEvent_t
39+
#define cudaEventDestroy hipEventDestroy
3540
#define cudaFree hipFree
3641
#define cudaFreeHost hipHostFree
3742
#define cudaGetDevice hipGetDevice
@@ -54,7 +59,7 @@
5459
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
5560
#define cudaStreamNonBlocking hipStreamNonBlocking
5661
#define cudaStreamSynchronize hipStreamSynchronize
57-
#define cudaStreamWaitEvent hipStreamWaitEvent
62+
#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
5863
#define cudaStream_t hipStream_t
5964
#define cudaSuccess hipSuccess
6065
#else

otherarch/ggml_v2-cuda.cu

+6-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
#include <hip/hip_runtime.h>
99
#include <hipblas/hipblas.h>
1010
#include <hip/hip_fp16.h>
11+
#ifdef __HIP_PLATFORM_AMD__
12+
// for rocblas_initialize()
13+
#include "rocblas/rocblas.h"
14+
#endif
1115
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
1216
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
1317
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
@@ -32,6 +36,7 @@
3236
#define cudaEventDisableTiming hipEventDisableTiming
3337
#define cudaEventRecord hipEventRecord
3438
#define cudaEvent_t hipEvent_t
39+
#define cudaEventDestroy hipEventDestroy
3540
#define cudaFree hipFree
3641
#define cudaFreeHost hipHostFree
3742
#define cudaGetDevice hipGetDevice
@@ -54,14 +59,13 @@
5459
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
5560
#define cudaStreamNonBlocking hipStreamNonBlocking
5661
#define cudaStreamSynchronize hipStreamSynchronize
57-
#define cudaStreamWaitEvent hipStreamWaitEvent
62+
#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
5863
#define cudaStream_t hipStream_t
5964
#define cudaSuccess hipSuccess
6065
#else
6166
#include <cuda_runtime.h>
6267
#include <cublas_v2.h>
6368
#include <cuda_fp16.h>
64-
6569
#endif
6670

6771
#include "ggml_v2-cuda.h"

0 commit comments

Comments
 (0)