diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index b16b5fb4..566de720 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -214,6 +214,11 @@ def compile_moe_gemm1( "(or `rocdl.mfma_i32_16x16x32_i8`)." ) + mfma_i32_k64 = None + if is_int8 and _is_gfx950: + mfma_i32_k64 = getattr(rocdl, "mfma_i32_16x16x64_i8", None) + _use_int8_k64 = is_int8 and _is_gfx950 and (mfma_i32_k64 is not None) + mfma_f32_bf16_k16 = None if is_bf16: mfma_f32_bf16_k16 = getattr(rocdl, "mfma_f32_16x16x16bf16_1k", None) or getattr( @@ -719,8 +724,22 @@ def load_b_tile(base_k, blk_list, intra_list): raw_ku.append(raw) raw_data.append(raw_ku) return raw_data + elif const_expr(_use_int8_k64): + # gfx950 int8 K=64: merge two K32 loads into one i32x4 per ni + b_tile = [] + for ku in range_constexpr(k_unroll): + packs = [] + for ni in range_constexpr(num_acc_n): + ki0 = (ku * 2) + 0 + ki1 = (ku * 2) + 1 + b0 = load_b_pack(base_k, ki0, ni, blk_list, intra_list) + b1 = load_b_pack(base_k, ki1, ni, blk_list, intra_list) + b_merged = vector.bitcast(T.i32x4, vector.from_elements(T.i64x2, [b0, b1])) + packs.append(b_merged) + b_tile.append(packs) + return b_tile else: - # fp8/int8/bf16/fp16: original code path + # fp8/int8(gfx942)/bf16/fp16: original code path b_tile = [] for ku in range_constexpr(k_unroll): packs0 = [] @@ -734,7 +753,7 @@ def load_b_tile(base_k, blk_list, intra_list): packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile - + acc_gate = [acc_init] * (num_acc_n * m_repeat) acc_up = [acc_init] * (num_acc_n * m_repeat) @@ -956,6 +975,35 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): p_idx, p_g, p_u, p_sc_g, p_sc_u = _pending_gate_up gate_list[p_idx] = _acc_scaled_f32(gate_list[p_idx], p_g, p_sc_g) up_list[p_idx] = _acc_scaled_f32(up_list[p_idx], p_u, p_sc_u) + elif const_expr(_use_int8_k64): + # gfx950 int8 K=64: single MFMA instruction per K=64 step + def mfma_k64_one(acc_in, a_full, b_full): + return mfma_i32_k64(T.i32x4, [a_full, b_full, acc_in, 0, 0, 0]) + + for ku in range_constexpr(k_unroll): + b_gate_packs = b_gate_tile_in[ku] + b_up_packs = b_up_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + + if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + a_full = vector.bitcast(T.i32x4, vector.from_elements(T.i64x2, [a0, a1])) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + gate_list[acc_idx] = mfma_k64_one( + gate_list[acc_idx], a_full, b_gate_packs[ni] + ) + up_list[acc_idx] = mfma_k64_one( + up_list[acc_idx], a_full, b_up_packs[ni] + ) else: for ku in range_constexpr(k_unroll): b_gate_packs0, b_gate_packs1 = b_gate_tile_in[ku] @@ -1069,7 +1117,8 @@ def hot_loop_scheduler(): # Flattened as: [even_0..N, odd_0..N] → 2 * num_acc_n values # int4_bf16_single_field = is_int4_bf16 and not is_int4_bf16_groupwise - _fields_per_ku = 1 if int4_bf16_single_field else 2 + is_even_odd_split = not int4_bf16_single_field and not _use_int8_k64 + _fields_per_ku = 2 if is_even_odd_split else 1 _vals_per_b_tile = k_unroll * _fields_per_ku * num_acc_n def _flatten_b_tile(b_tile): @@ -1080,8 +1129,7 @@ def _flatten_b_tile(b_tile): # [(packed, scale), ...] → [packed_0..N, scale_0..N] flat.extend(t[0] for t in ku_entry) flat.extend(t[1] for t in ku_entry) - elif const_expr(int4_bf16_single_field): - # [raw_i64, ...] → [raw_0..N] + elif const_expr(not is_even_odd_split): flat.extend(ku_entry) else: # (packs_even, packs_odd) → [even_0..N, odd_0..N] @@ -1099,7 +1147,7 @@ def _unflatten_b_tile(vals): scales = list(vals[idx:idx + num_acc_n]) idx += num_acc_n b_tile.append([(packed[ni], scales[ni]) for ni in range_constexpr(num_acc_n)]) - elif const_expr(int4_bf16_single_field): + elif const_expr(not is_even_odd_split): b_tile.append(list(vals[idx:idx + num_acc_n])) idx += num_acc_n else: @@ -1767,6 +1815,11 @@ def compile_moe_gemm2( "(or `rocdl.mfma_i32_16x16x32_i8`)." ) + mfma_i32_k64 = None + if is_int8 and _is_gfx950: + mfma_i32_k64 = getattr(rocdl, "mfma_i32_16x16x64_i8", None) + _use_int8_k64 = is_int8 and _is_gfx950 and (mfma_i32_k64 is not None) + mfma_f32_bf16_k16 = None if is_bf16: mfma_f32_bf16_k16 = getattr(rocdl, "mfma_f32_16x16x16bf16_1k", None) or getattr( @@ -2279,8 +2332,22 @@ def load_b_tile(base_k): raw_ku.append(raw) raw_data.append(raw_ku) return raw_data + elif const_expr(_use_int8_k64): + # gfx950 int8 K=64: merge two K32 loads into one i32x4 per ni + b_tile = [] + for ku in range_constexpr(k_unroll): + packs = [] + for ni in range_constexpr(num_acc_n): + ki0 = (ku * 2) + 0 + ki1 = (ku * 2) + 1 + b0 = load_b_pack(base_k, ki0, ni) + b1 = load_b_pack(base_k, ki1, ni) + b_merged = vector.bitcast(T.i32x4, vector.from_elements(T.i64x2, [b0, b1])) + packs.append(b_merged) + b_tile.append(packs) + return b_tile else: - # fp8/int8/bf16/fp16: original code path + # fp8/int8(gfx942)/bf16/fp16: original code path b_tile = [] for ku in range_constexpr(k_unroll): packs0 = [] @@ -2294,7 +2361,7 @@ def load_b_tile(base_k): packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile - + # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v): for i in range_constexpr(num_x_loads): @@ -2498,6 +2565,31 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): if const_expr(_pending_acc is not None): p_idx, p_tmp, p_sc = _pending_acc acc_list[p_idx] = _acc_scaled_f32(acc_list[p_idx], p_tmp, p_sc) + elif const_expr(_use_int8_k64): + # gfx950 int8 K=64: single MFMA instruction per K=64 step + def mfma_k64_one(acc_in, a_full, b_full): + return mfma_i32_k64(T.i32x4, [a_full, b_full, acc_in, 0, 0, 0]) + + for ku in range_constexpr(k_unroll): + b_packs = b_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + + if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + a_full = vector.bitcast(T.i32x4, vector.from_elements(T.i64x2, [a0, a1])) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + acc_list[acc_idx] = mfma_k64_one( + acc_list[acc_idx], a_full, b_packs[ni] + ) else: for ku in range_constexpr(k_unroll): b_packs0, b_packs1 = b_tile_in[ku] @@ -2645,7 +2737,8 @@ def hot_loop_scheduler(): # B-tile data layout per k_unroll entry (3 variants): # See gemm1 _flatten_b_tile for full layout documentation. int4_bf16_single_field = is_int4_bf16 and not is_int4_bf16_groupwise - _fields_per_ku = 1 if int4_bf16_single_field else 2 + is_even_odd_split = not int4_bf16_single_field and not _use_int8_k64 + _fields_per_ku = 2 if is_even_odd_split else 1 _vals_per_b_tile = k_unroll * _fields_per_ku * num_acc_n _n_acc = m_repeat * num_acc_n _p_b = _n_acc @@ -2658,7 +2751,7 @@ def _flatten_b_tile(b_tile): if const_expr(is_int4_bf16_groupwise): flat.extend(t[0] for t in ku_entry) flat.extend(t[1] for t in ku_entry) - elif const_expr(int4_bf16_single_field): + elif const_expr(not is_even_odd_split): flat.extend(ku_entry) else: flat.extend(ku_entry[0]) @@ -2675,7 +2768,7 @@ def _unflatten_b_tile(vals): scales = list(vals[idx:idx + num_acc_n]) idx += num_acc_n b_tile.append([(packed[ni], scales[ni]) for ni in range_constexpr(num_acc_n)]) - elif const_expr(int4_bf16_single_field): + elif const_expr(not is_even_odd_split): b_tile.append(list(vals[idx:idx + num_acc_n])) idx += num_acc_n else: diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index 56ab93d8..728b2494 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -35,6 +35,7 @@ _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 _ods_mfma_i32_16x16x32_i8 = mfma_i32_16x16x32_i8 +_ods_mfma_i32_16x16x64_i8 = globals().get("mfma_i32_16x16x64_i8", None) _ods_mfma_f32_16x16x32_f16 = globals().get("mfma_f32_16x16x32_f16", None) _ods_mfma_f32_16x16x32_bf16 = globals().get("mfma_f32_16x16x32_bf16", None) _ods_mfma_scale_f32_16x16x128_f8f6f4 = ( @@ -114,6 +115,14 @@ def mfma_i32_16x16x32_i8(result_type, operands, *, loc=None, ip=None): return _ods_mfma_i32_16x16x32_i8(result_type, a, b, c, cbsz, abid, blgp, loc=loc, ip=ip).result +@traced_op +def mfma_i32_16x16x64_i8(result_type, operands, *, loc=None, ip=None): + if _ods_mfma_i32_16x16x64_i8 is None: + raise AttributeError("ROCDL op not found: mfma_i32_16x16x64_i8 (gfx950+)") + a, b, c, cbsz, abid, blgp = _split_mfma_operands(operands, loc=loc) + return _ods_mfma_i32_16x16x64_i8(result_type, a, b, c, cbsz, abid, blgp, loc=loc, ip=ip).result + + @traced_op def mfma_f32_16x16x32_f16(result_type, operands, *, loc=None, ip=None): if _ods_mfma_f32_16x16x32_f16 is None: