diff --git a/features/config/TEMPLATE_cublasLt.xml b/features/config/TEMPLATE_cublasLt.xml
new file mode 100644
index 00000000..cb15eeed
--- /dev/null
+++ b/features/config/TEMPLATE_cublasLt.xml
@@ -0,0 +1,13 @@
+
+
+
+ test
+
+
+
+
+
+
+
+
+
diff --git a/features/feature_case/cublasLt/matmul.cu b/features/feature_case/cublasLt/matmul.cu
new file mode 100644
index 00000000..0cf382f7
--- /dev/null
+++ b/features/feature_case/cublasLt/matmul.cu
@@ -0,0 +1,754 @@
+// ===------------ matmul.cu ----------------------------- *- CUDA -* ----=== //
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// ===--------------------------------------------------------------------=== //
+
+#include
+#include
+#include
+
+const constexpr int COL_TURING = 0;
+const constexpr int COL_AMPERE = 1;
+
+// The original source of below two functions was under the license below:
+// Copyright (c) Facebook, Inc. and its affiliates.
+//
+// This source code is licensed under the MIT license found in the
+// LICENSE file in the root directory of this source tree.
+//
+// Repo: https://github.com/TimDettmers/bitsandbytes.git
+inline int checkCublasStatus(cublasStatus_t status) {
+ if (status != CUBLAS_STATUS_SUCCESS) {
+ printf("cuBLAS API failed with status %d\n", status);
+ //throw std::logic_error("cuBLAS API failed");
+ return 1;
+ }
+ return 0;
+}
+
+template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+{
+ int has_error = 0;
+ cublasLtMatmulDesc_t matmulDesc = NULL;
+ cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
+ cublasOperation_t opT = CUBLAS_OP_T;
+ cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C;
+ cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4;
+
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb));
+
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ if(FORMATB == COL_TURING)
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
+ else
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
+
+ if(DTYPE_OUT == 32)
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ int alpha = 1, beta = 0;
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ else
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F));
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ if(!SCALE_ROWS)
+ {
+ float alpha = 1.0f, beta = 0.0f;
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ else
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ }
+
+ cudaStreamSynchronize(0);
+
+ if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc));
+ if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc));
+ if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc));
+ if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
+ if(has_error == 1)
+ printf("error detected");
+
+ return has_error;
+}
+
+void transform(cublasLtHandle_t ltHandle, const void *in, int ld_in,
+ cublasLtMatrixLayout_t layout_in, void *out, int ld_out,
+ cublasLtMatrixLayout_t layout_out) {
+ cublasLtMatrixTransformDesc_t transform_desc = NULL;
+ cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F);
+ float alpha = 1.0f, beta = 0.0f;
+ cublasLtMatrixTransform(ltHandle, transform_desc, &alpha, in, layout_in,
+ &beta, NULL, NULL, out, layout_out, 0);
+ cublasLtMatrixTransformDescDestroy(transform_desc);
+}
+
+// igemmlt
+bool test1() {
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+ const constexpr int m = 4;
+ const constexpr int n = 2;
+ const constexpr int k = 3;
+ int lda = m;
+ int ldb = n;
+ int ldc = m;
+ void *Adev;
+ void *Bdev;
+ void *Cdev;
+ cudaMalloc(&Adev, m * k * sizeof(int8_t));
+ cudaMalloc(&Bdev, n * k * sizeof(int8_t));
+ cudaMalloc(&Cdev, m * n * sizeof(int32_t));
+
+ int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17};
+ int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0};
+
+ cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+ cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+
+ cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL,
+ Cdesc_col_major = NULL;
+ cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda);
+ cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb);
+ cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_32I, m, n, ldc);
+
+ // Convert A and B
+ cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col4_4r2_8c = NULL,
+ Cdesc_col32 = NULL;
+ int8_t *A_col32, *B_col4_4r2_8c;
+ int32_t *C_col32;
+ cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t));
+ cudaMalloc(&B_col4_4r2_8c, ((n + 8 - 1) / 8) * 8 * 32 * sizeof(std::int8_t));
+ cudaMalloc(&C_col32, m * 32 * sizeof(std::int32_t));
+ cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32);
+ cublasLtMatrixLayoutCreate(&Bdesc_col4_4r2_8c, CUDA_R_8I, k, n,
+ ((n + 8 - 1) / 8) * 8 * 32);
+ cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_32I, m, n, m * 32);
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col4_4r2_8c = CUBLASLT_ORDER_COL4_4R2_8C;
+ cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+ cublasLtMatrixLayoutSetAttribute(Bdesc_col4_4r2_8c,
+ CUBLASLT_MATRIX_LAYOUT_ORDER, &col4_4r2_8c,
+ sizeof(col4_4r2_8c));
+ cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+
+ transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32);
+ transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col4_4r2_8c, 8 * 32,
+ Bdesc_col4_4r2_8c);
+
+ // Matmul
+ igemmlt(ltHandle, m, n, k, A_col32, B_col4_4r2_8c, C_col32,
+ nullptr, m * 32, ((n + 8 - 1) / 8) * 8 * 32,
+ m * 32);
+
+ // Convert C
+ transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major);
+ cudaStreamSynchronize(0);
+
+ // Check result
+ int32_t Chost[m * n];
+ cudaMemcpy(Chost, Cdev, m * n * sizeof(int32_t), cudaMemcpyDeviceToHost);
+
+ bool error = false;
+ int32_t C_ref[m * n] = {14, 17, 20, 23, 4, 6, 8, 10};
+ for (int i = 0; i < m * n; i++) {
+ if (Chost[i] != C_ref[i]) {
+ error = true;
+ break;
+ }
+ }
+ printf("c:\n");
+ for (int i = 0; i < m * n; i++)
+ printf("%d, ", Chost[i]);
+ printf("\n");
+
+ if (error) {
+ printf("error\n");
+ } else {
+ printf("success\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+ cublasLtMatrixLayoutDestroy(Adesc_col32);
+ cublasLtMatrixLayoutDestroy(Bdesc_col4_4r2_8c);
+ cublasLtMatrixLayoutDestroy(Cdesc_col32);
+ cublasLtMatrixLayoutDestroy(Adesc_col_major);
+ cublasLtMatrixLayoutDestroy(Bdesc_col_major);
+ cublasLtMatrixLayoutDestroy(Cdesc_col_major);
+ cudaFree(Adev);
+ cudaFree(Bdev);
+ cudaFree(Cdev);
+
+ return !error;
+}
+
+// igemmlt
+bool test2() {
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+ const constexpr int m = 4;
+ const constexpr int n = 2;
+ const constexpr int k = 3;
+ int lda = m;
+ int ldb = n;
+ int ldc = m;
+ void *Adev;
+ void *Bdev;
+ void *Cdev;
+ cudaMalloc(&Adev, m * k * sizeof(int8_t));
+ cudaMalloc(&Bdev, n * k * sizeof(int8_t));
+ cudaMalloc(&Cdev, m * n * sizeof(int8_t));
+
+ int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17};
+ int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0};
+
+ cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+ cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+
+ cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL,
+ Cdesc_col_major = NULL;
+ cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda);
+ cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb);
+ cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_8I, m, n, ldc);
+
+ // Convert A and B
+ cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col4_4r2_8c = NULL,
+ Cdesc_col32 = NULL;
+ int8_t *A_col32, *B_col4_4r2_8c;
+ int8_t *C_col32;
+ cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t));
+ cudaMalloc(&B_col4_4r2_8c, ((n + 8 - 1) / 8) * 8 * 32 * sizeof(std::int8_t));
+ cudaMalloc(&C_col32, m * 32 * sizeof(std::int8_t));
+ cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32);
+ cublasLtMatrixLayoutCreate(&Bdesc_col4_4r2_8c, CUDA_R_8I, k, n,
+ ((n + 8 - 1) / 8) * 8 * 32);
+ cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_8I, m, n, m * 32);
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col4_4r2_8c = CUBLASLT_ORDER_COL4_4R2_8C;
+ cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+ cublasLtMatrixLayoutSetAttribute(Bdesc_col4_4r2_8c,
+ CUBLASLT_MATRIX_LAYOUT_ORDER, &col4_4r2_8c,
+ sizeof(col4_4r2_8c));
+ cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+
+ transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32);
+ transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col4_4r2_8c, 8 * 32,
+ Bdesc_col4_4r2_8c);
+
+ // Matmul
+ igemmlt(ltHandle, m, n, k, A_col32, B_col4_4r2_8c, C_col32,
+ nullptr, m * 32, ((n + 8 - 1) / 8) * 8 * 32,
+ m * 32);
+
+ // Convert C
+ transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major);
+ cudaStreamSynchronize(0);
+
+ // Check result
+ int8_t Chost[m * n];
+ cudaMemcpy(Chost, Cdev, m * n * sizeof(int8_t), cudaMemcpyDeviceToHost);
+
+ bool error = false;
+ int8_t C_ref[m * n] = {14, 17, 20, 23, 4, 6, 8, 10};
+ for (int i = 0; i < m * n; i++) {
+ if (Chost[i] != C_ref[i]) {
+ error = true;
+ break;
+ }
+ }
+ printf("c:\n");
+ for (int i = 0; i < m * n; i++)
+ printf("%d, ", Chost[i]);
+ printf("\n");
+
+ if (error) {
+ printf("error\n");
+ } else {
+ printf("success\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+ cublasLtMatrixLayoutDestroy(Adesc_col32);
+ cublasLtMatrixLayoutDestroy(Bdesc_col4_4r2_8c);
+ cublasLtMatrixLayoutDestroy(Cdesc_col32);
+ cublasLtMatrixLayoutDestroy(Adesc_col_major);
+ cublasLtMatrixLayoutDestroy(Bdesc_col_major);
+ cublasLtMatrixLayoutDestroy(Cdesc_col_major);
+ cudaFree(Adev);
+ cudaFree(Bdev);
+ cudaFree(Cdev);
+
+ return !error;
+}
+
+// igemmlt
+bool test3() {
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+ const constexpr int m = 4;
+ const constexpr int n = 2;
+ const constexpr int k = 3;
+ int lda = m;
+ int ldb = n;
+ int ldc = m;
+ void *Adev;
+ void *Bdev;
+ void *Cdev;
+ cudaMalloc(&Adev, m * k * sizeof(int8_t));
+ cudaMalloc(&Bdev, n * k * sizeof(int8_t));
+ cudaMalloc(&Cdev, m * n * sizeof(int8_t));
+
+ int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17};
+ int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0};
+
+ cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+ cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+
+ cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL,
+ Cdesc_col_major = NULL;
+ cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda);
+ cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb);
+ cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_8I, m, n, ldc);
+
+ // Convert A and B
+ cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col4_4r2_8c = NULL,
+ Cdesc_col32 = NULL;
+ int8_t *A_col32, *B_col4_4r2_8c;
+ int8_t *C_col32;
+ cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t));
+ cudaMalloc(&B_col4_4r2_8c, ((n + 8 - 1) / 8) * 8 * 32 * sizeof(std::int8_t));
+ cudaMalloc(&C_col32, m * 32 * sizeof(std::int8_t));
+ cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32);
+ cublasLtMatrixLayoutCreate(&Bdesc_col4_4r2_8c, CUDA_R_8I, k, n,
+ ((n + 8 - 1) / 8) * 8 * 32);
+ cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_8I, m, n, m * 32);
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col4_4r2_8c = CUBLASLT_ORDER_COL4_4R2_8C;
+ cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+ cublasLtMatrixLayoutSetAttribute(Bdesc_col4_4r2_8c,
+ CUBLASLT_MATRIX_LAYOUT_ORDER, &col4_4r2_8c,
+ sizeof(col4_4r2_8c));
+ cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+
+ transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32);
+ transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col4_4r2_8c, 8 * 32,
+ Bdesc_col4_4r2_8c);
+
+ float *alpha;
+ cudaMallocManaged(&alpha, 4 * sizeof(float));
+ alpha[0] = 0;
+ alpha[1] = 1;
+ alpha[2] = 2;
+ alpha[3] = 3;
+
+ // Matmul
+ igemmlt(ltHandle, m, n, k, A_col32, B_col4_4r2_8c, C_col32,
+ alpha, m * 32, ((n + 8 - 1) / 8) * 8 * 32, m * 32);
+
+ // Convert C
+ transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major);
+ cudaStreamSynchronize(0);
+
+ // Check result
+ int8_t Chost[m * n];
+ cudaMemcpy(Chost, Cdev, m * n * sizeof(int8_t), cudaMemcpyDeviceToHost);
+
+ bool error = false;
+ int8_t C_ref[m * n] = {0, 17, 40, 69, 0, 6, 16, 30};
+ for (int i = 0; i < m * n; i++) {
+ if (Chost[i] != C_ref[i]) {
+ error = true;
+ break;
+ }
+ }
+ printf("c:\n");
+ for (int i = 0; i < m * n; i++)
+ printf("%d, ", Chost[i]);
+ printf("\n");
+
+ if (error) {
+ printf("error\n");
+ } else {
+ printf("success\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+ cublasLtMatrixLayoutDestroy(Adesc_col32);
+ cublasLtMatrixLayoutDestroy(Bdesc_col4_4r2_8c);
+ cublasLtMatrixLayoutDestroy(Cdesc_col32);
+ cublasLtMatrixLayoutDestroy(Adesc_col_major);
+ cublasLtMatrixLayoutDestroy(Bdesc_col_major);
+ cublasLtMatrixLayoutDestroy(Cdesc_col_major);
+ cudaFree(Adev);
+ cudaFree(Bdev);
+ cudaFree(Cdev);
+ cudaFree(alpha);
+
+ return !error;
+}
+
+// igemmlt
+bool test4() {
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+ const constexpr int m = 4;
+ const constexpr int n = 2;
+ const constexpr int k = 3;
+ int lda = m;
+ int ldb = n;
+ int ldc = m;
+ void *Adev;
+ void *Bdev;
+ void *Cdev;
+ cudaMalloc(&Adev, m * k * sizeof(int8_t));
+ cudaMalloc(&Bdev, n * k * sizeof(int8_t));
+ cudaMalloc(&Cdev, m * n * sizeof(int32_t));
+
+ int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17};
+ int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0};
+
+ cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+ cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+
+ cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL,
+ Cdesc_col_major = NULL;
+ cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda);
+ cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb);
+ cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_32I, m, n, ldc);
+
+ // Convert A and B
+ cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col32_2r_4r4 = NULL,
+ Cdesc_col32 = NULL;
+ int8_t *A_col32, *B_col32_2r_4r4;
+ int32_t *C_col32;
+ cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t));
+ cudaMalloc(&B_col32_2r_4r4,
+ ((n + 32 - 1) / 32) * 32 * 32 * sizeof(std::int8_t));
+ cudaMalloc(&C_col32, m * 32 * sizeof(std::int32_t));
+ cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32);
+ cublasLtMatrixLayoutCreate(&Bdesc_col32_2r_4r4, CUDA_R_8I, k, n,
+ ((n + 32 - 1) / 32) * 32 * 32);
+ cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_32I, m, n, m * 32);
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col32_2r_4r4 = CUBLASLT_ORDER_COL32_2R_4R4;
+ cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+ cublasLtMatrixLayoutSetAttribute(Bdesc_col32_2r_4r4,
+ CUBLASLT_MATRIX_LAYOUT_ORDER, &col32_2r_4r4,
+ sizeof(col32_2r_4r4));
+ cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+
+ transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32);
+ transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col32_2r_4r4, 8 * 32,
+ Bdesc_col32_2r_4r4);
+
+ // Matmul
+ igemmlt(ltHandle, m, n, k, A_col32, B_col32_2r_4r4,
+ C_col32, nullptr, m * 32,
+ ((n + 8 - 1) / 8) * 8 * 32, m * 32);
+
+ // Convert C
+ transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major);
+ cudaStreamSynchronize(0);
+
+ // Check result
+ int32_t Chost[m * n];
+ cudaMemcpy(Chost, Cdev, m * n * sizeof(int32_t), cudaMemcpyDeviceToHost);
+
+ bool error = false;
+ int32_t C_ref[m * n] = {14, 17, 20, 23, 4, 6, 8, 10};
+ for (int i = 0; i < m * n; i++) {
+ if (Chost[i] != C_ref[i]) {
+ error = true;
+ break;
+ }
+ }
+ printf("c:\n");
+ for (int i = 0; i < m * n; i++)
+ printf("%d, ", Chost[i]);
+ printf("\n");
+
+ if (error) {
+ printf("error\n");
+ } else {
+ printf("success\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+ cublasLtMatrixLayoutDestroy(Adesc_col32);
+ cublasLtMatrixLayoutDestroy(Bdesc_col32_2r_4r4);
+ cublasLtMatrixLayoutDestroy(Cdesc_col32);
+ cublasLtMatrixLayoutDestroy(Adesc_col_major);
+ cublasLtMatrixLayoutDestroy(Bdesc_col_major);
+ cublasLtMatrixLayoutDestroy(Cdesc_col_major);
+ cudaFree(Adev);
+ cudaFree(Bdev);
+ cudaFree(Cdev);
+
+ return !error;
+}
+
+// igemmlt
+bool test5() {
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+ const constexpr int m = 4;
+ const constexpr int n = 2;
+ const constexpr int k = 3;
+ int lda = m;
+ int ldb = n;
+ int ldc = m;
+ void *Adev;
+ void *Bdev;
+ void *Cdev;
+ cudaMalloc(&Adev, m * k * sizeof(int8_t));
+ cudaMalloc(&Bdev, n * k * sizeof(int8_t));
+ cudaMalloc(&Cdev, m * n * sizeof(int8_t));
+
+ int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17};
+ int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0};
+
+ cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+ cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+
+ cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL,
+ Cdesc_col_major = NULL;
+ cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda);
+ cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb);
+ cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_8I, m, n, ldc);
+
+ // Convert A and B
+ cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col32_2r_4r4 = NULL,
+ Cdesc_col32 = NULL;
+ int8_t *A_col32, *B_col32_2r_4r4;
+ int8_t *C_col32;
+ cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t));
+ cudaMalloc(&B_col32_2r_4r4,
+ ((n + 32 - 1) / 32) * 32 * 32 * sizeof(std::int8_t));
+ cudaMalloc(&C_col32, m * 32 * sizeof(std::int8_t));
+ cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32);
+ cublasLtMatrixLayoutCreate(&Bdesc_col32_2r_4r4, CUDA_R_8I, k, n,
+ ((n + 32 - 1) / 32) * 32 * 32);
+ cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_8I, m, n, m * 32);
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col32_2r_4r4 = CUBLASLT_ORDER_COL32_2R_4R4;
+ cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+ cublasLtMatrixLayoutSetAttribute(Bdesc_col32_2r_4r4,
+ CUBLASLT_MATRIX_LAYOUT_ORDER, &col32_2r_4r4,
+ sizeof(col32_2r_4r4));
+ cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+
+ transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32);
+ transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col32_2r_4r4, 8 * 32,
+ Bdesc_col32_2r_4r4);
+
+ // Matmul
+ igemmlt(ltHandle, m, n, k, A_col32, B_col32_2r_4r4, C_col32,
+ nullptr, m * 32, ((n + 8 - 1) / 8) * 8 * 32,
+ m * 32);
+
+ // Convert C
+ transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major);
+ cudaStreamSynchronize(0);
+
+ // Check result
+ int8_t Chost[m * n];
+ cudaMemcpy(Chost, Cdev, m * n * sizeof(int8_t), cudaMemcpyDeviceToHost);
+
+ bool error = false;
+ int8_t C_ref[m * n] = {14, 17, 20, 23, 4, 6, 8, 10};
+ for (int i = 0; i < m * n; i++) {
+ if (Chost[i] != C_ref[i]) {
+ error = true;
+ break;
+ }
+ }
+ printf("c:\n");
+ for (int i = 0; i < m * n; i++)
+ printf("%d, ", Chost[i]);
+ printf("\n");
+
+ if (error) {
+ printf("error\n");
+ } else {
+ printf("success\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+ cublasLtMatrixLayoutDestroy(Adesc_col32);
+ cublasLtMatrixLayoutDestroy(Bdesc_col32_2r_4r4);
+ cublasLtMatrixLayoutDestroy(Cdesc_col32);
+ cublasLtMatrixLayoutDestroy(Adesc_col_major);
+ cublasLtMatrixLayoutDestroy(Bdesc_col_major);
+ cublasLtMatrixLayoutDestroy(Cdesc_col_major);
+ cudaFree(Adev);
+ cudaFree(Bdev);
+ cudaFree(Cdev);
+
+ return !error;
+}
+
+// igemmlt
+bool test6() {
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+ const constexpr int m = 4;
+ const constexpr int n = 2;
+ const constexpr int k = 3;
+ int lda = m;
+ int ldb = n;
+ int ldc = m;
+ void *Adev;
+ void *Bdev;
+ void *Cdev;
+ cudaMalloc(&Adev, m * k * sizeof(int8_t));
+ cudaMalloc(&Bdev, n * k * sizeof(int8_t));
+ cudaMalloc(&Cdev, m * n * sizeof(int8_t));
+
+ int8_t Ahost[m * k] = {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17};
+ int8_t Bhost[n * k] = {5, 4, -3, -2, 1, 0};
+
+ cudaMemcpy(Adev, Ahost, m * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+ cudaMemcpy(Bdev, Bhost, n * k * sizeof(int8_t), cudaMemcpyHostToDevice);
+
+ cublasLtMatrixLayout_t Adesc_col_major = NULL, Bdesc_col_major = NULL,
+ Cdesc_col_major = NULL;
+ cublasLtMatrixLayoutCreate(&Adesc_col_major, CUDA_R_8I, m, k, lda);
+ cublasLtMatrixLayoutCreate(&Bdesc_col_major, CUDA_R_8I, n, k, ldb);
+ cublasLtMatrixLayoutCreate(&Cdesc_col_major, CUDA_R_8I, m, n, ldc);
+
+ // Convert A and B
+ cublasLtMatrixLayout_t Adesc_col32 = NULL, Bdesc_col32_2r_4r4 = NULL,
+ Cdesc_col32 = NULL;
+ int8_t *A_col32, *B_col32_2r_4r4;
+ int8_t *C_col32;
+ cudaMalloc(&A_col32, m * 32 * sizeof(std::int8_t));
+ cudaMalloc(&B_col32_2r_4r4,
+ ((n + 32 - 1) / 32) * 32 * 32 * sizeof(std::int8_t));
+ cudaMalloc(&C_col32, m * 32 * sizeof(std::int8_t));
+ cublasLtMatrixLayoutCreate(&Adesc_col32, CUDA_R_8I, m, k, m * 32);
+ cublasLtMatrixLayoutCreate(&Bdesc_col32_2r_4r4, CUDA_R_8I, k, n,
+ ((n + 32 - 1) / 32) * 32 * 32);
+ cublasLtMatrixLayoutCreate(&Cdesc_col32, CUDA_R_8I, m, n, m * 32);
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col32_2r_4r4 = CUBLASLT_ORDER_COL32_2R_4R4;
+ cublasLtMatrixLayoutSetAttribute(Adesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+ cublasLtMatrixLayoutSetAttribute(Bdesc_col32_2r_4r4,
+ CUBLASLT_MATRIX_LAYOUT_ORDER, &col32_2r_4r4,
+ sizeof(col32_2r_4r4));
+ cublasLtMatrixLayoutSetAttribute(Cdesc_col32, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &col32, sizeof(col32));
+
+ transform(ltHandle, Adev, lda, Adesc_col_major, A_col32, m * 32, Adesc_col32);
+ transform(ltHandle, Bdev, ldb, Bdesc_col_major, B_col32_2r_4r4, 8 * 32,
+ Bdesc_col32_2r_4r4);
+
+ float *alpha;
+ cudaMallocManaged(&alpha, 4 * sizeof(float));
+ alpha[0] = 0;
+ alpha[1] = 1;
+ alpha[2] = 2;
+ alpha[3] = 3;
+
+ // Matmul
+ igemmlt(ltHandle, m, n, k, A_col32, B_col32_2r_4r4, C_col32,
+ alpha, m * 32, ((n + 8 - 1) / 8) * 8 * 32, m * 32);
+
+ // Convert C
+ transform(ltHandle, C_col32, m * 32, Cdesc_col32, Cdev, ldc, Cdesc_col_major);
+ cudaStreamSynchronize(0);
+
+ // Check result
+ int8_t Chost[m * n];
+ cudaMemcpy(Chost, Cdev, m * n * sizeof(int8_t), cudaMemcpyDeviceToHost);
+
+ bool error = false;
+ int8_t C_ref[m * n] = {0, 17, 40, 69, 0, 6, 16, 30};
+ for (int i = 0; i < m * n; i++) {
+ if (Chost[i] != C_ref[i]) {
+ error = true;
+ break;
+ }
+ }
+ printf("c:\n");
+ for (int i = 0; i < m * n; i++)
+ printf("%d, ", Chost[i]);
+ printf("\n");
+
+ if (error) {
+ printf("error\n");
+ } else {
+ printf("success\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+ cublasLtMatrixLayoutDestroy(Adesc_col32);
+ cublasLtMatrixLayoutDestroy(Bdesc_col32_2r_4r4);
+ cublasLtMatrixLayoutDestroy(Cdesc_col32);
+ cublasLtMatrixLayoutDestroy(Adesc_col_major);
+ cublasLtMatrixLayoutDestroy(Bdesc_col_major);
+ cublasLtMatrixLayoutDestroy(Cdesc_col_major);
+ cudaFree(Adev);
+ cudaFree(Bdev);
+ cudaFree(Cdev);
+ cudaFree(alpha);
+
+ return !error;
+}
+
+// clang-format off
+// A (4*3) B (2*3)
+// 6 10 14 5 -3 1
+// 7 11 15 4 -2 0
+// 8 12 16
+// 9 13 17
+//
+// alpha * A * op(B) = alpha * C = C
+// 0 6 10 14 5 4 0 14 4 0 0
+// 1 7 11 15 -3 -2 1 17 6 17 6
+// 2 8 12 16 1 0 2 20 8 40 16
+// 3 9 13 17 3 23 10 69 30
+//
+// alpha * A * op(B) = alpha * C = C
+// 1 6 10 14 5 4 1 14 4 14 4
+// 7 11 15 -3 -2 17 6 17 6
+// 8 12 16 1 0 20 8 20 8
+// 9 13 17 23 10 23 10
+// clang-format on
+
+int main() {
+ bool pass = true;
+ pass = test1() && pass;
+ pass = test2() && pass;
+ pass = test3() && pass;
+ pass = test4() && pass;
+ pass = test5() && pass;
+ pass = test6() && pass;
+ return pass ? 0 : 1;
+}
diff --git a/features/feature_case/cublasLt/transform.cu b/features/feature_case/cublasLt/transform.cu
new file mode 100644
index 00000000..3a1205ad
--- /dev/null
+++ b/features/feature_case/cublasLt/transform.cu
@@ -0,0 +1,600 @@
+// ===------------ transform.cu -------------------------- *- CUDA -* ----=== //
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// ===--------------------------------------------------------------------=== //
+
+#include "cublasLt.h"
+#include
+
+void transform(cublasLtHandle_t ltHandle, void *in, int ld_in,
+ cublasLtOrder_t order_in, void *out, int ld_out,
+ cublasLtOrder_t order_out, int dim1, int dim2) {
+ cublasLtMatrixLayout_t in_desc = NULL, out_desc = NULL;
+ cublasLtMatrixTransformDesc_t transform_desc = NULL;
+
+ cublasLtMatrixLayoutCreate(&in_desc, CUDA_R_8I, dim1, dim2, ld_in);
+ cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ld_out);
+
+ cublasLtMatrixLayoutSetAttribute(in_desc, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &order_in, sizeof(order_in));
+ cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &order_out, sizeof(order_out));
+
+ cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F);
+
+ float alpha = 1.0f, beta = 0.0f;
+ cublasLtMatrixTransform(ltHandle, transform_desc, &alpha, in, in_desc, &beta,
+ NULL, NULL, out, out_desc, 0);
+
+ cublasLtMatrixLayoutDestroy(in_desc);
+ cublasLtMatrixLayoutDestroy(out_desc);
+ cublasLtMatrixTransformDescDestroy(transform_desc);
+}
+
+bool test_ROW() {
+ const constexpr int m = 2;
+ const constexpr int n = 33;
+ const constexpr int in_ld = 4;
+ void *in_dev;
+ cudaMalloc(&in_dev, n * in_ld * sizeof(int8_t));
+
+ int8_t in_host[n * in_ld];
+ int8_t value = 0;
+ for (int i = 0; i < n * in_ld; i++) {
+ if (i % 4 < 2) {
+ in_host[i] = value;
+ value++;
+ } else
+ in_host[i] = 99;
+ }
+ int8_t ref_2nd[n * in_ld];
+ std::memcpy(ref_2nd, in_host, n * in_ld * sizeof(int8_t));
+
+ cudaMemcpy(in_dev, in_host, n * in_ld * sizeof(int8_t),
+ cudaMemcpyHostToDevice);
+
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+
+ void *out_dev;
+ const constexpr int out_ld = 36;
+ cudaMalloc(&out_dev, out_ld * m * sizeof(int8_t));
+ cudaMemset(out_dev, 0, out_ld * m * sizeof(int8_t));
+ transform(ltHandle, in_dev, in_ld, CUBLASLT_ORDER_COL, out_dev, out_ld,
+ CUBLASLT_ORDER_ROW, m, n);
+
+ int8_t out_host[out_ld * m];
+ cudaMemcpy(out_host, out_dev, out_ld * m * sizeof(int8_t),
+ cudaMemcpyDeviceToHost);
+
+ bool pass_1st = true;
+ int8_t ref_1st[out_ld * m] =
+ {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 0, 0, 0,
+ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 0, 0, 0};
+ for (int i = 0; i < out_ld * m; i++) {
+ if (i % out_ld < n) {
+ if (out_host[i] != ref_1st[i]) {
+ pass_1st = false;
+ break;
+ }
+ }
+ }
+
+ for (int i = 0; i < out_ld * m; i++) {
+ printf("%d, ", out_host[i]);
+ }
+ printf("\n");
+ if (pass_1st) {
+ printf("ROW 1st pass\n");
+ } else {
+ printf("ROW 1st fail\n");
+ }
+
+ cudaMemset(in_dev, 0, n * in_ld * sizeof(int8_t));
+ std::memset(in_host, 0, n * in_ld * sizeof(int8_t));
+ transform(ltHandle, out_dev, out_ld, CUBLASLT_ORDER_ROW, in_dev, in_ld,
+ CUBLASLT_ORDER_COL, m, n);
+ cudaMemcpy(in_host, in_dev, n * in_ld * sizeof(int8_t),
+ cudaMemcpyDeviceToHost);
+
+ bool pass_2nd = true;
+ for (int i = 0; i < n * in_ld; i++) {
+ if (i % in_ld < m) {
+ if (in_host[i] != ref_2nd[i]) {
+ pass_2nd = false;
+ break;
+ }
+ }
+ }
+
+ for (int i = 0; i < n * in_ld; i++) {
+ printf("%d, ", in_host[i]);
+ }
+ printf("\n");
+ if (pass_2nd) {
+ printf("ROW 2nd pass\n");
+ } else {
+ printf("ROW 2nd fail\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+
+ return pass_1st && pass_2nd;
+}
+
+bool test_COL32() {
+ const constexpr int m = 2;
+ const constexpr int n = 33;
+ const constexpr int in_ld = 4;
+ void *in_dev;
+ cudaMalloc(&in_dev, n * in_ld * sizeof(int8_t));
+
+ int8_t in_host[n * in_ld];
+ int8_t value = 0;
+ for (int i = 0; i < n * in_ld; i++) {
+ if (i % 4 < 2) {
+ in_host[i] = value;
+ value++;
+ } else
+ in_host[i] = 99;
+ }
+ int8_t ref_2nd[n * in_ld];
+ std::memcpy(ref_2nd, in_host, n * in_ld * sizeof(int8_t));
+
+ cudaMemcpy(in_dev, in_host, n * in_ld * sizeof(int8_t),
+ cudaMemcpyHostToDevice);
+
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+
+ void *out_dev;
+ const constexpr int out_ld = 64;
+ cudaMalloc(&out_dev, out_ld * m * sizeof(int8_t));
+ cudaMemset(out_dev, 0, out_ld * m * sizeof(int8_t));
+ transform(ltHandle, in_dev, in_ld, CUBLASLT_ORDER_COL, out_dev, out_ld,
+ CUBLASLT_ORDER_COL32, m, n);
+
+ int8_t out_host[out_ld * m];
+ cudaMemcpy(out_host, out_dev, out_ld * m * sizeof(int8_t),
+ cudaMemcpyDeviceToHost);
+
+ bool pass_1st = true;
+ int8_t ref_1st[out_ld * m] =
+ {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62,
+ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63,
+ 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ for (int i = 0; i < out_ld * m; i++) {
+ if (i % out_ld < n) {
+ if (out_host[i] != ref_1st[i]) {
+ pass_1st = false;
+ break;
+ }
+ }
+ }
+
+ for (int i = 0; i < out_ld * m; i++) {
+ printf("%d, ", out_host[i]);
+ }
+ printf("\n");
+ if (pass_1st) {
+ printf("COL32 1st pass\n");
+ } else {
+ printf("COL32 1st fail\n");
+ }
+
+ cudaMemset(in_dev, 0, n * in_ld * sizeof(int8_t));
+ std::memset(in_host, 0, n * in_ld * sizeof(int8_t));
+ transform(ltHandle, out_dev, out_ld, CUBLASLT_ORDER_COL32, in_dev, in_ld,
+ CUBLASLT_ORDER_COL, m, n);
+ cudaMemcpy(in_host, in_dev, n * in_ld * sizeof(int8_t),
+ cudaMemcpyDeviceToHost);
+
+ bool pass_2nd = true;
+ for (int i = 0; i < n * in_ld; i++) {
+ if (i % in_ld < m) {
+ if (in_host[i] != ref_2nd[i]) {
+ pass_2nd = false;
+ break;
+ }
+ }
+ }
+
+ for (int i = 0; i < n * in_ld; i++) {
+ printf("%d, ", in_host[i]);
+ }
+ printf("\n");
+ if (pass_2nd) {
+ printf("COL32 2nd pass\n");
+ } else {
+ printf("COL32 2nd fail\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+
+ return pass_1st && pass_2nd;
+}
+
+bool test_COL4_4R2_8C() {
+ const constexpr int m = 2;
+ const constexpr int n = 33;
+ const constexpr int in_ld = 4;
+ void *in_dev;
+ cudaMalloc(&in_dev, n * in_ld * sizeof(int8_t));
+
+ int8_t in_host[n * in_ld];
+ int8_t value = 0;
+ for (int i = 0; i < n * in_ld; i++) {
+ if (i % 4 < 2) {
+ in_host[i] = value;
+ value++;
+ } else
+ in_host[i] = 99;
+ }
+ int8_t ref_2nd[n * in_ld];
+ std::memcpy(ref_2nd, in_host, n * in_ld * sizeof(int8_t));
+
+ cudaMemcpy(in_dev, in_host, n * in_ld * sizeof(int8_t),
+ cudaMemcpyHostToDevice);
+
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+
+ void *out_dev;
+ const constexpr int out_ld = (32 * 8) * 2;
+ cudaMalloc(&out_dev, out_ld * m * sizeof(int8_t));
+ cudaMemset(out_dev, 0, out_ld * m * sizeof(int8_t));
+ transform(ltHandle, in_dev, in_ld, CUBLASLT_ORDER_COL, out_dev, out_ld,
+ CUBLASLT_ORDER_COL4_4R2_8C, m, n);
+
+ int8_t out_host[out_ld * m];
+ cudaMemcpy(out_host, out_dev, out_ld * m * sizeof(int8_t),
+ cudaMemcpyDeviceToHost);
+
+ bool pass_1st = true;
+ int8_t ref_1st[out_ld * m] =
+ {0, 2, 4, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 8, 10, 12, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 16, 18, 20, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 24, 26, 28, 30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 32, 34, 36, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 40, 42, 44, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 48, 50, 52, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 56, 58, 60, 62, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 3, 5, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 9, 11, 13, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 17, 19, 21, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 25, 27, 29, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 33, 35, 37, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 41, 43, 45, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 49, 51, 53, 55, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 57, 59, 61, 63, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ for (int i = 0; i < out_ld * m; i++) {
+ if (i % out_ld < n) {
+ if (out_host[i] != ref_1st[i]) {
+ pass_1st = false;
+ break;
+ }
+ }
+ }
+
+ for (int i = 0; i < out_ld * m; i++) {
+ printf("%d, ", out_host[i]);
+ }
+ printf("\n");
+ if (pass_1st) {
+ printf("COL4_4R2_8C 1st pass\n");
+ } else {
+ printf("COL4_4R2_8C 1st fail\n");
+ }
+
+ cudaMemset(in_dev, 0, n * in_ld * sizeof(int8_t));
+ std::memset(in_host, 0, n * in_ld * sizeof(int8_t));
+ transform(ltHandle, out_dev, out_ld, CUBLASLT_ORDER_COL4_4R2_8C, in_dev,
+ in_ld, CUBLASLT_ORDER_COL, m, n);
+ cudaMemcpy(in_host, in_dev, n * in_ld * sizeof(int8_t),
+ cudaMemcpyDeviceToHost);
+
+ bool pass_2nd = true;
+ for (int i = 0; i < n * in_ld; i++) {
+ if (i % in_ld < m) {
+ if (in_host[i] != ref_2nd[i]) {
+ pass_2nd = false;
+ break;
+ }
+ }
+ }
+
+ for (int i = 0; i < n * in_ld; i++) {
+ printf("%d, ", in_host[i]);
+ }
+ printf("\n");
+ if (pass_2nd) {
+ printf("COL4_4R2_8C 2nd pass\n");
+ } else {
+ printf("COL4_4R2_8C 2nd fail\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+
+ return pass_1st && pass_2nd;
+}
+
+bool test_COL32_2R_4R4() {
+ const constexpr int m = 2;
+ const constexpr int n = 33;
+ const constexpr int in_ld = 4;
+ void *in_dev;
+ cudaMalloc(&in_dev, n * in_ld * sizeof(int8_t));
+
+ int8_t in_host[n * in_ld];
+ int8_t value = 0;
+ for (int i = 0; i < n * in_ld; i++) {
+ if (i % 4 < 2) {
+ in_host[i] = value;
+ value++;
+ } else
+ in_host[i] = 99;
+ }
+ int8_t ref_2nd[n * in_ld];
+ std::memcpy(ref_2nd, in_host, n * in_ld * sizeof(int8_t));
+
+ cudaMemcpy(in_dev, in_host, n * in_ld * sizeof(int8_t),
+ cudaMemcpyHostToDevice);
+
+ cublasLtHandle_t ltHandle;
+ cublasLtCreate(<Handle);
+
+ void *out_dev;
+ const constexpr int out_ld = (32 * 32) * 2;
+ cudaMalloc(&out_dev, out_ld * m * sizeof(int8_t));
+ cudaMemset(out_dev, 0, out_ld * m * sizeof(int8_t));
+ transform(ltHandle, in_dev, in_ld, CUBLASLT_ORDER_COL, out_dev, out_ld,
+ CUBLASLT_ORDER_COL32_2R_4R4, m, n);
+
+ int8_t out_host[out_ld * m];
+ cudaMemcpy(out_host, out_dev, out_ld * m * sizeof(int8_t),
+ cudaMemcpyDeviceToHost);
+
+ bool pass_1st = true;
+ int8_t ref_1st[out_ld * m] =
+ {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62,
+ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 65, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+ for (int i = 0; i < out_ld * m; i++) {
+ if (i % out_ld < n) {
+ if (out_host[i] != ref_1st[i]) {
+ pass_1st = false;
+ break;
+ }
+ }
+ }
+
+ for (int i = 0; i < out_ld * m; i++) {
+ printf("%d, ", out_host[i]);
+ }
+ printf("\n");
+ if (pass_1st) {
+ printf("COL32_2R_4R4 1st pass\n");
+ } else {
+ printf("COL32_2R_4R4 1st fail\n");
+ }
+
+ cudaMemset(in_dev, 0, n * in_ld * sizeof(int8_t));
+ std::memset(in_host, 0, n * in_ld * sizeof(int8_t));
+ transform(ltHandle, out_dev, out_ld, CUBLASLT_ORDER_COL32_2R_4R4, in_dev,
+ in_ld, CUBLASLT_ORDER_COL, m, n);
+ cudaMemcpy(in_host, in_dev, n * in_ld * sizeof(int8_t),
+ cudaMemcpyDeviceToHost);
+
+ bool pass_2nd = true;
+ for (int i = 0; i < n * in_ld; i++) {
+ if (i % in_ld < m) {
+ if (in_host[i] != ref_2nd[i]) {
+ pass_2nd = false;
+ break;
+ }
+ }
+ }
+
+ for (int i = 0; i < n * in_ld; i++) {
+ printf("%d, ", in_host[i]);
+ }
+ printf("\n");
+ if (pass_2nd) {
+ printf("COL32_2R_4R4 2nd pass\n");
+ } else {
+ printf("COL32_2R_4R4 2nd fail\n");
+ }
+
+ cublasLtDestroy(ltHandle);
+
+ return pass_1st && pass_2nd;
+}
+
+// Input col_major matrix:
+// 2 rows * 33 columns, ld is 4
+int main() {
+ bool pass = true;
+ pass = test_ROW() && pass;
+ pass = test_COL32() && pass;
+ pass = test_COL4_4R2_8C() && pass;
+ pass = test_COL32_2R_4R4() && pass;
+ return pass ? 0 : 1;
+}
diff --git a/features/features.xml b/features/features.xml
index 8ad6e396..ff7838f2 100644
--- a/features/features.xml
+++ b/features/features.xml
@@ -342,5 +342,7 @@
+
+
diff --git a/features/test_feature.py b/features/test_feature.py
index 2f6bab32..bbd6e08c 100644
--- a/features/test_feature.py
+++ b/features/test_feature.py
@@ -60,7 +60,7 @@
'thrust_swap_ranges', 'thrust_uninitialized_fill_n', 'thrust_equal', 'system_atomic', 'thrust_detail_types',
'operator_eq', 'operator_neq', 'operator_lege', 'thrust_system', 'thrust_reverse_copy',
'thrust_device_new_delete', 'thrust_temporary_buffer', 'thrust_malloc_free', 'codepin', 'thrust_unique_count',
- 'thrust_advance_trans_op_itr', 'cuda_stream_query']
+ 'thrust_advance_trans_op_itr', 'cuda_stream_query', "matmul", "transform"]
occupancy_calculation_exper = ['occupancy_calculation']
@@ -166,7 +166,7 @@ def build_test():
'cudnn-binary', 'cudnn-bnp1', 'cudnn-bnp2', 'cudnn-bnp3', 'cudnn-normp1', 'cudnn-normp2', 'cudnn-normp3',
'cudnn-convp1', 'cudnn-convp2', 'cudnn-convp3', 'cudnn-convp4', 'cudnn-convp5', 'cudnn-convp6', 'cudnn-rnn',
'cudnn-GetErrorString', 'cudnn-convp7',
- 'cudnn-types', 'cudnn-version', 'cudnn-dropout'
+ 'cudnn-types', 'cudnn-version', 'cudnn-dropout', 'matmul'
]
no_fast_math_tests = ['math-emu-half-after11', 'math-emu-half2-after11', 'math-ext-half-after11', 'math-ext-half2-after11',
diff --git a/help_function/help_function.xml b/help_function/help_function.xml
index 3bdbd26e..fc047e53 100644
--- a/help_function/help_function.xml
+++ b/help_function/help_function.xml
@@ -215,5 +215,6 @@
+
diff --git a/help_function/src/blas_gemm_utils_interface.cpp b/help_function/src/blas_gemm_utils_interface.cpp
new file mode 100644
index 00000000..551021b3
--- /dev/null
+++ b/help_function/src/blas_gemm_utils_interface.cpp
@@ -0,0 +1,152 @@
+// ===------ blas_gemm_utils_interface.cpp ----------------- *- C++ -* ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// ===--------------------------------------------------------------------=== //
+
+#include
+#include
+#include
+#include
+
+void foo1 () {
+ dpct::blas_gemm::experimental::descriptor_ptr ltHandle;
+ ltHandle = new dpct::blas_gemm::experimental::descriptor();
+ delete (ltHandle);
+
+ dpct::blas_gemm::experimental::matrix_layout_ptr matLayout;
+ dpct::library_data_t type;
+ uint64_t rows;
+ uint64_t cols;
+ int64_t ld;
+ matLayout =
+ new dpct::blas_gemm::experimental::matrix_layout_t(type, rows, cols, ld);
+
+ dpct::blas_gemm::experimental::matrix_layout_t::attribute attr1;
+ void *buf1;
+ size_t sizeInBytes1;
+ size_t *sizeWritten1;
+ matLayout->get_attribute(attr1, buf1);
+ matLayout->set_attribute(attr1, buf1);
+ delete (matLayout);
+
+ dpct::blas_gemm::experimental::matmul_desc_ptr matmulDesc;
+ dpct::compute_type computeType;
+ dpct::library_data_t scaleType;
+ matmulDesc =
+ new dpct::blas_gemm::experimental::matmul_desc_t(computeType, scaleType);
+
+ dpct::blas_gemm::experimental::matmul_desc_t::attribute attr2;
+ void *buf2;
+ size_t sizeInBytes2;
+ size_t *sizeWritten2;
+ matmulDesc->get_attribute(attr2, buf2);
+ matmulDesc->set_attribute(attr2, buf2);
+ delete (matmulDesc);
+
+ int matmulPreference;
+ void *buf3;
+ size_t sizeInBytes3;
+ size_t *sizeWritten3;
+
+ dpct::blas_gemm::experimental::matrix_layout_ptr Adesc;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Bdesc;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Cdesc;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Ddesc;
+
+ int requestedAlgoCount = 1;
+ int heuristicResultsArray;
+ int returnAlgoCount;
+ returnAlgoCount = 1;
+}
+
+void foo2() {
+ dpct::blas_gemm::experimental::descriptor_ptr lightHandle;
+ dpct::blas_gemm::experimental::matmul_desc_ptr computeDesc;
+ const void *alpha;
+ const void *A;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Adesc;
+ const void *B;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Bdesc;
+ const void *beta;
+ const void *C;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Cdesc;
+ void *D;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Ddesc;
+ const int *algo;
+ void *workspace;
+ size_t workspaceSizeInBytes;
+ dpct::queue_ptr stream;
+ dpct::blas_gemm::experimental::matmul(lightHandle, computeDesc, alpha, A,
+ Adesc, B, Bdesc, beta, C, Cdesc, D,
+ Ddesc, stream);
+}
+
+void foo3() {
+ dpct::blas_gemm::experimental::order_t a;
+ a = dpct::blas_gemm::experimental::order_t::col;
+ a = dpct::blas_gemm::experimental::order_t::row;
+ a = dpct::blas_gemm::experimental::order_t::col32;
+ a = dpct::blas_gemm::experimental::order_t::col4_4r2_8c;
+ a = dpct::blas_gemm::experimental::order_t::col32_2r_4r4;
+
+ dpct::blas_gemm::experimental::pointer_mode_t b;
+ b = dpct::blas_gemm::experimental::pointer_mode_t::host;
+ b = dpct::blas_gemm::experimental::pointer_mode_t::device;
+ b = dpct::blas_gemm::experimental::pointer_mode_t::device_vector;
+ b = dpct::blas_gemm::experimental::pointer_mode_t::
+ alpha_device_vector_beta_zero;
+ b = dpct::blas_gemm::experimental::pointer_mode_t::
+ alpha_device_vector_beta_host;
+
+ dpct::blas_gemm::experimental::matrix_layout_t::attribute c;
+ c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::type;
+ c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::order;
+ c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::rows;
+ c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::cols;
+ c = dpct::blas_gemm::experimental::matrix_layout_t::attribute::ld;
+
+ dpct::blas_gemm::experimental::matmul_desc_t::attribute d;
+ d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::compute_type;
+ d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::scale_type;
+ d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::pointer_mode;
+ d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::trans_a;
+ d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::trans_b;
+ d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::trans_c;
+ d = dpct::blas_gemm::experimental::matmul_desc_t::attribute::epilogue;
+}
+
+void foo4() {
+ dpct::blas_gemm::experimental::transform_desc_ptr transformDesc;
+ dpct::library_data_t scaleType;
+ transformDesc =
+ new dpct::blas_gemm::experimental::transform_desc_t(scaleType);
+ oneapi::mkl::transpose opT = oneapi::mkl::transpose::trans;
+ size_t sizeWritten;
+ transformDesc->set_attribute(
+ dpct::blas_gemm::experimental::transform_desc_t::attribute::trans_a,
+ &opT);
+ transformDesc->get_attribute(
+ dpct::blas_gemm::experimental::transform_desc_t::attribute::trans_a,
+ &opT);
+ delete (transformDesc);
+
+ dpct::blas_gemm::experimental::descriptor_ptr lightHandle;
+ const void *alpha;
+ const void *A;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Adesc;
+ const void *beta;
+ const void *B;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Bdesc;
+ void *C;
+ dpct::blas_gemm::experimental::matrix_layout_ptr Cdesc;
+ dpct::queue_ptr stream;
+ dpct::blas_gemm::experimental::matrix_transform(
+ transformDesc, alpha, A, Adesc, beta, B, Bdesc, C, Cdesc, stream);
+}
+
+int main() {
+ return 0;
+}
diff --git a/help_function/test_help.py b/help_function/test_help.py
index b6c58527..52588464 100644
--- a/help_function/test_help.py
+++ b/help_function/test_help.py
@@ -45,7 +45,7 @@ def build_test():
"dnnl_utils_batch_normalization_2", "dnnl_utils_batch_normalization_3", "dnnl_utils_convolution_1",
"dnnl_utils_convolution_2", "dnnl_utils_convolution_3", "dnnl_utils_convolution_4", "dnnl_utils_convolution_5",
"dnnl_utils_normalization_1", "dnnl_utils_normalization_2", "dnnl_utils_normalization_3", "dnnl_utils_rnn",
- "dnnl_utils_version", "dnnl_utils_dropout"]
+ "dnnl_utils_version", "dnnl_utils_dropout", "blas_gemm_utils_interface"]
fft_cases = ["fft_utils_engine_buffer", "fft_utils_engine_usm", "fft_workspace_interface", "fft_set_workspace"]
lapack_cases = ["lapack_utils_buffer", "lapack_utils_usm"]
rng_cases = ["rng_generator", "rng_generator_vec_size_1", "rng_host"]