diff --git a/clang/lib/DPCT/RulesMathLib/BLASAPIMigration.cpp b/clang/lib/DPCT/RulesMathLib/BLASAPIMigration.cpp index a2a8c42298c6..c5720cdc7ad9 100644 --- a/clang/lib/DPCT/RulesMathLib/BLASAPIMigration.cpp +++ b/clang/lib/DPCT/RulesMathLib/BLASAPIMigration.cpp @@ -723,13 +723,20 @@ void BLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) { ReplInfo.BufferTypeInfo[ReplInfo.BufferTypeInfo.size() - 1]; std::string ReturnValueParamsStr; if (DpctGlobalInfo::getUsmLevel() == UsmLevel::UL_Restricted) { + CallExprReplStr = CallExprReplStr + ", " + ResultTempPtr; + if (FuncName == "cublasIsamax" || FuncName == "cublasIdamax" || + FuncName == "cublasIcamax" || FuncName == "cublasIzamax" || + FuncName == "cublasIsamin" || FuncName == "cublasIdamin" || + FuncName == "cublasIcamin" || FuncName == "cublasIzamin") { + CallExprReplStr = CallExprReplStr + ", oneapi::mkl::index_base::one"; + } requestFeature(HelperFeatureEnum::device_ext); auto DefaultQueue = DpctGlobalInfo::getDefaultQueue(CE); PrefixInsertStr = PrefixInsertStr + ResultType + "* " + ResultTempPtr + " = " + MapNames::getClNamespace() + - "malloc_shared<" + ResultType + ">(1, " + DefaultQueue + ");" + - getNL() + IndentStr + CallExprReplStr + ", " + - ResultTempPtr + ").wait();" + getNL() + IndentStr; + "malloc_shared<" + ResultType + ">(1, " + + DefaultQueue + ");" + getNL() + IndentStr + + CallExprReplStr + ").wait();" + getNL() + IndentStr; ReturnValueParamsStr = "(" + ResultTempPtr + "->real(), " + ResultTempPtr + "->imag())"; @@ -748,11 +755,18 @@ void BLASFunctionCallRule::runRule(const MatchFinder::MatchResult &Result) { ResultTempPtr + ", " + DefaultQueue + ");"; } } else { + CallExprReplStr = CallExprReplStr + ", " + ResultTempBuf; + if (FuncName == "cublasIsamax" || FuncName == "cublasIdamax" || + FuncName == "cublasIcamax" || FuncName == "cublasIzamax" || + FuncName == "cublasIsamin" || FuncName == "cublasIdamin" || + FuncName == "cublasIcamin" || FuncName == "cublasIzamin") { + CallExprReplStr = CallExprReplStr + ", oneapi::mkl::index_base::one"; + } PrefixInsertStr = PrefixInsertStr + MapNames::getClNamespace() + "buffer<" + ResultType + "> " + ResultTempBuf + "(" + MapNames::getClNamespace() + "range<1>(1));" + - getNL() + IndentStr + CallExprReplStr + ", " + - ResultTempBuf + ");" + getNL() + IndentStr; + getNL() + IndentStr + CallExprReplStr + ");" + + getNL() + IndentStr; ReturnValueParamsStr = "(" + ResultTempBuf + ".get_host_access(" + MapNames::getClNamespace() + "read_only)[0].real(), " + diff --git a/clang/test/dpct/cublas-usm-legacy.cu b/clang/test/dpct/cublas-usm-legacy.cu index 873f5512b4b9..c1bc3b38a87c 100644 --- a/clang/test/dpct/cublas-usm-legacy.cu +++ b/clang/test/dpct/cublas-usm-legacy.cu @@ -66,22 +66,22 @@ int main() { // CHECK: int res; // CHECK-NEXT: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared(1, q_ct1); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S, incx, res_temp_ptr_ct{{[0-9]+}}).wait(); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait(); // CHECK-NEXT: res = *res_temp_ptr_ct{{[0-9]+}}; // CHECK-NEXT:sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1); int res = cublasIsamax(n, x_S, incx); // CHECK: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared(1, q_ct1); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_D, incx, res_temp_ptr_ct{{[0-9]+}}).wait(); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_D, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait(); // CHECK-NEXT: res = *res_temp_ptr_ct{{[0-9]+}}; // CHECK-NEXT:sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1); res = cublasIdamax(n, x_D, incx); // CHECK: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared(1, q_ct1); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_C, incx, res_temp_ptr_ct{{[0-9]+}}).wait(); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_C, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait(); // CHECK-NEXT: res = *res_temp_ptr_ct{{[0-9]+}}; // CHECK-NEXT:sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1); res = cublasIcamax(n, x_C, incx); // CHECK: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared(1, q_ct1); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}).wait(); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait(); // CHECK-NEXT: res = *res_temp_ptr_ct{{[0-9]+}}; // CHECK-NEXT:sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1); res = cublasIzamax(n, x_Z, incx); @@ -89,7 +89,7 @@ int main() { // Because the return value of origin API is the result value, not the status, so keep using lambda here. // CHECK: if([&](){ // CHECK-NEXT: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared(1, q_ct1); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}).wait(); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait(); // CHECK-NEXT: int64_t res_temp_val_ct{{[0-9]+}} = *res_temp_ptr_ct{{[0-9]+}}; // CHECK-NEXT: sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1); // CHECK-NEXT: return res_temp_val_ct{{[0-9]+}}; @@ -98,7 +98,7 @@ int main() { // CHECK: if(0!=[&](){ // CHECK-NEXT: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared(1, q_ct1); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}).wait(); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait(); // CHECK-NEXT: int64_t res_temp_val_ct{{[0-9]+}} = *res_temp_ptr_ct{{[0-9]+}}; // CHECK-NEXT: sycl::free(res_temp_ptr_ct{{[0-9]+}}, q_ct1); // CHECK-NEXT: return res_temp_val_ct{{[0-9]+}}; @@ -233,7 +233,7 @@ int main() { //CHECK:int foo(){ //CHECK-NEXT: return [&](){ //CHECK-NEXT: int64_t* res_temp_ptr_ct{{[0-9]+}} = sycl::malloc_shared(1, dpct::get_in_order_queue()); -//CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}).wait(); +//CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, (std::complex*)x_Z, incx, res_temp_ptr_ct{{[0-9]+}}, oneapi::mkl::index_base::one).wait(); //CHECK-NEXT: int64_t res_temp_val_ct{{[0-9]+}} = *res_temp_ptr_ct{{[0-9]+}}; //CHECK-NEXT: sycl::free(res_temp_ptr_ct{{[0-9]+}}, dpct::get_in_order_queue()); //CHECK-NEXT: return res_temp_val_ct{{[0-9]+}}; diff --git a/clang/test/dpct/cublasLegacyCZ.cu b/clang/test/dpct/cublasLegacyCZ.cu index 4e3af19e9bed..db9768311c7f 100644 --- a/clang/test/dpct/cublasLegacyCZ.cu +++ b/clang/test/dpct/cublasLegacyCZ.cu @@ -60,7 +60,7 @@ int main() { // CHECK-NEXT: { // CHECK-NEXT: auto x_C_buf_ct{{[0-9]+}} = dpct::get_buffer>(x_C); // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_C_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_C_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: res = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: } int res = cublasIcamax(n, x_C, incx); @@ -68,7 +68,7 @@ int main() { // CHECK: { // CHECK-NEXT: auto x_Z_buf_ct{{[0-9]+}} = dpct::get_buffer>(x_Z); // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_Z_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_Z_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: } *result = cublasIzamax(n, x_Z, incx); @@ -77,7 +77,7 @@ int main() { // CHECK: { // CHECK-NEXT: auto x_C_buf_ct{{[0-9]+}} = dpct::get_buffer>(x_C); // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_C_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_C_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: } *result = cublasIcamin(n, x_C, incx); @@ -85,7 +85,7 @@ int main() { // CHECK: { // CHECK-NEXT: auto x_Z_buf_ct{{[0-9]+}} = dpct::get_buffer>(x_Z); // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_Z_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_Z_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: } *result = cublasIzamin(n, x_Z, incx); diff --git a/clang/test/dpct/cublasLegacyLv123.cu b/clang/test/dpct/cublasLegacyLv123.cu index 21e448193852..28ed41f27886 100644 --- a/clang/test/dpct/cublasLegacyLv123.cu +++ b/clang/test/dpct/cublasLegacyLv123.cu @@ -50,7 +50,7 @@ int main() { // CHECK-NEXT: { // CHECK-NEXT: auto x_S_buf_ct{{[0-9]+}} = dpct::get_buffer(x_S); // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: res = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: } int res = cublasIsamax(n, x_S, incx); @@ -58,7 +58,7 @@ int main() { // CHECK: { // CHECK-NEXT: auto x_D_buf_ct{{[0-9]+}} = dpct::get_buffer(x_D); // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_D_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_D_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: } *result = cublasIdamax(n, x_D, incx); @@ -67,7 +67,7 @@ int main() { // CHECK: { // CHECK-NEXT: auto x_S_buf_ct{{[0-9]+}} = dpct::get_buffer(x_S); // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: } *result = cublasIsamin(n, x_S, incx); @@ -75,7 +75,7 @@ int main() { // CHECK: { // CHECK-NEXT: auto x_D_buf_ct{{[0-9]+}} = dpct::get_buffer(x_D); // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_D_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamin(dpct::blas::descriptor::get_saved_queue(), n, x_D_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: *result = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: } *result = cublasIdamin(n, x_D, incx); @@ -627,7 +627,7 @@ int main() { // CHECK: for(int i = [&](){ // CHECK-NEXT: auto x_S_buf_ct{{[0-9]+}} = dpct::get_buffer(x_S); // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); - // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); + // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: return res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: }();;){} for(int i = cublasIsamax(n, x_S, incx);;){} @@ -640,7 +640,7 @@ int main() { //CHECK-NEXT: return [&](){ //CHECK-NEXT: auto x_S_buf_ct{{[0-9]+}} = dpct::get_buffer(x_S); //CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); -//CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}); +//CHECK-NEXT: oneapi::mkl::blas::column_major::iamax(dpct::blas::descriptor::get_saved_queue(), n, x_S_buf_ct{{[0-9]+}}, incx, res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); //CHECK-NEXT: return res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; //CHECK-NEXT: }(); //CHECK-NEXT:} diff --git a/clang/test/dpct/error-handling.cu b/clang/test/dpct/error-handling.cu index bd77999de606..dba13a86741e 100644 --- a/clang/test/dpct/error-handling.cu +++ b/clang/test/dpct/error-handling.cu @@ -705,7 +705,7 @@ void foo12() { // CHECK-NEXT: sycl::buffer res_temp_buf_ct{{[0-9]+}}(sycl::range<1>(1)); // CHECK-NEXT: oneapi::mkl::blas::column_major::iamax( // CHECK-NEXT: dpct::blas::descriptor::get_saved_queue(), 10, ct_0_buf_ct{{[0-9]+}}, 0, -// CHECK-NEXT: res_temp_buf_ct{{[0-9]+}}); +// CHECK-NEXT: res_temp_buf_ct{{[0-9]+}}, oneapi::mkl::index_base::one); // CHECK-NEXT: res = res_temp_buf_ct{{[0-9]+}}.get_host_access(sycl::read_only)[0]; // CHECK-NEXT: } // CHECK-NEXT: }