Skip to content

Commit ca40e0e

Browse files
committed
fix
Signed-off-by: yangqun <qun.yang@intel.com>
1 parent bca30bd commit ca40e0e

2 files changed

Lines changed: 5 additions & 9 deletions

File tree

csrc/xpu/gdn_attn/xe_2/chunk_gated_delta_rule_kernels_xe2.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,8 +1016,8 @@ CUTE_DEVICE void chunk_fwd_o_kernel(
10161016
float g_cumsum_value =
10171017
a[(chunk_offset + e) + v_head_id * total_virtual_seqlen];
10181018
g_slm_ptr[e] = g_cumsum_value;
1019+
g_multi_slm_ptr[e] = sycl::native::exp(g_last_value - g_cumsum_value);
10191020
g_exp_slm_ptr[e] = sycl::native::exp(g_cumsum_value);
1020-
g_multi_slm_ptr[e] = g_last_value_exp / g_exp_slm_ptr[e];
10211021
}
10221022

10231023
CUTE_UNROLL
@@ -1148,8 +1148,7 @@ CUTE_DEVICE void chunk_fwd_o_kernel(
11481148
// Fused gemm: W×S[dv] -> tSrU_d, Q×S[dv] -> tSrO_c
11491149
// S is loaded once and reused in registers for both
11501150
gemm_TTS_fused_2A(
1151-
W_tensor, Q_tensor, S_tensor,
1152-
tSrU_d, tSrO_c, 0, 0, dv, mma);
1151+
W_tensor, Q_tensor, S_tensor, tSrU_d, tSrO_c, 0, 0, dv, mma);
11531152

11541153
// --- WS epilogue: U_new[dv] = U_old[dv] - W×S[dv] ---
11551154
auto tCrU_d = thr_copy_U_d.partition_sg_fragment_S(gU_C);

csrc/xpu/gdn_attn/xe_2/gemm.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -407,12 +407,9 @@ CUTE_DEVICE void gemm_TTS_fused_2A(
407407

408408
auto wg_tile = mma.tile_mnk();
409409

410-
Tensor gA1 = local_tile(
411-
cA1, select<0, 2>(wg_tile), make_coord(wg_m1, _));
412-
Tensor gA2 = local_tile(
413-
cA2, select<0, 2>(wg_tile), make_coord(wg_m2, _));
414-
Tensor gB = local_tile(
415-
cB, select<1, 2>(wg_tile), make_coord(wg_n, _));
410+
Tensor gA1 = local_tile(cA1, select<0, 2>(wg_tile), make_coord(wg_m1, _));
411+
Tensor gA2 = local_tile(cA2, select<0, 2>(wg_tile), make_coord(wg_m2, _));
412+
Tensor gB = local_tile(cB, select<1, 2>(wg_tile), make_coord(wg_n, _));
416413

417414
auto copy_a1 = get_block_2d_copy_A<void>(mma, A1);
418415
auto copy_a2 = get_block_2d_copy_A<void>(mma, A2);

0 commit comments

Comments
 (0)