From 69fea3744628e400ee0a97a59badb51124cad468 Mon Sep 17 00:00:00 2001 From: huizzhan Date: Thu, 16 Apr 2026 07:17:08 +0000 Subject: [PATCH 1/8] Add implementation --- kernels/chunk_gated_delta_h.py | 773 ++++++++++++++++++++++ tests/kernels/test_chunk_gated_delta_h.py | 745 +++++++++++++++++++++ 2 files changed, 1518 insertions(+) create mode 100644 kernels/chunk_gated_delta_h.py create mode 100644 tests/kernels/test_chunk_gated_delta_h.py diff --git a/kernels/chunk_gated_delta_h.py b/kernels/chunk_gated_delta_h.py new file mode 100644 index 000000000..740535fb7 --- /dev/null +++ b/kernels/chunk_gated_delta_h.py @@ -0,0 +1,773 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +""" +Gated Delta Net K5 hidden-state recurrence kernel using the @flyc.kernel API. + +Mirrors the Triton `chunk_gated_delta_rule_fwd_kernel_h_opt3` from ATOM/FLA, +rewritten in FlyDSL for AMD GPUs (gfx942/gfx950). + +For each chunk t (serial over NT chunks): + 1. Store h snapshot for downstream K6 + 2. v_new = u - w @ h (delta correction via MFMA) + 3. Gated decay + state update: + v_new *= exp(g_last - g_cumsum) + h = h * exp(g_last) + k^T @ v_new +""" + +import functools +import math + +import torch +import triton + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.expr.typing import T +from flydsl.expr import range_constexpr, arith, vector, gpu, rocdl, buffer_ops +from flydsl._mlir import ir +from flydsl._mlir.dialects import scf, math as math_dialect, llvm as _llvm +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.compiler.protocol import fly_values +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +from kernels.tensor_shim import GTensor, STensor, _to_raw + +_LOG2E = math.log2(math.e) # 1.4426950408889634 +_LLVM_GEP_DYNAMIC = -2147483648 + + +def _llvm_lds_ptr_ty(): + return ir.Type.parse("!llvm.ptr<3>") + + +def _llvm_exp2_f32(x): + """Emit llvm.exp2.f32 intrinsic directly (maps to single v_exp_f32 on AMD).""" + x_raw = _to_raw(x) + return _llvm.call_intrinsic( + ir.F32Type.get(), "llvm.exp2.f32", [x_raw], [], [] + ) + + +def _fast_exp(x): + """exp(x) via exp2(x * log2(e)) using the LLVM intrinsic.""" + log2e = arith.constant(_LOG2E, type=T.f32) + return _llvm_exp2_f32(arith.mulf(x, log2e)) + + +def _mfma_bf16_16x16x32(a_bf16x8, b_bf16x8, acc_f32x4): + """Single mfma_f32_16x16x32_bf16 instruction.""" + return rocdl.mfma_f32_16x16x32_bf16( + T.f32x4, a_bf16x8, b_bf16x8, acc_f32x4, 0, 0, 0 + ).res + + +# ── Utility helpers ────────────────────────────────────────────────────── + +def _prepare_lens(cu_seqlens): + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@functools.lru_cache(maxsize=8) +def _prepare_chunk_offsets(cu_seqlens_id, chunk_size, device): + cu_seqlens = torch._dynamo.utils.get_fake_value(cu_seqlens_id) if hasattr(torch._dynamo, 'utils') else None + return None + + +def prepare_chunk_offsets(cu_seqlens, chunk_size): + lens = _prepare_lens(cu_seqlens) + return torch.cat([ + cu_seqlens.new_tensor([0]), + triton.cdiv(lens, chunk_size), + ]).cumsum(-1) + + +# ── Compile the kernel ─────────────────────────────────────────────────── + +def compile_chunk_gated_delta_h( + *, + K: int, + V: int, + BT: int = 64, + BV: int = 32, + H: int, + Hg: int, + USE_G: bool = True, + USE_INITIAL_STATE: bool = True, + STORE_FINAL_STATE: bool = True, + SAVE_NEW_VALUE: bool = True, + IS_VARLEN: bool = True, + WU_CONTIGUOUS: bool = True, +): + """Compile the GDN K5 kernel. + + Returns a @flyc.jit function: + launch_fn(k, v, w, v_new, g, h, h0, ht, + cu_seqlens, chunk_offsets, + T_val, T_flat, N_val, stream) + """ + assert K <= 256 + assert K % 64 == 0 + assert BV % 16 == 0 + NUM_K_BLOCKS = K // 64 + + WARP_SIZE = 64 + NUM_WARPS = 4 + BLOCK_THREADS = NUM_WARPS * WARP_SIZE + + WMMA_M = 16 + WMMA_N = 16 + WMMA_K = 32 + WMMA_C_FRAG = 4 + + M_REPEAT = BT // WMMA_M + N_REPEAT = BV // WMMA_N + + NUM_H_ACCS = NUM_K_BLOCKS * N_REPEAT + + # ── LDS layout: w and k store all K-blocks to reduce barriers ── + LDS_W_STRIDE = K + LDS_W_ELEMS = BT * LDS_W_STRIDE + LDS_W_BYTES = LDS_W_ELEMS * 2 + + LDS_K_STRIDE = K + LDS_K_ELEMS = BT * LDS_K_STRIDE + LDS_K_BYTES = LDS_K_ELEMS * 2 + + LDS_VN_STRIDE = BV + LDS_VN_ELEMS = BT * LDS_VN_STRIDE + LDS_VN_BYTES = LDS_VN_ELEMS * 2 + + LDS_H_STRIDE = BV + LDS_H_ELEMS = K * LDS_H_STRIDE + LDS_H_BYTES = LDS_H_ELEMS * 2 + + allocator = SmemAllocator(None, arch="gfx942", global_sym_name="gdn_h_smem") + lds_w_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_w_offset + LDS_W_BYTES + lds_k_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_k_offset + LDS_K_BYTES + lds_vn_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_vn_offset + LDS_VN_BYTES + lds_h_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_h_offset + LDS_H_BYTES + + # Cooperative load parameters + LOAD_VEC_WIDTH = 8 # 8 bf16 = 16 bytes = buffer_load_dwordx4 + THREADS_PER_ROW_64 = 64 // LOAD_VEC_WIDTH # 8 + ROWS_PER_BATCH_64 = BLOCK_THREADS // THREADS_PER_ROW_64 # 32 + NUM_LOAD_BATCHES_64 = BT // ROWS_PER_BATCH_64 # 2 + + @flyc.kernel(name="chunk_gdn_fwd_h_opt3") + def gdn_h_kernel( + k_tensor: fx.Tensor, + v_tensor: fx.Tensor, + w_tensor: fx.Tensor, + v_new_tensor: fx.Tensor, + g_tensor: fx.Tensor, + h_tensor: fx.Tensor, + h0_tensor: fx.Tensor, + ht_tensor: fx.Tensor, + cu_seqlens_tensor: fx.Tensor, + chunk_offsets_tensor: fx.Tensor, + T_val: fx.Int32, + T_flat: fx.Int32, + N_val: fx.Int32, + ): + i_v = arith.index_cast(T.i32, gpu.block_id("x")) + i_nh = arith.index_cast(T.i32, gpu.block_id("y")) + i_n = i_nh // fx.Int32(H) + i_h = i_nh % fx.Int32(H) + + tid = arith.index_cast(T.i32, gpu.thread_id("x")) + wid = tid // fx.Int32(WARP_SIZE) + lane = tid % fx.Int32(WARP_SIZE) + + k_ = GTensor(k_tensor, dtype=T.bf16, shape=(-1,)) + v_ = GTensor(v_tensor, dtype=T.bf16, shape=(-1,)) + w_ = GTensor(w_tensor, dtype=T.bf16, shape=(-1,)) + h_ = GTensor(h_tensor, dtype=T.bf16, shape=(-1,)) + g_ = GTensor(g_tensor, dtype=T.f32, shape=(-1,)) + + vn_ = GTensor(v_new_tensor, dtype=T.bf16, shape=(-1,)) + if USE_INITIAL_STATE: + h0_ = GTensor(h0_tensor, dtype=T.f32, shape=(-1,)) + if STORE_FINAL_STATE: + ht_ = GTensor(ht_tensor, dtype=T.f32, shape=(-1,)) + + if IS_VARLEN: + cu_ = GTensor(cu_seqlens_tensor, dtype=T.i32, shape=(-1,)) + co_ = GTensor(chunk_offsets_tensor, dtype=T.i32, shape=(-1,)) + + # ── LDS views ── + lds_base_ptr = allocator.get_base() + + # w tile (bf16) — separate from k + lds_w_ptr = SmemPtr(lds_base_ptr, lds_w_offset, T.bf16, shape=(LDS_W_ELEMS,)) + lds_w = STensor(lds_w_ptr, dtype=T.bf16, shape=(LDS_W_ELEMS,)) + + # k tile (bf16) — separate from w + lds_k_ptr = SmemPtr(lds_base_ptr, lds_k_offset, T.bf16, shape=(LDS_K_ELEMS,)) + lds_k = STensor(lds_k_ptr, dtype=T.bf16, shape=(LDS_K_ELEMS,)) + + # gated v_new (bf16) + lds_vn_ptr = SmemPtr(lds_base_ptr, lds_vn_offset, T.bf16, shape=(LDS_VN_ELEMS,)) + lds_vn = STensor(lds_vn_ptr, dtype=T.bf16, shape=(LDS_VN_ELEMS,)) + + # h snapshot (bf16) + lds_h_ptr = SmemPtr(lds_base_ptr, lds_h_offset, T.bf16, shape=(LDS_H_ELEMS,)) + lds_h = STensor(lds_h_ptr, dtype=T.bf16, shape=(LDS_H_ELEMS,)) + + # ── Cooperative load decomposition ── + load_row_in_batch = tid // fx.Int32(THREADS_PER_ROW_64) + load_col_base = (tid % fx.Int32(THREADS_PER_ROW_64)) * fx.Int32(LOAD_VEC_WIDTH) + + # ── XOR swizzle: col ^ ((row & 7) << 3) at 8-element granularity for bf16 ── + def _xor_swizzle(row, col): + return col ^ ((row & fx.Int32(0x7)) << fx.Int32(3)) + + def _xor_swizzle_idx(row, col): + return col ^ ((row & arith.index(0x7)) << arith.index(3)) + + # ── LDS vector read helpers (generates ds_read_b128 for 8xbf16) ── + v8bf16_type = T.vec(8, T.bf16) + lds_w_memref = lds_w_ptr.get() + lds_k_memref = lds_k_ptr.get() + + def _lds_vec_read_w_bf16x8(elem_idx): + return vector.load_op(v8bf16_type, lds_w_memref, [elem_idx]) + + def _lds_vec_read_k_bf16x8(elem_idx): + return vector.load_op(v8bf16_type, lds_k_memref, [elem_idx]) + + # ── ds_read_b64_tr_b16 helper (gfx950) ── + v4bf16_type = T.vec(4, T.bf16) + + def _ds_read_tr_bf16x4(lds_byte_offset): + byte_idx = arith.index_cast(T.index, lds_byte_offset) + byte_i64 = arith.index_cast(T.i64, byte_idx) + ptr = _llvm.IntToPtrOp(_llvm_lds_ptr_ty(), byte_i64).result + return rocdl.ds_read_tr16_b64(v4bf16_type, ptr).result + + # ds_read_b64_tr_b16 lane decomposition + tr_k_group = (lane % fx.Int32(16)) // fx.Int32(4) + tr_col_sub = lane % fx.Int32(4) + tr_col_half = (lane % fx.Int32(32)) // fx.Int32(16) + lane_div_32 = lane // fx.Int32(32) + + # ── Prologue: compute bos, T_local, NT, boh ── + if IS_VARLEN: + bos = cu_[fx.Index(i_n)] + eos = cu_[fx.Index(i_n) + fx.Index(1)] + T_local = eos - bos + NT = (T_local + fx.Int32(BT - 1)) // fx.Int32(BT) + boh = co_[fx.Index(i_n)] + else: + bos = i_n * T_val + T_local = T_val + NT = (T_local + fx.Int32(BT - 1)) // fx.Int32(BT) + boh = i_n * NT + + # ── Base pointer offsets (element counts) ── + # h: [B, NT, H, K, V] — base = (boh*H + i_h) * K * V + h_base = (boh * fx.Int32(H) + i_h) * fx.Int32(K * V) + stride_h = fx.Int32(H * K * V) + + # k: [B, T, Hg, K] — base = (bos*Hg + i_h//(H//Hg)) * K + gqa_ratio = H // Hg + k_base = (bos * fx.Int32(Hg) + i_h // fx.Int32(gqa_ratio)) * fx.Int32(K) + stride_k = fx.Int32(Hg * K) + + if WU_CONTIGUOUS: + if IS_VARLEN: + v_base = (i_h * T_flat + bos) * fx.Int32(V) + w_base = (i_h * T_flat + bos) * fx.Int32(K) + else: + v_base = ((i_n * fx.Int32(H) + i_h) * T_flat) * fx.Int32(V) + w_base = ((i_n * fx.Int32(H) + i_h) * T_flat) * fx.Int32(K) + stride_v = fx.Int32(V) + stride_w = fx.Int32(K) + else: + v_base = (bos * fx.Int32(H) + i_h) * fx.Int32(V) + w_base = (bos * fx.Int32(H) + i_h) * fx.Int32(K) + stride_v = fx.Int32(H * V) + stride_w = fx.Int32(H * K) + + if IS_VARLEN: + vn_base = (i_h * T_flat + bos) * fx.Int32(V) + else: + vn_base = ((i_n * fx.Int32(H) + i_h) * T_flat) * fx.Int32(V) + + if USE_INITIAL_STATE: + h0_base = (i_nh * fx.Int32(K * V)) + if STORE_FINAL_STATE: + ht_base = (i_nh * fx.Int32(K * V)) + + # ── MFMA lane mapping for 16x16 tiles ── + lane_n = lane % fx.Int32(16) + lane_m_base = lane // fx.Int32(16) + + # index-typed versions for LDS addressing + wid_idx = arith.index_cast(T.index, wid) + lane_n_idx = arith.index_cast(T.index, lane_n) + lane_m_base_idx = arith.index_cast(T.index, lane_m_base) + + # ── Initialize h accumulators ── + acc_zero = arith.constant_vector(0.0, T.f32x4) + + # h_accs[kb][nr] = f32x4 accumulator for k-block kb, v-repeat nr + h_accs = [] + for _kb in range_constexpr(NUM_K_BLOCKS): + for _nr in range_constexpr(N_REPEAT): + h_accs.append(acc_zero) + + # ── Load initial state if provided ── + if USE_INITIAL_STATE: + for kb in range_constexpr(NUM_K_BLOCKS): + for nr in range_constexpr(N_REPEAT): + h0_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + h0_elems = [] + for elem_i in range_constexpr(4): + h0_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + h0_off = h0_base + h0_row * fx.Int32(V) + h0_col + h0_elems.append(h0_[fx.Index(h0_off)]) + loaded_vec = vector.from_elements(T.f32x4, h0_elems) + acc_idx = kb * N_REPEAT + nr + h_accs[acc_idx] = arith.addf(h_accs[acc_idx], loaded_vec) + + # ── Main chunk loop ── + init_state = [_to_raw(v) for v in h_accs] + c_zero = arith.index(0) + c_one = arith.index(1) + nt_idx = arith.index_cast(T.index, NT) + + for i_t, state in range(c_zero, nt_idx, c_one, init=init_state): + h_accs_in = list(state) + i_t_i32 = arith.index_cast(T.i32, i_t) + + # ── 1. Prefetch all w K-blocks from global (overlap with h snapshot store) ── + w_prefetch_all = [] + w_prefetch_lds_all = [] + for kb in range_constexpr(NUM_K_BLOCKS): + for batch in range_constexpr(NUM_LOAD_BATCHES_64): + row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch + abs_row = i_t_i32 * fx.Int32(BT) + row + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_off = w_base + safe_row * stride_w + fx.Int32(kb * 64) + load_col_base + w_prefetch_all.append(w_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) + w_prefetch_lds_all.append(row * fx.Int32(LDS_W_STRIDE) + fx.Int32(kb * 64) + load_col_base) + + # ── Store h snapshot to global + LDS (w[0] loads in flight) ── + for kb in range_constexpr(NUM_K_BLOCKS): + for nr in range_constexpr(N_REPEAT): + acc_idx = kb * N_REPEAT + nr + acc_val = h_accs_in[acc_idx] + h_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + lds_h_col = fx.Int32(nr * 16) + lane_n + + for elem_i in range_constexpr(4): + f32_val = vector.extract(acc_val, static_position=[elem_i], dynamic_position=[]) + bf16_val = arith.trunc_f(T.bf16, f32_val) + + h_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + h_off = h_base + i_t_i32 * stride_h + h_row * fx.Int32(V) + h_col + h_[fx.Index(h_off)] = bf16_val + + lds_h_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + lds_h_idx = lds_h_row * fx.Int32(BV) + lds_h_col + lds_h[fx.Index(lds_h_idx)] = bf16_val + + # ── Store all w K-blocks to LDS in one batch ── + for i_wp in range_constexpr(NUM_K_BLOCKS * NUM_LOAD_BATCHES_64): + lds_w.vec_store((fx.Index(w_prefetch_lds_all[i_wp]),), w_prefetch_all[i_wp], LOAD_VEC_WIDTH) + + gpu.barrier() + + # ── 2. Delta correction: b_v = w @ h, then v_new = u - b_v ── + # Prefetch k[0] and u values during MFMA (overlap global loads with compute) + k_prefetch = [] + k_prefetch_lds = [] + for batch in range_constexpr(NUM_LOAD_BATCHES_64): + row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch + abs_row = i_t_i32 * fx.Int32(BT) + row + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_off = k_base + safe_row * stride_k + fx.Int32(0 * 64) + load_col_base + k_prefetch.append(k_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) + k_prefetch_lds.append(row * fx.Int32(LDS_K_STRIDE) + load_col_base) + + # Prefetch g values (overlap with MFMA below) + if USE_G: + next_chunk_end = (i_t_i32 + fx.Int32(1)) * fx.Int32(BT) + last_idx_raw = arith.select( + arith.cmpi(arith.CmpIPredicate.slt, next_chunk_end, T_local), + next_chunk_end, + T_local, + ) - fx.Int32(1) + g_last_off = (bos + last_idx_raw) * fx.Int32(H) + i_h + g_last_prefetch = g_[fx.Index(g_last_off)] + + g_row_prefetch = [] + for elem_i in range_constexpr(4): + abs_row = i_t_i32 * fx.Int32(BT) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_row_off = (bos + safe_row) * fx.Int32(H) + i_h + g_row_prefetch.append((g_[fx.Index(g_row_off)], in_bounds)) + + # Prefetch u values (overlap with MFMA below) + u_prefetch = [] + for nr in range_constexpr(N_REPEAT): + u_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + for elem_i in range_constexpr(4): + u_bt_row_raw = i_t_i32 * fx.Int32(BT) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + u_row_in_bounds = arith.cmpi(arith.CmpIPredicate.slt, u_bt_row_raw, T_local) + safe_u_row = arith.select(u_row_in_bounds, u_bt_row_raw, fx.Int32(0)) + u_off = v_base + safe_u_row * stride_v + u_col + u_prefetch.append(v_[fx.Index(u_off)]) + + bv_accs = [] + for _nr in range_constexpr(N_REPEAT): + bv_accs.append(arith.constant_vector(0.0, T.f32x4)) + + K_STEPS_PER_BLOCK = 64 // WMMA_K + + for kb in range_constexpr(NUM_K_BLOCKS): + for ks in range_constexpr(K_STEPS_PER_BLOCK): + w_lds_row_idx = wid_idx * arith.index(16) + lane_n_idx + w_lds_col_idx = arith.index(kb * 64 + ks * WMMA_K) + lane_m_base_idx * arith.index(8) + w_lds_idx = w_lds_row_idx * arith.index(LDS_W_STRIDE) + w_lds_col_idx + a_frag = _lds_vec_read_w_bf16x8(w_lds_idx) + + global_ks = kb * K_STEPS_PER_BLOCK + ks + + for nr in range_constexpr(N_REPEAT): + h_k_row = fx.Int32(global_ks * WMMA_K) + lane_m_base * fx.Int32(8) + tr_k_group + h_v_col = fx.Int32(nr * 16) + tr_col_sub * fx.Int32(4) + h_lds_elem = h_k_row * fx.Int32(BV) + h_v_col + h_lds_byte = h_lds_elem * fx.Int32(2) + fx.Int32(lds_h_offset) + + h_lo = _ds_read_tr_bf16x4(h_lds_byte) + h_hi = _ds_read_tr_bf16x4(h_lds_byte + fx.Int32(4 * BV * 2)) + b_frag = vector.shuffle(h_lo, h_hi, [0, 1, 2, 3, 4, 5, 6, 7]) + + bv_accs[nr] = _mfma_bf16_16x16x32(a_frag, b_frag, bv_accs[nr]) + + # v_new = u - b_v (u values already prefetched) + vn_frags = [] + for nr in range_constexpr(N_REPEAT): + bv_val = bv_accs[nr] + u_f32_elems = [] + for elem_i in range_constexpr(4): + u_bf16 = u_prefetch[nr * 4 + elem_i] + u_f32_elems.append(arith.extf(T.f32, u_bf16)) + u_f32 = vector.from_elements(T.f32x4, u_f32_elems) + + vn_frags.append(arith.subf(u_f32, bv_val)) + + # ── 2b. Store v_new (pre-gating) for output ── + if SAVE_NEW_VALUE: + for nr in range_constexpr(N_REPEAT): + vn_val = vn_frags[nr] + vn_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + for elem_i in range_constexpr(4): + vn_bt_row = i_t_i32 * fx.Int32(BT) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + vn_in_bounds = arith.cmpi(arith.CmpIPredicate.slt, vn_bt_row, T_local) + _if_vn = scf.IfOp(vn_in_bounds) + with ir.InsertionPoint(_if_vn.then_block): + f32_v = vector.extract(vn_val, static_position=[elem_i], dynamic_position=[]) + bf16_v = arith.trunc_f(T.bf16, f32_v) + vn_off = vn_base + vn_bt_row * fx.Int32(V) + vn_col + vn_[fx.Index(vn_off)] = bf16_v + scf.YieldOp([]) + + # ── 3. Gating — g values prefetched before MFMA ── + if USE_G: + g_last = g_last_prefetch + exp_g_last = _fast_exp(g_last) + + gate_vec = arith.constant_vector(0.0, T.f32x4) + for elem_i in range_constexpr(4): + g_row, in_bounds = g_row_prefetch[elem_i] + gate = _fast_exp(arith.subf(g_last, g_row)) + gate_masked = arith.select(in_bounds, gate, arith.constant(0.0, type=T.f32)) + gate_vec = vector.insert(gate_masked, gate_vec, static_position=[elem_i], dynamic_position=[]) + + for nr in range_constexpr(N_REPEAT): + vn_frags[nr] = arith.mulf(vn_frags[nr], gate_vec) + + exp_g_last_vec = arith.constant_vector(0.0, T.f32x4) + for ei in range_constexpr(4): + exp_g_last_vec = vector.insert(exp_g_last, exp_g_last_vec, static_position=[ei], dynamic_position=[]) + + for kb in range_constexpr(NUM_K_BLOCKS): + for nr in range_constexpr(N_REPEAT): + acc_idx = kb * N_REPEAT + nr + h_accs_in[acc_idx] = arith.mulf(h_accs_in[acc_idx], exp_g_last_vec) + + # ── 4. State update: h += k^T @ v_new_gated ── + BT_STEPS = BT // WMMA_K + + # Prefetch remaining k K-blocks (k[0] already prefetched during delta correction) + for kb_extra in range_constexpr(1, NUM_K_BLOCKS): + for batch in range_constexpr(NUM_LOAD_BATCHES_64): + row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch + abs_row = i_t_i32 * fx.Int32(BT) + row + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_off = k_base + safe_row * stride_k + fx.Int32(kb_extra * 64) + load_col_base + k_prefetch.append(k_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) + k_prefetch_lds.append(row * fx.Int32(LDS_K_STRIDE) + fx.Int32(kb_extra * 64) + load_col_base) + + # Store gated v_new + all k K-blocks to LDS in one batch, single barrier + for nr in range_constexpr(N_REPEAT): + vn_val = vn_frags[nr] + lds_col = fx.Int32(nr * 16) + lane_n + for elem_i in range_constexpr(4): + f32_v = vector.extract(vn_val, static_position=[elem_i], dynamic_position=[]) + bf16_v = arith.trunc_f(T.bf16, f32_v) + lds_row = wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + lds_idx = lds_row * fx.Int32(LDS_VN_STRIDE) + lds_col + lds_vn[fx.Index(lds_idx)] = bf16_v + + for i_kp in range_constexpr(NUM_K_BLOCKS * NUM_LOAD_BATCHES_64): + lds_k.vec_store((fx.Index(k_prefetch_lds[i_kp]),), k_prefetch[i_kp], LOAD_VEC_WIDTH) + + gpu.barrier() + + for kb in range_constexpr(NUM_K_BLOCKS): + for bt_s in range_constexpr(BT_STEPS): + k_col_tr = wid * fx.Int32(16) + tr_col_sub * fx.Int32(4) + bt_row_tr = fx.Int32(bt_s * WMMA_K) + lane_m_base * fx.Int32(8) + tr_k_group + k_lds_elem = bt_row_tr * fx.Int32(LDS_K_STRIDE) + fx.Int32(kb * 64) + k_col_tr + k_lds_byte = k_lds_elem * fx.Int32(2) + fx.Int32(lds_k_offset) + + k_lo = _ds_read_tr_bf16x4(k_lds_byte) + k_hi = _ds_read_tr_bf16x4(k_lds_byte + fx.Int32(4 * LDS_K_STRIDE * 2)) + k_a_frag = vector.shuffle(k_lo, k_hi, [0, 1, 2, 3, 4, 5, 6, 7]) + + for nr in range_constexpr(N_REPEAT): + vn_bt_row = fx.Int32(bt_s * WMMA_K) + lane_m_base * fx.Int32(8) + tr_k_group + vn_v_col = fx.Int32(nr * 16) + tr_col_sub * fx.Int32(4) + vn_lds_elem = vn_bt_row * fx.Int32(LDS_VN_STRIDE) + vn_v_col + vn_lds_byte = vn_lds_elem * fx.Int32(2) + fx.Int32(lds_vn_offset) + + vn_lo = _ds_read_tr_bf16x4(vn_lds_byte) + vn_hi = _ds_read_tr_bf16x4(vn_lds_byte + fx.Int32(4 * BV * 2)) + vn_b_frag = vector.shuffle(vn_lo, vn_hi, [0, 1, 2, 3, 4, 5, 6, 7]) + + acc_idx = kb * N_REPEAT + nr + h_accs_in[acc_idx] = _mfma_bf16_16x16x32(k_a_frag, vn_b_frag, h_accs_in[acc_idx]) + + results = yield [_to_raw(v) for v in h_accs_in] + + h_accs_final = list(results) + + # ── Epilogue: store final state ── + if STORE_FINAL_STATE: + for kb in range_constexpr(NUM_K_BLOCKS): + for nr in range_constexpr(N_REPEAT): + acc_idx = kb * N_REPEAT + nr + acc_val = h_accs_final[acc_idx] + + ht_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n + for elem_i in range_constexpr(4): + f32_val = vector.extract(acc_val, static_position=[elem_i], dynamic_position=[]) + ht_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) + ht_off = ht_base + ht_row * fx.Int32(V) + ht_col + ht_[fx.Index(ht_off)] = f32_val + + # ── Host launcher ────────────────────────────────────────────────────── + @flyc.jit + def launch_gdn_h( + k_tensor: fx.Tensor, + v_tensor: fx.Tensor, + w_tensor: fx.Tensor, + v_new_tensor: fx.Tensor, + g_tensor: fx.Tensor, + h_tensor: fx.Tensor, + h0_tensor: fx.Tensor, + ht_tensor: fx.Tensor, + cu_seqlens_tensor: fx.Tensor, + chunk_offsets_tensor: fx.Tensor, + T_val: fx.Int32, + T_flat: fx.Int32, + N_val: fx.Int32, + grid_v: fx.Int32, + grid_nh: fx.Int32, + stream: fx.Stream, + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + launcher = gdn_h_kernel( + k_tensor, v_tensor, w_tensor, v_new_tensor, g_tensor, + h_tensor, h0_tensor, ht_tensor, + cu_seqlens_tensor, chunk_offsets_tensor, + T_val, T_flat, N_val, + ) + launcher.launch( + grid=(grid_v, grid_nh, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_gdn_h + + +# ── Python wrapper (matches Triton interface) ──────────────────────────── + +_compiled_kernels = {} +_autotune_cache = {} # (shape_key) -> best BV +_BV_CANDIDATES = [16, 32, 64] +_AUTOTUNE_WARMUP = 5 +_AUTOTUNE_ITERS = 25 + + +def _get_or_compile(K, V, BT, BV, H, Hg, use_g, use_h0, store_fs, save_vn, is_varlen, wu_contig): + cache_key = (K, V, BT, BV, H, Hg, use_g, use_h0, store_fs, save_vn, is_varlen, wu_contig) + if cache_key not in _compiled_kernels: + _compiled_kernels[cache_key] = compile_chunk_gated_delta_h( + K=K, V=V, BT=BT, BV=BV, H=H, Hg=Hg, + USE_G=use_g, USE_INITIAL_STATE=use_h0, + STORE_FINAL_STATE=store_fs, SAVE_NEW_VALUE=save_vn, + IS_VARLEN=is_varlen, WU_CONTIGUOUS=wu_contig, + ) + return _compiled_kernels[cache_key] + + +def _launch_kernel(launch_fn, BV, V, N, H, + k, u, w, vn_arg, g_arg, h, h0_arg, ht_arg, + cu_arg, co_arg, T, T_flat, stream): + grid_v = triton.cdiv(V, BV) + grid_nh = N * H + launch_fn( + k, u, w, vn_arg, g_arg, + h, h0_arg, ht_arg, + cu_arg, co_arg, + T, T_flat, N, + grid_v, grid_nh, + stream, + ) + + +def chunk_gated_delta_rule_fwd_h_flydsl( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + wu_contiguous: bool = True, + BV: int = 0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """FlyDSL K5 wrapper with wrapper-level autotune over BV.""" + B, T, Hg, K = k.shape + BT = chunk_size + + if wu_contiguous: + H = w.shape[1] + V = u.shape[-1] + T_flat = w.shape[2] + else: + H = u.shape[-2] + V = u.shape[-1] + T_flat = w.shape[1] + + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N = len(cu_seqlens) - 1 + lens = cu_seqlens[1:] - cu_seqlens[:-1] + NT = sum(triton.cdiv(int(l), BT) for l in lens.tolist()) + chunk_offsets = torch.cat([ + cu_seqlens.new_tensor([0]), + triton.cdiv(lens, BT), + ]).cumsum(-1).to(torch.int32) + + assert K <= 256 + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + v_new_buf = k.new_empty(B, H, T_flat, V, dtype=u.dtype) + v_new = v_new_buf if save_new_value else None + + dummy = torch.empty(1, device=k.device, dtype=torch.float32) + g_arg = g if g is not None else dummy + h0_arg = initial_state if initial_state is not None else dummy + ht_arg = final_state if final_state is not None else dummy + vn_arg = v_new_buf + cu_arg = cu_seqlens.to(torch.int32) if cu_seqlens is not None else dummy.to(torch.int32) + co_arg = chunk_offsets if chunk_offsets is not None else dummy.to(torch.int32) + stream = torch.cuda.current_stream() + + use_g = g is not None + use_h0 = initial_state is not None + is_varlen = cu_seqlens is not None + + # Resolve BV: explicit > autotune cache > benchmark + if BV <= 0: + shape_key = (K, V, BT, H, Hg, T_flat, N, + use_g, use_h0, output_final_state, + save_new_value, is_varlen, wu_contiguous) + + if shape_key in _autotune_cache: + BV = _autotune_cache[shape_key] + else: + candidates = [bv for bv in _BV_CANDIDATES if bv <= V and V % bv == 0] + if len(candidates) <= 1: + BV = candidates[0] if candidates else 16 + else: + print(f"[K5 autotune] benchmarking BV in {candidates} ...") + best_bv, best_us = candidates[0], float('inf') + for bv in candidates: + fn = _get_or_compile(K, V, BT, bv, H, Hg, + use_g, use_h0, output_final_state, + save_new_value, is_varlen, wu_contiguous) + # warmup + for _ in range(_AUTOTUNE_WARMUP): + _launch_kernel(fn, bv, V, N, H, + k, u, w, vn_arg, g_arg, h, h0_arg, ht_arg, + cu_arg, co_arg, T, T_flat, stream) + torch.cuda.synchronize() + # benchmark + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + for _ in range(_AUTOTUNE_ITERS): + _launch_kernel(fn, bv, V, N, H, + k, u, w, vn_arg, g_arg, h, h0_arg, ht_arg, + cu_arg, co_arg, T, T_flat, stream) + e.record() + torch.cuda.synchronize() + us = s.elapsed_time(e) / _AUTOTUNE_ITERS * 1000 + print(f" BV={bv:3d}: {us:.2f} us") + if us < best_us: + best_us = us + best_bv = bv + BV = best_bv + print(f"[K5 autotune] best BV={BV} ({best_us:.2f} us)") + _autotune_cache[shape_key] = BV + + launch_fn = _get_or_compile(K, V, BT, BV, H, Hg, + use_g, use_h0, output_final_state, + save_new_value, is_varlen, wu_contiguous) + _launch_kernel(launch_fn, BV, V, N, H, + k, u, w, vn_arg, g_arg, h, h0_arg, ht_arg, + cu_arg, co_arg, T, T_flat, stream) + + return h, v_new, final_state + + +__all__ = [ + "compile_chunk_gated_delta_h", + "chunk_gated_delta_rule_fwd_h_flydsl", +] diff --git a/tests/kernels/test_chunk_gated_delta_h.py b/tests/kernels/test_chunk_gated_delta_h.py new file mode 100644 index 000000000..7560702e6 --- /dev/null +++ b/tests/kernels/test_chunk_gated_delta_h.py @@ -0,0 +1,745 @@ +""" +Tests for FlyDSL K5: chunk_gated_delta_rule_fwd_h (GDN hidden-state recurrence) + +Correctness: compare FlyDSL kernel against a pure-PyTorch reference. +Performance: compare FlyDSL kernel against Triton opt3 kernel. +Rocprof: profile with rocprofv3 for accurate GPU kernel timing. + +Runtime parameters derived from Qwen3.5-397B-A17B TP=8 serving config: + K=128, V=128, Hk=16->Hg=2, Hv=64->H=8, BT=64 + max_num_batched_tokens=8192, full_prompt_len=8000 + +Usage: + cd /workspace/FlyDSL + python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s + python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s -k "Correct" + python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s -k "Perf" + python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s -k "Rocprof" + + # Direct rocprofv3 profiling (without pytest): + python3 tests/kernels/test_chunk_gated_delta_h.py --mode rocprof + python3 tests/kernels/test_chunk_gated_delta_h.py --mode rocprof --full-prompt-len 1000 +""" + +import argparse +import csv +import ctypes +import subprocess +import sys +import os +from ctypes.util import find_library +from pathlib import Path + +import pytest +import torch +import triton + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from kernels.chunk_gated_delta_h import chunk_gated_delta_rule_fwd_h_flydsl + +# ── Triton opt3 kernel (inlined, no external dependency) ──────────────── + +import functools +import triton.language as tl + +def _check_platform(): + try: + backend = triton.runtime.driver.active.get_current_target().backend + except (RuntimeError, AttributeError): + backend = "cpu" + return {"cuda": "nvidia", "hip": "amd", "xpu": "intel"}.get(backend, backend) + +_use_cuda_graph = _check_platform() == "nvidia" and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + +def _tensor_cache(fn): + cache_entries = [] + @functools.wraps(fn) + def wrapper(*args, **kwargs): + nonlocal cache_entries + for i, (la, lk, lr) in enumerate(cache_entries): + if len(args) == len(la) and all(a is b for a, b in zip(args, la)) \ + and len(kwargs) == len(lk) and all(k in lk and v is lk[k] for k, v in kwargs.items()): + cache_entries = cache_entries[:i] + cache_entries[i+1:] + [(la, lk, lr)] + return lr + result = fn(*args, **kwargs) + if len(cache_entries) >= 8: + cache_entries.pop(0) + cache_entries.append((args, kwargs, result)) + return result + return wrapper + +@_tensor_cache +def _prepare_lens(cu_seqlens): + return cu_seqlens[1:] - cu_seqlens[:-1] + +@_tensor_cache +def _prepare_chunk_indices(cu_seqlens, chunk_size): + indices = torch.cat([torch.arange(n) for n in triton.cdiv(_prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + +@_tensor_cache +def _prepare_chunk_offsets(cu_seqlens, chunk_size): + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(_prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) + +@triton.heuristics({ + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +@triton.autotune( + configs=[triton.Config({"BV": BV}, num_warps=nw, num_stages=ns) + for nw in [2, 4] for ns in [1, 2, 3, 4] for BV in [16, 32, 64]], + key=["H", "K", "V", "BT", "IS_VARLEN"], + use_cuda_graph=_use_cuda_graph, +) +@triton.jit(do_not_specialize=["T"]) +def _triton_fwd_kernel_h_opt3( + k, v, w, v_new, g, gk, h, h0, ht, + cu_seqlens, chunk_offsets, T, T_flat, + H: tl.constexpr, Hg: tl.constexpr, K: tl.constexpr, V: tl.constexpr, + BT: tl.constexpr, BV: tl.constexpr, + USE_G: tl.constexpr, USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, IS_VARLEN: tl.constexpr, + WU_CONTIGUOUS: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos = tl.load(cu_seqlens + i_n).to(tl.int32) + eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: b_h4 = tl.zeros([64, BV], dtype=tl.float32) + h += ((boh * H + i_h) * K * V).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + if WU_CONTIGUOUS: + if IS_VARLEN: + v += ((i_h * T_flat + bos) * V).to(tl.int64) + w += ((i_h * T_flat + bos) * K).to(tl.int64) + else: + v += (((i_n * H + i_h) * T_flat) * V).to(tl.int64) + w += (((i_n * H + i_h) * T_flat) * K).to(tl.int64) + stride_v, stride_w = V, K + else: + v += ((bos * H + i_h) * V).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + stride_v, stride_w = H * V, H * K + if SAVE_NEW_VALUE: + if IS_VARLEN: v_new += ((i_h * T_flat + bos) * V).to(tl.int64) + else: v_new += (((i_n * H + i_h) * T_flat) * V).to(tl.int64) + stride_h, stride_k = H * K * V, Hg * K + if USE_INITIAL_STATE: h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: ht = ht + i_nh * K * V + if USE_INITIAL_STATE: + b_h1 += tl.load(tl.make_block_ptr(h0, (K,V),(V,1),(0,i_v*BV),(64,BV),(1,0)), boundary_check=(0,1)).to(tl.float32) + if K > 64: b_h2 += tl.load(tl.make_block_ptr(h0,(K,V),(V,1),(64,i_v*BV),(64,BV),(1,0)), boundary_check=(0,1)).to(tl.float32) + if K > 128: b_h3 += tl.load(tl.make_block_ptr(h0,(K,V),(V,1),(128,i_v*BV),(64,BV),(1,0)), boundary_check=(0,1)).to(tl.float32) + if K > 192: b_h4 += tl.load(tl.make_block_ptr(h0,(K,V),(V,1),(192,i_v*BV),(64,BV),(1,0)), boundary_check=(0,1)).to(tl.float32) + for i_t in range(NT): + tl.store(tl.make_block_ptr(h+i_t*stride_h,(K,V),(V,1),(0,i_v*BV),(64,BV),(1,0)), b_h1.to(tl.bfloat16), boundary_check=(0,1)) + if K > 64: tl.store(tl.make_block_ptr(h+i_t*stride_h,(K,V),(V,1),(64,i_v*BV),(64,BV),(1,0)), b_h2.to(tl.bfloat16), boundary_check=(0,1)) + if K > 128: tl.store(tl.make_block_ptr(h+i_t*stride_h,(K,V),(V,1),(128,i_v*BV),(64,BV),(1,0)), b_h3.to(tl.bfloat16), boundary_check=(0,1)) + if K > 192: tl.store(tl.make_block_ptr(h+i_t*stride_h,(K,V),(V,1),(192,i_v*BV),(64,BV),(1,0)), b_h4.to(tl.bfloat16), boundary_check=(0,1)) + p_w = tl.make_block_ptr(w,(T,K),(stride_w,1),(i_t*BT,0),(BT,64),(1,0)) + b_v = tl.dot(tl.load(p_w, boundary_check=(0,1)), b_h1.to(tl.bfloat16)) + if K > 64: b_v += tl.dot(tl.load(tl.make_block_ptr(w,(T,K),(stride_w,1),(i_t*BT,64),(BT,64),(1,0)), boundary_check=(0,1)), b_h2.to(tl.bfloat16)) + if K > 128: b_v += tl.dot(tl.load(tl.make_block_ptr(w,(T,K),(stride_w,1),(i_t*BT,128),(BT,64),(1,0)), boundary_check=(0,1)), b_h3.to(tl.bfloat16)) + if K > 192: b_v += tl.dot(tl.load(tl.make_block_ptr(w,(T,K),(stride_w,1),(i_t*BT,192),(BT,64),(1,0)), boundary_check=(0,1)), b_h4.to(tl.bfloat16)) + b_v = tl.load(tl.make_block_ptr(v,(T,V),(stride_v,1),(i_t*BT,i_v*BV),(BT,BV),(1,0)), boundary_check=(0,1)) - b_v + if SAVE_NEW_VALUE: + tl.store(tl.make_block_ptr(v_new,(T,V),(V,1),(i_t*BT,i_v*BV),(BT,BV),(1,0)), b_v.to(tl.bfloat16), boundary_check=(0,1)) + last_idx = min((i_t+1)*BT, T) - 1 + if USE_G: + m_t = (i_t*BT + tl.arange(0, BT)) < T + b_g_last = tl.load(g + bos*H + last_idx*H + i_h) + b_g = tl.load(tl.make_block_ptr(g+bos*H+i_h,(T,),(H,),(i_t*BT,),(BT,),(0,)), boundary_check=(0,)) + b_v = b_v * tl.where(m_t, tl.exp(b_g_last - b_g), 0)[:, None] + b_g_last = tl.exp(b_g_last) + b_h1 *= b_g_last + if K > 64: b_h2 *= b_g_last + if K > 128: b_h3 *= b_g_last + if K > 192: b_h4 *= b_g_last + if USE_GK: + o_k1 = tl.arange(0, 64) + b_h1 *= tl.exp(tl.load(gk+(bos+last_idx)*H*K+i_h*K+o_k1, mask=(o_k1 64: b_h2 *= tl.exp(tl.load(gk+(bos+last_idx)*H*K+i_h*K+64+o_k1, mask=(64+o_k1 128: b_h3 *= tl.exp(tl.load(gk+(bos+last_idx)*H*K+i_h*K+128+o_k1, mask=(128+o_k1 192: b_h4 *= tl.exp(tl.load(gk+(bos+last_idx)*H*K+i_h*K+192+o_k1, mask=(192+o_k1 64: b_h2 += tl.dot(tl.load(tl.make_block_ptr(k,(K,T),(1,stride_k),(64,i_t*BT),(64,BT),(0,1)), boundary_check=(0,1)), b_v) + if K > 128: b_h3 += tl.dot(tl.load(tl.make_block_ptr(k,(K,T),(1,stride_k),(128,i_t*BT),(64,BT),(0,1)), boundary_check=(0,1)), b_v) + if K > 192: b_h4 += tl.dot(tl.load(tl.make_block_ptr(k,(K,T),(1,stride_k),(192,i_t*BT),(64,BT),(0,1)), boundary_check=(0,1)), b_v) + if STORE_FINAL_STATE: + tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(0,i_v*BV),(64,BV),(1,0)), b_h1.to(tl.float32), boundary_check=(0,1)) + if K > 64: tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(64,i_v*BV),(64,BV),(1,0)), b_h2.to(tl.float32), boundary_check=(0,1)) + if K > 128: tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(128,i_v*BV),(64,BV),(1,0)), b_h3.to(tl.float32), boundary_check=(0,1)) + if K > 192: tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(192,i_v*BV),(64,BV),(1,0)), b_h4.to(tl.float32), boundary_check=(0,1)) + +def fwd_h_triton_opt3( + k, w, u, g=None, gk=None, initial_state=None, + output_final_state=False, chunk_size=64, save_new_value=True, + cu_seqlens=None, wu_contiguous=False, +): + B, T, Hg, K = k.shape + BT = chunk_size + if wu_contiguous: + H, V, T_flat = w.shape[1], u.shape[-1], w.shape[2] + else: + H, V, T_flat = u.shape[-2], u.shape[-1], w.shape[1] + chunk_indices = _prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N = len(cu_seqlens) - 1 + NT = len(chunk_indices) + chunk_offsets = _prepare_chunk_offsets(cu_seqlens, BT) + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + v_new = k.new_empty(B, H, T_flat, V, dtype=u.dtype) if save_new_value else None + _triton_fwd_kernel_h_opt3[(lambda meta: (triton.cdiv(V, meta["BV"]), N * H))]( + k=k, v=u, w=w, v_new=v_new, g=g, gk=gk, + h=h, h0=initial_state, ht=final_state, + cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets, + T=T, T_flat=T_flat, H=H, Hg=Hg, K=K, V=V, BT=BT, + WU_CONTIGUOUS=wu_contiguous, + ) + return h, v_new, final_state + + + +# ── Global test configuration ────────────────────────────────────────── + +K = 128 +V = 128 +Hg = 2 +H = 8 +BT = 64 + +MAX_NUM_BATCHED_TOKENS = 8192 +FULL_PROMPT_LENS = [8000] + +NUM_WARMUP = 10 +NUM_ITERS = 200 + + +def _build_context_lens(full_prompt_len, max_tokens=MAX_NUM_BATCHED_TOKENS): + context_lens = [] + remaining = max_tokens + while remaining > 0: + cur = min(full_prompt_len, remaining) + context_lens.append(cur) + remaining -= cur + return context_lens + + +def _build_cu_seqlens(context_lens, device="cuda"): + scheduled_q_lens = context_lens + cu_seqlens = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(scheduled_q_lens), 0).tolist()), + dtype=torch.int32, + device=device, + ) + return scheduled_q_lens, cu_seqlens + + +def _make_inputs(context_lens, dtype=torch.bfloat16, device="cuda", + with_initial_state=True): + scheduled_q_lens, cu_seqlens = _build_cu_seqlens(context_lens, device=device) + T_total = int(cu_seqlens[-1].item()) + N = len(scheduled_q_lens) + B = 1 + + k = torch.randn(B, T_total, Hg, K, dtype=dtype, device=device) * 0.1 + w_orig = torch.randn(B, T_total, H, K, dtype=dtype, device=device) * 0.1 + u_orig = torch.randn(B, T_total, H, V, dtype=dtype, device=device) * 0.1 + g = torch.randn(T_total, H, dtype=torch.float32, device=device).abs() * -0.5 + g = g.cumsum(dim=0) + + w_c = w_orig.permute(0, 2, 1, 3).contiguous() + u_c = u_orig.permute(0, 2, 1, 3).contiguous() + + initial_state = None + if with_initial_state: + initial_state = torch.randn(N, H, K, V, dtype=torch.float32, device=device) * 0.01 + + return k, w_orig, u_orig, w_c, u_c, g, initial_state, cu_seqlens, scheduled_q_lens + + +# ── Pure-PyTorch reference ────────────────────────────────────────────── + +def ref_chunk_gated_delta_rule_fwd_h( + k, w, u, g, + initial_state=None, + output_final_state=False, + chunk_size=64, + cu_seqlens=None, +): + """Reference in FP32 for correctness checking.""" + B, T, Hg_dim, K_dim = k.shape + H_dim, V_dim = u.shape[-2], u.shape[-1] + BT_dim = chunk_size + if cu_seqlens is None: + NT = triton.cdiv(T, BT_dim) + else: + seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + NT = sum(triton.cdiv(int(seq_len), BT_dim) for seq_len in seq_lens) + gqa_ratio = H_dim // Hg_dim + + h_out = k.new_zeros(B, NT, H_dim, K_dim, V_dim, dtype=torch.float32) + v_new_out = torch.zeros_like(u, dtype=torch.float32) + + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + final_state = torch.zeros(N, H_dim, K_dim, V_dim, dtype=torch.float32, + device=k.device) if output_final_state else None + + for b_idx in range(B): + if cu_seqlens is not None: + seqs = [(s, cu_seqlens[s].item(), cu_seqlens[s + 1].item()) + for s in range(N)] + else: + seqs = [(b_idx, 0, T)] + + chunk_offset = 0 + for seq_idx, bos, eos in seqs: + seq_len = eos - bos + seq_nt = triton.cdiv(seq_len, BT_dim) + + for i_h in range(H_dim): + i_hg = i_h // gqa_ratio + h_state = torch.zeros(K_dim, V_dim, dtype=torch.float32, + device=k.device) + if initial_state is not None: + h_state = initial_state[seq_idx, i_h].float().clone() + + for i_t in range(seq_nt): + t_start = i_t * BT_dim + t_end = min(t_start + BT_dim, seq_len) + actual_bt = t_end - t_start + + h_out[b_idx, chunk_offset + i_t, i_h] = h_state.clone() + + w_chunk = w[b_idx, bos + t_start:bos + t_end, i_h].float() + u_chunk = u[b_idx, bos + t_start:bos + t_end, i_h].float() + b_v = u_chunk - w_chunk @ h_state + v_new_out[b_idx, bos + t_start:bos + t_end, i_h] = b_v + + last_idx = bos + t_end - 1 + g_last = g[last_idx, i_h].float() + g_chunk = g[bos + t_start:bos + t_end, i_h].float() + + mask = torch.zeros(BT_dim, device=k.device) + mask[:actual_bt] = 1.0 + gate = torch.where( + mask[:actual_bt].bool(), + torch.exp(g_last - g_chunk), + torch.zeros_like(g_chunk), + ) + b_v_gated = b_v * gate.unsqueeze(-1) + + h_state = h_state * torch.exp(g_last) + k_chunk = k[b_idx, bos + t_start:bos + t_end, i_hg].float() + b_v_gated_cast = b_v_gated.to(k.dtype).float() + h_state = h_state + k_chunk.T @ b_v_gated_cast + + if output_final_state: + final_state[seq_idx, i_h] = h_state + + chunk_offset += seq_nt + + return h_out, v_new_out.to(u.dtype), final_state + + +def _normalize_opt_v_new(vn_opt): + """Convert opt v_new layout [B, H, T, V] back to [B, T, H, V].""" + return vn_opt.permute(0, 2, 1, 3).contiguous() + + +# ── Correctness tests ─────────────────────────────────────────────────── + +class TestCorrectness: + """Correctness against PyTorch reference.""" + + @pytest.mark.parametrize("full_prompt_len", FULL_PROMPT_LENS) + def test_correctness_flydsl(self, full_prompt_len): + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + + h_fly, vn_fly, fs_fly = chunk_gated_delta_rule_fwd_h_flydsl( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + h_ref, vn_ref, fs_ref = ref_chunk_gated_delta_rule_fwd_h( + k, w_orig, u_orig, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) + + torch.testing.assert_close( + h_fly.float(), h_ref.float(), atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + _normalize_opt_v_new(vn_fly).float(), vn_ref.float(), + atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + fs_fly.float(), fs_ref.float(), atol=1e-1, rtol=1e-1) + + @pytest.mark.parametrize("full_prompt_len", FULL_PROMPT_LENS) + def test_correctness_triton_opt3(self, full_prompt_len): + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + + h_tri, vn_tri, fs_tri = fwd_h_triton_opt3( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + h_ref, vn_ref, fs_ref = ref_chunk_gated_delta_rule_fwd_h( + k, w_orig, u_orig, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) + + torch.testing.assert_close( + h_tri.float(), h_ref.float(), atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + _normalize_opt_v_new(vn_tri).float(), vn_ref.float(), + atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + fs_tri.float(), fs_ref.float(), atol=1e-1, rtol=1e-1) + + @pytest.mark.parametrize("full_prompt_len", FULL_PROMPT_LENS) + def test_correctness_flydsl_vs_triton(self, full_prompt_len): + """Direct comparison between FlyDSL and Triton opt3 kernels.""" + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + + h_fly, vn_fly, fs_fly = chunk_gated_delta_rule_fwd_h_flydsl( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + h_tri, vn_tri, fs_tri = fwd_h_triton_opt3( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + + h_fly_f, h_tri_f = h_fly.float(), h_tri.float() + vn_fly_f, vn_tri_f = vn_fly.float(), vn_tri.float() + fs_fly_f, fs_tri_f = fs_fly.float(), fs_tri.float() + + def _report(name, a, b): + diff = (a - b).abs() + diff_flat = diff.flatten() + sorted_diff, _ = diff_flat.sort() + n = sorted_diff.numel() + median_val = sorted_diff[n // 2].item() + p99_val = sorted_diff[min(int(n * 0.99), n - 1)].item() + print(f" {name}:") + print(f" FlyDSL range: [{a.min().item():.6f}, {a.max().item():.6f}]") + print(f" Triton range: [{b.min().item():.6f}, {b.max().item():.6f}]") + print(f" abs_err max={diff.max().item():.6f} " + f"mean={diff.mean().item():.6f} " + f"median={median_val:.6f} " + f"p99={p99_val:.6f}") + + print(f"\n[FlyDSL vs Triton opt3 full_prompt_len={full_prompt_len}]") + _report("h", h_fly_f, h_tri_f) + _report("v_new", vn_fly_f, vn_tri_f) + _report("final_state", fs_fly_f, fs_tri_f) + + torch.testing.assert_close(h_fly_f, h_tri_f, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(vn_fly_f, vn_tri_f, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(fs_fly_f, fs_tri_f, atol=1e-1, rtol=1e-1) + + +# ── Performance tests ─────────────────────────────────────────────────── + +def _bench_fn(fn, *args, **kwargs): + """Warmup + measure, return average us.""" + fn(*args, **kwargs) + torch.cuda.synchronize() + for _ in range(NUM_WARMUP): + fn(*args, **kwargs) + torch.cuda.synchronize() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + s.record() + for _ in range(NUM_ITERS): + fn(*args, **kwargs) + e.record() + torch.cuda.synchronize() + return s.elapsed_time(e) / NUM_ITERS * 1000 + + +PERF_SHAPES = [ + pytest.param(fpl, id=f"full{fpl}") + for fpl in FULL_PROMPT_LENS +] + + +class TestPerformance: + """Performance comparison: FlyDSL vs Triton opt3.""" + + @pytest.mark.parametrize("full_prompt_len", PERF_SHAPES) + def test_perf_comparison(self, full_prompt_len): + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, scheduled_q_lens = _make_inputs( + context_lens) + total_tokens = int(cu[-1].item()) + + # FlyDSL kernel + us_fly = _bench_fn( + chunk_gated_delta_rule_fwd_h_flydsl, + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + + print(f"\n[K5 FlyDSL T={total_tokens}] {us_fly:.2f} us") + + # Triton opt3 kernel for comparison + us_triton = _bench_fn( + fwd_h_triton_opt3, + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + speedup = us_triton / us_fly if us_fly > 0 else float('inf') + print(f"[K5 Triton opt3 T={total_tokens}] {us_triton:.2f} us") + print(f"[Speedup FlyDSL/Triton] {speedup:.3f}x") + + +# ── rocprofv3 profiling infrastructure ────────────────────────────────── + +TARGET_KERNEL_FLYDSL = "chunk_gdn_fwd_h_opt3" +TARGET_KERNEL_TRITON = "_triton_fwd_kernel_h_opt3" + + +def _load_roctx_library(): + """Load the roctx shared library for profiler pause/resume control.""" + for candidate in ("rocprofiler-sdk-roctx", "roctx64"): + libname = find_library(candidate) + if libname is None: + continue + lib = ctypes.CDLL(libname) + lib.roctxGetThreadId.argtypes = [ctypes.POINTER(ctypes.c_uint64)] + lib.roctxGetThreadId.restype = None + lib.roctxProfilerPause.argtypes = [ctypes.c_uint64] + lib.roctxProfilerPause.restype = None + lib.roctxProfilerResume.argtypes = [ctypes.c_uint64] + lib.roctxProfilerResume.restype = None + lib.roctxRangePushA.argtypes = [ctypes.c_char_p] + lib.roctxRangePushA.restype = ctypes.c_int + lib.roctxRangePop.argtypes = [] + lib.roctxRangePop.restype = ctypes.c_int + return lib + return None + + +def _roctx_thread_id(lib): + tid = ctypes.c_uint64() + lib.roctxGetThreadId(ctypes.byref(tid)) + return int(tid.value) + + +def _rocprof_worker(full_prompt_len): + """Inner worker: runs under rocprofv3 --selected-regions. + + Profiling starts paused. We warmup both kernels, then + Resume -> measured iterations -> Pause for each kernel sequentially. + """ + roctx = _load_roctx_library() + if roctx is None: + raise RuntimeError("roctx library not found; cannot run as profiling worker") + + tid = _roctx_thread_id(roctx) + + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + total_tokens = int(cu[-1].item()) + + run_fly = lambda: chunk_gated_delta_rule_fwd_h_flydsl( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + + # Warmup FlyDSL (paused) + print(f"[rocprof-worker] Warmup FlyDSL (T={total_tokens}) ...", flush=True) + for _ in range(NUM_WARMUP): + run_fly() + torch.cuda.synchronize() + + # Measure FlyDSL + roctx.roctxProfilerResume(tid) + roctx.roctxRangePushA(b"flydsl_k5_bench") + for _ in range(NUM_ITERS): + run_fly() + torch.cuda.synchronize() + roctx.roctxRangePop() + roctx.roctxProfilerPause(tid) + print(f"[rocprof-worker] FlyDSL: {NUM_ITERS} iterations done", flush=True) + + # Triton opt3 + run_tri = lambda: fwd_h_triton_opt3( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + + print(f"[rocprof-worker] Warmup Triton opt3 ...", flush=True) + for _ in range(NUM_WARMUP): + run_tri() + torch.cuda.synchronize() + + roctx.roctxProfilerResume(tid) + roctx.roctxRangePushA(b"triton_k5_bench") + for _ in range(NUM_ITERS): + run_tri() + torch.cuda.synchronize() + roctx.roctxRangePop() + roctx.roctxProfilerPause(tid) + print(f"[rocprof-worker] Triton: {NUM_ITERS} iterations done", flush=True) + + +def _parse_kernel_stats(stats_path: Path) -> dict[str, dict]: + """Parse kernel_stats CSV -> {name: {AverageNs, TotalDurationNs, Calls, ...}}.""" + result = {} + with stats_path.open(newline="") as f: + for row in csv.DictReader(f): + result[row["Name"]] = row + return result + + +def _print_rocprof_summary(stats_path: Path, total_tokens: int): + """Print a formatted summary from rocprofv3 kernel_stats CSV.""" + stats = _parse_kernel_stats(stats_path) + + targets = [ + ("FlyDSL", TARGET_KERNEL_FLYDSL), + ("Triton opt3", TARGET_KERNEL_TRITON), + ] + + results = {} + for label, kname in targets: + entry = stats.get(kname) + if entry is None: + for name in stats: + if kname in name: + entry = stats[name] + break + if entry is None: + print(f" {label}: kernel '{kname}' not found in stats") + continue + + avg_ns = float(entry["AverageNs"]) + min_ns = float(entry["MinNs"]) + max_ns = float(entry["MaxNs"]) + calls = int(entry["Calls"]) + total_ns = float(entry["TotalDurationNs"]) + results[label] = avg_ns + + print(f" {label} ({kname}):") + print(f" Calls: {calls}") + print(f" Average: {avg_ns / 1000:.2f} us ({avg_ns:.0f} ns)") + print(f" Min: {min_ns / 1000:.2f} us") + print(f" Max: {max_ns / 1000:.2f} us") + print(f" Total: {total_ns / 1e6:.2f} ms") + + if "FlyDSL" in results and "Triton opt3" in results: + speedup = results["Triton opt3"] / results["FlyDSL"] + print(f"\n Speedup (FlyDSL vs Triton): {speedup:.3f}x") + + if not stats: + print(" WARNING: no kernels found in stats file") + elif not results: + print(" Available kernels:") + for name in sorted(stats.keys()): + print(f" {name}") + + +def _do_rocprof(full_prompt_len): + """Outer driver: launches rocprofv3 wrapping this script in --_rocprof-worker mode.""" + repo_root = Path(__file__).resolve().parent.parent.parent + output_dir = repo_root / "rocprof_output" + output_dir.mkdir(parents=True, exist_ok=True) + output_stem = f"gdn_k5_fpl{full_prompt_len}" + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + + inner_cmd = [ + "python3", "-u", str(Path(__file__).resolve()), + "--_rocprof-worker", + "--full-prompt-len", str(full_prompt_len), + ] + rocprof_cmd = [ + "rocprofv3", + "--kernel-trace", + "--marker-trace", + "--output-format", "csv", + "-d", str(output_dir), + "-o", output_stem, + "--stats", + "--selected-regions", + "--", *inner_cmd, + ] + + context_lens = _build_context_lens(full_prompt_len) + total_tokens = sum(context_lens) + + print(f"\n[rocprof] full_prompt_len={full_prompt_len}, T={total_tokens}") + print(f"[rocprof] cmd: {' '.join(rocprof_cmd)}", flush=True) + result = subprocess.run(rocprof_cmd, cwd=repo_root, env=env) + + stats_path = output_dir / f"{output_stem}_kernel_stats.csv" + if stats_path.exists(): + print(f"\n[rocprof] Results (full_prompt_len={full_prompt_len}, T={total_tokens}):") + _print_rocprof_summary(stats_path, total_tokens) + else: + print(f"[rocprof] kernel stats not found: {stats_path}", flush=True) + trace_path = output_dir / f"{output_stem}_kernel_trace.csv" + if trace_path.exists(): + print(f"[rocprof] trace file exists: {trace_path}") + + if result.returncode != 0: + print(f"[rocprof] rocprofv3 exited with code {result.returncode}", flush=True) + + return result.returncode + + +# ── rocprofv3 pytest tests ───────────────────────────────────────────── + +class TestRocprof: + """Profile FlyDSL and Triton kernels with rocprofv3.""" + + @pytest.mark.parametrize("full_prompt_len", PERF_SHAPES) + def test_rocprof(self, full_prompt_len): + rc = _do_rocprof(full_prompt_len) + assert rc == 0, f"rocprofv3 exited with code {rc}" + + +# ── Main ──────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GDN K5 test / profile") + parser.add_argument("--mode", choices=["test", "rocprof"], default="test", + help="test=pytest (default), rocprof=rocprofv3 profiling") + parser.add_argument("--full-prompt-len", type=int, default=8000) + parser.add_argument("--_rocprof-worker", action="store_true", + help=argparse.SUPPRESS) + args = parser.parse_args() + + if args._rocprof_worker: + _rocprof_worker(args.full_prompt_len) + elif args.mode == "rocprof": + _do_rocprof(args.full_prompt_len) + else: + pytest.main([__file__, "-v", "-s"]) From 316a31c35ccc20f7afa65c95f4c261b473edb920 Mon Sep 17 00:00:00 2001 From: huizzhan Date: Thu, 16 Apr 2026 12:36:24 +0000 Subject: [PATCH 2/8] refine test --- tests/kernels/test_chunk_gated_delta_h.py | 299 +++++++++++++++++----- 1 file changed, 229 insertions(+), 70 deletions(-) diff --git a/tests/kernels/test_chunk_gated_delta_h.py b/tests/kernels/test_chunk_gated_delta_h.py index 7560702e6..2eb0292c4 100644 --- a/tests/kernels/test_chunk_gated_delta_h.py +++ b/tests/kernels/test_chunk_gated_delta_h.py @@ -5,9 +5,10 @@ Performance: compare FlyDSL kernel against Triton opt3 kernel. Rocprof: profile with rocprofv3 for accurate GPU kernel timing. -Runtime parameters derived from Qwen3.5-397B-A17B TP=8 serving config: - K=128, V=128, Hk=16->Hg=2, Hv=64->H=8, BT=64 - max_num_batched_tokens=8192, full_prompt_len=8000 +Runtime parameters derived from Qwen3.5-397B-A17B serving config: + K=128, V=128, Hk=16, Hv=32, BT=64 + TP_LIST=[1,4] -> Hg=Hk/TP, H=Hv/TP (parametrized per test) + max_num_batched_tokens=32768 Usage: cd /workspace/FlyDSL @@ -36,7 +37,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) -from kernels.chunk_gated_delta_h import chunk_gated_delta_rule_fwd_h_flydsl +from kernels.chunk_gated_delta_h import chunk_gated_delta_rule_fwd_h_flydsl, _autotune_cache as _flydsl_autotune_cache # ── Triton opt3 kernel (inlined, no external dependency) ──────────────── @@ -223,15 +224,24 @@ def fwd_h_triton_opt3( # ── Global test configuration ────────────────────────────────────────── - +# # Qwen3 params +# K = 128 +# V = 128 +# Hk = 16 +# Hv = 32 +# TP_LIST = [1, 4] +# BT = 64 + +# Qwen3.5 params K = 128 V = 128 -Hg = 2 -H = 8 +Hk = 16 +Hv = 64 +TP_LIST = [4, 8] BT = 64 -MAX_NUM_BATCHED_TOKENS = 8192 -FULL_PROMPT_LENS = [8000] +MAX_NUM_BATCHED_TOKENS = 32768 +FULL_PROMPT_LENS = [1024, 2048, 4096, 8192] NUM_WARMUP = 10 NUM_ITERS = 200 @@ -257,8 +267,10 @@ def _build_cu_seqlens(context_lens, device="cuda"): return scheduled_q_lens, cu_seqlens -def _make_inputs(context_lens, dtype=torch.bfloat16, device="cuda", +def _make_inputs(context_lens, tp=1, dtype=torch.bfloat16, device="cuda", with_initial_state=True): + Hg = Hk // tp + H = Hv // tp scheduled_q_lens, cu_seqlens = _build_cu_seqlens(context_lens, device=device) T_total = int(cu_seqlens[-1].item()) N = len(scheduled_q_lens) @@ -369,15 +381,22 @@ def _normalize_opt_v_new(vn_opt): return vn_opt.permute(0, 2, 1, 3).contiguous() +PERF_SHAPES = [ + pytest.param(tp, fpl, id=f"TP{tp}_full{fpl}") + for tp in TP_LIST + for fpl in FULL_PROMPT_LENS +] + + # ── Correctness tests ─────────────────────────────────────────────────── class TestCorrectness: """Correctness against PyTorch reference.""" - @pytest.mark.parametrize("full_prompt_len", FULL_PROMPT_LENS) - def test_correctness_flydsl(self, full_prompt_len): + @pytest.mark.parametrize("tp, full_prompt_len", PERF_SHAPES) + def test_correctness_flydsl(self, tp, full_prompt_len): context_lens = _build_context_lens(full_prompt_len) - k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) h_fly, vn_fly, fs_fly = chunk_gated_delta_rule_fwd_h_flydsl( k, w_c, u_c, g=g, initial_state=h0, @@ -396,10 +415,10 @@ def test_correctness_flydsl(self, full_prompt_len): torch.testing.assert_close( fs_fly.float(), fs_ref.float(), atol=1e-1, rtol=1e-1) - @pytest.mark.parametrize("full_prompt_len", FULL_PROMPT_LENS) - def test_correctness_triton_opt3(self, full_prompt_len): + @pytest.mark.parametrize("tp, full_prompt_len", PERF_SHAPES) + def test_correctness_triton_opt3(self, tp, full_prompt_len): context_lens = _build_context_lens(full_prompt_len) - k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) h_tri, vn_tri, fs_tri = fwd_h_triton_opt3( k, w_c, u_c, g=g, initial_state=h0, @@ -418,11 +437,11 @@ def test_correctness_triton_opt3(self, full_prompt_len): torch.testing.assert_close( fs_tri.float(), fs_ref.float(), atol=1e-1, rtol=1e-1) - @pytest.mark.parametrize("full_prompt_len", FULL_PROMPT_LENS) - def test_correctness_flydsl_vs_triton(self, full_prompt_len): + @pytest.mark.parametrize("tp, full_prompt_len", PERF_SHAPES) + def test_correctness_flydsl_vs_triton(self, tp, full_prompt_len): """Direct comparison between FlyDSL and Triton opt3 kernels.""" context_lens = _build_context_lens(full_prompt_len) - k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) h_fly, vn_fly, fs_fly = chunk_gated_delta_rule_fwd_h_flydsl( k, w_c, u_c, g=g, initial_state=h0, @@ -452,7 +471,7 @@ def _report(name, a, b): f"median={median_val:.6f} " f"p99={p99_val:.6f}") - print(f"\n[FlyDSL vs Triton opt3 full_prompt_len={full_prompt_len}]") + print(f"\n[FlyDSL vs Triton opt3 TP={tp} full_prompt_len={full_prompt_len}]") _report("h", h_fly_f, h_tri_f) _report("v_new", vn_fly_f, vn_tri_f) _report("final_state", fs_fly_f, fs_tri_f) @@ -462,6 +481,31 @@ def _report(name, a, b): torch.testing.assert_close(fs_fly_f, fs_tri_f, atol=1e-1, rtol=1e-1) +# ── Best-config helpers ──────────────────────────────────────────────── + +def _get_flydsl_best_config() -> str: + """Return the last cached FlyDSL autotune result as a short string.""" + if not _flydsl_autotune_cache: + return "N/A" + bv = list(_flydsl_autotune_cache.values())[-1] + return f"BV={bv}" + + +def _get_triton_best_config() -> str: + """Return the Triton autotune best config as a short string.""" + try: + kernel = _triton_fwd_kernel_h_opt3 + autotuner = kernel.fn if hasattr(kernel, "fn") else kernel + cfg = autotuner.best_config + kw = cfg.kwargs + bv = kw.get("BV", "?") + nw = cfg.num_warps + ns = cfg.num_stages + return f"BV={bv},nw={nw},ns={ns}" + except (AttributeError, TypeError): + return "N/A" + + # ── Performance tests ─────────────────────────────────────────────────── def _bench_fn(fn, *args, **kwargs): @@ -482,41 +526,57 @@ def _bench_fn(fn, *args, **kwargs): return s.elapsed_time(e) / NUM_ITERS * 1000 -PERF_SHAPES = [ - pytest.param(fpl, id=f"full{fpl}") - for fpl in FULL_PROMPT_LENS -] - - class TestPerformance: """Performance comparison: FlyDSL vs Triton opt3.""" - @pytest.mark.parametrize("full_prompt_len", PERF_SHAPES) - def test_perf_comparison(self, full_prompt_len): + _results: list[dict] = [] + + @pytest.mark.parametrize("tp, full_prompt_len", PERF_SHAPES) + def test_perf_comparison(self, tp, full_prompt_len): context_lens = _build_context_lens(full_prompt_len) k, w_orig, u_orig, w_c, u_c, g, h0, cu, scheduled_q_lens = _make_inputs( - context_lens) + context_lens, tp=tp) total_tokens = int(cu[-1].item()) + num_seqs = len(context_lens) + Hg = Hk // tp + H = Hv // tp - # FlyDSL kernel us_fly = _bench_fn( chunk_gated_delta_rule_fwd_h_flydsl, k, w_c, u_c, g=g, initial_state=h0, output_final_state=True, cu_seqlens=cu, wu_contiguous=True, ) - print(f"\n[K5 FlyDSL T={total_tokens}] {us_fly:.2f} us") + print(f"\n[K5 FlyDSL TP={tp} Hg={Hg} H={H} T={total_tokens}] {us_fly:.2f} us") - # Triton opt3 kernel for comparison us_triton = _bench_fn( fwd_h_triton_opt3, k, w_c, u_c, g=g, initial_state=h0, output_final_state=True, cu_seqlens=cu, wu_contiguous=True, ) speedup = us_triton / us_fly if us_fly > 0 else float('inf') - print(f"[K5 Triton opt3 T={total_tokens}] {us_triton:.2f} us") + print(f"[K5 Triton opt3 TP={tp} T={total_tokens}] {us_triton:.2f} us") print(f"[Speedup FlyDSL/Triton] {speedup:.3f}x") + TestPerformance._results.append({ + "tp": tp, + "Hg": Hg, + "H": H, + "full_prompt_len": full_prompt_len, + "num_seqs": num_seqs, + "flydsl_us": us_fly, + "flydsl_cfg": _get_flydsl_best_config(), + "triton_us": us_triton, + "triton_cfg": _get_triton_best_config(), + "speedup": speedup, + }) + + @pytest.fixture(autouse=True, scope="class") + def _print_table(self): + TestPerformance._results = [] + yield + _print_summary_table(TestPerformance._results, title="CUDA Event Performance Summary") + # ── rocprofv3 profiling infrastructure ────────────────────────────────── @@ -551,7 +611,7 @@ def _roctx_thread_id(lib): return int(tid.value) -def _rocprof_worker(full_prompt_len): +def _rocprof_worker(full_prompt_len, tp=1): """Inner worker: runs under rocprofv3 --selected-regions. Profiling starts paused. We warmup both kernels, then @@ -564,7 +624,7 @@ def _rocprof_worker(full_prompt_len): tid = _roctx_thread_id(roctx) context_lens = _build_context_lens(full_prompt_len) - k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) total_tokens = int(cu[-1].item()) run_fly = lambda: chunk_gated_delta_rule_fwd_h_flydsl( @@ -618,8 +678,27 @@ def _parse_kernel_stats(stats_path: Path) -> dict[str, dict]: return result -def _print_rocprof_summary(stats_path: Path, total_tokens: int): - """Print a formatted summary from rocprofv3 kernel_stats CSV.""" +def _extract_kernel_us(stats: dict, kname: str) -> dict | None: + """Find a kernel entry in parsed stats and return timing dict, or None.""" + entry = stats.get(kname) + if entry is None: + for name in stats: + if kname in name: + entry = stats[name] + break + if entry is None: + return None + return { + "avg_us": float(entry["AverageNs"]) / 1000, + "min_us": float(entry["MinNs"]) / 1000, + "max_us": float(entry["MaxNs"]) / 1000, + "calls": int(entry["Calls"]), + "total_ms": float(entry["TotalDurationNs"]) / 1e6, + } + + +def _print_rocprof_summary(stats_path: Path, total_tokens: int) -> dict | None: + """Print a formatted summary and return {flydsl_us, triton_us, speedup} or None.""" stats = _parse_kernel_stats(stats_path) targets = [ @@ -629,33 +708,27 @@ def _print_rocprof_summary(stats_path: Path, total_tokens: int): results = {} for label, kname in targets: - entry = stats.get(kname) - if entry is None: - for name in stats: - if kname in name: - entry = stats[name] - break - if entry is None: + t = _extract_kernel_us(stats, kname) + if t is None: print(f" {label}: kernel '{kname}' not found in stats") continue - - avg_ns = float(entry["AverageNs"]) - min_ns = float(entry["MinNs"]) - max_ns = float(entry["MaxNs"]) - calls = int(entry["Calls"]) - total_ns = float(entry["TotalDurationNs"]) - results[label] = avg_ns - + results[label] = t print(f" {label} ({kname}):") - print(f" Calls: {calls}") - print(f" Average: {avg_ns / 1000:.2f} us ({avg_ns:.0f} ns)") - print(f" Min: {min_ns / 1000:.2f} us") - print(f" Max: {max_ns / 1000:.2f} us") - print(f" Total: {total_ns / 1e6:.2f} ms") + print(f" Calls: {t['calls']}") + print(f" Average: {t['avg_us']:.2f} us") + print(f" Min: {t['min_us']:.2f} us") + print(f" Max: {t['max_us']:.2f} us") + print(f" Total: {t['total_ms']:.2f} ms") + row = None if "FlyDSL" in results and "Triton opt3" in results: - speedup = results["Triton opt3"] / results["FlyDSL"] + speedup = results["Triton opt3"]["avg_us"] / results["FlyDSL"]["avg_us"] print(f"\n Speedup (FlyDSL vs Triton): {speedup:.3f}x") + row = { + "flydsl_us": results["FlyDSL"]["avg_us"], + "triton_us": results["Triton opt3"]["avg_us"], + "speedup": speedup, + } if not stats: print(" WARNING: no kernels found in stats file") @@ -664,13 +737,21 @@ def _print_rocprof_summary(stats_path: Path, total_tokens: int): for name in sorted(stats.keys()): print(f" {name}") + return row + -def _do_rocprof(full_prompt_len): - """Outer driver: launches rocprofv3 wrapping this script in --_rocprof-worker mode.""" +def _do_rocprof(full_prompt_len, tp=1) -> tuple[int, dict | None]: + """Outer driver: launches rocprofv3 wrapping this script in --_rocprof-worker mode. + + Returns (returncode, row_dict_or_None). + row_dict keys: tp, Hg, H, full_prompt_len, num_seqs, flydsl_us, triton_us, speedup. + """ + Hg = Hk // tp + H = Hv // tp repo_root = Path(__file__).resolve().parent.parent.parent output_dir = repo_root / "rocprof_output" output_dir.mkdir(parents=True, exist_ok=True) - output_stem = f"gdn_k5_fpl{full_prompt_len}" + output_stem = f"gdn_k5_tp{tp}_fpl{full_prompt_len}" env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" @@ -679,6 +760,7 @@ def _do_rocprof(full_prompt_len): "python3", "-u", str(Path(__file__).resolve()), "--_rocprof-worker", "--full-prompt-len", str(full_prompt_len), + "--tp", str(tp), ] rocprof_cmd = [ "rocprofv3", @@ -694,15 +776,26 @@ def _do_rocprof(full_prompt_len): context_lens = _build_context_lens(full_prompt_len) total_tokens = sum(context_lens) + num_seqs = len(context_lens) - print(f"\n[rocprof] full_prompt_len={full_prompt_len}, T={total_tokens}") + print(f"\n[rocprof] TP={tp} Hg={Hg} H={H} full_prompt_len={full_prompt_len}, T={total_tokens}") print(f"[rocprof] cmd: {' '.join(rocprof_cmd)}", flush=True) result = subprocess.run(rocprof_cmd, cwd=repo_root, env=env) + row = None stats_path = output_dir / f"{output_stem}_kernel_stats.csv" if stats_path.exists(): - print(f"\n[rocprof] Results (full_prompt_len={full_prompt_len}, T={total_tokens}):") - _print_rocprof_summary(stats_path, total_tokens) + print(f"\n[rocprof] Results (TP={tp} full_prompt_len={full_prompt_len}, T={total_tokens}):") + perf = _print_rocprof_summary(stats_path, total_tokens) + if perf is not None: + row = { + "tp": tp, + "Hg": Hg, + "H": H, + "full_prompt_len": full_prompt_len, + "num_seqs": num_seqs, + **perf, + } else: print(f"[rocprof] kernel stats not found: {stats_path}", flush=True) trace_path = output_dir / f"{output_stem}_kernel_trace.csv" @@ -712,7 +805,48 @@ def _do_rocprof(full_prompt_len): if result.returncode != 0: print(f"[rocprof] rocprofv3 exited with code {result.returncode}", flush=True) - return result.returncode + return result.returncode, row + + +def _print_summary_table(rows: list[dict], title: str = "Performance Summary"): + """Print a formatted summary table from collected benchmark rows. + + Rows are grouped by TP value, each group gets its own sub-table. + Each row dict keys: tp, Hg, H, full_prompt_len, num_seqs, + flydsl_us, flydsl_cfg, triton_us, triton_cfg, speedup. + """ + if not rows: + return + + from itertools import groupby + + w = 138 + sep = "-" * w + print(f"\n{'=' * w}") + print(f" {title}") + print(f" Fixed: K={K}, V={V}, Hk={Hk}, Hv={Hv}, BT={BT}, max_tokens={MAX_NUM_BATCHED_TOKENS}") + print(f"{'=' * w}") + + sorted_rows = sorted(rows, key=lambda r: (r.get("tp", 1), r["full_prompt_len"])) + for tp_val, group in groupby(sorted_rows, key=lambda r: r.get("tp", 1)): + group_rows = list(group) + Hg_val = group_rows[0].get("Hg", "?") + H_val = group_rows[0].get("H", "?") + print(f"\n TP={tp_val} (Hg={Hg_val}, H={H_val})") + print(f" {'FullPromptLen':>13} {'NumSeqs':>8} " + f"{'FlyDSL(us)':>11} {'FlyDSL BestCfg':>15} " + f"{'Triton(us)':>11} {'Triton BestCfg':>22} " + f"{'Speedup':>8}") + print(f" {sep}") + for r in group_rows: + fly_cfg = r.get("flydsl_cfg", "N/A") + tri_cfg = r.get("triton_cfg", "N/A") + print(f" {r['full_prompt_len']:>13} {r['num_seqs']:>8} " + f"{r['flydsl_us']:>11.2f} {fly_cfg:>15} " + f"{r['triton_us']:>11.2f} {tri_cfg:>22} " + f"{r['speedup']:>7.3f}x") + + print(f"{'=' * w}\n") # ── rocprofv3 pytest tests ───────────────────────────────────────────── @@ -720,11 +854,35 @@ def _do_rocprof(full_prompt_len): class TestRocprof: """Profile FlyDSL and Triton kernels with rocprofv3.""" - @pytest.mark.parametrize("full_prompt_len", PERF_SHAPES) - def test_rocprof(self, full_prompt_len): - rc = _do_rocprof(full_prompt_len) + _results: list[dict] = [] + + @pytest.mark.parametrize("tp, full_prompt_len", PERF_SHAPES) + def test_rocprof(self, tp, full_prompt_len): + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) + chunk_gated_delta_rule_fwd_h_flydsl( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + fwd_h_triton_opt3( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + torch.cuda.synchronize() + + rc, row = _do_rocprof(full_prompt_len, tp=tp) + if row is not None: + row["flydsl_cfg"] = _get_flydsl_best_config() + row["triton_cfg"] = _get_triton_best_config() + TestRocprof._results.append(row) assert rc == 0, f"rocprofv3 exited with code {rc}" + @pytest.fixture(autouse=True, scope="class") + def _print_table(self): + TestRocprof._results = [] + yield + _print_summary_table(TestRocprof._results, title="Rocprof Performance Summary") + # ── Main ──────────────────────────────────────────────────────────────── @@ -733,13 +891,14 @@ def test_rocprof(self, full_prompt_len): parser.add_argument("--mode", choices=["test", "rocprof"], default="test", help="test=pytest (default), rocprof=rocprofv3 profiling") parser.add_argument("--full-prompt-len", type=int, default=8000) + parser.add_argument("--tp", type=int, default=1) parser.add_argument("--_rocprof-worker", action="store_true", help=argparse.SUPPRESS) args = parser.parse_args() if args._rocprof_worker: - _rocprof_worker(args.full_prompt_len) + _rocprof_worker(args.full_prompt_len, tp=args.tp) elif args.mode == "rocprof": - _do_rocprof(args.full_prompt_len) + _do_rocprof(args.full_prompt_len, tp=args.tp) else: pytest.main([__file__, "-v", "-s"]) From 19e1cd8ecdbf6e9f30fad8df2f4326c7e8b05639 Mon Sep 17 00:00:00 2001 From: huizzhan Date: Mon, 20 Apr 2026 02:23:12 +0000 Subject: [PATCH 3/8] Support vk test --- tests/kernels/test_chunk_gated_delta_h.py | 563 ++++++++++++++++++++-- 1 file changed, 523 insertions(+), 40 deletions(-) diff --git a/tests/kernels/test_chunk_gated_delta_h.py b/tests/kernels/test_chunk_gated_delta_h.py index 2eb0292c4..76bb42692 100644 --- a/tests/kernels/test_chunk_gated_delta_h.py +++ b/tests/kernels/test_chunk_gated_delta_h.py @@ -25,6 +25,7 @@ import argparse import csv import ctypes +import json import subprocess import sys import os @@ -191,11 +192,15 @@ def _triton_fwd_kernel_h_opt3( if K > 128: tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(128,i_v*BV),(64,BV),(1,0)), b_h3.to(tl.float32), boundary_check=(0,1)) if K > 192: tl.store(tl.make_block_ptr(ht,(K,V),(V,1),(192,i_v*BV),(64,BV),(1,0)), b_h4.to(tl.float32), boundary_check=(0,1)) -def fwd_h_triton_opt3( +def _fwd_h_triton_opt3_kv( k, w, u, g=None, gk=None, initial_state=None, output_final_state=False, chunk_size=64, save_new_value=True, cu_seqlens=None, wu_contiguous=False, ): + """Raw triton opt3 kernel call with KV layout [K, V] for hidden states. + + Used directly by benchmark/rocprof to avoid transpose overhead. + """ B, T, Hg, K = k.shape BT = chunk_size if wu_contiguous: @@ -222,6 +227,366 @@ def fwd_h_triton_opt3( return h, v_new, final_state +def fwd_h_triton_opt3( + k, w, u, g=None, gk=None, initial_state=None, + output_final_state=False, chunk_size=64, save_new_value=True, + cu_seqlens=None, wu_contiguous=False, +): + """VK-layout wrapper for triton opt3 (transposes at boundaries).""" + h0_kv = initial_state.transpose(-2, -1).contiguous() if initial_state is not None else None + h_kv, v_new, fs_kv = _fwd_h_triton_opt3_kv( + k, w, u, g=g, gk=gk, initial_state=h0_kv, + output_final_state=output_final_state, + chunk_size=chunk_size, save_new_value=save_new_value, + cu_seqlens=cu_seqlens, wu_contiguous=wu_contiguous, + ) + h_vk = h_kv.transpose(-2, -1).contiguous() + fs_vk = fs_kv.transpose(-2, -1).contiguous() if fs_kv is not None else None + return h_vk, v_new, fs_vk + + +# ── Triton opt_vk kernel (inlined from linear_attn_example chunk_delta_h_vllm) ── +# w/u: [B, H, T, K/V] token-major; h / h0 / ht: [V, K]; k: [B, T, Hg, K] +_FLA_CHUNK_SIZE_OPT_VK = 64 + + +@triton.heuristics({ + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [1, 2, 3, 4] + for BV in [16, 32, 64] + ], + key=["H", "K", "V", "BT"], + use_cuda_graph=_use_cuda_graph, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_opt_vk( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + T_flat, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + b_h1 = tl.zeros([BV, 64], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([BV, 64], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([BV, 64], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([BV, 64], dtype=tl.float32) + + h += ((boh * H + i_h) * V * K).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + if IS_VARLEN: + w += ((i_h * T_flat + bos) * K).to(tl.int64) + else: + w += (((i_n * H + i_h) * T_flat) * K).to(tl.int64) + if IS_VARLEN: + v += ((i_h * T_flat + bos) * V).to(tl.int64) + else: + v += (((i_n * H + i_h) * T_flat) * V).to(tl.int64) + if SAVE_NEW_VALUE: + if IS_VARLEN: + v_new += ((i_h * T_flat + bos) * V).to(tl.int64) + else: + v_new += (((i_n * H + i_h) * T_flat) * V).to(tl.int64) + stride_v = V + stride_h = H * V * K + stride_k = Hg * K + stride_w = K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * V * K + if STORE_FINAL_STATE: + ht = ht + i_nh * V * K + + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + p_h1 = tl.make_block_ptr( + h + i_t.to(tl.int64) * stride_h, + (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0), + ) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr( + h + i_t.to(tl.int64) * stride_h, + (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0), + ) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr( + h + i_t.to(tl.int64) * stride_h, + (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0), + ) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr( + h + i_t.to(tl.int64) * stride_h, + (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0), + ) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0) + ) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0) + ) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0) + ) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr( + w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0) + ) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype)) + p_v = tl.make_block_ptr( + v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_vn = tl.make_block_ptr( + v_new, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0) + ) + tl.store(p_vn, b_v.to(p_vn.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t.to(tl.int64) + 1) * BT, T) - 1 + if USE_G: + m_t = (i_t.to(tl.int64) * BT + tl.arange(0, BT)) < T + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr( + g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,) + ) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v = b_v * tl.where(m_t, tl.exp(b_g_last - b_g), 0)[:, None] + b_g_last = tl.exp(b_g_last) + b_h1 *= b_g_last + if K > 64: + b_h2 *= b_g_last + if K > 128: + b_h3 *= b_g_last + if K > 192: + b_h4 *= b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k1, + mask=(o_k1 < K), other=0.0, + ) + b_h1 *= tl.exp(b_gk_last1)[None, :] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k2, + mask=(o_k2 < K), other=0.0, + ) + b_h2 *= tl.exp(b_gk_last2)[None, :] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k3, + mask=(o_k3 < K), other=0.0, + ) + b_h3 *= tl.exp(b_gk_last3)[None, :] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k4, + mask=(o_k4 < K), other=0.0, + ) + b_h4 *= tl.exp(b_gk_last4)[None, :] + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.trans(tl.dot(b_k, b_v)) + if K > 64: + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.trans(tl.dot(b_k, b_v)) + if K > 128: + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.trans(tl.dot(b_k, b_v)) + if K > 192: + p_k = tl.make_block_ptr( + k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.trans(tl.dot(b_k, b_v)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h_opt_vk( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = _FLA_CHUNK_SIZE_OPT_VK, + save_new_value: bool = True, + cu_seqlens: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + chunk_offsets: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + B, T, Hg, K = k.shape + H = w.shape[1] + V = u.shape[-1] + T_flat = w.shape[2] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = _prepare_chunk_indices(cu_seqlens, chunk_size) + + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT = len(cu_seqlens) - 1, len(chunk_indices) + if chunk_offsets is None: + chunk_offsets = _prepare_chunk_offsets(cu_seqlens, BT) + + assert K <= 256, "Current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, V, K) + final_state = k.new_empty(N, H, V, K, dtype=torch.float32) if output_final_state else None + v_new = k.new_empty(B, H, T_flat, V, dtype=u.dtype) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_opt_vk[grid]( + k=k, v=u, w=w, v_new=v_new, + g=g, gk=gk, + h=h, h0=initial_state, ht=final_state, + cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets, + T=T, T_flat=T_flat, + H=H, Hg=Hg, K=K, V=V, BT=BT, + ) + return h, v_new, final_state + + +def fwd_h_triton_opt_vk( + k, w, u, g=None, gk=None, initial_state=None, + output_final_state=False, chunk_size=64, save_new_value=True, + cu_seqlens=None, +): + """Wrapper around chunk_gated_delta_rule_fwd_h_opt_vk. + + All hidden state tensors use VK layout [V, K] natively — no transpose needed. + """ + return chunk_gated_delta_rule_fwd_h_opt_vk( + k, w, u, g=g, gk=gk, initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, save_new_value=save_new_value, + cu_seqlens=cu_seqlens, + ) + + +def fwd_h_flydsl( + k, w, u, g=None, initial_state=None, + output_final_state=False, cu_seqlens=None, wu_contiguous=True, +): + """Wrapper for FlyDSL kernel. + + External interface uses VK layout [V, K] for hidden states. + FlyDSL kernel internally uses KV layout [K, V], so we transpose at boundaries. + """ + h0_kv = initial_state.transpose(-2, -1).contiguous() if initial_state is not None else None + h_kv, v_new, fs_kv = chunk_gated_delta_rule_fwd_h_flydsl( + k, w, u, g=g, initial_state=h0_kv, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, wu_contiguous=wu_contiguous, + ) + h_vk = h_kv.transpose(-2, -1).contiguous() + fs_vk = fs_kv.transpose(-2, -1).contiguous() if fs_kv is not None else None + return h_vk, v_new, fs_vk + # ── Global test configuration ────────────────────────────────────────── # # Qwen3 params @@ -237,11 +602,12 @@ def fwd_h_triton_opt3( V = 128 Hk = 16 Hv = 64 -TP_LIST = [4, 8] +TP_LIST = [8] BT = 64 MAX_NUM_BATCHED_TOKENS = 32768 -FULL_PROMPT_LENS = [1024, 2048, 4096, 8192] +# FULL_PROMPT_LENS = [1024, 2048, 4096, 8192] +FULL_PROMPT_LENS = [8192] NUM_WARMUP = 10 NUM_ITERS = 200 @@ -287,7 +653,7 @@ def _make_inputs(context_lens, tp=1, dtype=torch.bfloat16, device="cuda", initial_state = None if with_initial_state: - initial_state = torch.randn(N, H, K, V, dtype=torch.float32, device=device) * 0.01 + initial_state = torch.randn(N, H, V, K, dtype=torch.float32, device=device) * 0.01 return k, w_orig, u_orig, w_c, u_c, g, initial_state, cu_seqlens, scheduled_q_lens @@ -312,11 +678,11 @@ def ref_chunk_gated_delta_rule_fwd_h( NT = sum(triton.cdiv(int(seq_len), BT_dim) for seq_len in seq_lens) gqa_ratio = H_dim // Hg_dim - h_out = k.new_zeros(B, NT, H_dim, K_dim, V_dim, dtype=torch.float32) + h_out = k.new_zeros(B, NT, H_dim, V_dim, K_dim, dtype=torch.float32) v_new_out = torch.zeros_like(u, dtype=torch.float32) N = len(cu_seqlens) - 1 if cu_seqlens is not None else B - final_state = torch.zeros(N, H_dim, K_dim, V_dim, dtype=torch.float32, + final_state = torch.zeros(N, H_dim, V_dim, K_dim, dtype=torch.float32, device=k.device) if output_final_state else None for b_idx in range(B): @@ -333,7 +699,8 @@ def ref_chunk_gated_delta_rule_fwd_h( for i_h in range(H_dim): i_hg = i_h // gqa_ratio - h_state = torch.zeros(K_dim, V_dim, dtype=torch.float32, + # h_state in VK layout: [V, K] + h_state = torch.zeros(V_dim, K_dim, dtype=torch.float32, device=k.device) if initial_state is not None: h_state = initial_state[seq_idx, i_h].float().clone() @@ -347,7 +714,8 @@ def ref_chunk_gated_delta_rule_fwd_h( w_chunk = w[b_idx, bos + t_start:bos + t_end, i_h].float() u_chunk = u[b_idx, bos + t_start:bos + t_end, i_h].float() - b_v = u_chunk - w_chunk @ h_state + # h_state is [V,K], need [K,V] for w @ h: w[T,K] @ h[K,V] + b_v = u_chunk - w_chunk @ h_state.T v_new_out[b_idx, bos + t_start:bos + t_end, i_h] = b_v last_idx = bos + t_end - 1 @@ -366,7 +734,8 @@ def ref_chunk_gated_delta_rule_fwd_h( h_state = h_state * torch.exp(g_last) k_chunk = k[b_idx, bos + t_start:bos + t_end, i_hg].float() b_v_gated_cast = b_v_gated.to(k.dtype).float() - h_state = h_state + k_chunk.T @ b_v_gated_cast + # h[V,K] += (k^T @ v_new)^T = v_new^T @ k + h_state = h_state + b_v_gated_cast.T @ k_chunk if output_final_state: final_state[seq_idx, i_h] = h_state @@ -398,9 +767,9 @@ def test_correctness_flydsl(self, tp, full_prompt_len): context_lens = _build_context_lens(full_prompt_len) k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) - h_fly, vn_fly, fs_fly = chunk_gated_delta_rule_fwd_h_flydsl( + h_fly, vn_fly, fs_fly = fwd_h_flydsl( k, w_c, u_c, g=g, initial_state=h0, - output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + output_final_state=True, cu_seqlens=cu, ) h_ref, vn_ref, fs_ref = ref_chunk_gated_delta_rule_fwd_h( k, w_orig, u_orig, g=g, initial_state=h0, @@ -437,15 +806,37 @@ def test_correctness_triton_opt3(self, tp, full_prompt_len): torch.testing.assert_close( fs_tri.float(), fs_ref.float(), atol=1e-1, rtol=1e-1) + @pytest.mark.parametrize("tp, full_prompt_len", PERF_SHAPES) + def test_correctness_triton_opt_vk(self, tp, full_prompt_len): + context_lens = _build_context_lens(full_prompt_len) + k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) + + h_vk, vn_vk, fs_vk = fwd_h_triton_opt_vk( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) + h_ref, vn_ref, fs_ref = ref_chunk_gated_delta_rule_fwd_h( + k, w_orig, u_orig, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) + + torch.testing.assert_close( + h_vk.float(), h_ref.float(), atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + _normalize_opt_v_new(vn_vk).float(), vn_ref.float(), + atol=1e-1, rtol=1e-1) + torch.testing.assert_close( + fs_vk.float(), fs_ref.float(), atol=1e-1, rtol=1e-1) + @pytest.mark.parametrize("tp, full_prompt_len", PERF_SHAPES) def test_correctness_flydsl_vs_triton(self, tp, full_prompt_len): """Direct comparison between FlyDSL and Triton opt3 kernels.""" context_lens = _build_context_lens(full_prompt_len) k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) - h_fly, vn_fly, fs_fly = chunk_gated_delta_rule_fwd_h_flydsl( + h_fly, vn_fly, fs_fly = fwd_h_flydsl( k, w_c, u_c, g=g, initial_state=h0, - output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + output_final_state=True, cu_seqlens=cu, ) h_tri, vn_tri, fs_tri = fwd_h_triton_opt3( k, w_c, u_c, g=g, initial_state=h0, @@ -506,6 +897,21 @@ def _get_triton_best_config() -> str: return "N/A" +def _get_triton_opt_vk_best_config() -> str: + """Return the Triton opt_vk autotune best config as a short string.""" + try: + kernel = chunk_gated_delta_rule_fwd_kernel_h_opt_vk + autotuner = kernel.fn if hasattr(kernel, "fn") else kernel + cfg = autotuner.best_config + kw = cfg.kwargs + bv = kw.get("BV", "?") + nw = cfg.num_warps + ns = cfg.num_stages + return f"BV={bv},nw={nw},ns={ns}" + except (AttributeError, TypeError): + return "N/A" + + # ── Performance tests ─────────────────────────────────────────────────── def _bench_fn(fn, *args, **kwargs): @@ -541,22 +947,34 @@ def test_perf_comparison(self, tp, full_prompt_len): Hg = Hk // tp H = Hv // tp + # Pre-transpose h0 for KV-layout kernels (outside benchmark loop) + h0_kv = h0.transpose(-2, -1).contiguous() if h0 is not None else None + us_fly = _bench_fn( chunk_gated_delta_rule_fwd_h_flydsl, - k, w_c, u_c, g=g, initial_state=h0, + k, w_c, u_c, g=g, initial_state=h0_kv, output_final_state=True, cu_seqlens=cu, wu_contiguous=True, ) print(f"\n[K5 FlyDSL TP={tp} Hg={Hg} H={H} T={total_tokens}] {us_fly:.2f} us") us_triton = _bench_fn( - fwd_h_triton_opt3, - k, w_c, u_c, g=g, initial_state=h0, + _fwd_h_triton_opt3_kv, + k, w_c, u_c, g=g, initial_state=h0_kv, output_final_state=True, cu_seqlens=cu, wu_contiguous=True, ) speedup = us_triton / us_fly if us_fly > 0 else float('inf') print(f"[K5 Triton opt3 TP={tp} T={total_tokens}] {us_triton:.2f} us") - print(f"[Speedup FlyDSL/Triton] {speedup:.3f}x") + print(f"[Speedup FlyDSL/Triton opt3] {speedup:.3f}x") + + us_triton_vk = _bench_fn( + chunk_gated_delta_rule_fwd_h_opt_vk, + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) + speedup_vk = us_triton_vk / us_fly if us_fly > 0 else float('inf') + print(f"[K5 Triton opt_vk TP={tp} T={total_tokens}] {us_triton_vk:.2f} us") + print(f"[Speedup FlyDSL/Triton opt_vk] {speedup_vk:.3f}x") TestPerformance._results.append({ "tp": tp, @@ -569,6 +987,9 @@ def test_perf_comparison(self, tp, full_prompt_len): "triton_us": us_triton, "triton_cfg": _get_triton_best_config(), "speedup": speedup, + "triton_opt_vk_us": us_triton_vk, + "triton_opt_vk_cfg": _get_triton_opt_vk_best_config(), + "speedup_vk": speedup_vk, }) @pytest.fixture(autouse=True, scope="class") @@ -582,6 +1003,7 @@ def _print_table(self): TARGET_KERNEL_FLYDSL = "chunk_gdn_fwd_h_opt3" TARGET_KERNEL_TRITON = "_triton_fwd_kernel_h_opt3" +TARGET_KERNEL_TRITON_OPT_VK = "chunk_gated_delta_rule_fwd_kernel_h_opt_vk" def _load_roctx_library(): @@ -611,7 +1033,7 @@ def _roctx_thread_id(lib): return int(tid.value) -def _rocprof_worker(full_prompt_len, tp=1): +def _rocprof_worker(full_prompt_len, tp=1, config_path: str | None = None): """Inner worker: runs under rocprofv3 --selected-regions. Profiling starts paused. We warmup both kernels, then @@ -627,8 +1049,11 @@ def _rocprof_worker(full_prompt_len, tp=1): k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) total_tokens = int(cu[-1].item()) + # Pre-transpose h0 for KV-layout kernels (outside profiling region) + h0_kv = h0.transpose(-2, -1).contiguous() if h0 is not None else None + run_fly = lambda: chunk_gated_delta_rule_fwd_h_flydsl( - k, w_c, u_c, g=g, initial_state=h0, + k, w_c, u_c, g=g, initial_state=h0_kv, output_final_state=True, cu_seqlens=cu, wu_contiguous=True, ) @@ -648,9 +1073,9 @@ def _rocprof_worker(full_prompt_len, tp=1): roctx.roctxProfilerPause(tid) print(f"[rocprof-worker] FlyDSL: {NUM_ITERS} iterations done", flush=True) - # Triton opt3 - run_tri = lambda: fwd_h_triton_opt3( - k, w_c, u_c, g=g, initial_state=h0, + # Triton opt3 (KV layout, h0 pre-transposed) + run_tri = lambda: _fwd_h_triton_opt3_kv( + k, w_c, u_c, g=g, initial_state=h0_kv, output_final_state=True, cu_seqlens=cu, wu_contiguous=True, ) @@ -666,7 +1091,34 @@ def _rocprof_worker(full_prompt_len, tp=1): torch.cuda.synchronize() roctx.roctxRangePop() roctx.roctxProfilerPause(tid) - print(f"[rocprof-worker] Triton: {NUM_ITERS} iterations done", flush=True) + print(f"[rocprof-worker] Triton opt3: {NUM_ITERS} iterations done", flush=True) + + # Triton opt_vk (VK layout, h0 directly) + run_tri_vk = lambda: chunk_gated_delta_rule_fwd_h_opt_vk( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) + + print(f"[rocprof-worker] Warmup Triton opt_vk ...", flush=True) + for _ in range(NUM_WARMUP): + run_tri_vk() + torch.cuda.synchronize() + + roctx.roctxProfilerResume(tid) + roctx.roctxRangePushA(b"triton_k5_opt_vk_bench") + for _ in range(NUM_ITERS): + run_tri_vk() + torch.cuda.synchronize() + roctx.roctxRangePop() + roctx.roctxProfilerPause(tid) + print(f"[rocprof-worker] Triton opt_vk: {NUM_ITERS} iterations done", flush=True) + + if config_path is not None: + Path(config_path).write_text(json.dumps({ + "flydsl_cfg": _get_flydsl_best_config(), + "triton_cfg": _get_triton_best_config(), + "triton_opt_vk_cfg": _get_triton_opt_vk_best_config(), + })) def _parse_kernel_stats(stats_path: Path) -> dict[str, dict]: @@ -704,6 +1156,7 @@ def _print_rocprof_summary(stats_path: Path, total_tokens: int) -> dict | None: targets = [ ("FlyDSL", TARGET_KERNEL_FLYDSL), ("Triton opt3", TARGET_KERNEL_TRITON), + ("Triton opt_vk", TARGET_KERNEL_TRITON_OPT_VK), ] results = {} @@ -723,12 +1176,17 @@ def _print_rocprof_summary(stats_path: Path, total_tokens: int) -> dict | None: row = None if "FlyDSL" in results and "Triton opt3" in results: speedup = results["Triton opt3"]["avg_us"] / results["FlyDSL"]["avg_us"] - print(f"\n Speedup (FlyDSL vs Triton): {speedup:.3f}x") + print(f"\n Speedup (FlyDSL vs Triton opt3): {speedup:.3f}x") row = { "flydsl_us": results["FlyDSL"]["avg_us"], "triton_us": results["Triton opt3"]["avg_us"], "speedup": speedup, } + if "Triton opt_vk" in results: + speedup_vk = results["Triton opt_vk"]["avg_us"] / results["FlyDSL"]["avg_us"] + print(f" Speedup (FlyDSL vs Triton opt_vk): {speedup_vk:.3f}x") + row["triton_opt_vk_us"] = results["Triton opt_vk"]["avg_us"] + row["speedup_vk"] = speedup_vk if not stats: print(" WARNING: no kernels found in stats file") @@ -752,6 +1210,9 @@ def _do_rocprof(full_prompt_len, tp=1) -> tuple[int, dict | None]: output_dir = repo_root / "rocprof_output" output_dir.mkdir(parents=True, exist_ok=True) output_stem = f"gdn_k5_tp{tp}_fpl{full_prompt_len}" + config_path = output_dir / f"{output_stem}_best_cfg.json" + if config_path.exists(): + config_path.unlink() env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" @@ -761,6 +1222,7 @@ def _do_rocprof(full_prompt_len, tp=1) -> tuple[int, dict | None]: "--_rocprof-worker", "--full-prompt-len", str(full_prompt_len), "--tp", str(tp), + "--rocprof-config-path", str(config_path), ] rocprof_cmd = [ "rocprofv3", @@ -796,6 +1258,8 @@ def _do_rocprof(full_prompt_len, tp=1) -> tuple[int, dict | None]: "num_seqs": num_seqs, **perf, } + if config_path.exists(): + row.update(json.loads(config_path.read_text())) else: print(f"[rocprof] kernel stats not found: {stats_path}", flush=True) trace_path = output_dir / f"{output_stem}_kernel_trace.csv" @@ -812,15 +1276,17 @@ def _print_summary_table(rows: list[dict], title: str = "Performance Summary"): """Print a formatted summary table from collected benchmark rows. Rows are grouped by TP value, each group gets its own sub-table. - Each row dict keys: tp, Hg, H, full_prompt_len, num_seqs, - flydsl_us, flydsl_cfg, triton_us, triton_cfg, speedup. + Supports both 2-kernel (FlyDSL + Triton opt3) and 3-kernel + (FlyDSL + Triton opt3 + Triton opt_vk) result sets. """ if not rows: return from itertools import groupby - w = 138 + has_opt_vk = any("triton_opt_vk_us" in r for r in rows) + + w = 200 if has_opt_vk else 138 sep = "-" * w print(f"\n{'=' * w}") print(f" {title}") @@ -833,18 +1299,28 @@ def _print_summary_table(rows: list[dict], title: str = "Performance Summary"): Hg_val = group_rows[0].get("Hg", "?") H_val = group_rows[0].get("H", "?") print(f"\n TP={tp_val} (Hg={Hg_val}, H={H_val})") - print(f" {'FullPromptLen':>13} {'NumSeqs':>8} " - f"{'FlyDSL(us)':>11} {'FlyDSL BestCfg':>15} " - f"{'Triton(us)':>11} {'Triton BestCfg':>22} " - f"{'Speedup':>8}") + hdr = (f" {'FullPromptLen':>13} {'NumSeqs':>8} " + f"{'FlyDSL(us)':>11} {'FlyDSL BestCfg':>15} " + f"{'Triton(us)':>11} {'Triton BestCfg':>22} " + f"{'Speedup':>8}") + if has_opt_vk: + hdr += (f" {'TritonVK(us)':>13} {'TritonVK BestCfg':>22} " + f"{'SpeedupVK':>10}") + print(hdr) print(f" {sep}") for r in group_rows: fly_cfg = r.get("flydsl_cfg", "N/A") tri_cfg = r.get("triton_cfg", "N/A") - print(f" {r['full_prompt_len']:>13} {r['num_seqs']:>8} " - f"{r['flydsl_us']:>11.2f} {fly_cfg:>15} " - f"{r['triton_us']:>11.2f} {tri_cfg:>22} " - f"{r['speedup']:>7.3f}x") + line = (f" {r['full_prompt_len']:>13} {r['num_seqs']:>8} " + f"{r['flydsl_us']:>11.2f} {fly_cfg:>15} " + f"{r['triton_us']:>11.2f} {tri_cfg:>22} " + f"{r['speedup']:>7.3f}x") + if has_opt_vk: + vk_us = r.get("triton_opt_vk_us", 0) + vk_cfg = r.get("triton_opt_vk_cfg", "N/A") + vk_speedup = r.get("speedup_vk", 0) + line += f" {vk_us:>13.2f} {vk_cfg:>22} {vk_speedup:>9.3f}x" + print(line) print(f"{'=' * w}\n") @@ -860,20 +1336,25 @@ class TestRocprof: def test_rocprof(self, tp, full_prompt_len): context_lens = _build_context_lens(full_prompt_len) k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) - chunk_gated_delta_rule_fwd_h_flydsl( + fwd_h_flydsl( k, w_c, u_c, g=g, initial_state=h0, - output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + output_final_state=True, cu_seqlens=cu, ) fwd_h_triton_opt3( k, w_c, u_c, g=g, initial_state=h0, output_final_state=True, cu_seqlens=cu, wu_contiguous=True, ) + fwd_h_triton_opt_vk( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) torch.cuda.synchronize() rc, row = _do_rocprof(full_prompt_len, tp=tp) if row is not None: - row["flydsl_cfg"] = _get_flydsl_best_config() - row["triton_cfg"] = _get_triton_best_config() + row.setdefault("flydsl_cfg", _get_flydsl_best_config()) + row.setdefault("triton_cfg", _get_triton_best_config()) + row.setdefault("triton_opt_vk_cfg", _get_triton_opt_vk_best_config()) TestRocprof._results.append(row) assert rc == 0, f"rocprofv3 exited with code {rc}" @@ -894,10 +1375,12 @@ def _print_table(self): parser.add_argument("--tp", type=int, default=1) parser.add_argument("--_rocprof-worker", action="store_true", help=argparse.SUPPRESS) + parser.add_argument("--rocprof-config-path", type=str, default=None, + help=argparse.SUPPRESS) args = parser.parse_args() if args._rocprof_worker: - _rocprof_worker(args.full_prompt_len, tp=args.tp) + _rocprof_worker(args.full_prompt_len, tp=args.tp, config_path=args.rocprof_config_path) elif args.mode == "rocprof": _do_rocprof(args.full_prompt_len, tp=args.tp) else: From 3c8a920c4cc6e56ba1bb866922c2f13b0ad9f7ce Mon Sep 17 00:00:00 2001 From: huizzhan Date: Mon, 20 Apr 2026 02:47:34 +0000 Subject: [PATCH 4/8] Add flydsl vk impl --- kernels/chunk_gated_delta_h.py | 25 +++++++++++++---------- tests/kernels/test_chunk_gated_delta_h.py | 22 +++++++------------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/kernels/chunk_gated_delta_h.py b/kernels/chunk_gated_delta_h.py index 740535fb7..5dc78712d 100644 --- a/kernels/chunk_gated_delta_h.py +++ b/kernels/chunk_gated_delta_h.py @@ -270,9 +270,9 @@ def _ds_read_tr_bf16x4(lds_byte_offset): boh = i_n * NT # ── Base pointer offsets (element counts) ── - # h: [B, NT, H, K, V] — base = (boh*H + i_h) * K * V - h_base = (boh * fx.Int32(H) + i_h) * fx.Int32(K * V) - stride_h = fx.Int32(H * K * V) + # h: [B, NT, H, V, K] (VK) — base = (boh*H + i_h) * V * K + h_base = (boh * fx.Int32(H) + i_h) * fx.Int32(V * K) + stride_h = fx.Int32(H * V * K) # k: [B, T, Hg, K] — base = (bos*Hg + i_h//(H//Hg)) * K gqa_ratio = H // Hg @@ -300,9 +300,9 @@ def _ds_read_tr_bf16x4(lds_byte_offset): vn_base = ((i_n * fx.Int32(H) + i_h) * T_flat) * fx.Int32(V) if USE_INITIAL_STATE: - h0_base = (i_nh * fx.Int32(K * V)) + h0_base = (i_nh * fx.Int32(V * K)) if STORE_FINAL_STATE: - ht_base = (i_nh * fx.Int32(K * V)) + ht_base = (i_nh * fx.Int32(V * K)) # ── MFMA lane mapping for 16x16 tiles ── lane_n = lane % fx.Int32(16) @@ -330,7 +330,7 @@ def _ds_read_tr_bf16x4(lds_byte_offset): h0_elems = [] for elem_i in range_constexpr(4): h0_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) - h0_off = h0_base + h0_row * fx.Int32(V) + h0_col + h0_off = h0_base + h0_col * fx.Int32(K) + h0_row h0_elems.append(h0_[fx.Index(h0_off)]) loaded_vec = vector.from_elements(T.f32x4, h0_elems) acc_idx = kb * N_REPEAT + nr @@ -372,7 +372,7 @@ def _ds_read_tr_bf16x4(lds_byte_offset): bf16_val = arith.trunc_f(T.bf16, f32_val) h_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) - h_off = h_base + i_t_i32 * stride_h + h_row * fx.Int32(V) + h_col + h_off = h_base + i_t_i32 * stride_h + h_col * fx.Int32(K) + h_row h_[fx.Index(h_off)] = bf16_val lds_h_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) @@ -576,7 +576,7 @@ def _ds_read_tr_bf16x4(lds_byte_offset): for elem_i in range_constexpr(4): f32_val = vector.extract(acc_val, static_position=[elem_i], dynamic_position=[]) ht_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) - ht_off = ht_base + ht_row * fx.Int32(V) + ht_col + ht_off = ht_base + ht_col * fx.Int32(K) + ht_row ht_[fx.Index(ht_off)] = f32_val # ── Host launcher ────────────────────────────────────────────────────── @@ -669,7 +669,10 @@ def chunk_gated_delta_rule_fwd_h_flydsl( wu_contiguous: bool = True, BV: int = 0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - """FlyDSL K5 wrapper with wrapper-level autotune over BV.""" + """FlyDSL K5 wrapper with wrapper-level autotune over BV. + + ``h`` / ``initial_state`` / ``final_state`` are VK-ordered on the last two dims: ``[..., V, K]``. + """ B, T, Hg, K = k.shape BT = chunk_size @@ -695,8 +698,8 @@ def chunk_gated_delta_rule_fwd_h_flydsl( assert K <= 256 - h = k.new_empty(B, NT, H, K, V) - final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + h = k.new_empty(B, NT, H, V, K) + final_state = k.new_empty(N, H, V, K, dtype=torch.float32) if output_final_state else None v_new_buf = k.new_empty(B, H, T_flat, V, dtype=u.dtype) v_new = v_new_buf if save_new_value else None diff --git a/tests/kernels/test_chunk_gated_delta_h.py b/tests/kernels/test_chunk_gated_delta_h.py index 76bb42692..f230ed36a 100644 --- a/tests/kernels/test_chunk_gated_delta_h.py +++ b/tests/kernels/test_chunk_gated_delta_h.py @@ -572,20 +572,12 @@ def fwd_h_flydsl( k, w, u, g=None, initial_state=None, output_final_state=False, cu_seqlens=None, wu_contiguous=True, ): - """Wrapper for FlyDSL kernel. - - External interface uses VK layout [V, K] for hidden states. - FlyDSL kernel internally uses KV layout [K, V], so we transpose at boundaries. - """ - h0_kv = initial_state.transpose(-2, -1).contiguous() if initial_state is not None else None - h_kv, v_new, fs_kv = chunk_gated_delta_rule_fwd_h_flydsl( - k, w, u, g=g, initial_state=h0_kv, + """FlyDSL K5: h / h0 / final_state are VK [V, K] (same convention as Triton opt_vk).""" + return chunk_gated_delta_rule_fwd_h_flydsl( + k, w, u, g=g, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, wu_contiguous=wu_contiguous, ) - h_vk = h_kv.transpose(-2, -1).contiguous() - fs_vk = fs_kv.transpose(-2, -1).contiguous() if fs_kv is not None else None - return h_vk, v_new, fs_vk # ── Global test configuration ────────────────────────────────────────── @@ -947,12 +939,12 @@ def test_perf_comparison(self, tp, full_prompt_len): Hg = Hk // tp H = Hv // tp - # Pre-transpose h0 for KV-layout kernels (outside benchmark loop) + # Triton opt3 uses KV hidden-state layout; FlyDSL / opt_vk use VK natively. h0_kv = h0.transpose(-2, -1).contiguous() if h0 is not None else None us_fly = _bench_fn( chunk_gated_delta_rule_fwd_h_flydsl, - k, w_c, u_c, g=g, initial_state=h0_kv, + k, w_c, u_c, g=g, initial_state=h0, output_final_state=True, cu_seqlens=cu, wu_contiguous=True, ) @@ -1049,11 +1041,11 @@ def _rocprof_worker(full_prompt_len, tp=1, config_path: str | None = None): k, w_orig, u_orig, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) total_tokens = int(cu[-1].item()) - # Pre-transpose h0 for KV-layout kernels (outside profiling region) + # Triton opt3 (KV) needs transposed h0; FlyDSL uses VK like opt_vk. h0_kv = h0.transpose(-2, -1).contiguous() if h0 is not None else None run_fly = lambda: chunk_gated_delta_rule_fwd_h_flydsl( - k, w_c, u_c, g=g, initial_state=h0_kv, + k, w_c, u_c, g=g, initial_state=h0, output_final_state=True, cu_seqlens=cu, wu_contiguous=True, ) From 46c08f642e3d012216332397a861c386ba28b660 Mon Sep 17 00:00:00 2001 From: huizzhan Date: Mon, 20 Apr 2026 03:22:13 +0000 Subject: [PATCH 5/8] Use best autotune perf --- kernels/chunk_gated_delta_h.py | 22 ++- tests/kernels/test_chunk_gated_delta_h.py | 219 +++++++++++++++++++++- 2 files changed, 232 insertions(+), 9 deletions(-) diff --git a/kernels/chunk_gated_delta_h.py b/kernels/chunk_gated_delta_h.py index 5dc78712d..6498f382c 100644 --- a/kernels/chunk_gated_delta_h.py +++ b/kernels/chunk_gated_delta_h.py @@ -359,7 +359,9 @@ def _ds_read_tr_bf16x4(lds_byte_offset): w_prefetch_all.append(w_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) w_prefetch_lds_all.append(row * fx.Int32(LDS_W_STRIDE) + fx.Int32(kb * 64) + load_col_base) - # ── Store h snapshot to global + LDS (w[0] loads in flight) ── + # ── Store h snapshot to LDS only (K-major rows × V tile; feeds MFMA ds_read). ── + # Global VK layout is v*K+k; naive per-lane stores stride by K and lose coalescing. + # After a full-tile barrier, write global h in linear VK order (k consecutive). for kb in range_constexpr(NUM_K_BLOCKS): for nr in range_constexpr(N_REPEAT): acc_idx = kb * N_REPEAT + nr @@ -371,14 +373,22 @@ def _ds_read_tr_bf16x4(lds_byte_offset): f32_val = vector.extract(acc_val, static_position=[elem_i], dynamic_position=[]) bf16_val = arith.trunc_f(T.bf16, f32_val) - h_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) - h_off = h_base + i_t_i32 * stride_h + h_col * fx.Int32(K) + h_row - h_[fx.Index(h_off)] = bf16_val - lds_h_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) lds_h_idx = lds_h_row * fx.Int32(BV) + lds_h_col lds_h[fx.Index(lds_h_idx)] = bf16_val + gpu.barrier() + + for vk_base in range_constexpr(0, LDS_H_ELEMS, BLOCK_THREADS): + linear = fx.Int32(vk_base) + tid + k_idx = linear % fx.Int32(K) + v_loc = linear // fx.Int32(K) + lds_read_idx = k_idx * fx.Int32(BV) + v_loc + bf16_tile = lds_h[fx.Index(lds_read_idx)] + v_global = i_v * fx.Int32(BV) + v_loc + h_off = h_base + i_t_i32 * stride_h + v_global * fx.Int32(K) + k_idx + h_[fx.Index(h_off)] = bf16_tile + # ── Store all w K-blocks to LDS in one batch ── for i_wp in range_constexpr(NUM_K_BLOCKS * NUM_LOAD_BATCHES_64): lds_w.vec_store((fx.Index(w_prefetch_lds_all[i_wp]),), w_prefetch_all[i_wp], LOAD_VEC_WIDTH) @@ -623,7 +633,7 @@ def launch_gdn_h( _compiled_kernels = {} _autotune_cache = {} # (shape_key) -> best BV -_BV_CANDIDATES = [16, 32, 64] +_BV_CANDIDATES = [16] _AUTOTUNE_WARMUP = 5 _AUTOTUNE_ITERS = 25 diff --git a/tests/kernels/test_chunk_gated_delta_h.py b/tests/kernels/test_chunk_gated_delta_h.py index f230ed36a..ee4096da2 100644 --- a/tests/kernels/test_chunk_gated_delta_h.py +++ b/tests/kernels/test_chunk_gated_delta_h.py @@ -17,6 +17,9 @@ python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s -k "Perf" python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s -k "Rocprof" + # Autotune best config + IR/asm (Triton opt_vk + FlyDSL) into tests/kernels/k5_autotune_artifacts/: + FLYDSL_K5_AUTOTUNE_ARTIFACTS=1 python3 -m pytest tests/kernels/test_chunk_gated_delta_h.py -v -s -k "dump_autotune" + # Direct rocprofv3 profiling (without pytest): python3 tests/kernels/test_chunk_gated_delta_h.py --mode rocprof python3 tests/kernels/test_chunk_gated_delta_h.py --mode rocprof --full-prompt-len 1000 @@ -261,9 +264,9 @@ def fwd_h_triton_opt3( @triton.autotune( configs=[ triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4] - for num_stages in [1, 2, 3, 4] - for BV in [16, 32, 64] + for num_warps in [4] + for num_stages in [3] + for BV in [16] ], key=["H", "K", "V", "BT"], use_cuda_graph=_use_cuda_graph, @@ -904,6 +907,169 @@ def _get_triton_opt_vk_best_config() -> str: return "N/A" +def _k5_autotune_artifacts_enabled() -> bool: + return os.environ.get("FLYDSL_K5_AUTOTUNE_ARTIFACTS", "").lower() in ("1", "true", "yes", "on") + + +def _k5_autotune_artifact_root() -> Path: + return Path(__file__).resolve().parent / "k5_autotune_artifacts" + + +def _triton_opt_vk_jit_launch_bundle( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None, + gk: torch.Tensor | None, + initial_state: torch.Tensor | None, + output_final_state: bool, + save_new_value: bool, + cu_seqlens: torch.Tensor | None, +): + """Mirror ``chunk_gated_delta_rule_fwd_h_opt_vk`` setup for ``JITFunction.run(..., warmup=True)``.""" + B, T, Hg, K = k.shape + H = w.shape[1] + V = u.shape[-1] + T_flat = w.shape[2] + BT = _FLA_CHUNK_SIZE_OPT_VK + + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT = len(cu_seqlens) - 1, len(_prepare_chunk_indices(cu_seqlens, BT)) + chunk_offsets = _prepare_chunk_offsets(cu_seqlens, BT) + + h = k.new_empty(B, NT, H, V, K) + final_state = k.new_empty(N, H, V, K, dtype=torch.float32) if output_final_state else None + v_new = k.new_empty(B, H, T_flat, V, dtype=u.dtype) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + launch_kw = dict( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + T_flat=T_flat, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return grid, launch_kw + + +def _dump_triton_opt_vk_ir_asm( + out_dir: Path, + k: torch.Tensor, + w_c: torch.Tensor, + u_c: torch.Tensor, + g: torch.Tensor | None, + gk: torch.Tensor | None, + h0: torch.Tensor | None, + cu: torch.Tensor | None, +) -> None: + """Write Triton compiler IR stages + HSACO for ``chunk_gated_delta_rule_fwd_kernel_h_opt_vk`` best config.""" + heur = chunk_gated_delta_rule_fwd_kernel_h_opt_vk + auto = heur.fn + jit = auto.fn + cfg = auto.best_config + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "best_config.txt").write_text( + f"Triton chunk_gated_delta_rule_fwd_kernel_h_opt_vk best_config:\n{cfg}\n", encoding="utf-8" + ) + grid, launch_kw = _triton_opt_vk_jit_launch_bundle( + k=k, + w=w_c, + u=u_c, + g=g, + gk=gk, + initial_state=h0, + output_final_state=True, + save_new_value=True, + cu_seqlens=cu, + ) + ctx = dict(launch_kw) + for v, heur_fn in heur.values.items(): + ctx[v] = heur_fn({**dict(zip(heur.arg_names, [])), **ctx}) + merged = {**ctx, **cfg.all_kwargs()} + ck = jit.run(grid=grid, warmup=True, **merged) + if ck is None: + (out_dir / "README.txt").write_text("jit.run(warmup=True) returned None.\n", encoding="utf-8") + return + for name, payload in ck.asm.items(): + if isinstance(payload, (bytes, bytearray, memoryview)): + (out_dir / f"{name}.bin").write_bytes(bytes(payload)) + else: + (out_dir / f"{name}.txt").write_text(str(payload), encoding="utf-8") + + +def _dump_flydsl_ir_asm_for_best_bv(out_dir: Path, cg_mod, k, w_c, u_c, g, h0, cu) -> None: + """Recompile FlyDSL K5 with ``BV`` fixed to autotuned value and ``FLYDSL_DUMP_IR=1`` (MLIR stages + ``*_final_isa.s``).""" + out_dir.mkdir(parents=True, exist_ok=True) + shape_key = None + best_bv = None + for key, bv in cg_mod._autotune_cache.items(): + if key[0] == K and key[1] == V: # match global K,V + shape_key, best_bv = key, bv + break + if shape_key is None and cg_mod._autotune_cache: + shape_key, best_bv = next(iter(cg_mod._autotune_cache.items())) + if best_bv is None: + (out_dir / "README.txt").write_text("FlyDSL _autotune_cache is empty.\n", encoding="utf-8") + return + + (out_dir / "best_config.txt").write_text( + f"FlyDSL chunk_gated_delta_rule_fwd_h_flydsl autotune (BV only):\n" + f" shape_key={shape_key!r}\n BV={best_bv}\n" + f"(MLIR gpu.func name is typically chunk_gdn_fwd_h_opt3 — see 00_origin.mlir)\n", + encoding="utf-8", + ) + cg_mod._compiled_kernels.clear() + prev_dump = os.environ.get("FLYDSL_DUMP_IR") + prev_dir = os.environ.get("FLYDSL_DUMP_DIR") + prev_rt_cache = os.environ.get("FLYDSL_RUNTIME_ENABLE_CACHE") + try: + os.environ["FLYDSL_DUMP_IR"] = "1" + os.environ["FLYDSL_DUMP_DIR"] = str(out_dir.resolve()) + os.environ["FLYDSL_RUNTIME_ENABLE_CACHE"] = "0" + chunk_gated_delta_rule_fwd_h_flydsl( + k, + w_c, + u_c, + g=g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu, + wu_contiguous=True, + BV=int(best_bv), + ) + torch.cuda.synchronize() + finally: + if prev_dump is None: + os.environ.pop("FLYDSL_DUMP_IR", None) + else: + os.environ["FLYDSL_DUMP_IR"] = prev_dump + if prev_dir is None: + os.environ.pop("FLYDSL_DUMP_DIR", None) + else: + os.environ["FLYDSL_DUMP_DIR"] = prev_dir + if prev_rt_cache is None: + os.environ.pop("FLYDSL_RUNTIME_ENABLE_CACHE", None) + else: + os.environ["FLYDSL_RUNTIME_ENABLE_CACHE"] = prev_rt_cache + + # ── Performance tests ─────────────────────────────────────────────────── def _bench_fn(fn, *args, **kwargs): @@ -1357,6 +1523,53 @@ def _print_table(self): _print_summary_table(TestRocprof._results, title="Rocprof Performance Summary") +class TestK5AutotuneArtifacts: + """Dump autotune winners and compiler artifacts (opt-in via ``FLYDSL_K5_AUTOTUNE_ARTIFACTS=1``).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA/HIP required") + def test_dump_autotune_best_configs_and_ir_asm(self): + if not _k5_autotune_artifacts_enabled(): + pytest.skip("set FLYDSL_K5_AUTOTUNE_ARTIFACTS=1 to dump configs / IR / asm") + + import kernels.chunk_gated_delta_h as cg_k5 + + fpl = FULL_PROMPT_LENS[0] if FULL_PROMPT_LENS else 8192 + tp = TP_LIST[0] if TP_LIST else 1 + context_lens = _build_context_lens(fpl) + k, _, _, w_c, u_c, g, h0, cu, _ = _make_inputs(context_lens, tp=tp) + + root = _k5_autotune_artifact_root() + root.mkdir(parents=True, exist_ok=True) + tri_dir = root / "triton_chunk_gated_delta_rule_fwd_kernel_h_opt_vk" + fly_dir = root / "flydsl_chunk_gated_delta_rule_fwd_h_flydsl" + + # --- Triton opt_vk: autotune (no artifact dump yet) --- + fwd_h_triton_opt_vk( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, + ) + torch.cuda.synchronize() + tri_cfg = _get_triton_opt_vk_best_config() + print(f"\n[K5 artifacts] Triton opt_vk autotune best: {tri_cfg}") + _dump_triton_opt_vk_ir_asm(tri_dir, k=k, w_c=w_c, u_c=u_c, g=g, gk=None, h0=h0, cu=cu) + print(f"[K5 artifacts] Triton IR/asm -> {tri_dir}") + + # --- FlyDSL: autotune BV, then single recompile with FLYDSL_DUMP_IR --- + cg_k5._compiled_kernels.clear() + cg_k5._autotune_cache.clear() + chunk_gated_delta_rule_fwd_h_flydsl( + k, w_c, u_c, g=g, initial_state=h0, + output_final_state=True, cu_seqlens=cu, wu_contiguous=True, + ) + torch.cuda.synchronize() + fly_cfg = _get_flydsl_best_config() + print(f"[K5 artifacts] FlyDSL BV autotune best: {fly_cfg}") + _dump_flydsl_ir_asm_for_best_bv(fly_dir, cg_k5, k, w_c, u_c, g, h0, cu) + subdirs = [p.name for p in fly_dir.iterdir() if p.is_dir()] + extra = f" subdirs={subdirs!r}" if subdirs else "" + print(f"[K5 artifacts] FlyDSL MLIR stages + final_isa.s -> {fly_dir}/" f"{extra}\n") + + # ── Main ──────────────────────────────────────────────────────────────── if __name__ == "__main__": From 0d475742f9c27551f3a2dac91e9a6776f9bdb06e Mon Sep 17 00:00:00 2001 From: huizzhan Date: Mon, 20 Apr 2026 04:14:26 +0000 Subject: [PATCH 6/8] Add prefetch --- kernels/chunk_gated_delta_h.py | 96 ++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 38 deletions(-) diff --git a/kernels/chunk_gated_delta_h.py b/kernels/chunk_gated_delta_h.py index 6498f382c..cffb80d45 100644 --- a/kernels/chunk_gated_delta_h.py +++ b/kernels/chunk_gated_delta_h.py @@ -336,37 +336,55 @@ def _ds_read_tr_bf16x4(lds_byte_offset): acc_idx = kb * N_REPEAT + nr h_accs[acc_idx] = arith.addf(h_accs[acc_idx], loaded_vec) - # ── Main chunk loop ── - init_state = [_to_raw(v) for v in h_accs] + # ── Software-pipelined main chunk loop ── + # Prefetch strategy: overlap next iteration's w global loads with current + # iteration's state-update MFMA, and overlap all k K-blocks global loads + # with delta-correction MFMA. w prefetch data is carried across iterations + # via scf.for loop-carried values. + # + # Timeline per iteration: + # [h snapshot store] → [w LDS store (from prefetch)] → barrier + # → [delta correction MFMA ‖ k ALL K-blocks load + g/u load] + # → [v_new + gating] → [vn+k LDS store] → barrier + # → [state update MFMA ‖ NEXT w load (prefetch)] + # → yield + + NUM_W_LOADS = NUM_K_BLOCKS * NUM_LOAD_BATCHES_64 + + # ── Prologue: pre-load first chunk's w data ── + i_t0_i32 = fx.Int32(0) + w_prefetch_init = [] + for kb in range_constexpr(NUM_K_BLOCKS): + for batch in range_constexpr(NUM_LOAD_BATCHES_64): + row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch + abs_row = i_t0_i32 * fx.Int32(BT) + row + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_off = w_base + safe_row * stride_w + fx.Int32(kb * 64) + load_col_base + w_prefetch_init.append(w_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) + + init_state = [_to_raw(v) for v in h_accs] + [_to_raw(v) for v in w_prefetch_init] c_zero = arith.index(0) c_one = arith.index(1) nt_idx = arith.index_cast(T.index, NT) for i_t, state in range(c_zero, nt_idx, c_one, init=init_state): - h_accs_in = list(state) + h_accs_in = list(state[:NUM_H_ACCS]) + w_prefetch_all = list(state[NUM_H_ACCS:]) i_t_i32 = arith.index_cast(T.i32, i_t) - # ── 1. Prefetch all w K-blocks from global (overlap with h snapshot store) ── - w_prefetch_all = [] + # ── 1. Compute w LDS offsets (w data already prefetched) ── w_prefetch_lds_all = [] for kb in range_constexpr(NUM_K_BLOCKS): for batch in range_constexpr(NUM_LOAD_BATCHES_64): row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch - abs_row = i_t_i32 * fx.Int32(BT) + row - in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) - safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) - g_off = w_base + safe_row * stride_w + fx.Int32(kb * 64) + load_col_base - w_prefetch_all.append(w_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) w_prefetch_lds_all.append(row * fx.Int32(LDS_W_STRIDE) + fx.Int32(kb * 64) + load_col_base) - # ── Store h snapshot to LDS only (K-major rows × V tile; feeds MFMA ds_read). ── - # Global VK layout is v*K+k; naive per-lane stores stride by K and lose coalescing. - # After a full-tile barrier, write global h in linear VK order (k consecutive). + # ── Store h snapshot to LDS ── for kb in range_constexpr(NUM_K_BLOCKS): for nr in range_constexpr(N_REPEAT): acc_idx = kb * N_REPEAT + nr acc_val = h_accs_in[acc_idx] - h_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n lds_h_col = fx.Int32(nr * 16) + lane_n for elem_i in range_constexpr(4): @@ -389,24 +407,25 @@ def _ds_read_tr_bf16x4(lds_byte_offset): h_off = h_base + i_t_i32 * stride_h + v_global * fx.Int32(K) + k_idx h_[fx.Index(h_off)] = bf16_tile - # ── Store all w K-blocks to LDS in one batch ── - for i_wp in range_constexpr(NUM_K_BLOCKS * NUM_LOAD_BATCHES_64): + # ── Store prefetched w to LDS (data already in registers from previous iter/prologue) ── + for i_wp in range_constexpr(NUM_W_LOADS): lds_w.vec_store((fx.Index(w_prefetch_lds_all[i_wp]),), w_prefetch_all[i_wp], LOAD_VEC_WIDTH) gpu.barrier() # ── 2. Delta correction: b_v = w @ h, then v_new = u - b_v ── - # Prefetch k[0] and u values during MFMA (overlap global loads with compute) + # Prefetch ALL k K-blocks during MFMA (not just k[0]) k_prefetch = [] k_prefetch_lds = [] - for batch in range_constexpr(NUM_LOAD_BATCHES_64): - row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch - abs_row = i_t_i32 * fx.Int32(BT) + row - in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) - safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) - g_off = k_base + safe_row * stride_k + fx.Int32(0 * 64) + load_col_base - k_prefetch.append(k_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) - k_prefetch_lds.append(row * fx.Int32(LDS_K_STRIDE) + load_col_base) + for kb in range_constexpr(NUM_K_BLOCKS): + for batch in range_constexpr(NUM_LOAD_BATCHES_64): + row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch + abs_row = i_t_i32 * fx.Int32(BT) + row + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_off = k_base + safe_row * stride_k + fx.Int32(kb * 64) + load_col_base + k_prefetch.append(k_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) + k_prefetch_lds.append(row * fx.Int32(LDS_K_STRIDE) + fx.Int32(kb * 64) + load_col_base) # Prefetch g values (overlap with MFMA below) if USE_G: @@ -520,17 +539,6 @@ def _ds_read_tr_bf16x4(lds_byte_offset): # ── 4. State update: h += k^T @ v_new_gated ── BT_STEPS = BT // WMMA_K - # Prefetch remaining k K-blocks (k[0] already prefetched during delta correction) - for kb_extra in range_constexpr(1, NUM_K_BLOCKS): - for batch in range_constexpr(NUM_LOAD_BATCHES_64): - row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch - abs_row = i_t_i32 * fx.Int32(BT) + row - in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) - safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) - g_off = k_base + safe_row * stride_k + fx.Int32(kb_extra * 64) + load_col_base - k_prefetch.append(k_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) - k_prefetch_lds.append(row * fx.Int32(LDS_K_STRIDE) + fx.Int32(kb_extra * 64) + load_col_base) - # Store gated v_new + all k K-blocks to LDS in one batch, single barrier for nr in range_constexpr(N_REPEAT): vn_val = vn_frags[nr] @@ -547,6 +555,18 @@ def _ds_read_tr_bf16x4(lds_byte_offset): gpu.barrier() + # ── Prefetch NEXT iteration's w during state update MFMA ── + next_i_t_i32 = i_t_i32 + fx.Int32(1) + w_next_prefetch = [] + for kb in range_constexpr(NUM_K_BLOCKS): + for batch in range_constexpr(NUM_LOAD_BATCHES_64): + row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch + abs_row = next_i_t_i32 * fx.Int32(BT) + row + in_bounds = arith.cmpi(arith.CmpIPredicate.slt, abs_row, T_local) + safe_row = arith.select(in_bounds, abs_row, fx.Int32(0)) + g_off = w_base + safe_row * stride_w + fx.Int32(kb * 64) + load_col_base + w_next_prefetch.append(w_.vec_load((fx.Index(g_off),), LOAD_VEC_WIDTH)) + for kb in range_constexpr(NUM_K_BLOCKS): for bt_s in range_constexpr(BT_STEPS): k_col_tr = wid * fx.Int32(16) + tr_col_sub * fx.Int32(4) @@ -571,9 +591,9 @@ def _ds_read_tr_bf16x4(lds_byte_offset): acc_idx = kb * N_REPEAT + nr h_accs_in[acc_idx] = _mfma_bf16_16x16x32(k_a_frag, vn_b_frag, h_accs_in[acc_idx]) - results = yield [_to_raw(v) for v in h_accs_in] + results = yield [_to_raw(v) for v in h_accs_in] + [_to_raw(v) for v in w_next_prefetch] - h_accs_final = list(results) + h_accs_final = list(results[:NUM_H_ACCS]) # ── Epilogue: store final state ── if STORE_FINAL_STATE: From 0e8afa14f73b0129060b74a09c6ecd6d14600d12 Mon Sep 17 00:00:00 2001 From: huizzhan Date: Tue, 21 Apr 2026 08:28:56 +0000 Subject: [PATCH 7/8] Refine --- kernels/chunk_gated_delta_h.py | 6 +++--- tests/kernels/test_chunk_gated_delta_h.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/kernels/chunk_gated_delta_h.py b/kernels/chunk_gated_delta_h.py index cffb80d45..670150412 100644 --- a/kernels/chunk_gated_delta_h.py +++ b/kernels/chunk_gated_delta_h.py @@ -59,8 +59,8 @@ def _fast_exp(x): def _mfma_bf16_16x16x32(a_bf16x8, b_bf16x8, acc_f32x4): """Single mfma_f32_16x16x32_bf16 instruction.""" return rocdl.mfma_f32_16x16x32_bf16( - T.f32x4, a_bf16x8, b_bf16x8, acc_f32x4, 0, 0, 0 - ).res + T.f32x4, [a_bf16x8, b_bf16x8, acc_f32x4, 0, 0, 0] + ) # ── Utility helpers ────────────────────────────────────────────────────── @@ -653,7 +653,7 @@ def launch_gdn_h( _compiled_kernels = {} _autotune_cache = {} # (shape_key) -> best BV -_BV_CANDIDATES = [16] +_BV_CANDIDATES = [16, 32, 64] _AUTOTUNE_WARMUP = 5 _AUTOTUNE_ITERS = 25 diff --git a/tests/kernels/test_chunk_gated_delta_h.py b/tests/kernels/test_chunk_gated_delta_h.py index ee4096da2..193f1a219 100644 --- a/tests/kernels/test_chunk_gated_delta_h.py +++ b/tests/kernels/test_chunk_gated_delta_h.py @@ -264,11 +264,11 @@ def fwd_h_triton_opt3( @triton.autotune( configs=[ triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [4] - for num_stages in [3] - for BV in [16] + for num_warps in [2, 4] + for num_stages in [1, 2, 3, 4] + for BV in [16, 32, 64] ], - key=["H", "K", "V", "BT"], + key=["H", "K", "V", "BT", "IS_VARLEN"], use_cuda_graph=_use_cuda_graph, ) @triton.jit(do_not_specialize=["T"]) From d4643e0e3f4c16f46a2e456fd7a87a2b323004d1 Mon Sep 17 00:00:00 2001 From: huizzhan Date: Tue, 28 Apr 2026 09:02:44 +0000 Subject: [PATCH 8/8] 1k bv64 opt --- kernels/chunk_gated_delta_h.py | 100 ++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 21 deletions(-) diff --git a/kernels/chunk_gated_delta_h.py b/kernels/chunk_gated_delta_h.py index 670150412..dc035abb5 100644 --- a/kernels/chunk_gated_delta_h.py +++ b/kernels/chunk_gated_delta_h.py @@ -135,15 +135,33 @@ def compile_chunk_gated_delta_h( LDS_K_ELEMS = BT * LDS_K_STRIDE LDS_K_BYTES = LDS_K_ELEMS * 2 - LDS_VN_STRIDE = BV + # OPT-D: lds_vn stride padding to break LDS bank conflicts on + # ds_write_b128 / ds_read_tr_b16 (avg 91.3 cyc/hit at BV=32 due to 2-way + # bank conflict on 64-byte rows). 8-byte padding shifts row alignment. + LDS_VN_PAD = 4 # 4 bf16 = 8 bytes + LDS_VN_STRIDE = BV + LDS_VN_PAD LDS_VN_ELEMS = BT * LDS_VN_STRIDE LDS_VN_BYTES = LDS_VN_ELEMS * 2 - LDS_H_STRIDE = BV + # OPT-H (replacement): lds_h stride padding. Original BV-stride causes + # heavy 2-way bank conflict on the ds_read_u16 path that materializes + # the h snapshot to HBM (avg 86.7 cyc/hit at BV=32, 11.6 % of total stall + # — see flydsl_vs_triton_1k_gap_analysis.md §2). +8 B padding shifts the + # K-row alignment and breaks the conflict. Cost: +K*8 bytes LDS (1 KB at + # BV=32, 0.5 KB at BV=16; well within budget for both 1k and 8k). + LDS_H_PAD = 4 # 4 bf16 = 8 bytes + LDS_H_STRIDE = BV + LDS_H_PAD LDS_H_ELEMS = K * LDS_H_STRIDE LDS_H_BYTES = LDS_H_ELEMS * 2 - allocator = SmemAllocator(None, arch="gfx942", global_sym_name="gdn_h_smem") + # OPT bumped: any IR-affecting change in this body should bump this rev so + # the FlyDSL JIT disk cache (~/.flydsl/cache/launch_gdn_h_*) invalidates. + _K5_KERNEL_REVISION = 2 # OPT-D + OPT-7 + OPT-F + OPT-H + OPT-4 (2026-04-28) + + allocator = SmemAllocator( + None, arch="gfx942", + global_sym_name=f"gdn_h_smem_v{_K5_KERNEL_REVISION}", + ) lds_w_offset = allocator._align(allocator.ptr, 16) allocator.ptr = lds_w_offset + LDS_W_BYTES lds_k_offset = allocator._align(allocator.ptr, 16) @@ -323,16 +341,22 @@ def _ds_read_tr_bf16x4(lds_byte_offset): h_accs.append(acc_zero) # ── Load initial state if provided ── + # OPT-F: 4 × scalar f32 load → 1 × buffer_load_dwordx4 (16 bytes). + # h0 is [V, K] so K is innermost; elem_i ∈ [0, 4) hits 4 consecutive K + # positions which are contiguous in memory. Original code emitted 4 + # separate buffer_load_dword followed by a chain of s_waitcnt vmcnt(N) + # in the prologue (see flydsl_vs_triton_1k_gap_analysis.md §7.6). if USE_INITIAL_STATE: for kb in range_constexpr(NUM_K_BLOCKS): for nr in range_constexpr(N_REPEAT): h0_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n - h0_elems = [] - for elem_i in range_constexpr(4): - h0_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) - h0_off = h0_base + h0_col * fx.Int32(K) + h0_row - h0_elems.append(h0_[fx.Index(h0_off)]) - loaded_vec = vector.from_elements(T.f32x4, h0_elems) + h0_row_base = ( + fx.Int32(kb * 64) + + wid * fx.Int32(16) + + lane_m_base * fx.Int32(4) + ) + h0_off_base = h0_base + h0_col * fx.Int32(K) + h0_row_base + loaded_vec = h0_.vec_load((fx.Index(h0_off_base),), 4) acc_idx = kb * N_REPEAT + nr h_accs[acc_idx] = arith.addf(h_accs[acc_idx], loaded_vec) @@ -374,13 +398,28 @@ def _ds_read_tr_bf16x4(lds_byte_offset): i_t_i32 = arith.index_cast(T.i32, i_t) # ── 1. Compute w LDS offsets (w data already prefetched) ── + # OPT-4: XOR swizzle to break 64-way bank conflict on lds_w (avg + # 41.8 cyc/hit at BV=32 vs Triton's 10.7). Pattern flips bits 3..5 + # of col with (row & 7), at 8-element bf16 granularity matching + # LOAD_VEC_WIDTH. Read path (W A frag below) applies the SAME + # swizzle. lds_k is NOT swizzled because ds_read_tr_b16 spans 4 + # rows per instruction and a row-dependent XOR mask would break + # the hardware transpose. w_prefetch_lds_all = [] for kb in range_constexpr(NUM_K_BLOCKS): for batch in range_constexpr(NUM_LOAD_BATCHES_64): row = fx.Int32(batch * ROWS_PER_BATCH_64) + load_row_in_batch - w_prefetch_lds_all.append(row * fx.Int32(LDS_W_STRIDE) + fx.Int32(kb * 64) + load_col_base) + col = fx.Int32(kb * 64) + load_col_base + swz_col = _xor_swizzle(row, col) + w_prefetch_lds_all.append( + row * fx.Int32(LDS_W_STRIDE) + swz_col + ) # ── Store h snapshot to LDS ── + # OPT-H (lds_h padding): lds_h layout is [K, BV+LDS_H_PAD]; write + # uses LDS_H_STRIDE so the K-rows are spaced by (BV+pad) bf16 + # instead of BV bf16, which breaks the 2-way bank conflict on + # the ds_read_u16 path below. for kb in range_constexpr(NUM_K_BLOCKS): for nr in range_constexpr(N_REPEAT): acc_idx = kb * N_REPEAT + nr @@ -392,16 +431,21 @@ def _ds_read_tr_bf16x4(lds_byte_offset): bf16_val = arith.trunc_f(T.bf16, f32_val) lds_h_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) - lds_h_idx = lds_h_row * fx.Int32(BV) + lds_h_col + lds_h_idx = lds_h_row * fx.Int32(LDS_H_STRIDE) + lds_h_col lds_h[fx.Index(lds_h_idx)] = bf16_val gpu.barrier() - for vk_base in range_constexpr(0, LDS_H_ELEMS, BLOCK_THREADS): + # LDS → HBM transpose: each thread copies one bf16 element per pass. + # Iteration count is K * BV / BLOCK_THREADS (NOT LDS_H_ELEMS, which + # includes padding). Reading uses LDS_H_STRIDE so we hit the same + # padded row layout as the writer. + VK_TOTAL = K * BV # actual elements (excluding padding) + for vk_base in range_constexpr(0, VK_TOTAL, BLOCK_THREADS): linear = fx.Int32(vk_base) + tid k_idx = linear % fx.Int32(K) v_loc = linear // fx.Int32(K) - lds_read_idx = k_idx * fx.Int32(BV) + v_loc + lds_read_idx = k_idx * fx.Int32(LDS_H_STRIDE) + v_loc bf16_tile = lds_h[fx.Index(lds_read_idx)] v_global = i_v * fx.Int32(BV) + v_loc h_off = h_base + i_t_i32 * stride_h + v_global * fx.Int32(K) + k_idx @@ -467,6 +511,8 @@ def _ds_read_tr_bf16x4(lds_byte_offset): for ks in range_constexpr(K_STEPS_PER_BLOCK): w_lds_row_idx = wid_idx * arith.index(16) + lane_n_idx w_lds_col_idx = arith.index(kb * 64 + ks * WMMA_K) + lane_m_base_idx * arith.index(8) + # OPT-4: same XOR swizzle as the write side. + w_lds_col_idx = _xor_swizzle_idx(w_lds_row_idx, w_lds_col_idx) w_lds_idx = w_lds_row_idx * arith.index(LDS_W_STRIDE) + w_lds_col_idx a_frag = _lds_vec_read_w_bf16x8(w_lds_idx) @@ -475,11 +521,14 @@ def _ds_read_tr_bf16x4(lds_byte_offset): for nr in range_constexpr(N_REPEAT): h_k_row = fx.Int32(global_ks * WMMA_K) + lane_m_base * fx.Int32(8) + tr_k_group h_v_col = fx.Int32(nr * 16) + tr_col_sub * fx.Int32(4) - h_lds_elem = h_k_row * fx.Int32(BV) + h_v_col + # OPT-H: stride is LDS_H_STRIDE = BV + LDS_H_PAD + h_lds_elem = h_k_row * fx.Int32(LDS_H_STRIDE) + h_v_col h_lds_byte = h_lds_elem * fx.Int32(2) + fx.Int32(lds_h_offset) h_lo = _ds_read_tr_bf16x4(h_lds_byte) - h_hi = _ds_read_tr_bf16x4(h_lds_byte + fx.Int32(4 * BV * 2)) + h_hi = _ds_read_tr_bf16x4( + h_lds_byte + fx.Int32(4 * LDS_H_STRIDE * 2) + ) b_frag = vector.shuffle(h_lo, h_hi, [0, 1, 2, 3, 4, 5, 6, 7]) bv_accs[nr] = _mfma_bf16_16x16x32(a_frag, b_frag, bv_accs[nr]) @@ -585,7 +634,10 @@ def _ds_read_tr_bf16x4(lds_byte_offset): vn_lds_byte = vn_lds_elem * fx.Int32(2) + fx.Int32(lds_vn_offset) vn_lo = _ds_read_tr_bf16x4(vn_lds_byte) - vn_hi = _ds_read_tr_bf16x4(vn_lds_byte + fx.Int32(4 * BV * 2)) + # OPT-D: stride is LDS_VN_STRIDE = BV + LDS_VN_PAD + vn_hi = _ds_read_tr_bf16x4( + vn_lds_byte + fx.Int32(4 * LDS_VN_STRIDE * 2) + ) vn_b_frag = vector.shuffle(vn_lo, vn_hi, [0, 1, 2, 3, 4, 5, 6, 7]) acc_idx = kb * N_REPEAT + nr @@ -596,6 +648,10 @@ def _ds_read_tr_bf16x4(lds_byte_offset): h_accs_final = list(results[:NUM_H_ACCS]) # ── Epilogue: store final state ── + # OPT-7: 4 × scalar f32 store → 1 × buffer_store_dwordx4 (16 bytes). + # Per-lane elem_i ∈ [0, 4) maps to consecutive K-rows in the [V, K] + # HBM layout (ht_row = base + elem_i), so the 4 f32 values are + # contiguous in memory. acc_val is f32x4 with element i at K offset i. if STORE_FINAL_STATE: for kb in range_constexpr(NUM_K_BLOCKS): for nr in range_constexpr(N_REPEAT): @@ -603,11 +659,13 @@ def _ds_read_tr_bf16x4(lds_byte_offset): acc_val = h_accs_final[acc_idx] ht_col = i_v * fx.Int32(BV) + fx.Int32(nr * 16) + lane_n - for elem_i in range_constexpr(4): - f32_val = vector.extract(acc_val, static_position=[elem_i], dynamic_position=[]) - ht_row = fx.Int32(kb * 64) + wid * fx.Int32(16) + lane_m_base * fx.Int32(4) + fx.Int32(elem_i) - ht_off = ht_base + ht_col * fx.Int32(K) + ht_row - ht_[fx.Index(ht_off)] = f32_val + ht_row_base = ( + fx.Int32(kb * 64) + + wid * fx.Int32(16) + + lane_m_base * fx.Int32(4) + ) + ht_off_base = ht_base + ht_col * fx.Int32(K) + ht_row_base + ht_.vec_store((fx.Index(ht_off_base),), acc_val, 4) # ── Host launcher ────────────────────────────────────────────────────── @flyc.jit