Skip to content

Commit

Permalink
[HIPIFY][BLAS][fix] Added support for the missing `hipblas(S|D)gemvBa…
Browse files Browse the repository at this point in the history
…tched`

+ Updated synthetic tests, the regenerated `hipify-perl`, and `BLAS` `CUDA2HIP` documentation
  • Loading branch information
emankov committed Jun 24, 2024
1 parent 8f85a4d commit 2371d60
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 8 deletions.
4 changes: 2 additions & 2 deletions bin/hipify-perl
Original file line number Diff line number Diff line change
Expand Up @@ -3921,6 +3921,7 @@ sub simpleSubstitutions {
subst("cublasDgemmStridedBatched", "hipblasDgemmStridedBatched", "library");
subst("cublasDgemm_v2", "hipblasDgemm", "library");
subst("cublasDgemv", "hipblasDgemv", "library");
subst("cublasDgemvBatched", "hipblasDgemvBatched", "library");
subst("cublasDgemvStridedBatched", "hipblasDgemvStridedBatched", "library");
subst("cublasDgemv_v2", "hipblasDgemv", "library");
subst("cublasDgeqrfBatched", "hipblasDgeqrfBatched", "library");
Expand Down Expand Up @@ -4117,6 +4118,7 @@ sub simpleSubstitutions {
subst("cublasSgemmStridedBatched", "hipblasSgemmStridedBatched", "library");
subst("cublasSgemm_v2", "hipblasSgemm", "library");
subst("cublasSgemv", "hipblasSgemv", "library");
subst("cublasSgemvBatched", "hipblasSgemvBatched", "library");
subst("cublasSgemvStridedBatched", "hipblasSgemvStridedBatched", "library");
subst("cublasSgemv_v2", "hipblasSgemv", "library");
subst("cublasSgeqrfBatched", "hipblasSgeqrfBatched", "library");
Expand Down Expand Up @@ -11442,7 +11444,6 @@ sub warnHipOnlyUnsupportedFunctions {
"cublasSger_v2_64",
"cublasSger_64",
"cublasSgemvStridedBatched_64",
"cublasSgemvBatched",
"cublasSgemm_v2_64",
"cublasSgemm_64",
"cublasSgemmStridedBatched_64",
Expand Down Expand Up @@ -11587,7 +11588,6 @@ sub warnHipOnlyUnsupportedFunctions {
"cublasDger_v2_64",
"cublasDger_64",
"cublasDgemvStridedBatched_64",
"cublasDgemvBatched",
"cublasDgemm_v2_64",
"cublasDgemm_64",
"cublasDgemmStridedBatched_64",
Expand Down
4 changes: 2 additions & 2 deletions docs/tables/CUBLAS_API_supported_by_HIP.md
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@
|`cublasDgemm_64`|12.0| | | | | | | | | |
|`cublasDgemm_v2`| | | | |`hipblasDgemm`|1.8.2| | | | |
|`cublasDgemm_v2_64`|12.0| | | | | | | | | |
|`cublasDgemvBatched`|11.6| | | | | | | | | |
|`cublasDgemvBatched`|11.6| | | |`hipblasDgemvBatched`|3.0.0| | | | |
|`cublasDgemvBatched_64`|12.0| | | |`hipblasDgemvBatched_64`|6.2.0| | | |6.2.0|
|`cublasDgemvStridedBatched`|11.6| | | |`hipblasDgemvStridedBatched`|3.0.0| | | | |
|`cublasDgemvStridedBatched_64`|12.0| | | | | | | | | |
Expand Down Expand Up @@ -1133,7 +1133,7 @@
|`cublasSgemm_64`|12.0| | | | | | | | | |
|`cublasSgemm_v2`| | | | |`hipblasSgemm`|1.8.2| | | | |
|`cublasSgemm_v2_64`|12.0| | | | | | | | | |
|`cublasSgemvBatched`|11.6| | | | | | | | | |
|`cublasSgemvBatched`|11.6| | | |`hipblasSgemvBatched`|1.6.0| | | | |
|`cublasSgemvBatched_64`|12.0| | | |`hipblasSgemvBatched_64`|6.2.0| | | |6.2.0|
|`cublasSgemvStridedBatched`|11.6| | | |`hipblasSgemvStridedBatched`|3.0.0| | | | |
|`cublasSgemvStridedBatched_64`|12.0| | | | | | | | | |
Expand Down
4 changes: 2 additions & 2 deletions docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,7 @@
|`cublasDgemm_64`|12.0| | | | | | | | | | | | | | | |
|`cublasDgemm_v2`| | | | |`hipblasDgemm`|1.8.2| | | | |`rocblas_dgemm`|1.5.0| | | | |
|`cublasDgemm_v2_64`|12.0| | | | | | | | | | | | | | | |
|`cublasDgemvBatched`|11.6| | | | | | | | | | | | | | | |
|`cublasDgemvBatched`|11.6| | | |`hipblasDgemvBatched`|3.0.0| | | | | | | | | | |
|`cublasDgemvBatched_64`|12.0| | | |`hipblasDgemvBatched_64`|6.2.0| | | |6.2.0| | | | | | |
|`cublasDgemvStridedBatched`|11.6| | | |`hipblasDgemvStridedBatched`|3.0.0| | | | | | | | | | |
|`cublasDgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | |
Expand Down Expand Up @@ -1133,7 +1133,7 @@
|`cublasSgemm_64`|12.0| | | | | | | | | | | | | | | |
|`cublasSgemm_v2`| | | | |`hipblasSgemm`|1.8.2| | | | |`rocblas_sgemm`|1.5.0| | | | |
|`cublasSgemm_v2_64`|12.0| | | | | | | | | | | | | | | |
|`cublasSgemvBatched`|11.6| | | | | | | | | | | | | | | |
|`cublasSgemvBatched`|11.6| | | |`hipblasSgemvBatched`|1.6.0| | | | | | | | | | |
|`cublasSgemvBatched_64`|12.0| | | |`hipblasSgemvBatched_64`|6.2.0| | | |6.2.0| | | | | | |
|`cublasSgemvStridedBatched`|11.6| | | |`hipblasSgemvStridedBatched`|3.0.0| | | | | | | | | | |
|`cublasSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | |
Expand Down
6 changes: 4 additions & 2 deletions src/CUDA2HIP_BLAS_API_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,9 @@ const std::map<llvm::StringRef, hipCounter> CUDA_BLAS_FUNCTION_MAP {
{"cublasGemmGroupedBatchedEx_64", {"hipblasGemmGroupedBatchedEx_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},

// BATCH GEMV
{"cublasSgemvBatched", {"hipblasSgemvBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},
{"cublasSgemvBatched", {"hipblasSgemvBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}},
{"cublasSgemvBatched_64", {"hipblasSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED | HIP_EXPERIMENTAL}},
{"cublasDgemvBatched", {"hipblasDgemvBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},
{"cublasDgemvBatched", {"hipblasDgemvBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}},
{"cublasDgemvBatched_64", {"hipblasDgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED | HIP_EXPERIMENTAL}},
{"cublasCgemvBatched", {"hipblasCgemvBatched_v2", "rocblas_cgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}},
{"cublasCgemvBatched_64", {"hipblasCgemvBatched_v2_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED | HIP_EXPERIMENTAL}},
Expand Down Expand Up @@ -2068,6 +2068,8 @@ const std::map<llvm::StringRef, hipAPIversions> HIP_BLAS_FUNCTION_VER_MAP {
{"hipblasZgemvBatched_v2_64", {HIP_6020, HIP_0, HIP_0, HIP_LATEST}},
{"hipblasSgemvStridedBatched", {HIP_3000, HIP_0, HIP_0 }},
{"hipblasDgemvStridedBatched", {HIP_3000, HIP_0, HIP_0 }},
{"hipblasSgemvBatched", {HIP_1060, HIP_0, HIP_0 }},
{"hipblasDgemvBatched", {HIP_3000, HIP_0, HIP_0 }},

{"rocblas_status_to_string", {HIP_3050, HIP_0, HIP_0 }},
{"rocblas_sscal", {HIP_1050, HIP_0, HIP_0 }},
Expand Down
10 changes: 10 additions & 0 deletions tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,16 @@ int main() {
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasDgemvStridedBatched(hipblasHandle_t handle, hipblasOperation_t transA, int m, int n, const double* alpha, const double* AP, int lda, hipblasStride strideA, const double* x, int incx, hipblasStride stridex, const double* beta, double* y, int incy, hipblasStride stridey, int batchCount);
// CHECK: blasStatus = hipblasDgemvStridedBatched(blasHandle, blasOperation, m, n, &da, &dA, lda, strideA, &dx, incx, strideX, &db, &dy, incy, strideY, batchCount);
blasStatus = cublasDgemvStridedBatched(blasHandle, blasOperation, m, n, &da, &dA, lda, strideA, &dx, incx, strideX, &db, &dy, incy, strideY, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const float* const Aarray[], int lda, const float* const xarray[], int incx, const float* beta, float* const yarray[], int incy, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasSgemvBatched(hipblasHandle_t handle, hipblasOperation_t trans, int m, int n, const float* alpha, const float* const AP[], int lda, const float* const x[], int incx, const float* beta, float* const y[], int incy, int batchCount);
// CHECK: blasStatus = hipblasSgemvBatched(blasHandle, blasOperation, m, n, &fa, fAarray_const, lda, fXarray_const, incx, &fb, fYarray, incy, batchCount);
blasStatus = cublasSgemvBatched(blasHandle, blasOperation, m, n, &fa, fAarray_const, lda, fXarray_const, incx, &fb, fYarray, incy, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const double* alpha, const double* const Aarray[], int lda, const double* const xarray[], int incx, const double* beta, double* const yarray[], int incy, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasDgemvBatched(hipblasHandle_t handle, hipblasOperation_t trans, int m, int n, const double* alpha, const double* const AP[], int lda, const double* const x[], int incx, const double* beta, double* const y[], int incy, int batchCount);
// CHECK: blasStatus = hipblasDgemvBatched(blasHandle, blasOperation, m, n, &da, dAarray_const, lda, dXarray_const, incx, &db, dYarray, incy, batchCount);
blasStatus = cublasDgemvBatched(blasHandle, blasOperation, m, n, &da, dAarray_const, lda, dXarray_const, incx, &db, dYarray, incy, batchCount);
#endif

#if CUDA_VERSION >= 12000
Expand Down

0 comments on commit 2371d60

Please sign in to comment.