Skip to content

Commit ac41d1b

Browse files
authored
fix: compatible with torch 2.2 (#478)
torch 2.2 no member function `getCurrentCUDABlasLtHandle` torch 2.3 and 2.4 works well
1 parent 1a6b17e commit ac41d1b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

include/flashinfer/bmm_fp8.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ void bmm_fp8_internal_cublaslt(const AT* A, const BT* B, DT* D, int batch_size,
152152
auto workspace = allocator.allocate(workspace_size);
153153
cublasLtMatmulHeuristicResult_t heuristic_result = {};
154154
int returned_result = 0;
155-
auto lt_handle = at::cuda::getCurrentCUDABlasLtHandle();
155+
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
156156
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
157157
lt_handle, matmul_desp.descriptor(), a_desp.descriptor(), b_desp.descriptor(),
158158
d_desp.descriptor(), d_desp.descriptor(), preference.descriptor(), 1, &heuristic_result,

0 commit comments

Comments
 (0)