Skip to content

Commit cbfb3e2

Browse files
samremesassistant-librarian[bot]
authored andcommitted
[rocm-libraries] ROCm/rocm-libraries#6611 (commit 5375c0f)
[CK_TILE] Preserve input strides in EightWaves async-load descriptor (#6611) `MakeAsyncLoadADramWindow` in `GemmPipelineAgBgCrCompAsyncEightWavesPolicy` was rebuilding the 6D view descriptor with `make_naive_tensor_descriptor_packed`, which synthesizes strides from lengths and assumes a dense layout. When the input view's leading-dim stride is larger than its inner length (non-packed memory layout), the resulting tile window stepped through memory at the wrong stride. Compose the unmerge transforms on top of the input view's existing descriptor instead, so the actual runtime strides are preserved and the correct `element_space_size` is inherited for bounds checking. ## Test Plan Added an unit test showing the problem. ## Test Result The new test fails before fixes and passes after. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 9d34174 commit cbfb3e2

4 files changed

Lines changed: 53 additions & 6 deletions

File tree

include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,15 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy
176176
const index_t M0 = integer_divide_ceil(rows, M1);
177177
const auto row_lens = make_tuple(M0, number<M1>{});
178178

179-
const auto d0 = make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
180-
const auto desc_0 = decltype(d0)( // set correct size (without padding)
181-
d0.get_transforms(),
182-
tensor_view_tmp.get_tensor_descriptor().get_element_space_size());
179+
// Build the 6D view by composing unmerge transforms on top of the
180+
// input view's existing descriptor. This preserves the input's actual
181+
// strides (so a non-packed leading-dim stride is honored) and inherits
182+
// its element_space_size for bounds checking.
183+
const auto desc_0 = transform_tensor_descriptor(
184+
tensor_view_tmp.get_tensor_descriptor(),
185+
make_tuple(make_unmerge_transform(row_lens), make_unmerge_transform(col_lens)),
186+
make_tuple(sequence<0>{}, sequence<1>{}),
187+
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4, 5>{}));
183188
const auto desc_1 = transform_tensor_descriptor(
184189
desc_0,
185190
make_tuple(make_pass_through_transform(M0),

test/ck_tile/gemm_block_scale/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
8686
)
8787
target_compile_options(test_tile_gemm_quant_abquant_eightwaves PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
8888

89+
add_gtest_executable(test_tile_gemm_quant_abquant_eightwaves_padded_stride
90+
test_gemm_quant_abquant_eightwaves_padded_stride.cpp
91+
)
92+
target_compile_options(test_tile_gemm_quant_abquant_eightwaves_padded_stride PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
93+
8994
# ABQuant split-K tests
9095
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode
9196
test_gemm_quant_abquant_splitk_decode.cpp
@@ -281,6 +286,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
281286
test_tile_gemm_quant_abquant_a4w4_padding
282287
test_tile_gemm_quant_abquant_a4w4_preshuffle
283288
test_tile_gemm_quant_abquant_eightwaves
289+
test_tile_gemm_quant_abquant_eightwaves_padded_stride
284290
# ABQuant split-K tests
285291
test_tile_gemm_quant_abquant_splitk_decode
286292
test_tile_gemm_quant_abquant_splitk_prefill
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2+
// SPDX-License-Identifier: MIT
3+
4+
// Regression test for the EightWaves ABQuant pipeline on a B tensor whose
5+
// leading-dim stride is larger than the packed value. The async B-load
6+
// descriptor in the EightWaves policy must be built from the input view's
7+
// real strides so that the kernel addresses B correctly when stride_B is
8+
// larger than the inner length (e.g. row-aligned weight padding).
9+
10+
#include "test_gemm_quant_common.hpp"
11+
12+
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
13+
#ifdef CK_GFX950_SUPPORT
14+
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
15+
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
16+
// clang-format off
17+
using ABQuantEightWavesPaddedStrideTypes = ::testing::Types<
18+
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWaves, GroupSize1D_128, GroupSize2D128N, ColumnMajor>
19+
>;
20+
// clang-format on
21+
22+
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWavesPaddedStrideTypes);
23+
24+
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedPaddedBStrideTest)
25+
{
26+
// 256-byte row alignment for FP8 -> 256 elements of leading-dim padding.
27+
constexpr ck_tile::index_t k_batch = 1;
28+
constexpr ck_tile::index_t stride_B_pad = 256;
29+
this->run_test_with_validation(1024, 1024, 1024, k_batch, stride_B_pad);
30+
}
31+
#endif

test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,12 +1038,17 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
10381038
void run_test_with_validation(ck_tile::index_t M,
10391039
ck_tile::index_t N,
10401040
ck_tile::index_t K,
1041-
ck_tile::index_t k_batch = 1)
1041+
ck_tile::index_t k_batch = 1,
1042+
ck_tile::index_t stride_B_pad = 0)
10421043
{
10431044
const ck_tile::index_t stride_A =
10441045
ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{}));
1046+
// stride_B_pad lets a test exercise a B tensor whose leading-dim stride is
1047+
// larger than the packed value (e.g. row-aligned padding). The host tensor,
1048+
// device buffer, and kernel args are all built with this padded stride so
1049+
// the kernel must honor the runtime stride to address B correctly.
10451050
const ck_tile::index_t stride_B =
1046-
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{}));
1051+
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})) + stride_B_pad;
10471052
const ck_tile::index_t stride_C =
10481053
ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{}));
10491054

0 commit comments

Comments
 (0)