Skip to content

Commit 5c7b7ec

Browse files
[rocm-libraries] ROCm/rocm-libraries#7272 (commit d02f3c0)
[ck_tile][fmha_bwd] Fix sink_host OOB in group mode reference runner (#7272) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary In `fmha_bwd_runner.hpp`, the `sink_host` `HostTensor` is allocated with first dimension `shape_batch` (= 1 in group mode), but the reference forward loop accesses `sink_host(wb, i_h)` with `wb ∈ [0, batch-1]`. For any `wb >= 1` this is an out-of-bounds heap read, silently corrupting the reference forward math chain (`lse_host`, `o_host`) and turning the bwd-side `d_sink_head_acc` reference into non-deterministic garbage. `HostTensor::operator()` does not bounds check, so the OOB is not caught at runtime. This manifests as intermittent `tile_example_fmha_bwd` failures (25–67% fail rate) when `-sink_grad=1` is combined with `-mode=1` (group mode), with bit-exact but spurious `max_err` values like 4.27 / 14.6. ## Fix One-line: allocate `sink_host` with `batch` (the real per-batch dim) instead of `shape_batch`, mirroring how `sink_host` is accessed by the loop. ```diff - sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead} + sink_grad ? std::array<ck_tile::index_t, 2>{batch, nhead} Repro tile_example_fmha_bwd -b=2 -h=2 -s=516 -s_k=253 -prec=bf16 -d=72 \ -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 \ -v=3 -mode=1 -kname=1 -sink_grad=1 Verification - 0/30 fail on the repro config after fix - Baselines (before fix): - sink=1, mask=n: 25% fail rate (p ≈ 1.8e-4) - sink=1, mask=t: 67% fail rate (p ≈ 6e-15) Attribution Shape bug introduced together with sink_grad in #5504. Unrelated to #6914 (which is a fwd-only fix on a different code path) ``` ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1 parent 6989cf8 commit 5c7b7ec

2 files changed

Lines changed: 81 additions & 1 deletion

File tree

example/ck_tile/01_fmha/fmha_bwd_runner.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
264264
ck_tile::HostTensor<LSEDataType> lse_host(
265265
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
266266
ck_tile::HostTensor<LSEDataType> sink_host(
267-
sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
267+
sink_grad ? std::array<ck_tile::index_t, 2>{batch, nhead}
268268
: std::array<ck_tile::index_t, 2>{1, 1} /* dummy when sink is disabled */);
269269
if(sink_grad)
270270
{

test/ck_tile/fmha/test_fmha_bwd.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,3 +995,83 @@ TEST_P(MultiBatchPadding, DataTypeConfig)
995995
GTEST_SKIP() << "No instance for multi-batch padding";
996996
ASSERT_EQ(result, bwd_result::success);
997997
}
998+
999+
// ============================================================================
1000+
// Regression test for sink_host group-mode OOB fix (PR #7272)
1001+
// ----------------------------------------------------------------------------
1002+
// Bug: in group mode, fmha_bwd_runner.hpp allocated sink_host with first
1003+
// dimension shape_batch (=1) but the fwd reference loop iterates wb in
1004+
// [0, batch-1], causing out-of-bounds reads of heap garbage when batch > 1.
1005+
//
1006+
// Repro condition: sink_grad=true AND mode=group AND batch>=2.
1007+
// Without the fix, the fwd reference computes a poisoned LSE and the bwd
1008+
// validation fails non-deterministically (~25-67% failure rate observed
1009+
// across 30 trial runs at b=2,h=2,s=516,s_k=253,d=72,bf16,mask=no).
1010+
// With the fix (1-line change shape_batch -> batch on line 267 of
1011+
// fmha_bwd_runner.hpp), all 30 runs PASS.
1012+
//
1013+
// This test exercises the fixed code path; a regression that re-introduces
1014+
// the OOB will be detected as flaky/failing validation in CI.
1015+
// ============================================================================
1016+
class SinkGradGroupMode : public TestWithParam<FmhaBwdTestParam>
1017+
{
1018+
};
1019+
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
1020+
SinkGradGroupMode,
1021+
Combine(Values(mode_enum::group), // group mode required to hit OOB
1022+
Values(std::tuple{72, -1}, // hdim covered by repro command
1023+
std::tuple{64, -1},
1024+
std::tuple{128, -1}),
1025+
Values(std::tuple{true, true}), // perm matching repro
1026+
Values("n"), // bias=n matching repro
1027+
Values(false), // use_dbias
1028+
Values(0.0f), // no dropout
1029+
Values(std::tuple{0, 0, false}), // seed/offset/prefs
1030+
Values(std::tuple{2, 2, -1, 516, 253, "0"}, // exact repro config
1031+
std::tuple{2, 2, -1, 516, 253, "1"}, // + causal top-left
1032+
std::tuple{
1033+
2, 2, -1, 516, 253, "2"}, // + causal bottom-right
1034+
std::tuple{3, 4, 2, 259, -1, "0"}, // larger batch, square
1035+
std::tuple{4, 2, -1, 200, 180, "0"}), // batch=4 stress
1036+
Values(false) // deterministic
1037+
));
1038+
TEST_P(SinkGradGroupMode, DataTypeConfig)
1039+
{
1040+
auto [mode, hdims, perm, bias_str, use_dbias, p_drop, drop_misc, dims_mask, det] = GetParam();
1041+
auto [hdim_q, hdim_v] = hdims;
1042+
auto [i_perm, o_perm] = perm;
1043+
auto [drop_seed, drop_offset, drop_prefs] = drop_misc;
1044+
auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask;
1045+
1046+
auto result = fmha_bwd_run<DataTypeConfig>(
1047+
mode,
1048+
batch,
1049+
nhead,
1050+
nhead_k,
1051+
{seqlen_q},
1052+
{seqlen_k},
1053+
{-1},
1054+
{-1},
1055+
hdim_q,
1056+
hdim_v,
1057+
i_perm,
1058+
o_perm,
1059+
0, // scale
1060+
bias_str,
1061+
use_dbias,
1062+
p_drop,
1063+
drop_seed,
1064+
drop_offset,
1065+
drop_prefs,
1066+
mask_str,
1067+
true, // sink_grad: critical to trigger sink_host alloc/access path
1068+
det,
1069+
init_method,
1070+
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
1071+
1,
1072+
stream_config);
1073+
1074+
if(result == bwd_result::no_instance)
1075+
GTEST_SKIP() << "No instance for sink_grad group-mode regression";
1076+
ASSERT_EQ(result, bwd_result::success);
1077+
}

0 commit comments

Comments
 (0)