diff --git a/projects/rocblas/clients/common/rocblas_arguments.cpp b/projects/rocblas/clients/common/rocblas_arguments.cpp index f4828a11263..c1d53fbd251 100644 --- a/projects/rocblas/clients/common/rocblas_arguments.cpp +++ b/projects/rocblas/clients/common/rocblas_arguments.cpp @@ -152,6 +152,7 @@ void Arguments::init() HMM = false; graph_test = false; repeatability_check = false; + alpha_beta_stride = false; use_hipblaslt = -1; diff --git a/projects/rocblas/clients/gtest/CMakeLists.txt b/projects/rocblas/clients/gtest/CMakeLists.txt index 6195ecff4b2..290b8138e9a 100644 --- a/projects/rocblas/clients/gtest/CMakeLists.txt +++ b/projects/rocblas/clients/gtest/CMakeLists.txt @@ -44,6 +44,7 @@ set(rocblas_no_tensile_test_source general_gtest.cpp set_get_pointer_mode_gtest.cpp set_get_atomics_mode_gtest.cpp + set_get_alpha_beta_stride_gtest.cpp logging_mode_gtest.cpp ostream_threadsafety_gtest.cpp set_get_vector_gtest.cpp diff --git a/projects/rocblas/clients/gtest/aux_gtest.yaml b/projects/rocblas/clients/gtest/aux_gtest.yaml index b9dd3fb5566..d2354555d17 100644 --- a/projects/rocblas/clients/gtest/aux_gtest.yaml +++ b/projects/rocblas/clients/gtest/aux_gtest.yaml @@ -4,6 +4,7 @@ include: set_get_vector_gtest.yaml include: logging_mode_gtest.yaml include: set_get_pointer_mode_gtest.yaml include: set_get_atomics_mode_gtest.yaml +include: set_get_alpha_beta_stride_gtest.yaml include: ostream_threadsafety_gtest.yaml include: multiheaded_gtest.yaml include: atomics_mode_gtest.yaml diff --git a/projects/rocblas/clients/gtest/blas2/gemv_gtest.cpp b/projects/rocblas/clients/gtest/blas2/gemv_gtest.cpp index e825b186af9..06b3b7afef7 100644 --- a/projects/rocblas/clients/gtest/blas2/gemv_gtest.cpp +++ b/projects/rocblas/clients/gtest/blas2/gemv_gtest.cpp @@ -168,8 +168,16 @@ namespace name << '_' << arg.stride_y; if(GEMV_TYPE != GEMV) + { name << '_' << arg.batch_count; + if(arg.alpha_beta_stride) + { + name << '_' << arg.stride_c; + name << '_' << arg.stride_d; + } + } + if(arg.api & c_API_64) { name << "_I64"; diff --git a/projects/rocblas/clients/gtest/gemv_gtest.yaml b/projects/rocblas/clients/gtest/gemv_gtest.yaml index be4506efc93..18c4cd2b9ed 100644 --- a/projects/rocblas/clients/gtest/gemv_gtest.yaml +++ b/projects/rocblas/clients/gtest/gemv_gtest.yaml @@ -651,6 +651,22 @@ Tests: batch_count: [ 2 ] gpu_arch: '9??' +- name: gemv_stride_alpha_beta + category: pre_checkin + function: + - gemv_batched + - gemv_strided_batched + precision: *gemv_bfloat_half_single_double_complex_real_precisions + transA: [ N, T, C ] + matrix_size: *qmcpack_matrix_size_range + incx_incy: *incx_incy_range + alpha_beta: *alpha_beta_range + batch_count: [ 3 ] + alpha_beta_stride: true + stride_c: 2 # alpha beta stride + stride_d: 2 + pointer_mode_host: false + - name: gemv_row_vectorized_coverage category: pre_checkin function: diff --git a/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp new file mode 100644 index 00000000000..ebb8c054f8f --- /dev/null +++ b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp @@ -0,0 +1,103 @@ +/* ************************************************************************ + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ + +#include "client_utility.hpp" +#include "rocblas.hpp" +#include "rocblas_data.hpp" +#include "rocblas_datatype2string.hpp" +#include "rocblas_test.hpp" +#include + +namespace +{ + template + struct testing_set_get_alpha_beta_stride : rocblas_test_valid + { + void operator()(const Arguments&) + { + rocblas_handle handle; + CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); + + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); + + rocblas_stride batch_alpha_stride = -1; + rocblas_stride batch_beta_stride = -1; + CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride)); + EXPECT_EQ(0, batch_alpha_stride); + EXPECT_EQ(0, batch_beta_stride); + + CHECK_ROCBLAS_ERROR(rocblas_set_batch_alpha_stride(handle, 7)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride)); + EXPECT_EQ(7, batch_alpha_stride); + CHECK_ROCBLAS_ERROR(rocblas_set_batch_beta_stride(handle, 7)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride)); + EXPECT_EQ(7, batch_beta_stride); + + CHECK_ROCBLAS_ERROR(rocblas_set_batch_alpha_stride(handle, 0)); + CHECK_ROCBLAS_ERROR(rocblas_set_batch_beta_stride(handle, 0)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride)); + EXPECT_EQ(0, batch_alpha_stride); + EXPECT_EQ(0, batch_beta_stride); + + // stored regardless of mode, but only utilized in device mode kernels + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + + CHECK_ROCBLAS_ERROR(rocblas_set_batch_alpha_stride(handle, 7)); + CHECK_ROCBLAS_ERROR(rocblas_set_batch_beta_stride(handle, 11)); + + CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride)); + EXPECT_EQ(7, batch_alpha_stride); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride)); + EXPECT_EQ(11, batch_beta_stride); + + CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); + } + }; + + struct set_get_alpha_beta_stride + : RocBLAS_Test + { + static bool type_filter(const Arguments&) + { + return true; + } + + static bool function_filter(const Arguments& arg) + { + return !strcmp(arg.function, "set_get_alpha_beta_stride"); + } + + static std::string name_suffix(const Arguments& arg) + { + return RocBLAS_TestName(arg.name); + } + }; + + TEST_P(set_get_alpha_beta_stride, auxiliary_tensile) + { + CATCH_SIGNALS_AND_EXCEPTIONS_AS_FAILURES(testing_set_get_alpha_beta_stride<>{}(GetParam())); + } + INSTANTIATE_TEST_CATEGORIES(set_get_alpha_beta_stride) + +} // namespace diff --git a/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.yaml b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.yaml new file mode 100644 index 00000000000..f4af643f25a --- /dev/null +++ b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.yaml @@ -0,0 +1,10 @@ +--- +include: rocblas_common.yaml +include: known_bugs.yaml + +Tests: +- name: set_get_alpha_beta_stride + category: quick + function: set_get_alpha_beta_stride + precision: *single_precision +... diff --git a/projects/rocblas/clients/include/blas2/testing_gemv_batched.hpp b/projects/rocblas/clients/include/blas2/testing_gemv_batched.hpp index 90452cbe140..db67c67465d 100644 --- a/projects/rocblas/clients/include/blas2/testing_gemv_batched.hpp +++ b/projects/rocblas/clients/include/blas2/testing_gemv_batched.hpp @@ -300,6 +300,10 @@ void testing_gemv_batched(const Arguments& arg) rocblas_operation transA = char2rocblas_operation(arg.transA); int64_t batch_count = arg.batch_count; + bool ab_striding = arg.alpha_beta_stride; + int64_t alpha_stride = ab_striding ? arg.stride_c : 0; + int64_t beta_stride = ab_striding ? arg.stride_d : 0; + rocblas_local_handle handle{arg}; // argument sanity check before allocating invalid memory @@ -344,23 +348,24 @@ void testing_gemv_batched(const Arguments& arg) HOST_MEMCHECK(host_batch_vector, hx, (dim_x, incx, batch_count)); HOST_MEMCHECK(host_batch_vector, hy, (dim_y, incy, batch_count)); HOST_MEMCHECK(host_batch_vector, hy_gold, (dim_y, incy, batch_count)); - HOST_MEMCHECK(host_vector, halpha, (1)); - HOST_MEMCHECK(host_vector, hbeta, (1)); + HOST_MEMCHECK(host_vector, halpha, (batch_count, alpha_stride)); + HOST_MEMCHECK(host_vector, hbeta, (batch_count, beta_stride)); // Allocate device memory DEVICE_MEMCHECK(device_batch_matrix, dA, (M, N, lda, batch_count)); DEVICE_MEMCHECK(device_batch_vector, dx, (dim_x, incx, batch_count)); DEVICE_MEMCHECK(device_batch_vector, dy, (dim_y, incy, batch_count)); - DEVICE_MEMCHECK(device_vector, d_alpha, (1)); - DEVICE_MEMCHECK(device_vector, d_beta, (1)); + DEVICE_MEMCHECK(device_vector, d_alpha, (batch_count, alpha_stride)); + DEVICE_MEMCHECK(device_vector, d_beta, (batch_count, beta_stride)); // Initialize data on host memory rocblas_init_matrix( hA, arg, rocblas_client_alpha_sets_nan, rocblas_client_general_matrix, true); rocblas_init_vector(hx, arg, rocblas_client_alpha_sets_nan, false, true); rocblas_init_vector(hy, arg, rocblas_client_beta_sets_nan); - halpha[0] = h_alpha; - hbeta[0] = h_beta; + + rocblas_init_vector_alternating_zero(halpha, h_alpha); + rocblas_init_vector_alternating_zero(hbeta, h_beta); hy_gold.copy_from(hy); @@ -448,8 +453,8 @@ void testing_gemv_batched(const Arguments& arg) DEVICE_MEMCHECK(device_batch_matrix, dA_copy, (M, N, lda, batch_count)); DEVICE_MEMCHECK(device_batch_vector, dx_copy, (dim_x, incx, batch_count)); DEVICE_MEMCHECK(device_batch_vector, dy_copy, (dim_y, incy, batch_count)); - DEVICE_MEMCHECK(device_vector, d_alpha_copy, (1)); - DEVICE_MEMCHECK(device_vector, d_beta_copy, (1)); + DEVICE_MEMCHECK(device_vector, d_alpha_copy, (batch_count, alpha_stride)); + DEVICE_MEMCHECK(device_vector, d_beta_copy, (batch_count, beta_stride)); CHECK_HIP_ERROR(dA_copy.transfer_from(hA)); CHECK_HIP_ERROR(dx_copy.transfer_from(hx)); @@ -489,8 +494,17 @@ void testing_gemv_batched(const Arguments& arg) cpu_time_used = get_time_us_no_sync(); for(int64_t b = 0; b < batch_count; ++b) { - ref_gemv( - transA, M, N, h_alpha, hA[b], lda, hx[b], incx, h_beta, hy_gold[b], incy); + ref_gemv(transA, + M, + N, + halpha[b * alpha_stride], + hA[b], + lda, + hx[b], + incx, + hbeta[b * beta_stride], + hy_gold[b], + incy); } cpu_time_used = get_time_us_no_sync() - cpu_time_used; diff --git a/projects/rocblas/clients/include/blas2/testing_gemv_strided_batched.hpp b/projects/rocblas/clients/include/blas2/testing_gemv_strided_batched.hpp index 651aded0d32..6d5eb650640 100644 --- a/projects/rocblas/clients/include/blas2/testing_gemv_strided_batched.hpp +++ b/projects/rocblas/clients/include/blas2/testing_gemv_strided_batched.hpp @@ -342,6 +342,10 @@ void testing_gemv_strided_batched(const Arguments& arg) int64_t stride_y = arg.stride_y; int64_t batch_count = arg.batch_count; + bool ab_striding = arg.alpha_beta_stride; + int64_t alpha_stride = ab_striding ? arg.stride_c : 0; + int64_t beta_stride = ab_striding ? arg.stride_d : 0; + rocblas_local_handle handle{arg}; size_t dim_x, row_A; @@ -393,17 +397,15 @@ void testing_gemv_strided_batched(const Arguments& arg) HOST_MEMCHECK(host_strided_batch_vector, hx, (dim_x, incx, stride_x, batch_count)); HOST_MEMCHECK(host_strided_batch_vector, hy, (dim_y, incy, stride_y, batch_count)); HOST_MEMCHECK(host_strided_batch_vector, hy_gold, (dim_y, incy, stride_y, batch_count)); - HOST_MEMCHECK(host_vector, halpha, (1)); - HOST_MEMCHECK(host_vector, hbeta, (1)); - halpha[0] = h_alpha; - hbeta[0] = h_beta; + HOST_MEMCHECK(host_vector, halpha, (batch_count, alpha_stride)); + HOST_MEMCHECK(host_vector, hbeta, (batch_count, beta_stride)); // Allocate device memory DEVICE_MEMCHECK(device_strided_batch_matrix, dA, (M, N, lda, stride_a, batch_count)); DEVICE_MEMCHECK(device_strided_batch_vector, dx, (dim_x, incx, stride_x, batch_count)); DEVICE_MEMCHECK(device_strided_batch_vector, dy, (dim_y, incy, stride_y, batch_count)); - DEVICE_MEMCHECK(device_vector, d_alpha, (1)); - DEVICE_MEMCHECK(device_vector, d_beta, (1)); + DEVICE_MEMCHECK(device_vector, d_alpha, (batch_count, alpha_stride)); + DEVICE_MEMCHECK(device_vector, d_beta, (batch_count, beta_stride)); // Initialize data on host memory rocblas_init_matrix( @@ -411,6 +413,9 @@ void testing_gemv_strided_batched(const Arguments& arg) rocblas_init_vector(hx, arg, rocblas_client_alpha_sets_nan, false, true); rocblas_init_vector(hy, arg, rocblas_client_beta_sets_nan); + rocblas_init_vector_alternating_zero(halpha, h_alpha); + rocblas_init_vector_alternating_zero(hbeta, h_beta); + hy_gold.copy_from(hy); // copy data from CPU to device @@ -511,8 +516,8 @@ void testing_gemv_strided_batched(const Arguments& arg) DEVICE_MEMCHECK(device_strided_batch_vector, dy_copy, (dim_y, incy, stride_y, batch_count)); - DEVICE_MEMCHECK(device_vector, d_alpha_copy, (1)); - DEVICE_MEMCHECK(device_vector, d_beta_copy, (1)); + DEVICE_MEMCHECK(device_vector, d_alpha_copy, (batch_count, alpha_stride)); + DEVICE_MEMCHECK(device_vector, d_beta_copy, (batch_count, beta_stride)); CHECK_HIP_ERROR(dA_copy.transfer_from(hA)); CHECK_HIP_ERROR(dx_copy.transfer_from(hx)); @@ -555,8 +560,17 @@ void testing_gemv_strided_batched(const Arguments& arg) cpu_time_used = get_time_us_no_sync(); for(int64_t b = 0; b < batch_count; ++b) { - ref_gemv( - transA, M, N, h_alpha, hA[b], lda, hx[b], incx, h_beta, hy_gold[b], incy); + ref_gemv(transA, + M, + N, + halpha[b * alpha_stride], + hA[b], + lda, + hx[b], + incx, + hbeta[b * beta_stride], + hy_gold[b], + incy); } cpu_time_used = get_time_us_no_sync() - cpu_time_used; diff --git a/projects/rocblas/clients/include/client_utility.hpp b/projects/rocblas/clients/include/client_utility.hpp index c2908a7e337..1cb09e50bfc 100644 --- a/projects/rocblas/clients/include/client_utility.hpp +++ b/projects/rocblas/clients/include/client_utility.hpp @@ -148,6 +148,12 @@ class rocblas_local_handle void pre_test(const Arguments& arg) { + if(arg.alpha_beta_stride && arg.pointer_mode_device) + { + // for now no GEMM applicability so c/d stride used + rocblas_set_batch_alpha_stride(m_handle, arg.stride_c); + rocblas_set_batch_beta_stride(m_handle, arg.stride_d); + } #if HIP_VERSION >= 50500000 arg.graph_test ? rocblas_stream_begin_capture() : NOOP; #endif @@ -158,6 +164,11 @@ class rocblas_local_handle #if HIP_VERSION >= 50500000 arg.graph_test ? rocblas_stream_end_capture() : NOOP; #endif + if(arg.alpha_beta_stride && arg.pointer_mode_device) + { + rocblas_set_batch_alpha_stride(m_handle, 0); + rocblas_set_batch_beta_stride(m_handle, 0); + } } }; diff --git a/projects/rocblas/clients/include/rocblas_arguments.hpp b/projects/rocblas/clients/include/rocblas_arguments.hpp index ec9d2d398d1..02c79bbad89 100644 --- a/projects/rocblas/clients/include/rocblas_arguments.hpp +++ b/projects/rocblas/clients/include/rocblas_arguments.hpp @@ -182,6 +182,7 @@ struct Arguments bool HMM; // xnack+ bool graph_test; bool repeatability_check; + bool alpha_beta_stride; int use_hipblaslt; @@ -272,6 +273,7 @@ struct Arguments OPER(HMM) SEP \ OPER(graph_test) SEP \ OPER(repeatability_check) SEP \ + OPER(alpha_beta_stride) SEP \ OPER(use_hipblaslt) SEP \ OPER(cleanup) // clang-format on diff --git a/projects/rocblas/clients/include/rocblas_common.yaml b/projects/rocblas/clients/include/rocblas_common.yaml index d68535e33bd..2dc3a529752 100644 --- a/projects/rocblas/clients/include/rocblas_common.yaml +++ b/projects/rocblas/clients/include/rocblas_common.yaml @@ -586,6 +586,7 @@ Arguments: - HMM: c_bool - graph_test: c_bool - repeatability_check: c_bool + - alpha_beta_stride: c_bool - use_hipblaslt: c_int32 - cleanup: c_bool @@ -638,6 +639,10 @@ Defaults: alphai: 0.0 beta: 0.0 betai: 0.0 + stride_a: 0 + stride_b: 0 + stride_c: 0 + stride_d: 0 transA: '*' transB: '*' side: '*' @@ -659,6 +664,7 @@ Defaults: api: C graph_test: false repeatability_check: false + alpha_beta_stride: false norm_check: 0 unit_check: 1 res_check: 0 diff --git a/projects/rocblas/clients/include/rocblas_init.hpp b/projects/rocblas/clients/include/rocblas_init.hpp index 735789a002b..49e2465d2b7 100644 --- a/projects/rocblas/clients/include/rocblas_init.hpp +++ b/projects/rocblas/clients/include/rocblas_init.hpp @@ -51,6 +51,18 @@ typedef enum rocblas_check_nan_init_ } rocblas_check_nan_init; +// Initialize vector so adjacent entries have alternating zero and passed value. +template +void rocblas_init_vector_alternating_zero(host_vector& A, T value) +{ + auto M = A.n(); + auto inc = A.inc(); + for(size_t i = 0; i < M; ++i) + { + A[i * inc] = (i & 1) ? T(0) : value; + } +} + // Initialize matrix so adjacent entries have alternating sign. // In gemm if either A or B are initialized with alternating // sign the reduction sum will be summing positive diff --git a/projects/rocblas/docs/reference/helper-functions.rst b/projects/rocblas/docs/reference/helper-functions.rst index 5af8f6e9313..42a61877814 100644 --- a/projects/rocblas/docs/reference/helper-functions.rst +++ b/projects/rocblas/docs/reference/helper-functions.rst @@ -19,6 +19,10 @@ Auxiliary functions .. doxygenfunction:: rocblas_get_pointer_mode .. doxygenfunction:: rocblas_set_atomics_mode .. doxygenfunction:: rocblas_get_atomics_mode +.. doxygenfunction:: rocblas_set_batch_alpha_stride +.. doxygenfunction:: rocblas_get_batch_alpha_stride +.. doxygenfunction:: rocblas_set_batch_beta_stride +.. doxygenfunction:: rocblas_get_batch_beta_stride .. doxygenfunction:: rocblas_pointer_to_mode .. doxygenfunction:: rocblas_initialize .. doxygenfunction:: rocblas_status_to_string diff --git a/projects/rocblas/docs/reference/level-2.rst b/projects/rocblas/docs/reference/level-2.rst index d11a1ab7400..1a557480e5d 100644 --- a/projects/rocblas/docs/reference/level-2.rst +++ b/projects/rocblas/docs/reference/level-2.rst @@ -85,6 +85,8 @@ The ``gemv`` functions support the ``_64`` interface. See the :ref:`ILP64 API` s ``gemv_batched`` functions have an implementation which uses atomic operations. See the :ref:`Atomic Operations` section for more information. The ``gemv_batched`` functions support the ``_64`` interface. See the :ref:`ILP64 API` section. +The ``gemv_batched`` functions support ``rocblas_set_batch_alpha_stride`` and ``rocblas_set_batch_beta_stride`` when the ``rocblas_handle`` is +in mode ``rocblas_pointer_mode_device``. .. doxygenfunction:: rocblas_sgemv_strided_batched :outline: @@ -104,6 +106,8 @@ The ``gemv_batched`` functions support the ``_64`` interface. See the :ref:`ILP6 ``gemv_strided_batched`` functions have an implementation which uses atomic operations. See the :ref:`Atomic Operations` section for more information. The ``gemv_strided_batched`` functions support the ``_64`` interface. See the :ref:`ILP64 API` section. +The ``gemv_strided_batched`` functions support ``rocblas_set_batch_alpha_stride`` and ``rocblas_set_batch_beta_stride`` when the ``rocblas_handle`` is +in mode ``rocblas_pointer_mode_device``. .. _rocblas_ger: diff --git a/projects/rocblas/library/include/internal/rocblas-auxiliary.h b/projects/rocblas/library/include/internal/rocblas-auxiliary.h index 81a50ce4be9..d3945818cc4 100644 --- a/projects/rocblas/library/include/internal/rocblas-auxiliary.h +++ b/projects/rocblas/library/include/internal/rocblas-auxiliary.h @@ -75,6 +75,34 @@ ROCBLAS_EXPORT rocblas_status rocblas_set_atomics_mode(rocblas_handle hand ROCBLAS_EXPORT rocblas_status rocblas_get_atomics_mode(rocblas_handle handle, rocblas_atomics_mode* atomics_mode); +/*! \brief Set alpha stride for limited set of batched and strided_batched functions to specify the stride for alpha between successive batch elements. +Only applies to rocblas_pointer_mode_device and thus device side allocations. +It enables interpretation of the alpha pointer for both batched and strided_batched functions as a pointer to a vector of values. +Default value is 0 which treats it as a pointer to a single scalar. Support is denoted with specific function documentation. +Warning this is a modal like state in the handle. Restore to value 0 if no longer applicable to later function calls. + */ +ROCBLAS_EXPORT rocblas_status rocblas_set_batch_alpha_stride(rocblas_handle handle, + rocblas_stride alpha_stride); + +/*! \brief Get batch alpha stride from the handle. + */ +ROCBLAS_EXPORT rocblas_status rocblas_get_batch_alpha_stride(rocblas_handle handle, + rocblas_stride* alpha_stride); + +/*! \brief Set beta stride for limited set of batched and strided_batched functions to specify the stride for beta between successive batch elements. +Only applies to rocblas_pointer_mode_device and thus device side allocations. +It enables interpretation of the beta pointer for both batched and strided_batched functions as a pointer to a vector of values. +Default value is 0 which treats it as a pointer to a single scalar. Support is denoted with specific function documentation. +Warning this is a modal like state in the handle. Restore to value 0 if no longer applicable to later function calls. + */ +ROCBLAS_EXPORT rocblas_status rocblas_set_batch_beta_stride(rocblas_handle handle, + rocblas_stride beta_stride); + +/*! \brief Get batch beta stride from the handle. + */ +ROCBLAS_EXPORT rocblas_status rocblas_get_batch_beta_stride(rocblas_handle handle, + rocblas_stride* beta_stride); + /*! \brief Set ``rocblas_math_mode``. */ ROCBLAS_EXPORT rocblas_status rocblas_set_math_mode(rocblas_handle handle, diff --git a/projects/rocblas/library/src/blas2/rocblas_gemv_batched_imp.hpp b/projects/rocblas/library/src/blas2/rocblas_gemv_batched_imp.hpp index 59add8295dd..d2a52bdfe5c 100644 --- a/projects/rocblas/library/src/blas2/rocblas_gemv_batched_imp.hpp +++ b/projects/rocblas/library/src/blas2/rocblas_gemv_batched_imp.hpp @@ -230,28 +230,29 @@ namespace } // we don't instantiate _template for mixed types so directly calling launcher - rocblas_status status = ROCBLAS_API(rocblas_internal_gemv_launcher)(handle, - transA, - m, - n, - alpha, - 0, - A, - 0, - lda, - 0, - x, - 0, - incx, - 0, - beta, - 0, - y, - 0, - incy, - 0, - batch_count, - (Tex*)w_mem); + rocblas_status status + = ROCBLAS_API(rocblas_internal_gemv_launcher)(handle, + transA, + m, + n, + alpha, + handle->get_stride_alpha(), + A, + 0, + lda, + 0, + x, + 0, + incx, + 0, + beta, + handle->get_stride_beta(), + y, + 0, + incy, + 0, + batch_count, + (Tex*)w_mem); status = (status != rocblas_status_success) ? status : perf_status; if(status != rocblas_status_success) diff --git a/projects/rocblas/library/src/blas2/rocblas_gemv_strided_batched_imp.hpp b/projects/rocblas/library/src/blas2/rocblas_gemv_strided_batched_imp.hpp index 71e0cbf0c37..5464ed9e649 100644 --- a/projects/rocblas/library/src/blas2/rocblas_gemv_strided_batched_imp.hpp +++ b/projects/rocblas/library/src/blas2/rocblas_gemv_strided_batched_imp.hpp @@ -259,28 +259,29 @@ namespace } // we don't instantiate _template for mixed types so directly calling launcher - rocblas_status status = ROCBLAS_API(rocblas_internal_gemv_launcher)(handle, - transA, - m, - n, - alpha, - 0, - A, - 0, - lda, - strideA, - x, - 0, - incx, - stridex, - beta, - 0, - y, - 0, - incy, - stridey, - batch_count, - (Tex*)w_mem); + rocblas_status status + = ROCBLAS_API(rocblas_internal_gemv_launcher)(handle, + transA, + m, + n, + alpha, + handle->get_stride_alpha(), + A, + 0, + lda, + strideA, + x, + 0, + incx, + stridex, + beta, + handle->get_stride_beta(), + y, + 0, + incy, + stridey, + batch_count, + (Tex*)w_mem); status = (status != rocblas_status_success) ? status : perf_status; if(status != rocblas_status_success) diff --git a/projects/rocblas/library/src/include/handle.hpp b/projects/rocblas/library/src/include/handle.hpp index 268640ef341..0fa2964ae06 100644 --- a/projects/rocblas/library/src/include/handle.hpp +++ b/projects/rocblas/library/src/include/handle.hpp @@ -412,6 +412,10 @@ struct _rocblas_handle // default atomics mode does not allows atomic operations rocblas_atomics_mode atomics_mode = rocblas_atomics_not_allowed; + // optional stride between successive alpha/beta values for advanced batched use; 0 is default + rocblas_stride stride_alpha = 0; + rocblas_stride stride_beta = 0; + // Selects the benchmark library to be used for solution selection rocblas_performance_metric performance_metric = rocblas_default_performance_metric; @@ -440,6 +444,26 @@ struct _rocblas_handle this->data_ptr = data_ptr; } + rocblas_stride get_stride_alpha() const + { + // only applicable to device mode alpha, load_scalar ignores for value + return stride_alpha; + } + void set_stride_alpha(rocblas_stride stride) + { + stride_alpha = stride; + } + + rocblas_stride get_stride_beta() const + { + // only applicable to device mode beta, load_scalar ignores for value + return stride_beta; + } + void set_stride_beta(rocblas_stride stride) + { + stride_beta = stride; + } + // C interfaces for manipulating device memory friend rocblas_status(::rocblas_start_device_memory_size_query)(_rocblas_handle*); friend rocblas_status(::rocblas_stop_device_memory_size_query)(_rocblas_handle*, size_t*); diff --git a/projects/rocblas/library/src/rocblas_auxiliary.cpp b/projects/rocblas/library/src/rocblas_auxiliary.cpp index 37a4bc5b097..295dfe67388 100644 --- a/projects/rocblas/library/src/rocblas_auxiliary.cpp +++ b/projects/rocblas/library/src/rocblas_auxiliary.cpp @@ -126,6 +126,86 @@ catch(...) return exception_to_rocblas_status(); } +/******************************************************************************* + * ! \brief get alpha stride + ******************************************************************************/ +extern "C" rocblas_status rocblas_get_batch_alpha_stride(rocblas_handle handle, + rocblas_stride* alpha_stride) +try +{ + if(!handle) + return rocblas_status_invalid_handle; + *alpha_stride = handle->get_stride_alpha(); + rocblas_internal_logger logger; + if(handle->layer_mode & rocblas_layer_mode_log_trace) + logger.log_trace(handle, "rocblas_get_batch_alpha_stride", *alpha_stride); + return rocblas_status_success; +} +catch(...) +{ + return exception_to_rocblas_status(); +} + +/******************************************************************************* + * ! \brief set alpha stride + ******************************************************************************/ +extern "C" rocblas_status rocblas_set_batch_alpha_stride(rocblas_handle handle, + rocblas_stride alpha_stride) +try +{ + if(!handle) + return rocblas_status_invalid_handle; + rocblas_internal_logger logger; + if(handle->layer_mode & rocblas_layer_mode_log_trace) + logger.log_trace(handle, "rocblas_set_batch_alpha_stride", alpha_stride); + handle->set_stride_alpha(alpha_stride); + return rocblas_status_success; +} +catch(...) +{ + return exception_to_rocblas_status(); +} + +/******************************************************************************* + * ! \brief get beta stride + ******************************************************************************/ +extern "C" rocblas_status rocblas_get_batch_beta_stride(rocblas_handle handle, + rocblas_stride* beta_stride) +try +{ + if(!handle) + return rocblas_status_invalid_handle; + *beta_stride = handle->get_stride_beta(); + rocblas_internal_logger logger; + if(handle->layer_mode & rocblas_layer_mode_log_trace) + logger.log_trace(handle, "rocblas_get_batch_beta_stride", *beta_stride); + return rocblas_status_success; +} +catch(...) +{ + return exception_to_rocblas_status(); +} + +/******************************************************************************* + * ! \brief set beta stride + ******************************************************************************/ +extern "C" rocblas_status rocblas_set_batch_beta_stride(rocblas_handle handle, + rocblas_stride beta_stride) +try +{ + if(!handle) + return rocblas_status_invalid_handle; + rocblas_internal_logger logger; + if(handle->layer_mode & rocblas_layer_mode_log_trace) + logger.log_trace(handle, "rocblas_set_batch_beta_stride", beta_stride); + handle->set_stride_beta(beta_stride); + return rocblas_status_success; +} +catch(...) +{ + return exception_to_rocblas_status(); +} + /******************************************************************************* * ! \brief get math mode ******************************************************************************/ diff --git a/projects/rocblas/next-cmake/clients/test/CMakeLists.txt b/projects/rocblas/next-cmake/clients/test/CMakeLists.txt index 69e8aae76ac..fc10796f259 100644 --- a/projects/rocblas/next-cmake/clients/test/CMakeLists.txt +++ b/projects/rocblas/next-cmake/clients/test/CMakeLists.txt @@ -71,6 +71,7 @@ set(yaml_test_files "${_CMAKE_CURRENT_SOURCE_DIR}/sbmv_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/scal_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_atomics_mode_gtest.yaml" + "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_alpha_beta_stride_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_matrix_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_pointer_mode_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_vector_gtest.yaml" diff --git a/projects/rocblas/next-cmake/clients/test/src/CMakeLists.txt b/projects/rocblas/next-cmake/clients/test/src/CMakeLists.txt index a2a6408ecb5..777a11cf884 100644 --- a/projects/rocblas/next-cmake/clients/test/src/CMakeLists.txt +++ b/projects/rocblas/next-cmake/clients/test/src/CMakeLists.txt @@ -19,6 +19,7 @@ target_sources( "${_CMAKE_CURRENT_SOURCE_DIR}/../general_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../set_get_pointer_mode_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../set_get_atomics_mode_gtest.cpp" + "${_CMAKE_CURRENT_SOURCE_DIR}/../set_get_alpha_beta_stride_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../logging_mode_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../ostream_threadsafety_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../set_get_vector_gtest.cpp"