Skip to content

Commit

Permalink
[HIPIFY][BLAS][fix] Added support for the missing `hipblas(S|D)gemvSt…
Browse files Browse the repository at this point in the history
…ridedBatched`

+ Updated synthetic tests, the regenerated `hipify-perl`, and `BLAS` `CUDA2HIP` documentation
  • Loading branch information
emankov committed Jun 21, 2024
1 parent 66da687 commit 967de9b
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 10 deletions.
4 changes: 2 additions & 2 deletions bin/hipify-perl
Original file line number Diff line number Diff line change
Expand Up @@ -3921,6 +3921,7 @@ sub simpleSubstitutions {
subst("cublasDgemmStridedBatched", "hipblasDgemmStridedBatched", "library");
subst("cublasDgemm_v2", "hipblasDgemm", "library");
subst("cublasDgemv", "hipblasDgemv", "library");
subst("cublasDgemvStridedBatched", "hipblasDgemvStridedBatched", "library");
subst("cublasDgemv_v2", "hipblasDgemv", "library");
subst("cublasDgeqrfBatched", "hipblasDgeqrfBatched", "library");
subst("cublasDger", "hipblasDger", "library");
Expand Down Expand Up @@ -4116,6 +4117,7 @@ sub simpleSubstitutions {
subst("cublasSgemmStridedBatched", "hipblasSgemmStridedBatched", "library");
subst("cublasSgemm_v2", "hipblasSgemm", "library");
subst("cublasSgemv", "hipblasSgemv", "library");
subst("cublasSgemvStridedBatched", "hipblasSgemvStridedBatched", "library");
subst("cublasSgemv_v2", "hipblasSgemv", "library");
subst("cublasSgeqrfBatched", "hipblasSgeqrfBatched", "library");
subst("cublasSger", "hipblasSger", "library");
Expand Down Expand Up @@ -11439,7 +11441,6 @@ sub warnHipOnlyUnsupportedFunctions {
"cublasSger_v2_64",
"cublasSger_64",
"cublasSgemvStridedBatched_64",
"cublasSgemvStridedBatched",
"cublasSgemvBatched",
"cublasSgemm_v2_64",
"cublasSgemm_64",
Expand Down Expand Up @@ -11583,7 +11584,6 @@ sub warnHipOnlyUnsupportedFunctions {
"cublasDger_v2_64",
"cublasDger_64",
"cublasDgemvStridedBatched_64",
"cublasDgemvStridedBatched",
"cublasDgemvBatched",
"cublasDgemm_v2_64",
"cublasDgemm_64",
Expand Down
4 changes: 2 additions & 2 deletions docs/tables/CUBLAS_API_supported_by_HIP.md
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@
|`cublasDgemm_v2_64`|12.0| | | | | | | | | |
|`cublasDgemvBatched`|11.6| | | | | | | | | |
|`cublasDgemvBatched_64`|12.0| | | |`hipblasDgemvBatched_64`|6.2.0| | | |6.2.0|
|`cublasDgemvStridedBatched`|11.6| | | | | | | | | |
|`cublasDgemvStridedBatched`|11.6| | | |`hipblasDgemvStridedBatched`|3.0.0| | | | |
|`cublasDgemvStridedBatched_64`|12.0| | | | | | | | | |
|`cublasDsymm`| | | | |`hipblasDsymm`|3.6.0| | | | |
|`cublasDsymm_64`|12.0| | | | | | | | | |
Expand Down Expand Up @@ -1133,7 +1133,7 @@
|`cublasSgemm_v2_64`|12.0| | | | | | | | | |
|`cublasSgemvBatched`|11.6| | | | | | | | | |
|`cublasSgemvBatched_64`|12.0| | | |`hipblasSgemvBatched_64`|6.2.0| | | |6.2.0|
|`cublasSgemvStridedBatched`|11.6| | | | | | | | | |
|`cublasSgemvStridedBatched`|11.6| | | |`hipblasSgemvStridedBatched`|3.0.0| | | | |
|`cublasSgemvStridedBatched_64`|12.0| | | | | | | | | |
|`cublasSsymm`| | | | |`hipblasSsymm`|3.6.0| | | | |
|`cublasSsymm_64`|12.0| | | | | | | | | |
Expand Down
4 changes: 2 additions & 2 deletions docs/tables/CUBLAS_API_supported_by_HIP_and_ROC.md
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@
|`cublasDgemm_v2_64`|12.0| | | | | | | | | | | | | | | |
|`cublasDgemvBatched`|11.6| | | | | | | | | | | | | | | |
|`cublasDgemvBatched_64`|12.0| | | |`hipblasDgemvBatched_64`|6.2.0| | | |6.2.0| | | | | | |
|`cublasDgemvStridedBatched`|11.6| | | | | | | | | | | | | | | |
|`cublasDgemvStridedBatched`|11.6| | | |`hipblasDgemvStridedBatched`|3.0.0| | | | | | | | | | |
|`cublasDgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | |
|`cublasDsymm`| | | | |`hipblasDsymm`|3.6.0| | | | |`rocblas_dsymm`|3.5.0| | | | |
|`cublasDsymm_64`|12.0| | | | | | | | | | | | | | | |
Expand Down Expand Up @@ -1133,7 +1133,7 @@
|`cublasSgemm_v2_64`|12.0| | | | | | | | | | | | | | | |
|`cublasSgemvBatched`|11.6| | | | | | | | | | | | | | | |
|`cublasSgemvBatched_64`|12.0| | | |`hipblasSgemvBatched_64`|6.2.0| | | |6.2.0| | | | | | |
|`cublasSgemvStridedBatched`|11.6| | | | | | | | | | | | | | | |
|`cublasSgemvStridedBatched`|11.6| | | |`hipblasSgemvStridedBatched`|3.0.0| | | | | | | | | | |
|`cublasSgemvStridedBatched_64`|12.0| | | | | | | | | | | | | | | |
|`cublasSsymm`| | | | |`hipblasSsymm`|3.6.0| | | | |`rocblas_ssymm`|3.5.0| | | | |
|`cublasSsymm_64`|12.0| | | | | | | | | | | | | | | |
Expand Down
8 changes: 5 additions & 3 deletions src/CUDA2HIP_BLAS_API_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ const std::map<llvm::StringRef, hipCounter> CUDA_BLAS_FUNCTION_MAP {
{"cublasTSTgemvBatched_64", {"hipblasTSTgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},
{"cublasTSSgemvBatched", {"hipblasTSSgemvBatched", "rocblas_tssgemv_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, HIP_UNSUPPORTED}},
{"cublasTSSgemvBatched_64", {"hipblasTSSgemvBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},
{"cublasSgemvStridedBatched", {"hipblasSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},
{"cublasSgemvStridedBatched", {"hipblasSgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}},
{"cublasSgemvStridedBatched_64", {"hipblasSgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},
{"cublasDgemvStridedBatched", {"hipblasDgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},
{"cublasDgemvStridedBatched", {"hipblasDgemvStridedBatched", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, ROC_UNSUPPORTED}},
{"cublasDgemvStridedBatched_64", {"hipblasDgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},
{"cublasCgemvStridedBatched", {"hipblasCgemvStridedBatched_v2", "rocblas_cgemv_strided_batched", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3}},
{"cublasCgemvStridedBatched_64", {"hipblasCgemvStridedBatched_64", "", CONV_LIB_FUNC, API_BLAS, SEC::BLAS_LEVEL_3, UNSUPPORTED}},
Expand Down Expand Up @@ -1455,7 +1455,7 @@ const std::map<llvm::StringRef, cudaAPIversions> CUDA_BLAS_FUNCTION_VER_MAP {
{"cublasTSTgemvBatched_64", {CUDA_120, CUDA_0, CUDA_0 }},
{"cublasTSSgemvBatched", {CUDA_116, CUDA_0, CUDA_0 }},
{"cublasTSSgemvBatched_64", {CUDA_120, CUDA_0, CUDA_0 }},
{"cublasSgemvStridedBatched", {CUDA_116, CUDA_0, CUDA_0 }},
{"cublasSgemvStridedBatched", {CUDA_116, CUDA_0, CUDA_0 }}, // A: CUDA_VERSION 11062, CUBLAS_VERSION 110902, CUBLAS_VER_MAJOR 11 CUBLAS_VER_MINOR 9 CUBLAS_VER_PATCH 2
{"cublasSgemvStridedBatched_64", {CUDA_120, CUDA_0, CUDA_0 }},
{"cublasDgemvStridedBatched", {CUDA_116, CUDA_0, CUDA_0 }},
{"cublasDgemvStridedBatched_64", {CUDA_120, CUDA_0, CUDA_0 }},
Expand Down Expand Up @@ -2062,6 +2062,8 @@ const std::map<llvm::StringRef, hipAPIversions> HIP_BLAS_FUNCTION_VER_MAP {
{"hipblasDgemvBatched_64", {HIP_6020, HIP_0, HIP_0, HIP_LATEST}},
{"hipblasCgemvBatched_v2_64", {HIP_6020, HIP_0, HIP_0, HIP_LATEST}},
{"hipblasZgemvBatched_v2_64", {HIP_6020, HIP_0, HIP_0, HIP_LATEST}},
{"hipblasSgemvStridedBatched", {HIP_3000, HIP_0, HIP_0 }},
{"hipblasDgemvStridedBatched", {HIP_3000, HIP_0, HIP_0 }},

{"rocblas_status_to_string", {HIP_3050, HIP_0, HIP_0 }},
{"rocblas_sscal", {HIP_1050, HIP_0, HIP_0 }},
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/synthetic/libraries/cublas2hipblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1703,7 +1703,7 @@ int main() {
blasStatus = cublasGemmStridedBatchedEx(blasHandle, transa, transb, m, n, k, aptr, Aptr, Atype, lda, strideA, Bptr, Btype, ldb, strideB, bptr, Cptr, Ctype, ldc, strideC, batchCount, blasComputeType, blasGemmAlgo);
#endif

#if CUDA_VERSION >= 11060 && CUBLAS_VERSION >= 110902 // CUDA 11.6.2
#if CUDA_VERSION > 11060 && CUBLAS_VERSION >= 110902 // CUDA 11.6.2
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasCgemvBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const cuComplex* alpha, const cuComplex* const Aarray[], int lda, const cuComplex* const xarray[], int incx, const cuComplex* beta, cuComplex* const yarray[], int incy, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasCgemvBatched_v2(hipblasHandle_t handle, hipblasOperation_t trans, int m, int n, const hipComplex* alpha, const hipComplex* const AP[], int lda, const hipComplex* const x[], int incx, const hipComplex* beta, hipComplex* const y[], int incy, int batchCount);
// CHECK: blasStatus = hipblasCgemvBatched_v2(blasHandle, blasOperation, m, n, &complexa, complexAarray_const, lda, complexXarray_const, incx, &complexb, complexYarray, incy, batchCount);
Expand Down
14 changes: 14 additions & 0 deletions tests/unit_tests/synthetic/libraries/cublas2hipblas_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,8 @@ int main() {
long long int strideA = 0;
long long int strideB = 0;
long long int strideC = 0;
long long int strideX = 0;
long long int strideY = 0;

#if CUDA_VERSION >= 7050
// CHECK: __half* ha = 0;
Expand Down Expand Up @@ -1875,6 +1877,18 @@ int main() {
blasStatus = cublasGemmStridedBatchedEx(blasHandle, transa, transb, m, n, k, aptr, Aptr, Atype, lda, strideA, Bptr, Btype, ldb, strideB, bptr, Cptr, Ctype, ldc, strideC, batchCount, blasComputeType, blasGemmAlgo);
#endif

#if CUDA_VERSION > 11060 && CUBLAS_VERSION >= 110902 // CUDA 11.6.2
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasSgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const float* alpha, const float* A, int lda, long long int strideA, const float* x, int incx, long long int stridex, const float* beta, float* y, int incy, long long int stridey, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasSgemvStridedBatched(hipblasHandle_t handle, hipblasOperation_t transA, int m, int n, const float* alpha, const float* AP, int lda, hipblasStride strideA, const float* x, int incx, hipblasStride stridex, const float* beta, float* y, int incy, hipblasStride stridey, int batchCount);
// CHECK: blasStatus = hipblasSgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, &fA, lda, strideA, &fx, incx, strideX, &fb, &fy, incy, strideY, batchCount);
blasStatus = cublasSgemvStridedBatched(blasHandle, blasOperation, m, n, &fa, &fA, lda, strideA, &fx, incx, strideX, &fb, &fy, incy, strideY, batchCount);

// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasDgemvStridedBatched(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const double* alpha, const double* A, int lda, long long int strideA, const double* x, int incx, long long int stridex, const double* beta, double* y, int incy, long long int stridey, int batchCount);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasDgemvStridedBatched(hipblasHandle_t handle, hipblasOperation_t transA, int m, int n, const double* alpha, const double* AP, int lda, hipblasStride strideA, const double* x, int incx, hipblasStride stridex, const double* beta, double* y, int incy, hipblasStride stridey, int batchCount);
// CHECK: blasStatus = hipblasDgemvStridedBatched(blasHandle, blasOperation, m, n, &da, &dA, lda, strideA, &dx, incx, strideX, &db, &dy, incy, strideY, batchCount);
blasStatus = cublasDgemvStridedBatched(blasHandle, blasOperation, m, n, &da, &dA, lda, strideA, &dx, incx, strideX, &db, &dy, incy, strideY, batchCount);
#endif

#if CUDA_VERSION >= 12000
// CUDA: CUBLASAPI cublasStatus_t CUBLASWINAPI cublasIsamax_v2_64(cublasHandle_t handle, int64_t n, const float* x, int64_t incx, int64_t* result);
// HIP: HIPBLAS_EXPORT hipblasStatus_t hipblasIsamax_64(hipblasHandle_t handle, int64_t n, const float* x, int64_t incx, int64_t* result);
Expand Down

0 comments on commit 967de9b

Please sign in to comment.