Skip to content

Commit 141f77a

Browse files
authored
[CK Tile] Fix FMHA LSE calculation and potential division by zero (#3326)
This commit addresses numerical stability issues in the BlockFmhaPipelineQRKSVS pipeline when bias has -inf masking values: 1. Explicitly handle the case where the accumulated exponential sum (l) is zero. In this case, the LSE is now correctly set to negative infinity, preventing log(0) errors. 2. Extend the zero-check protection in the normalization step to cover the ELEMENTWISE_BIAS case, preventing potential division by zero.
1 parent c9f112b commit 141f77a

1 file changed

Lines changed: 21 additions & 9 deletions

File tree

include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -714,26 +714,35 @@ struct BlockFmhaPipelineQRKSVS
714714
constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
715715
sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
716716
constexpr auto i_idx = make_tuple(idx0);
717-
#if CK_TILE_FMHA_FWD_FAST_EXP2
718-
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
719-
BiasEnum == BlockAttentionBiasEnum::ALIBI)
717+
// In the masked biased case, the entire row can be suppressed and the accumulated
718+
// softmax denominator becomes zero; treat it as log(0) = -inf to avoid NaNs.
719+
if(l_[i_idx] == 0.0f)
720720
{
721-
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
721+
lse(i_idx) = -numeric<LSEDataType>::infinity();
722722
}
723723
else
724724
{
725-
if constexpr(kHasLogitsSoftCap)
725+
#if CK_TILE_FMHA_FWD_FAST_EXP2
726+
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
727+
BiasEnum == BlockAttentionBiasEnum::ALIBI)
726728
{
727729
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
728730
}
729731
else
730732
{
731-
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
733+
if constexpr(kHasLogitsSoftCap)
734+
{
735+
lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
736+
}
737+
else
738+
{
739+
lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
740+
}
732741
}
733-
}
734742
#else
735-
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
743+
lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
736744
#endif
745+
}
737746
});
738747

739748
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
@@ -745,7 +754,10 @@ struct BlockFmhaPipelineQRKSVS
745754
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
746755
constexpr auto i_idx = make_tuple(idx0);
747756
const auto tmp = [&]() {
748-
if constexpr(FmhaMask::IsMasking)
757+
// When bias carries -inf masks the denominator can be zero; guard the normalization
758+
// so we do not divide by zero after a fully masked row.
759+
if constexpr(FmhaMask::IsMasking ||
760+
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
749761
{
750762
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
751763
}

0 commit comments

Comments
 (0)