From dace4ab3bf598ae2f74b0a9a43df17854ac0692d Mon Sep 17 00:00:00 2001 From: Torre Zuk Date: Fri, 15 May 2026 08:49:43 -0600 Subject: [PATCH 1/3] proposal concepts --- projects/rocblas/clients/gtest/CMakeLists.txt | 1 + projects/rocblas/clients/gtest/aux_gtest.yaml | 1 + .../gtest/set_get_alpha_beta_inc_gtest.cpp | 95 +++++++++++++++++++ .../gtest/set_get_alpha_beta_inc_gtest.yaml | 10 ++ .../docs/reference/helper-functions.rst | 4 + .../include/internal/rocblas-auxiliary.h | 17 ++++ .../rocblas/library/src/include/handle.hpp | 4 + .../rocblas/library/src/rocblas_auxiliary.cpp | 76 +++++++++++++++ .../next-cmake/clients/test/CMakeLists.txt | 1 + .../clients/test/src/CMakeLists.txt | 1 + 10 files changed, 210 insertions(+) create mode 100644 projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.cpp create mode 100644 projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.yaml diff --git a/projects/rocblas/clients/gtest/CMakeLists.txt b/projects/rocblas/clients/gtest/CMakeLists.txt index 6195ecff4b2..543eeb1d284 100644 --- a/projects/rocblas/clients/gtest/CMakeLists.txt +++ b/projects/rocblas/clients/gtest/CMakeLists.txt @@ -44,6 +44,7 @@ set(rocblas_no_tensile_test_source general_gtest.cpp set_get_pointer_mode_gtest.cpp set_get_atomics_mode_gtest.cpp + set_get_alpha_beta_inc_gtest.cpp logging_mode_gtest.cpp ostream_threadsafety_gtest.cpp set_get_vector_gtest.cpp diff --git a/projects/rocblas/clients/gtest/aux_gtest.yaml b/projects/rocblas/clients/gtest/aux_gtest.yaml index b9dd3fb5566..92bc1fbf040 100644 --- a/projects/rocblas/clients/gtest/aux_gtest.yaml +++ b/projects/rocblas/clients/gtest/aux_gtest.yaml @@ -4,6 +4,7 @@ include: set_get_vector_gtest.yaml include: logging_mode_gtest.yaml include: set_get_pointer_mode_gtest.yaml include: set_get_atomics_mode_gtest.yaml +include: set_get_alpha_beta_inc_gtest.yaml include: ostream_threadsafety_gtest.yaml include: multiheaded_gtest.yaml include: atomics_mode_gtest.yaml diff --git a/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.cpp b/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.cpp new file mode 100644 index 00000000000..4cb88362750 --- /dev/null +++ b/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.cpp @@ -0,0 +1,95 @@ +/* ************************************************************************ + * Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ + +#include "client_utility.hpp" +#include "rocblas.hpp" +#include "rocblas_data.hpp" +#include "rocblas_datatype2string.hpp" +#include "rocblas_test.hpp" +#include + +namespace +{ + template + struct testing_set_get_alpha_beta_inc : rocblas_test_valid + { + void operator()(const Arguments&) + { + rocblas_handle handle; + CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); + + rocblas_int alpha_inc = -1; + rocblas_int beta_inc = -1; + CHECK_ROCBLAS_ERROR(rocblas_get_alpha_inc(handle, &alpha_inc)); + CHECK_ROCBLAS_ERROR(rocblas_get_beta_inc(handle, &beta_inc)); + EXPECT_EQ(0, alpha_inc); + EXPECT_EQ(0, beta_inc); + + CHECK_ROCBLAS_ERROR(rocblas_set_alpha_inc(handle, 7)); + CHECK_ROCBLAS_ERROR(rocblas_get_alpha_inc(handle, &alpha_inc)); + EXPECT_EQ(7, alpha_inc); + CHECK_ROCBLAS_ERROR(rocblas_get_beta_inc(handle, &beta_inc)); + EXPECT_EQ(0, beta_inc); + + CHECK_ROCBLAS_ERROR(rocblas_set_beta_inc(handle, 11)); + CHECK_ROCBLAS_ERROR(rocblas_get_beta_inc(handle, &beta_inc)); + EXPECT_EQ(11, beta_inc); + CHECK_ROCBLAS_ERROR(rocblas_get_alpha_inc(handle, &alpha_inc)); + EXPECT_EQ(7, alpha_inc); + + CHECK_ROCBLAS_ERROR(rocblas_set_alpha_inc(handle, 0)); + CHECK_ROCBLAS_ERROR(rocblas_set_beta_inc(handle, 0)); + CHECK_ROCBLAS_ERROR(rocblas_get_alpha_inc(handle, &alpha_inc)); + CHECK_ROCBLAS_ERROR(rocblas_get_beta_inc(handle, &beta_inc)); + EXPECT_EQ(0, alpha_inc); + EXPECT_EQ(0, beta_inc); + + CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); + } + }; + + struct set_get_alpha_beta_inc + : RocBLAS_Test + { + static bool type_filter(const Arguments&) + { + return true; + } + + static bool function_filter(const Arguments& arg) + { + return !strcmp(arg.function, "set_get_alpha_beta_inc"); + } + + static std::string name_suffix(const Arguments& arg) + { + return RocBLAS_TestName(arg.name); + } + }; + + TEST_P(set_get_alpha_beta_inc, auxiliary_tensile) + { + CATCH_SIGNALS_AND_EXCEPTIONS_AS_FAILURES(testing_set_get_alpha_beta_inc<>{}(GetParam())); + } + INSTANTIATE_TEST_CATEGORIES(set_get_alpha_beta_inc) + +} // namespace diff --git a/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.yaml b/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.yaml new file mode 100644 index 00000000000..ffad2e4e1dd --- /dev/null +++ b/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.yaml @@ -0,0 +1,10 @@ +--- +include: rocblas_common.yaml +include: known_bugs.yaml + +Tests: +- name: set_get_alpha_beta_inc + category: quick + function: set_get_alpha_beta_inc + precision: *single_precision +... diff --git a/projects/rocblas/docs/reference/helper-functions.rst b/projects/rocblas/docs/reference/helper-functions.rst index 5af8f6e9313..f5261a772f7 100644 --- a/projects/rocblas/docs/reference/helper-functions.rst +++ b/projects/rocblas/docs/reference/helper-functions.rst @@ -19,6 +19,10 @@ Auxiliary functions .. doxygenfunction:: rocblas_get_pointer_mode .. doxygenfunction:: rocblas_set_atomics_mode .. doxygenfunction:: rocblas_get_atomics_mode +.. doxygenfunction:: rocblas_set_alpha_inc +.. doxygenfunction:: rocblas_get_alpha_inc +.. doxygenfunction:: rocblas_set_beta_inc +.. doxygenfunction:: rocblas_get_beta_inc .. doxygenfunction:: rocblas_pointer_to_mode .. doxygenfunction:: rocblas_initialize .. doxygenfunction:: rocblas_status_to_string diff --git a/projects/rocblas/library/include/internal/rocblas-auxiliary.h b/projects/rocblas/library/include/internal/rocblas-auxiliary.h index 81a50ce4be9..bd3215a2a42 100644 --- a/projects/rocblas/library/include/internal/rocblas-auxiliary.h +++ b/projects/rocblas/library/include/internal/rocblas-auxiliary.h @@ -75,6 +75,23 @@ ROCBLAS_EXPORT rocblas_status rocblas_set_atomics_mode(rocblas_handle hand ROCBLAS_EXPORT rocblas_status rocblas_get_atomics_mode(rocblas_handle handle, rocblas_atomics_mode* atomics_mode); +/*! \brief Set alpha increment stored on the handle (element count between successive alpha values). Default is 0. +Used by limited set of L2 and L1 batched functions to specify the spacing between successive alpha values. + */ +ROCBLAS_EXPORT rocblas_status rocblas_set_alpha_inc(rocblas_handle handle, rocblas_int alpha_inc); + +/*! \brief Get alpha increment from the handle. + */ +ROCBLAS_EXPORT rocblas_status rocblas_get_alpha_inc(rocblas_handle handle, rocblas_int* alpha_inc); + +/*! \brief Set beta increment stored on the handle (element count between successive beta values). Default is 0. + */ +ROCBLAS_EXPORT rocblas_status rocblas_set_beta_inc(rocblas_handle handle, rocblas_int beta_inc); + +/*! \brief Get beta increment from the handle. + */ +ROCBLAS_EXPORT rocblas_status rocblas_get_beta_inc(rocblas_handle handle, rocblas_int* beta_inc); + /*! \brief Set ``rocblas_math_mode``. */ ROCBLAS_EXPORT rocblas_status rocblas_set_math_mode(rocblas_handle handle, diff --git a/projects/rocblas/library/src/include/handle.hpp b/projects/rocblas/library/src/include/handle.hpp index 268640ef341..3b2785cc6dd 100644 --- a/projects/rocblas/library/src/include/handle.hpp +++ b/projects/rocblas/library/src/include/handle.hpp @@ -412,6 +412,10 @@ struct _rocblas_handle // default atomics mode does not allows atomic operations rocblas_atomics_mode atomics_mode = rocblas_atomics_not_allowed; + // optional increment (in elements) between successive alpha/beta values for advanced batched use; 0 is default + rocblas_int alpha_inc = 0; + rocblas_int beta_inc = 0; + // Selects the benchmark library to be used for solution selection rocblas_performance_metric performance_metric = rocblas_default_performance_metric; diff --git a/projects/rocblas/library/src/rocblas_auxiliary.cpp b/projects/rocblas/library/src/rocblas_auxiliary.cpp index 37a4bc5b097..e51cea72a15 100644 --- a/projects/rocblas/library/src/rocblas_auxiliary.cpp +++ b/projects/rocblas/library/src/rocblas_auxiliary.cpp @@ -126,6 +126,82 @@ catch(...) return exception_to_rocblas_status(); } +/******************************************************************************* + * ! \brief get alpha increment + ******************************************************************************/ +extern "C" rocblas_status rocblas_get_alpha_inc(rocblas_handle handle, rocblas_int* alpha_inc) +try +{ + if(!handle) + return rocblas_status_invalid_handle; + *alpha_inc = handle->alpha_inc; + rocblas_internal_logger logger; + if(handle->layer_mode & rocblas_layer_mode_log_trace) + logger.log_trace(handle, "rocblas_get_alpha_inc", *alpha_inc); + return rocblas_status_success; +} +catch(...) +{ + return exception_to_rocblas_status(); +} + +/******************************************************************************* + * ! \brief set alpha increment + ******************************************************************************/ +extern "C" rocblas_status rocblas_set_alpha_inc(rocblas_handle handle, rocblas_int alpha_inc) +try +{ + if(!handle) + return rocblas_status_invalid_handle; + rocblas_internal_logger logger; + if(handle->layer_mode & rocblas_layer_mode_log_trace) + logger.log_trace(handle, "rocblas_set_alpha_inc", alpha_inc); + handle->alpha_inc = alpha_inc; + return rocblas_status_success; +} +catch(...) +{ + return exception_to_rocblas_status(); +} + +/******************************************************************************* + * ! \brief get beta increment + ******************************************************************************/ +extern "C" rocblas_status rocblas_get_beta_inc(rocblas_handle handle, rocblas_int* beta_inc) +try +{ + if(!handle) + return rocblas_status_invalid_handle; + *beta_inc = handle->beta_inc; + rocblas_internal_logger logger; + if(handle->layer_mode & rocblas_layer_mode_log_trace) + logger.log_trace(handle, "rocblas_get_beta_inc", *beta_inc); + return rocblas_status_success; +} +catch(...) +{ + return exception_to_rocblas_status(); +} + +/******************************************************************************* + * ! \brief set beta increment + ******************************************************************************/ +extern "C" rocblas_status rocblas_set_beta_inc(rocblas_handle handle, rocblas_int beta_inc) +try +{ + if(!handle) + return rocblas_status_invalid_handle; + rocblas_internal_logger logger; + if(handle->layer_mode & rocblas_layer_mode_log_trace) + logger.log_trace(handle, "rocblas_set_beta_inc", beta_inc); + handle->beta_inc = beta_inc; + return rocblas_status_success; +} +catch(...) +{ + return exception_to_rocblas_status(); +} + /******************************************************************************* * ! \brief get math mode ******************************************************************************/ diff --git a/projects/rocblas/next-cmake/clients/test/CMakeLists.txt b/projects/rocblas/next-cmake/clients/test/CMakeLists.txt index 69e8aae76ac..fc10796f259 100644 --- a/projects/rocblas/next-cmake/clients/test/CMakeLists.txt +++ b/projects/rocblas/next-cmake/clients/test/CMakeLists.txt @@ -71,6 +71,7 @@ set(yaml_test_files "${_CMAKE_CURRENT_SOURCE_DIR}/sbmv_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/scal_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_atomics_mode_gtest.yaml" + "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_alpha_beta_stride_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_matrix_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_pointer_mode_gtest.yaml" "${_CMAKE_CURRENT_SOURCE_DIR}/set_get_vector_gtest.yaml" diff --git a/projects/rocblas/next-cmake/clients/test/src/CMakeLists.txt b/projects/rocblas/next-cmake/clients/test/src/CMakeLists.txt index a2a6408ecb5..777a11cf884 100644 --- a/projects/rocblas/next-cmake/clients/test/src/CMakeLists.txt +++ b/projects/rocblas/next-cmake/clients/test/src/CMakeLists.txt @@ -19,6 +19,7 @@ target_sources( "${_CMAKE_CURRENT_SOURCE_DIR}/../general_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../set_get_pointer_mode_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../set_get_atomics_mode_gtest.cpp" + "${_CMAKE_CURRENT_SOURCE_DIR}/../set_get_alpha_beta_stride_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../logging_mode_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../ostream_threadsafety_gtest.cpp" "${_CMAKE_CURRENT_SOURCE_DIR}/../set_get_vector_gtest.cpp" From 1d681ade99282b8b7f702c02009fdcd076512c3a Mon Sep 17 00:00:00 2001 From: Torre Zuk Date: Tue, 19 May 2026 16:17:47 -0600 Subject: [PATCH 2/3] stride form of the same changes --- projects/rocblas/clients/gtest/CMakeLists.txt | 2 +- projects/rocblas/clients/gtest/aux_gtest.yaml | 2 +- ...pp => set_get_alpha_beta_stride_gtest.cpp} | 60 +++++++++---------- ...l => set_get_alpha_beta_stride_gtest.yaml} | 4 +- .../docs/reference/helper-functions.rst | 8 +-- .../include/internal/rocblas-auxiliary.h | 20 ++++--- .../rocblas/library/src/include/handle.hpp | 6 +- .../rocblas/library/src/rocblas_auxiliary.cpp | 32 +++++----- 8 files changed, 69 insertions(+), 65 deletions(-) rename projects/rocblas/clients/gtest/{set_get_alpha_beta_inc_gtest.cpp => set_get_alpha_beta_stride_gtest.cpp} (53%) rename projects/rocblas/clients/gtest/{set_get_alpha_beta_inc_gtest.yaml => set_get_alpha_beta_stride_gtest.yaml} (62%) diff --git a/projects/rocblas/clients/gtest/CMakeLists.txt b/projects/rocblas/clients/gtest/CMakeLists.txt index 543eeb1d284..290b8138e9a 100644 --- a/projects/rocblas/clients/gtest/CMakeLists.txt +++ b/projects/rocblas/clients/gtest/CMakeLists.txt @@ -44,7 +44,7 @@ set(rocblas_no_tensile_test_source general_gtest.cpp set_get_pointer_mode_gtest.cpp set_get_atomics_mode_gtest.cpp - set_get_alpha_beta_inc_gtest.cpp + set_get_alpha_beta_stride_gtest.cpp logging_mode_gtest.cpp ostream_threadsafety_gtest.cpp set_get_vector_gtest.cpp diff --git a/projects/rocblas/clients/gtest/aux_gtest.yaml b/projects/rocblas/clients/gtest/aux_gtest.yaml index 92bc1fbf040..d2354555d17 100644 --- a/projects/rocblas/clients/gtest/aux_gtest.yaml +++ b/projects/rocblas/clients/gtest/aux_gtest.yaml @@ -4,7 +4,7 @@ include: set_get_vector_gtest.yaml include: logging_mode_gtest.yaml include: set_get_pointer_mode_gtest.yaml include: set_get_atomics_mode_gtest.yaml -include: set_get_alpha_beta_inc_gtest.yaml +include: set_get_alpha_beta_stride_gtest.yaml include: ostream_threadsafety_gtest.yaml include: multiheaded_gtest.yaml include: atomics_mode_gtest.yaml diff --git a/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.cpp b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp similarity index 53% rename from projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.cpp rename to projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp index 4cb88362750..195f4998b99 100644 --- a/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.cpp +++ b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp @@ -30,45 +30,45 @@ namespace { template - struct testing_set_get_alpha_beta_inc : rocblas_test_valid + struct testing_set_get_alpha_beta_stride : rocblas_test_valid { void operator()(const Arguments&) { rocblas_handle handle; CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); - rocblas_int alpha_inc = -1; - rocblas_int beta_inc = -1; - CHECK_ROCBLAS_ERROR(rocblas_get_alpha_inc(handle, &alpha_inc)); - CHECK_ROCBLAS_ERROR(rocblas_get_beta_inc(handle, &beta_inc)); - EXPECT_EQ(0, alpha_inc); - EXPECT_EQ(0, beta_inc); + rocblas_stride stride_alpha = -1; + rocblas_stride stride_beta = -1; + CHECK_ROCBLAS_ERROR(rocblas_get_stride_alpha(handle, &stride_alpha)); + CHECK_ROCBLAS_ERROR(rocblas_get_stride_beta(handle, &stride_beta)); + EXPECT_EQ(0, stride_alpha); + EXPECT_EQ(0, stride_beta); - CHECK_ROCBLAS_ERROR(rocblas_set_alpha_inc(handle, 7)); - CHECK_ROCBLAS_ERROR(rocblas_get_alpha_inc(handle, &alpha_inc)); - EXPECT_EQ(7, alpha_inc); - CHECK_ROCBLAS_ERROR(rocblas_get_beta_inc(handle, &beta_inc)); - EXPECT_EQ(0, beta_inc); + CHECK_ROCBLAS_ERROR(rocblas_set_stride_alpha(handle, 7)); + CHECK_ROCBLAS_ERROR(rocblas_get_stride_alpha(handle, &stride_alpha)); + EXPECT_EQ(7, stride_alpha); + CHECK_ROCBLAS_ERROR(rocblas_get_stride_beta(handle, &stride_beta)); + EXPECT_EQ(0, stride_beta); - CHECK_ROCBLAS_ERROR(rocblas_set_beta_inc(handle, 11)); - CHECK_ROCBLAS_ERROR(rocblas_get_beta_inc(handle, &beta_inc)); - EXPECT_EQ(11, beta_inc); - CHECK_ROCBLAS_ERROR(rocblas_get_alpha_inc(handle, &alpha_inc)); - EXPECT_EQ(7, alpha_inc); + CHECK_ROCBLAS_ERROR(rocblas_set_stride_beta(handle, 11)); + CHECK_ROCBLAS_ERROR(rocblas_get_stride_beta(handle, &stride_beta)); + EXPECT_EQ(11, stride_beta); + CHECK_ROCBLAS_ERROR(rocblas_get_stride_alpha(handle, &stride_alpha)); + EXPECT_EQ(7, stride_alpha); - CHECK_ROCBLAS_ERROR(rocblas_set_alpha_inc(handle, 0)); - CHECK_ROCBLAS_ERROR(rocblas_set_beta_inc(handle, 0)); - CHECK_ROCBLAS_ERROR(rocblas_get_alpha_inc(handle, &alpha_inc)); - CHECK_ROCBLAS_ERROR(rocblas_get_beta_inc(handle, &beta_inc)); - EXPECT_EQ(0, alpha_inc); - EXPECT_EQ(0, beta_inc); + CHECK_ROCBLAS_ERROR(rocblas_set_stride_alpha(handle, 0)); + CHECK_ROCBLAS_ERROR(rocblas_set_stride_beta(handle, 0)); + CHECK_ROCBLAS_ERROR(rocblas_get_stride_alpha(handle, &stride_alpha)); + CHECK_ROCBLAS_ERROR(rocblas_get_stride_beta(handle, &stride_beta)); + EXPECT_EQ(0, stride_alpha); + EXPECT_EQ(0, stride_beta); CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); } }; - struct set_get_alpha_beta_inc - : RocBLAS_Test + struct set_get_alpha_beta_stride + : RocBLAS_Test { static bool type_filter(const Arguments&) { @@ -77,19 +77,19 @@ namespace static bool function_filter(const Arguments& arg) { - return !strcmp(arg.function, "set_get_alpha_beta_inc"); + return !strcmp(arg.function, "set_get_alpha_beta_stride"); } static std::string name_suffix(const Arguments& arg) { - return RocBLAS_TestName(arg.name); + return RocBLAS_TestName(arg.name); } }; - TEST_P(set_get_alpha_beta_inc, auxiliary_tensile) + TEST_P(set_get_alpha_beta_stride, auxiliary_tensile) { - CATCH_SIGNALS_AND_EXCEPTIONS_AS_FAILURES(testing_set_get_alpha_beta_inc<>{}(GetParam())); + CATCH_SIGNALS_AND_EXCEPTIONS_AS_FAILURES(testing_set_get_alpha_beta_stride<>{}(GetParam())); } - INSTANTIATE_TEST_CATEGORIES(set_get_alpha_beta_inc) + INSTANTIATE_TEST_CATEGORIES(set_get_alpha_beta_stride) } // namespace diff --git a/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.yaml b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.yaml similarity index 62% rename from projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.yaml rename to projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.yaml index ffad2e4e1dd..f4af643f25a 100644 --- a/projects/rocblas/clients/gtest/set_get_alpha_beta_inc_gtest.yaml +++ b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.yaml @@ -3,8 +3,8 @@ include: rocblas_common.yaml include: known_bugs.yaml Tests: -- name: set_get_alpha_beta_inc +- name: set_get_alpha_beta_stride category: quick - function: set_get_alpha_beta_inc + function: set_get_alpha_beta_stride precision: *single_precision ... diff --git a/projects/rocblas/docs/reference/helper-functions.rst b/projects/rocblas/docs/reference/helper-functions.rst index f5261a772f7..8bf63d39ffa 100644 --- a/projects/rocblas/docs/reference/helper-functions.rst +++ b/projects/rocblas/docs/reference/helper-functions.rst @@ -19,10 +19,10 @@ Auxiliary functions .. doxygenfunction:: rocblas_get_pointer_mode .. doxygenfunction:: rocblas_set_atomics_mode .. doxygenfunction:: rocblas_get_atomics_mode -.. doxygenfunction:: rocblas_set_alpha_inc -.. doxygenfunction:: rocblas_get_alpha_inc -.. doxygenfunction:: rocblas_set_beta_inc -.. doxygenfunction:: rocblas_get_beta_inc +.. doxygenfunction:: rocblas_set_stride_alpha +.. doxygenfunction:: rocblas_get_stride_alpha +.. doxygenfunction:: rocblas_set_stride_beta +.. doxygenfunction:: rocblas_get_stride_beta .. doxygenfunction:: rocblas_pointer_to_mode .. doxygenfunction:: rocblas_initialize .. doxygenfunction:: rocblas_status_to_string diff --git a/projects/rocblas/library/include/internal/rocblas-auxiliary.h b/projects/rocblas/library/include/internal/rocblas-auxiliary.h index bd3215a2a42..55207c95ad3 100644 --- a/projects/rocblas/library/include/internal/rocblas-auxiliary.h +++ b/projects/rocblas/library/include/internal/rocblas-auxiliary.h @@ -75,22 +75,26 @@ ROCBLAS_EXPORT rocblas_status rocblas_set_atomics_mode(rocblas_handle hand ROCBLAS_EXPORT rocblas_status rocblas_get_atomics_mode(rocblas_handle handle, rocblas_atomics_mode* atomics_mode); -/*! \brief Set alpha increment stored on the handle (element count between successive alpha values). Default is 0. +/*! \brief Set alpha stride stored on the handle (spacing between successive alpha values). Default is 0. Used by limited set of L2 and L1 batched functions to specify the spacing between successive alpha values. */ -ROCBLAS_EXPORT rocblas_status rocblas_set_alpha_inc(rocblas_handle handle, rocblas_int alpha_inc); +ROCBLAS_EXPORT rocblas_status rocblas_set_stride_alpha(rocblas_handle handle, + rocblas_stride stride_alpha); -/*! \brief Get alpha increment from the handle. +/*! \brief Get alpha stride from the handle. */ -ROCBLAS_EXPORT rocblas_status rocblas_get_alpha_inc(rocblas_handle handle, rocblas_int* alpha_inc); +ROCBLAS_EXPORT rocblas_status rocblas_get_stride_alpha(rocblas_handle handle, + rocblas_stride* stride_alpha); -/*! \brief Set beta increment stored on the handle (element count between successive beta values). Default is 0. +/*! \brief Set beta stride stored on the handle (spacing between successive beta values). Default is 0. */ -ROCBLAS_EXPORT rocblas_status rocblas_set_beta_inc(rocblas_handle handle, rocblas_int beta_inc); +ROCBLAS_EXPORT rocblas_status rocblas_set_stride_beta(rocblas_handle handle, + rocblas_stride stride_beta); -/*! \brief Get beta increment from the handle. +/*! \brief Get beta stride from the handle. */ -ROCBLAS_EXPORT rocblas_status rocblas_get_beta_inc(rocblas_handle handle, rocblas_int* beta_inc); +ROCBLAS_EXPORT rocblas_status rocblas_get_stride_beta(rocblas_handle handle, + rocblas_stride* stride_beta); /*! \brief Set ``rocblas_math_mode``. */ diff --git a/projects/rocblas/library/src/include/handle.hpp b/projects/rocblas/library/src/include/handle.hpp index 3b2785cc6dd..1e8b211bdd9 100644 --- a/projects/rocblas/library/src/include/handle.hpp +++ b/projects/rocblas/library/src/include/handle.hpp @@ -412,9 +412,9 @@ struct _rocblas_handle // default atomics mode does not allows atomic operations rocblas_atomics_mode atomics_mode = rocblas_atomics_not_allowed; - // optional increment (in elements) between successive alpha/beta values for advanced batched use; 0 is default - rocblas_int alpha_inc = 0; - rocblas_int beta_inc = 0; + // optional stride between successive alpha/beta values for advanced batched use; 0 is default + rocblas_stride stride_alpha = 0; + rocblas_stride stride_beta = 0; // Selects the benchmark library to be used for solution selection rocblas_performance_metric performance_metric = rocblas_default_performance_metric; diff --git a/projects/rocblas/library/src/rocblas_auxiliary.cpp b/projects/rocblas/library/src/rocblas_auxiliary.cpp index e51cea72a15..88c903a2a34 100644 --- a/projects/rocblas/library/src/rocblas_auxiliary.cpp +++ b/projects/rocblas/library/src/rocblas_auxiliary.cpp @@ -127,17 +127,17 @@ catch(...) } /******************************************************************************* - * ! \brief get alpha increment + * ! \brief get alpha stride ******************************************************************************/ -extern "C" rocblas_status rocblas_get_alpha_inc(rocblas_handle handle, rocblas_int* alpha_inc) +extern "C" rocblas_status rocblas_get_stride_alpha(rocblas_handle handle, rocblas_stride* stride_alpha) try { if(!handle) return rocblas_status_invalid_handle; - *alpha_inc = handle->alpha_inc; + *stride_alpha = handle->stride_alpha; rocblas_internal_logger logger; if(handle->layer_mode & rocblas_layer_mode_log_trace) - logger.log_trace(handle, "rocblas_get_alpha_inc", *alpha_inc); + logger.log_trace(handle, "rocblas_get_stride_alpha", *stride_alpha); return rocblas_status_success; } catch(...) @@ -146,17 +146,17 @@ catch(...) } /******************************************************************************* - * ! \brief set alpha increment + * ! \brief set alpha stride ******************************************************************************/ -extern "C" rocblas_status rocblas_set_alpha_inc(rocblas_handle handle, rocblas_int alpha_inc) +extern "C" rocblas_status rocblas_set_stride_alpha(rocblas_handle handle, rocblas_stride stride_alpha) try { if(!handle) return rocblas_status_invalid_handle; rocblas_internal_logger logger; if(handle->layer_mode & rocblas_layer_mode_log_trace) - logger.log_trace(handle, "rocblas_set_alpha_inc", alpha_inc); - handle->alpha_inc = alpha_inc; + logger.log_trace(handle, "rocblas_set_stride_alpha", stride_alpha); + handle->stride_alpha = stride_alpha; return rocblas_status_success; } catch(...) @@ -165,17 +165,17 @@ catch(...) } /******************************************************************************* - * ! \brief get beta increment + * ! \brief get beta stride ******************************************************************************/ -extern "C" rocblas_status rocblas_get_beta_inc(rocblas_handle handle, rocblas_int* beta_inc) +extern "C" rocblas_status rocblas_get_stride_beta(rocblas_handle handle, rocblas_stride* stride_beta) try { if(!handle) return rocblas_status_invalid_handle; - *beta_inc = handle->beta_inc; + *stride_beta = handle->stride_beta; rocblas_internal_logger logger; if(handle->layer_mode & rocblas_layer_mode_log_trace) - logger.log_trace(handle, "rocblas_get_beta_inc", *beta_inc); + logger.log_trace(handle, "rocblas_get_stride_beta", *stride_beta); return rocblas_status_success; } catch(...) @@ -184,17 +184,17 @@ catch(...) } /******************************************************************************* - * ! \brief set beta increment + * ! \brief set beta stride ******************************************************************************/ -extern "C" rocblas_status rocblas_set_beta_inc(rocblas_handle handle, rocblas_int beta_inc) +extern "C" rocblas_status rocblas_set_stride_beta(rocblas_handle handle, rocblas_stride stride_beta) try { if(!handle) return rocblas_status_invalid_handle; rocblas_internal_logger logger; if(handle->layer_mode & rocblas_layer_mode_log_trace) - logger.log_trace(handle, "rocblas_set_beta_inc", beta_inc); - handle->beta_inc = beta_inc; + logger.log_trace(handle, "rocblas_set_stride_beta", stride_beta); + handle->stride_beta = stride_beta; return rocblas_status_success; } catch(...) From 4aa0822bc25643ba48c5ef95f0e7c119ff585d54 Mon Sep 17 00:00:00 2001 From: Torre Zuk Date: Fri, 29 May 2026 15:45:52 -0600 Subject: [PATCH 3/3] flesh out new batch alpha beta stride API for gemv --- .../clients/common/rocblas_arguments.cpp | 1 + .../clients/gtest/blas2/gemv_gtest.cpp | 8 +++ .../rocblas/clients/gtest/gemv_gtest.yaml | 16 ++++++ .../gtest/set_get_alpha_beta_stride_gtest.cpp | 52 +++++++++++-------- .../include/blas2/testing_gemv_batched.hpp | 34 ++++++++---- .../blas2/testing_gemv_strided_batched.hpp | 34 ++++++++---- .../clients/include/client_utility.hpp | 11 ++++ .../clients/include/rocblas_arguments.hpp | 2 + .../clients/include/rocblas_common.yaml | 6 +++ .../rocblas/clients/include/rocblas_init.hpp | 12 +++++ .../docs/reference/helper-functions.rst | 8 +-- projects/rocblas/docs/reference/level-2.rst | 4 ++ .../include/internal/rocblas-auxiliary.h | 35 ++++++++----- .../src/blas2/rocblas_gemv_batched_imp.hpp | 45 ++++++++-------- .../rocblas_gemv_strided_batched_imp.hpp | 45 ++++++++-------- .../rocblas/library/src/include/handle.hpp | 20 +++++++ .../rocblas/library/src/rocblas_auxiliary.cpp | 28 +++++----- 17 files changed, 245 insertions(+), 116 deletions(-) diff --git a/projects/rocblas/clients/common/rocblas_arguments.cpp b/projects/rocblas/clients/common/rocblas_arguments.cpp index f4828a11263..c1d53fbd251 100644 --- a/projects/rocblas/clients/common/rocblas_arguments.cpp +++ b/projects/rocblas/clients/common/rocblas_arguments.cpp @@ -152,6 +152,7 @@ void Arguments::init() HMM = false; graph_test = false; repeatability_check = false; + alpha_beta_stride = false; use_hipblaslt = -1; diff --git a/projects/rocblas/clients/gtest/blas2/gemv_gtest.cpp b/projects/rocblas/clients/gtest/blas2/gemv_gtest.cpp index e825b186af9..06b3b7afef7 100644 --- a/projects/rocblas/clients/gtest/blas2/gemv_gtest.cpp +++ b/projects/rocblas/clients/gtest/blas2/gemv_gtest.cpp @@ -168,8 +168,16 @@ namespace name << '_' << arg.stride_y; if(GEMV_TYPE != GEMV) + { name << '_' << arg.batch_count; + if(arg.alpha_beta_stride) + { + name << '_' << arg.stride_c; + name << '_' << arg.stride_d; + } + } + if(arg.api & c_API_64) { name << "_I64"; diff --git a/projects/rocblas/clients/gtest/gemv_gtest.yaml b/projects/rocblas/clients/gtest/gemv_gtest.yaml index be4506efc93..18c4cd2b9ed 100644 --- a/projects/rocblas/clients/gtest/gemv_gtest.yaml +++ b/projects/rocblas/clients/gtest/gemv_gtest.yaml @@ -651,6 +651,22 @@ Tests: batch_count: [ 2 ] gpu_arch: '9??' +- name: gemv_stride_alpha_beta + category: pre_checkin + function: + - gemv_batched + - gemv_strided_batched + precision: *gemv_bfloat_half_single_double_complex_real_precisions + transA: [ N, T, C ] + matrix_size: *qmcpack_matrix_size_range + incx_incy: *incx_incy_range + alpha_beta: *alpha_beta_range + batch_count: [ 3 ] + alpha_beta_stride: true + stride_c: 2 # alpha beta stride + stride_d: 2 + pointer_mode_host: false + - name: gemv_row_vectorized_coverage category: pre_checkin function: diff --git a/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp index 195f4998b99..ebb8c054f8f 100644 --- a/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp +++ b/projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp @@ -37,31 +37,39 @@ namespace rocblas_handle handle; CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); - rocblas_stride stride_alpha = -1; - rocblas_stride stride_beta = -1; - CHECK_ROCBLAS_ERROR(rocblas_get_stride_alpha(handle, &stride_alpha)); - CHECK_ROCBLAS_ERROR(rocblas_get_stride_beta(handle, &stride_beta)); - EXPECT_EQ(0, stride_alpha); - EXPECT_EQ(0, stride_beta); + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device)); - CHECK_ROCBLAS_ERROR(rocblas_set_stride_alpha(handle, 7)); - CHECK_ROCBLAS_ERROR(rocblas_get_stride_alpha(handle, &stride_alpha)); - EXPECT_EQ(7, stride_alpha); - CHECK_ROCBLAS_ERROR(rocblas_get_stride_beta(handle, &stride_beta)); - EXPECT_EQ(0, stride_beta); + rocblas_stride batch_alpha_stride = -1; + rocblas_stride batch_beta_stride = -1; + CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride)); + EXPECT_EQ(0, batch_alpha_stride); + EXPECT_EQ(0, batch_beta_stride); - CHECK_ROCBLAS_ERROR(rocblas_set_stride_beta(handle, 11)); - CHECK_ROCBLAS_ERROR(rocblas_get_stride_beta(handle, &stride_beta)); - EXPECT_EQ(11, stride_beta); - CHECK_ROCBLAS_ERROR(rocblas_get_stride_alpha(handle, &stride_alpha)); - EXPECT_EQ(7, stride_alpha); + CHECK_ROCBLAS_ERROR(rocblas_set_batch_alpha_stride(handle, 7)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride)); + EXPECT_EQ(7, batch_alpha_stride); + CHECK_ROCBLAS_ERROR(rocblas_set_batch_beta_stride(handle, 7)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride)); + EXPECT_EQ(7, batch_beta_stride); - CHECK_ROCBLAS_ERROR(rocblas_set_stride_alpha(handle, 0)); - CHECK_ROCBLAS_ERROR(rocblas_set_stride_beta(handle, 0)); - CHECK_ROCBLAS_ERROR(rocblas_get_stride_alpha(handle, &stride_alpha)); - CHECK_ROCBLAS_ERROR(rocblas_get_stride_beta(handle, &stride_beta)); - EXPECT_EQ(0, stride_alpha); - EXPECT_EQ(0, stride_beta); + CHECK_ROCBLAS_ERROR(rocblas_set_batch_alpha_stride(handle, 0)); + CHECK_ROCBLAS_ERROR(rocblas_set_batch_beta_stride(handle, 0)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride)); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride)); + EXPECT_EQ(0, batch_alpha_stride); + EXPECT_EQ(0, batch_beta_stride); + + // stored regardless of mode, but only utilized in device mode kernels + CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host)); + + CHECK_ROCBLAS_ERROR(rocblas_set_batch_alpha_stride(handle, 7)); + CHECK_ROCBLAS_ERROR(rocblas_set_batch_beta_stride(handle, 11)); + + CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride)); + EXPECT_EQ(7, batch_alpha_stride); + CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride)); + EXPECT_EQ(11, batch_beta_stride); CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); } diff --git a/projects/rocblas/clients/include/blas2/testing_gemv_batched.hpp b/projects/rocblas/clients/include/blas2/testing_gemv_batched.hpp index 90452cbe140..db67c67465d 100644 --- a/projects/rocblas/clients/include/blas2/testing_gemv_batched.hpp +++ b/projects/rocblas/clients/include/blas2/testing_gemv_batched.hpp @@ -300,6 +300,10 @@ void testing_gemv_batched(const Arguments& arg) rocblas_operation transA = char2rocblas_operation(arg.transA); int64_t batch_count = arg.batch_count; + bool ab_striding = arg.alpha_beta_stride; + int64_t alpha_stride = ab_striding ? arg.stride_c : 0; + int64_t beta_stride = ab_striding ? arg.stride_d : 0; + rocblas_local_handle handle{arg}; // argument sanity check before allocating invalid memory @@ -344,23 +348,24 @@ void testing_gemv_batched(const Arguments& arg) HOST_MEMCHECK(host_batch_vector, hx, (dim_x, incx, batch_count)); HOST_MEMCHECK(host_batch_vector, hy, (dim_y, incy, batch_count)); HOST_MEMCHECK(host_batch_vector, hy_gold, (dim_y, incy, batch_count)); - HOST_MEMCHECK(host_vector, halpha, (1)); - HOST_MEMCHECK(host_vector, hbeta, (1)); + HOST_MEMCHECK(host_vector, halpha, (batch_count, alpha_stride)); + HOST_MEMCHECK(host_vector, hbeta, (batch_count, beta_stride)); // Allocate device memory DEVICE_MEMCHECK(device_batch_matrix, dA, (M, N, lda, batch_count)); DEVICE_MEMCHECK(device_batch_vector, dx, (dim_x, incx, batch_count)); DEVICE_MEMCHECK(device_batch_vector, dy, (dim_y, incy, batch_count)); - DEVICE_MEMCHECK(device_vector, d_alpha, (1)); - DEVICE_MEMCHECK(device_vector, d_beta, (1)); + DEVICE_MEMCHECK(device_vector, d_alpha, (batch_count, alpha_stride)); + DEVICE_MEMCHECK(device_vector, d_beta, (batch_count, beta_stride)); // Initialize data on host memory rocblas_init_matrix( hA, arg, rocblas_client_alpha_sets_nan, rocblas_client_general_matrix, true); rocblas_init_vector(hx, arg, rocblas_client_alpha_sets_nan, false, true); rocblas_init_vector(hy, arg, rocblas_client_beta_sets_nan); - halpha[0] = h_alpha; - hbeta[0] = h_beta; + + rocblas_init_vector_alternating_zero(halpha, h_alpha); + rocblas_init_vector_alternating_zero(hbeta, h_beta); hy_gold.copy_from(hy); @@ -448,8 +453,8 @@ void testing_gemv_batched(const Arguments& arg) DEVICE_MEMCHECK(device_batch_matrix, dA_copy, (M, N, lda, batch_count)); DEVICE_MEMCHECK(device_batch_vector, dx_copy, (dim_x, incx, batch_count)); DEVICE_MEMCHECK(device_batch_vector, dy_copy, (dim_y, incy, batch_count)); - DEVICE_MEMCHECK(device_vector, d_alpha_copy, (1)); - DEVICE_MEMCHECK(device_vector, d_beta_copy, (1)); + DEVICE_MEMCHECK(device_vector, d_alpha_copy, (batch_count, alpha_stride)); + DEVICE_MEMCHECK(device_vector, d_beta_copy, (batch_count, beta_stride)); CHECK_HIP_ERROR(dA_copy.transfer_from(hA)); CHECK_HIP_ERROR(dx_copy.transfer_from(hx)); @@ -489,8 +494,17 @@ void testing_gemv_batched(const Arguments& arg) cpu_time_used = get_time_us_no_sync(); for(int64_t b = 0; b < batch_count; ++b) { - ref_gemv( - transA, M, N, h_alpha, hA[b], lda, hx[b], incx, h_beta, hy_gold[b], incy); + ref_gemv(transA, + M, + N, + halpha[b * alpha_stride], + hA[b], + lda, + hx[b], + incx, + hbeta[b * beta_stride], + hy_gold[b], + incy); } cpu_time_used = get_time_us_no_sync() - cpu_time_used; diff --git a/projects/rocblas/clients/include/blas2/testing_gemv_strided_batched.hpp b/projects/rocblas/clients/include/blas2/testing_gemv_strided_batched.hpp index 651aded0d32..6d5eb650640 100644 --- a/projects/rocblas/clients/include/blas2/testing_gemv_strided_batched.hpp +++ b/projects/rocblas/clients/include/blas2/testing_gemv_strided_batched.hpp @@ -342,6 +342,10 @@ void testing_gemv_strided_batched(const Arguments& arg) int64_t stride_y = arg.stride_y; int64_t batch_count = arg.batch_count; + bool ab_striding = arg.alpha_beta_stride; + int64_t alpha_stride = ab_striding ? arg.stride_c : 0; + int64_t beta_stride = ab_striding ? arg.stride_d : 0; + rocblas_local_handle handle{arg}; size_t dim_x, row_A; @@ -393,17 +397,15 @@ void testing_gemv_strided_batched(const Arguments& arg) HOST_MEMCHECK(host_strided_batch_vector, hx, (dim_x, incx, stride_x, batch_count)); HOST_MEMCHECK(host_strided_batch_vector, hy, (dim_y, incy, stride_y, batch_count)); HOST_MEMCHECK(host_strided_batch_vector, hy_gold, (dim_y, incy, stride_y, batch_count)); - HOST_MEMCHECK(host_vector, halpha, (1)); - HOST_MEMCHECK(host_vector, hbeta, (1)); - halpha[0] = h_alpha; - hbeta[0] = h_beta; + HOST_MEMCHECK(host_vector, halpha, (batch_count, alpha_stride)); + HOST_MEMCHECK(host_vector, hbeta, (batch_count, beta_stride)); // Allocate device memory DEVICE_MEMCHECK(device_strided_batch_matrix, dA, (M, N, lda, stride_a, batch_count)); DEVICE_MEMCHECK(device_strided_batch_vector, dx, (dim_x, incx, stride_x, batch_count)); DEVICE_MEMCHECK(device_strided_batch_vector, dy, (dim_y, incy, stride_y, batch_count)); - DEVICE_MEMCHECK(device_vector, d_alpha, (1)); - DEVICE_MEMCHECK(device_vector, d_beta, (1)); + DEVICE_MEMCHECK(device_vector, d_alpha, (batch_count, alpha_stride)); + DEVICE_MEMCHECK(device_vector, d_beta, (batch_count, beta_stride)); // Initialize data on host memory rocblas_init_matrix( @@ -411,6 +413,9 @@ void testing_gemv_strided_batched(const Arguments& arg) rocblas_init_vector(hx, arg, rocblas_client_alpha_sets_nan, false, true); rocblas_init_vector(hy, arg, rocblas_client_beta_sets_nan); + rocblas_init_vector_alternating_zero(halpha, h_alpha); + rocblas_init_vector_alternating_zero(hbeta, h_beta); + hy_gold.copy_from(hy); // copy data from CPU to device @@ -511,8 +516,8 @@ void testing_gemv_strided_batched(const Arguments& arg) DEVICE_MEMCHECK(device_strided_batch_vector, dy_copy, (dim_y, incy, stride_y, batch_count)); - DEVICE_MEMCHECK(device_vector, d_alpha_copy, (1)); - DEVICE_MEMCHECK(device_vector, d_beta_copy, (1)); + DEVICE_MEMCHECK(device_vector, d_alpha_copy, (batch_count, alpha_stride)); + DEVICE_MEMCHECK(device_vector, d_beta_copy, (batch_count, beta_stride)); CHECK_HIP_ERROR(dA_copy.transfer_from(hA)); CHECK_HIP_ERROR(dx_copy.transfer_from(hx)); @@ -555,8 +560,17 @@ void testing_gemv_strided_batched(const Arguments& arg) cpu_time_used = get_time_us_no_sync(); for(int64_t b = 0; b < batch_count; ++b) { - ref_gemv( - transA, M, N, h_alpha, hA[b], lda, hx[b], incx, h_beta, hy_gold[b], incy); + ref_gemv(transA, + M, + N, + halpha[b * alpha_stride], + hA[b], + lda, + hx[b], + incx, + hbeta[b * beta_stride], + hy_gold[b], + incy); } cpu_time_used = get_time_us_no_sync() - cpu_time_used; diff --git a/projects/rocblas/clients/include/client_utility.hpp b/projects/rocblas/clients/include/client_utility.hpp index c2908a7e337..1cb09e50bfc 100644 --- a/projects/rocblas/clients/include/client_utility.hpp +++ b/projects/rocblas/clients/include/client_utility.hpp @@ -148,6 +148,12 @@ class rocblas_local_handle void pre_test(const Arguments& arg) { + if(arg.alpha_beta_stride && arg.pointer_mode_device) + { + // for now no GEMM applicability so c/d stride used + rocblas_set_batch_alpha_stride(m_handle, arg.stride_c); + rocblas_set_batch_beta_stride(m_handle, arg.stride_d); + } #if HIP_VERSION >= 50500000 arg.graph_test ? rocblas_stream_begin_capture() : NOOP; #endif @@ -158,6 +164,11 @@ class rocblas_local_handle #if HIP_VERSION >= 50500000 arg.graph_test ? rocblas_stream_end_capture() : NOOP; #endif + if(arg.alpha_beta_stride && arg.pointer_mode_device) + { + rocblas_set_batch_alpha_stride(m_handle, 0); + rocblas_set_batch_beta_stride(m_handle, 0); + } } }; diff --git a/projects/rocblas/clients/include/rocblas_arguments.hpp b/projects/rocblas/clients/include/rocblas_arguments.hpp index ec9d2d398d1..02c79bbad89 100644 --- a/projects/rocblas/clients/include/rocblas_arguments.hpp +++ b/projects/rocblas/clients/include/rocblas_arguments.hpp @@ -182,6 +182,7 @@ struct Arguments bool HMM; // xnack+ bool graph_test; bool repeatability_check; + bool alpha_beta_stride; int use_hipblaslt; @@ -272,6 +273,7 @@ struct Arguments OPER(HMM) SEP \ OPER(graph_test) SEP \ OPER(repeatability_check) SEP \ + OPER(alpha_beta_stride) SEP \ OPER(use_hipblaslt) SEP \ OPER(cleanup) // clang-format on diff --git a/projects/rocblas/clients/include/rocblas_common.yaml b/projects/rocblas/clients/include/rocblas_common.yaml index d68535e33bd..2dc3a529752 100644 --- a/projects/rocblas/clients/include/rocblas_common.yaml +++ b/projects/rocblas/clients/include/rocblas_common.yaml @@ -586,6 +586,7 @@ Arguments: - HMM: c_bool - graph_test: c_bool - repeatability_check: c_bool + - alpha_beta_stride: c_bool - use_hipblaslt: c_int32 - cleanup: c_bool @@ -638,6 +639,10 @@ Defaults: alphai: 0.0 beta: 0.0 betai: 0.0 + stride_a: 0 + stride_b: 0 + stride_c: 0 + stride_d: 0 transA: '*' transB: '*' side: '*' @@ -659,6 +664,7 @@ Defaults: api: C graph_test: false repeatability_check: false + alpha_beta_stride: false norm_check: 0 unit_check: 1 res_check: 0 diff --git a/projects/rocblas/clients/include/rocblas_init.hpp b/projects/rocblas/clients/include/rocblas_init.hpp index 735789a002b..49e2465d2b7 100644 --- a/projects/rocblas/clients/include/rocblas_init.hpp +++ b/projects/rocblas/clients/include/rocblas_init.hpp @@ -51,6 +51,18 @@ typedef enum rocblas_check_nan_init_ } rocblas_check_nan_init; +// Initialize vector so adjacent entries have alternating zero and passed value. +template +void rocblas_init_vector_alternating_zero(host_vector& A, T value) +{ + auto M = A.n(); + auto inc = A.inc(); + for(size_t i = 0; i < M; ++i) + { + A[i * inc] = (i & 1) ? T(0) : value; + } +} + // Initialize matrix so adjacent entries have alternating sign. // In gemm if either A or B are initialized with alternating // sign the reduction sum will be summing positive diff --git a/projects/rocblas/docs/reference/helper-functions.rst b/projects/rocblas/docs/reference/helper-functions.rst index 8bf63d39ffa..42a61877814 100644 --- a/projects/rocblas/docs/reference/helper-functions.rst +++ b/projects/rocblas/docs/reference/helper-functions.rst @@ -19,10 +19,10 @@ Auxiliary functions .. doxygenfunction:: rocblas_get_pointer_mode .. doxygenfunction:: rocblas_set_atomics_mode .. doxygenfunction:: rocblas_get_atomics_mode -.. doxygenfunction:: rocblas_set_stride_alpha -.. doxygenfunction:: rocblas_get_stride_alpha -.. doxygenfunction:: rocblas_set_stride_beta -.. doxygenfunction:: rocblas_get_stride_beta +.. doxygenfunction:: rocblas_set_batch_alpha_stride +.. doxygenfunction:: rocblas_get_batch_alpha_stride +.. doxygenfunction:: rocblas_set_batch_beta_stride +.. doxygenfunction:: rocblas_get_batch_beta_stride .. doxygenfunction:: rocblas_pointer_to_mode .. doxygenfunction:: rocblas_initialize .. doxygenfunction:: rocblas_status_to_string diff --git a/projects/rocblas/docs/reference/level-2.rst b/projects/rocblas/docs/reference/level-2.rst index d11a1ab7400..1a557480e5d 100644 --- a/projects/rocblas/docs/reference/level-2.rst +++ b/projects/rocblas/docs/reference/level-2.rst @@ -85,6 +85,8 @@ The ``gemv`` functions support the ``_64`` interface. See the :ref:`ILP64 API` s ``gemv_batched`` functions have an implementation which uses atomic operations. See the :ref:`Atomic Operations` section for more information. The ``gemv_batched`` functions support the ``_64`` interface. See the :ref:`ILP64 API` section. +The ``gemv_batched`` functions support ``rocblas_set_batch_alpha_stride`` and ``rocblas_set_batch_beta_stride`` when the ``rocblas_handle`` is +in mode ``rocblas_pointer_mode_device``. .. doxygenfunction:: rocblas_sgemv_strided_batched :outline: @@ -104,6 +106,8 @@ The ``gemv_batched`` functions support the ``_64`` interface. See the :ref:`ILP6 ``gemv_strided_batched`` functions have an implementation which uses atomic operations. See the :ref:`Atomic Operations` section for more information. The ``gemv_strided_batched`` functions support the ``_64`` interface. See the :ref:`ILP64 API` section. +The ``gemv_strided_batched`` functions support ``rocblas_set_batch_alpha_stride`` and ``rocblas_set_batch_beta_stride`` when the ``rocblas_handle`` is +in mode ``rocblas_pointer_mode_device``. .. _rocblas_ger: diff --git a/projects/rocblas/library/include/internal/rocblas-auxiliary.h b/projects/rocblas/library/include/internal/rocblas-auxiliary.h index 55207c95ad3..d3945818cc4 100644 --- a/projects/rocblas/library/include/internal/rocblas-auxiliary.h +++ b/projects/rocblas/library/include/internal/rocblas-auxiliary.h @@ -75,26 +75,33 @@ ROCBLAS_EXPORT rocblas_status rocblas_set_atomics_mode(rocblas_handle hand ROCBLAS_EXPORT rocblas_status rocblas_get_atomics_mode(rocblas_handle handle, rocblas_atomics_mode* atomics_mode); -/*! \brief Set alpha stride stored on the handle (spacing between successive alpha values). Default is 0. -Used by limited set of L2 and L1 batched functions to specify the spacing between successive alpha values. +/*! \brief Set alpha stride for limited set of batched and strided_batched functions to specify the stride for alpha between successive batch elements. +Only applies to rocblas_pointer_mode_device and thus device side allocations. +It enables interpretation of the alpha pointer for both batched and strided_batched functions as a pointer to a vector of values. +Default value is 0 which treats it as a pointer to a single scalar. Support is denoted with specific function documentation. +Warning this is a modal like state in the handle. Restore to value 0 if no longer applicable to later function calls. */ -ROCBLAS_EXPORT rocblas_status rocblas_set_stride_alpha(rocblas_handle handle, - rocblas_stride stride_alpha); +ROCBLAS_EXPORT rocblas_status rocblas_set_batch_alpha_stride(rocblas_handle handle, + rocblas_stride alpha_stride); -/*! \brief Get alpha stride from the handle. +/*! \brief Get batch alpha stride from the handle. */ -ROCBLAS_EXPORT rocblas_status rocblas_get_stride_alpha(rocblas_handle handle, - rocblas_stride* stride_alpha); - -/*! \brief Set beta stride stored on the handle (spacing between successive beta values). Default is 0. +ROCBLAS_EXPORT rocblas_status rocblas_get_batch_alpha_stride(rocblas_handle handle, + rocblas_stride* alpha_stride); + +/*! \brief Set beta stride for limited set of batched and strided_batched functions to specify the stride for beta between successive batch elements. +Only applies to rocblas_pointer_mode_device and thus device side allocations. +It enables interpretation of the beta pointer for both batched and strided_batched functions as a pointer to a vector of values. +Default value is 0 which treats it as a pointer to a single scalar. Support is denoted with specific function documentation. +Warning this is a modal like state in the handle. Restore to value 0 if no longer applicable to later function calls. */ -ROCBLAS_EXPORT rocblas_status rocblas_set_stride_beta(rocblas_handle handle, - rocblas_stride stride_beta); +ROCBLAS_EXPORT rocblas_status rocblas_set_batch_beta_stride(rocblas_handle handle, + rocblas_stride beta_stride); -/*! \brief Get beta stride from the handle. +/*! \brief Get batch beta stride from the handle. */ -ROCBLAS_EXPORT rocblas_status rocblas_get_stride_beta(rocblas_handle handle, - rocblas_stride* stride_beta); +ROCBLAS_EXPORT rocblas_status rocblas_get_batch_beta_stride(rocblas_handle handle, + rocblas_stride* beta_stride); /*! \brief Set ``rocblas_math_mode``. */ diff --git a/projects/rocblas/library/src/blas2/rocblas_gemv_batched_imp.hpp b/projects/rocblas/library/src/blas2/rocblas_gemv_batched_imp.hpp index 59add8295dd..d2a52bdfe5c 100644 --- a/projects/rocblas/library/src/blas2/rocblas_gemv_batched_imp.hpp +++ b/projects/rocblas/library/src/blas2/rocblas_gemv_batched_imp.hpp @@ -230,28 +230,29 @@ namespace } // we don't instantiate _template for mixed types so directly calling launcher - rocblas_status status = ROCBLAS_API(rocblas_internal_gemv_launcher)(handle, - transA, - m, - n, - alpha, - 0, - A, - 0, - lda, - 0, - x, - 0, - incx, - 0, - beta, - 0, - y, - 0, - incy, - 0, - batch_count, - (Tex*)w_mem); + rocblas_status status + = ROCBLAS_API(rocblas_internal_gemv_launcher)(handle, + transA, + m, + n, + alpha, + handle->get_stride_alpha(), + A, + 0, + lda, + 0, + x, + 0, + incx, + 0, + beta, + handle->get_stride_beta(), + y, + 0, + incy, + 0, + batch_count, + (Tex*)w_mem); status = (status != rocblas_status_success) ? status : perf_status; if(status != rocblas_status_success) diff --git a/projects/rocblas/library/src/blas2/rocblas_gemv_strided_batched_imp.hpp b/projects/rocblas/library/src/blas2/rocblas_gemv_strided_batched_imp.hpp index 71e0cbf0c37..5464ed9e649 100644 --- a/projects/rocblas/library/src/blas2/rocblas_gemv_strided_batched_imp.hpp +++ b/projects/rocblas/library/src/blas2/rocblas_gemv_strided_batched_imp.hpp @@ -259,28 +259,29 @@ namespace } // we don't instantiate _template for mixed types so directly calling launcher - rocblas_status status = ROCBLAS_API(rocblas_internal_gemv_launcher)(handle, - transA, - m, - n, - alpha, - 0, - A, - 0, - lda, - strideA, - x, - 0, - incx, - stridex, - beta, - 0, - y, - 0, - incy, - stridey, - batch_count, - (Tex*)w_mem); + rocblas_status status + = ROCBLAS_API(rocblas_internal_gemv_launcher)(handle, + transA, + m, + n, + alpha, + handle->get_stride_alpha(), + A, + 0, + lda, + strideA, + x, + 0, + incx, + stridex, + beta, + handle->get_stride_beta(), + y, + 0, + incy, + stridey, + batch_count, + (Tex*)w_mem); status = (status != rocblas_status_success) ? status : perf_status; if(status != rocblas_status_success) diff --git a/projects/rocblas/library/src/include/handle.hpp b/projects/rocblas/library/src/include/handle.hpp index 1e8b211bdd9..0fa2964ae06 100644 --- a/projects/rocblas/library/src/include/handle.hpp +++ b/projects/rocblas/library/src/include/handle.hpp @@ -444,6 +444,26 @@ struct _rocblas_handle this->data_ptr = data_ptr; } + rocblas_stride get_stride_alpha() const + { + // only applicable to device mode alpha, load_scalar ignores for value + return stride_alpha; + } + void set_stride_alpha(rocblas_stride stride) + { + stride_alpha = stride; + } + + rocblas_stride get_stride_beta() const + { + // only applicable to device mode beta, load_scalar ignores for value + return stride_beta; + } + void set_stride_beta(rocblas_stride stride) + { + stride_beta = stride; + } + // C interfaces for manipulating device memory friend rocblas_status(::rocblas_start_device_memory_size_query)(_rocblas_handle*); friend rocblas_status(::rocblas_stop_device_memory_size_query)(_rocblas_handle*, size_t*); diff --git a/projects/rocblas/library/src/rocblas_auxiliary.cpp b/projects/rocblas/library/src/rocblas_auxiliary.cpp index 88c903a2a34..295dfe67388 100644 --- a/projects/rocblas/library/src/rocblas_auxiliary.cpp +++ b/projects/rocblas/library/src/rocblas_auxiliary.cpp @@ -129,15 +129,16 @@ catch(...) /******************************************************************************* * ! \brief get alpha stride ******************************************************************************/ -extern "C" rocblas_status rocblas_get_stride_alpha(rocblas_handle handle, rocblas_stride* stride_alpha) +extern "C" rocblas_status rocblas_get_batch_alpha_stride(rocblas_handle handle, + rocblas_stride* alpha_stride) try { if(!handle) return rocblas_status_invalid_handle; - *stride_alpha = handle->stride_alpha; + *alpha_stride = handle->get_stride_alpha(); rocblas_internal_logger logger; if(handle->layer_mode & rocblas_layer_mode_log_trace) - logger.log_trace(handle, "rocblas_get_stride_alpha", *stride_alpha); + logger.log_trace(handle, "rocblas_get_batch_alpha_stride", *alpha_stride); return rocblas_status_success; } catch(...) @@ -148,15 +149,16 @@ catch(...) /******************************************************************************* * ! \brief set alpha stride ******************************************************************************/ -extern "C" rocblas_status rocblas_set_stride_alpha(rocblas_handle handle, rocblas_stride stride_alpha) +extern "C" rocblas_status rocblas_set_batch_alpha_stride(rocblas_handle handle, + rocblas_stride alpha_stride) try { if(!handle) return rocblas_status_invalid_handle; rocblas_internal_logger logger; if(handle->layer_mode & rocblas_layer_mode_log_trace) - logger.log_trace(handle, "rocblas_set_stride_alpha", stride_alpha); - handle->stride_alpha = stride_alpha; + logger.log_trace(handle, "rocblas_set_batch_alpha_stride", alpha_stride); + handle->set_stride_alpha(alpha_stride); return rocblas_status_success; } catch(...) @@ -167,15 +169,16 @@ catch(...) /******************************************************************************* * ! \brief get beta stride ******************************************************************************/ -extern "C" rocblas_status rocblas_get_stride_beta(rocblas_handle handle, rocblas_stride* stride_beta) +extern "C" rocblas_status rocblas_get_batch_beta_stride(rocblas_handle handle, + rocblas_stride* beta_stride) try { if(!handle) return rocblas_status_invalid_handle; - *stride_beta = handle->stride_beta; + *beta_stride = handle->get_stride_beta(); rocblas_internal_logger logger; if(handle->layer_mode & rocblas_layer_mode_log_trace) - logger.log_trace(handle, "rocblas_get_stride_beta", *stride_beta); + logger.log_trace(handle, "rocblas_get_batch_beta_stride", *beta_stride); return rocblas_status_success; } catch(...) @@ -186,15 +189,16 @@ catch(...) /******************************************************************************* * ! \brief set beta stride ******************************************************************************/ -extern "C" rocblas_status rocblas_set_stride_beta(rocblas_handle handle, rocblas_stride stride_beta) +extern "C" rocblas_status rocblas_set_batch_beta_stride(rocblas_handle handle, + rocblas_stride beta_stride) try { if(!handle) return rocblas_status_invalid_handle; rocblas_internal_logger logger; if(handle->layer_mode & rocblas_layer_mode_log_trace) - logger.log_trace(handle, "rocblas_set_stride_beta", stride_beta); - handle->stride_beta = stride_beta; + logger.log_trace(handle, "rocblas_set_batch_beta_stride", beta_stride); + handle->set_stride_beta(beta_stride); return rocblas_status_success; } catch(...)