Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/rocblas/clients/common/rocblas_arguments.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ void Arguments::init()
HMM = false;
graph_test = false;
repeatability_check = false;
alpha_beta_stride = false;

use_hipblaslt = -1;

Expand Down
1 change: 1 addition & 0 deletions projects/rocblas/clients/gtest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ set(rocblas_no_tensile_test_source
general_gtest.cpp
set_get_pointer_mode_gtest.cpp
set_get_atomics_mode_gtest.cpp
set_get_alpha_beta_stride_gtest.cpp
logging_mode_gtest.cpp
ostream_threadsafety_gtest.cpp
set_get_vector_gtest.cpp
Expand Down
1 change: 1 addition & 0 deletions projects/rocblas/clients/gtest/aux_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include: set_get_vector_gtest.yaml
include: logging_mode_gtest.yaml
include: set_get_pointer_mode_gtest.yaml
include: set_get_atomics_mode_gtest.yaml
include: set_get_alpha_beta_stride_gtest.yaml
include: ostream_threadsafety_gtest.yaml
include: multiheaded_gtest.yaml
include: atomics_mode_gtest.yaml
Expand Down
8 changes: 8 additions & 0 deletions projects/rocblas/clients/gtest/blas2/gemv_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,16 @@ namespace
name << '_' << arg.stride_y;

if(GEMV_TYPE != GEMV)
{
name << '_' << arg.batch_count;

if(arg.alpha_beta_stride)
{
name << '_' << arg.stride_c;
name << '_' << arg.stride_d;
}
}

if(arg.api & c_API_64)
{
name << "_I64";
Expand Down
16 changes: 16 additions & 0 deletions projects/rocblas/clients/gtest/gemv_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,22 @@ Tests:
batch_count: [ 2 ]
gpu_arch: '9??'

- name: gemv_stride_alpha_beta
category: pre_checkin
function:
- gemv_batched
- gemv_strided_batched
precision: *gemv_bfloat_half_single_double_complex_real_precisions
transA: [ N, T, C ]
matrix_size: *qmcpack_matrix_size_range
incx_incy: *incx_incy_range
alpha_beta: *alpha_beta_range
batch_count: [ 3 ]
alpha_beta_stride: true
stride_c: 2 # alpha beta stride
stride_d: 2
pointer_mode_host: false

- name: gemv_row_vectorized_coverage
category: pre_checkin
function:
Expand Down
103 changes: 103 additions & 0 deletions projects/rocblas/clients/gtest/set_get_alpha_beta_stride_gtest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/* ************************************************************************
* Copyright (C) 2020-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */

#include "client_utility.hpp"
#include "rocblas.hpp"
#include "rocblas_data.hpp"
#include "rocblas_datatype2string.hpp"
#include "rocblas_test.hpp"
#include <string>

namespace
{
template <typename...>
struct testing_set_get_alpha_beta_stride : rocblas_test_valid
{
void operator()(const Arguments&)
{
rocblas_handle handle;
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));

CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device));

rocblas_stride batch_alpha_stride = -1;
rocblas_stride batch_beta_stride = -1;
CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride));
CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride));
EXPECT_EQ(0, batch_alpha_stride);
EXPECT_EQ(0, batch_beta_stride);

CHECK_ROCBLAS_ERROR(rocblas_set_batch_alpha_stride(handle, 7));
CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride));
EXPECT_EQ(7, batch_alpha_stride);
CHECK_ROCBLAS_ERROR(rocblas_set_batch_beta_stride(handle, 7));
CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride));
EXPECT_EQ(7, batch_beta_stride);

CHECK_ROCBLAS_ERROR(rocblas_set_batch_alpha_stride(handle, 0));
CHECK_ROCBLAS_ERROR(rocblas_set_batch_beta_stride(handle, 0));
CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride));
CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride));
EXPECT_EQ(0, batch_alpha_stride);
EXPECT_EQ(0, batch_beta_stride);

// stored regardless of mode, but only utilized in device mode kernels
CHECK_ROCBLAS_ERROR(rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host));

CHECK_ROCBLAS_ERROR(rocblas_set_batch_alpha_stride(handle, 7));
CHECK_ROCBLAS_ERROR(rocblas_set_batch_beta_stride(handle, 11));

CHECK_ROCBLAS_ERROR(rocblas_get_batch_alpha_stride(handle, &batch_alpha_stride));
EXPECT_EQ(7, batch_alpha_stride);
CHECK_ROCBLAS_ERROR(rocblas_get_batch_beta_stride(handle, &batch_beta_stride));
EXPECT_EQ(11, batch_beta_stride);

CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
}
};

struct set_get_alpha_beta_stride
: RocBLAS_Test<set_get_alpha_beta_stride, testing_set_get_alpha_beta_stride>
{
static bool type_filter(const Arguments&)
{
return true;
}

static bool function_filter(const Arguments& arg)
{
return !strcmp(arg.function, "set_get_alpha_beta_stride");
}

static std::string name_suffix(const Arguments& arg)
{
return RocBLAS_TestName<set_get_alpha_beta_stride>(arg.name);
}
};

TEST_P(set_get_alpha_beta_stride, auxiliary_tensile)
{
CATCH_SIGNALS_AND_EXCEPTIONS_AS_FAILURES(testing_set_get_alpha_beta_stride<>{}(GetParam()));
}
INSTANTIATE_TEST_CATEGORIES(set_get_alpha_beta_stride)

} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
include: rocblas_common.yaml
include: known_bugs.yaml

Tests:
- name: set_get_alpha_beta_stride
category: quick
function: set_get_alpha_beta_stride
precision: *single_precision
...
34 changes: 24 additions & 10 deletions projects/rocblas/clients/include/blas2/testing_gemv_batched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ void testing_gemv_batched(const Arguments& arg)
rocblas_operation transA = char2rocblas_operation(arg.transA);
int64_t batch_count = arg.batch_count;

bool ab_striding = arg.alpha_beta_stride;
int64_t alpha_stride = ab_striding ? arg.stride_c : 0;
int64_t beta_stride = ab_striding ? arg.stride_d : 0;

rocblas_local_handle handle{arg};

// argument sanity check before allocating invalid memory
Expand Down Expand Up @@ -344,23 +348,24 @@ void testing_gemv_batched(const Arguments& arg)
HOST_MEMCHECK(host_batch_vector<Ti>, hx, (dim_x, incx, batch_count));
HOST_MEMCHECK(host_batch_vector<To>, hy, (dim_y, incy, batch_count));
HOST_MEMCHECK(host_batch_vector<To>, hy_gold, (dim_y, incy, batch_count));
HOST_MEMCHECK(host_vector<Tex>, halpha, (1));
HOST_MEMCHECK(host_vector<Tex>, hbeta, (1));
HOST_MEMCHECK(host_vector<Tex>, halpha, (batch_count, alpha_stride));
HOST_MEMCHECK(host_vector<Tex>, hbeta, (batch_count, beta_stride));

// Allocate device memory
DEVICE_MEMCHECK(device_batch_matrix<Ti>, dA, (M, N, lda, batch_count));
DEVICE_MEMCHECK(device_batch_vector<Ti>, dx, (dim_x, incx, batch_count));
DEVICE_MEMCHECK(device_batch_vector<To>, dy, (dim_y, incy, batch_count));
DEVICE_MEMCHECK(device_vector<Tex>, d_alpha, (1));
DEVICE_MEMCHECK(device_vector<Tex>, d_beta, (1));
DEVICE_MEMCHECK(device_vector<Tex>, d_alpha, (batch_count, alpha_stride));
DEVICE_MEMCHECK(device_vector<Tex>, d_beta, (batch_count, beta_stride));

// Initialize data on host memory
rocblas_init_matrix(
hA, arg, rocblas_client_alpha_sets_nan, rocblas_client_general_matrix, true);
rocblas_init_vector(hx, arg, rocblas_client_alpha_sets_nan, false, true);
rocblas_init_vector(hy, arg, rocblas_client_beta_sets_nan);
halpha[0] = h_alpha;
hbeta[0] = h_beta;

rocblas_init_vector_alternating_zero(halpha, h_alpha);
rocblas_init_vector_alternating_zero(hbeta, h_beta);

hy_gold.copy_from(hy);

Expand Down Expand Up @@ -448,8 +453,8 @@ void testing_gemv_batched(const Arguments& arg)
DEVICE_MEMCHECK(device_batch_matrix<Ti>, dA_copy, (M, N, lda, batch_count));
DEVICE_MEMCHECK(device_batch_vector<Ti>, dx_copy, (dim_x, incx, batch_count));
DEVICE_MEMCHECK(device_batch_vector<To>, dy_copy, (dim_y, incy, batch_count));
DEVICE_MEMCHECK(device_vector<Tex>, d_alpha_copy, (1));
DEVICE_MEMCHECK(device_vector<Tex>, d_beta_copy, (1));
DEVICE_MEMCHECK(device_vector<Tex>, d_alpha_copy, (batch_count, alpha_stride));
DEVICE_MEMCHECK(device_vector<Tex>, d_beta_copy, (batch_count, beta_stride));

CHECK_HIP_ERROR(dA_copy.transfer_from(hA));
CHECK_HIP_ERROR(dx_copy.transfer_from(hx));
Expand Down Expand Up @@ -489,8 +494,17 @@ void testing_gemv_batched(const Arguments& arg)
cpu_time_used = get_time_us_no_sync();
for(int64_t b = 0; b < batch_count; ++b)
{
ref_gemv<Ti, To>(
transA, M, N, h_alpha, hA[b], lda, hx[b], incx, h_beta, hy_gold[b], incy);
ref_gemv<Ti, To>(transA,
M,
N,
halpha[b * alpha_stride],
hA[b],
lda,
hx[b],
incx,
hbeta[b * beta_stride],
hy_gold[b],
incy);
}
cpu_time_used = get_time_us_no_sync() - cpu_time_used;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ void testing_gemv_strided_batched(const Arguments& arg)
int64_t stride_y = arg.stride_y;
int64_t batch_count = arg.batch_count;

bool ab_striding = arg.alpha_beta_stride;
int64_t alpha_stride = ab_striding ? arg.stride_c : 0;
int64_t beta_stride = ab_striding ? arg.stride_d : 0;

rocblas_local_handle handle{arg};

size_t dim_x, row_A;
Expand Down Expand Up @@ -393,24 +397,25 @@ void testing_gemv_strided_batched(const Arguments& arg)
HOST_MEMCHECK(host_strided_batch_vector<Ti>, hx, (dim_x, incx, stride_x, batch_count));
HOST_MEMCHECK(host_strided_batch_vector<To>, hy, (dim_y, incy, stride_y, batch_count));
HOST_MEMCHECK(host_strided_batch_vector<To>, hy_gold, (dim_y, incy, stride_y, batch_count));
HOST_MEMCHECK(host_vector<Tex>, halpha, (1));
HOST_MEMCHECK(host_vector<Tex>, hbeta, (1));
halpha[0] = h_alpha;
hbeta[0] = h_beta;
HOST_MEMCHECK(host_vector<Tex>, halpha, (batch_count, alpha_stride));
HOST_MEMCHECK(host_vector<Tex>, hbeta, (batch_count, beta_stride));

// Allocate device memory
DEVICE_MEMCHECK(device_strided_batch_matrix<Ti>, dA, (M, N, lda, stride_a, batch_count));
DEVICE_MEMCHECK(device_strided_batch_vector<Ti>, dx, (dim_x, incx, stride_x, batch_count));
DEVICE_MEMCHECK(device_strided_batch_vector<To>, dy, (dim_y, incy, stride_y, batch_count));
DEVICE_MEMCHECK(device_vector<Tex>, d_alpha, (1));
DEVICE_MEMCHECK(device_vector<Tex>, d_beta, (1));
DEVICE_MEMCHECK(device_vector<Tex>, d_alpha, (batch_count, alpha_stride));
DEVICE_MEMCHECK(device_vector<Tex>, d_beta, (batch_count, beta_stride));

// Initialize data on host memory
rocblas_init_matrix(
hA, arg, rocblas_client_alpha_sets_nan, rocblas_client_general_matrix, true);
rocblas_init_vector(hx, arg, rocblas_client_alpha_sets_nan, false, true);
rocblas_init_vector(hy, arg, rocblas_client_beta_sets_nan);

rocblas_init_vector_alternating_zero(halpha, h_alpha);
rocblas_init_vector_alternating_zero(hbeta, h_beta);

hy_gold.copy_from(hy);

// copy data from CPU to device
Expand Down Expand Up @@ -511,8 +516,8 @@ void testing_gemv_strided_batched(const Arguments& arg)
DEVICE_MEMCHECK(device_strided_batch_vector<To>,
dy_copy,
(dim_y, incy, stride_y, batch_count));
DEVICE_MEMCHECK(device_vector<Tex>, d_alpha_copy, (1));
DEVICE_MEMCHECK(device_vector<Tex>, d_beta_copy, (1));
DEVICE_MEMCHECK(device_vector<Tex>, d_alpha_copy, (batch_count, alpha_stride));
DEVICE_MEMCHECK(device_vector<Tex>, d_beta_copy, (batch_count, beta_stride));

CHECK_HIP_ERROR(dA_copy.transfer_from(hA));
CHECK_HIP_ERROR(dx_copy.transfer_from(hx));
Expand Down Expand Up @@ -555,8 +560,17 @@ void testing_gemv_strided_batched(const Arguments& arg)
cpu_time_used = get_time_us_no_sync();
for(int64_t b = 0; b < batch_count; ++b)
{
ref_gemv<Ti, To>(
transA, M, N, h_alpha, hA[b], lda, hx[b], incx, h_beta, hy_gold[b], incy);
ref_gemv<Ti, To>(transA,
M,
N,
halpha[b * alpha_stride],
hA[b],
lda,
hx[b],
incx,
hbeta[b * beta_stride],
hy_gold[b],
incy);
}
cpu_time_used = get_time_us_no_sync() - cpu_time_used;

Expand Down
11 changes: 11 additions & 0 deletions projects/rocblas/clients/include/client_utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ class rocblas_local_handle

void pre_test(const Arguments& arg)
{
if(arg.alpha_beta_stride && arg.pointer_mode_device)
{
// for now no GEMM applicability so c/d stride used
rocblas_set_batch_alpha_stride(m_handle, arg.stride_c);
rocblas_set_batch_beta_stride(m_handle, arg.stride_d);
}
#if HIP_VERSION >= 50500000
arg.graph_test ? rocblas_stream_begin_capture() : NOOP;
#endif
Expand All @@ -158,6 +164,11 @@ class rocblas_local_handle
#if HIP_VERSION >= 50500000
arg.graph_test ? rocblas_stream_end_capture() : NOOP;
#endif
if(arg.alpha_beta_stride && arg.pointer_mode_device)
{
rocblas_set_batch_alpha_stride(m_handle, 0);
rocblas_set_batch_beta_stride(m_handle, 0);
}
}
};

Expand Down
2 changes: 2 additions & 0 deletions projects/rocblas/clients/include/rocblas_arguments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ struct Arguments
bool HMM; // xnack+
bool graph_test;
bool repeatability_check;
bool alpha_beta_stride;

int use_hipblaslt;

Expand Down Expand Up @@ -272,6 +273,7 @@ struct Arguments
OPER(HMM) SEP \
OPER(graph_test) SEP \
OPER(repeatability_check) SEP \
OPER(alpha_beta_stride) SEP \
OPER(use_hipblaslt) SEP \
OPER(cleanup)
// clang-format on
Expand Down
6 changes: 6 additions & 0 deletions projects/rocblas/clients/include/rocblas_common.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ Arguments:
- HMM: c_bool
- graph_test: c_bool
- repeatability_check: c_bool
- alpha_beta_stride: c_bool
- use_hipblaslt: c_int32
- cleanup: c_bool

Expand Down Expand Up @@ -638,6 +639,10 @@ Defaults:
alphai: 0.0
beta: 0.0
betai: 0.0
stride_a: 0
stride_b: 0
stride_c: 0
stride_d: 0
transA: '*'
transB: '*'
side: '*'
Expand All @@ -659,6 +664,7 @@ Defaults:
api: C
graph_test: false
repeatability_check: false
alpha_beta_stride: false
norm_check: 0
unit_check: 1
res_check: 0
Expand Down
Loading
Loading