Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 104 additions & 11 deletions kernels/moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ def compile_moe_gemm1(
"(or `rocdl.mfma_i32_16x16x32_i8`)."
)

mfma_i32_k64 = None
if is_int8 and _is_gfx950:
mfma_i32_k64 = getattr(rocdl, "mfma_i32_16x16x64_i8", None)
_use_int8_k64 = is_int8 and _is_gfx950 and (mfma_i32_k64 is not None)

mfma_f32_bf16_k16 = None
if is_bf16:
mfma_f32_bf16_k16 = getattr(rocdl, "mfma_f32_16x16x16bf16_1k", None) or getattr(
Expand Down Expand Up @@ -719,8 +724,22 @@ def load_b_tile(base_k, blk_list, intra_list):
raw_ku.append(raw)
raw_data.append(raw_ku)
return raw_data
elif const_expr(_use_int8_k64):
# gfx950 int8 K=64: merge two K32 loads into one i32x4 per ni
b_tile = []
for ku in range_constexpr(k_unroll):
packs = []
for ni in range_constexpr(num_acc_n):
ki0 = (ku * 2) + 0
ki1 = (ku * 2) + 1
b0 = load_b_pack(base_k, ki0, ni, blk_list, intra_list)
b1 = load_b_pack(base_k, ki1, ni, blk_list, intra_list)
b_merged = vector.bitcast(T.i32x4, vector.from_elements(T.i64x2, [b0, b1]))
packs.append(b_merged)
b_tile.append(packs)
return b_tile
else:
# fp8/int8/bf16/fp16: original code path
# fp8/int8(gfx942)/bf16/fp16: original code path
b_tile = []
for ku in range_constexpr(k_unroll):
packs0 = []
Expand All @@ -734,7 +753,7 @@ def load_b_tile(base_k, blk_list, intra_list):
packs1.append(b1)
b_tile.append((packs0, packs1))
return b_tile

acc_gate = [acc_init] * (num_acc_n * m_repeat)
acc_up = [acc_init] * (num_acc_n * m_repeat)

Expand Down Expand Up @@ -956,6 +975,35 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val):
p_idx, p_g, p_u, p_sc_g, p_sc_u = _pending_gate_up
gate_list[p_idx] = _acc_scaled_f32(gate_list[p_idx], p_g, p_sc_g)
up_list[p_idx] = _acc_scaled_f32(up_list[p_idx], p_u, p_sc_u)
elif const_expr(_use_int8_k64):
# gfx950 int8 K=64: single MFMA instruction per K=64 step
def mfma_k64_one(acc_in, a_full, b_full):
return mfma_i32_k64(T.i32x4, [a_full, b_full, acc_in, 0, 0, 0])

for ku in range_constexpr(k_unroll):
b_gate_packs = b_gate_tile_in[ku]
b_up_packs = b_up_tile_in[ku]
ki64 = arith.index(ku * 64)
col_base = col_offset_base_bytes + ki64

for mi in range_constexpr(m_repeat):
mi_val = arith.index(mi * 16)
curr_row_a_lds = row_a_lds + mi_val

if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)):
a0, a1 = a0_prefetch
else:
a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base)
a_full = vector.bitcast(T.i32x4, vector.from_elements(T.i64x2, [a0, a1]))

for ni in range_constexpr(num_acc_n):
acc_idx = mi * num_acc_n + ni
gate_list[acc_idx] = mfma_k64_one(
gate_list[acc_idx], a_full, b_gate_packs[ni]
)
up_list[acc_idx] = mfma_k64_one(
up_list[acc_idx], a_full, b_up_packs[ni]
)
else:
for ku in range_constexpr(k_unroll):
b_gate_packs0, b_gate_packs1 = b_gate_tile_in[ku]
Expand Down Expand Up @@ -1069,7 +1117,8 @@ def hot_loop_scheduler():
# Flattened as: [even_0..N, odd_0..N] → 2 * num_acc_n values
#
int4_bf16_single_field = is_int4_bf16 and not is_int4_bf16_groupwise
_fields_per_ku = 1 if int4_bf16_single_field else 2
is_even_odd_split = not int4_bf16_single_field and not _use_int8_k64
_fields_per_ku = 2 if is_even_odd_split else 1
_vals_per_b_tile = k_unroll * _fields_per_ku * num_acc_n

def _flatten_b_tile(b_tile):
Expand All @@ -1080,8 +1129,7 @@ def _flatten_b_tile(b_tile):
# [(packed, scale), ...] → [packed_0..N, scale_0..N]
flat.extend(t[0] for t in ku_entry)
flat.extend(t[1] for t in ku_entry)
elif const_expr(int4_bf16_single_field):
# [raw_i64, ...] → [raw_0..N]
elif const_expr(not is_even_odd_split):
flat.extend(ku_entry)
else:
# (packs_even, packs_odd) → [even_0..N, odd_0..N]
Expand All @@ -1099,7 +1147,7 @@ def _unflatten_b_tile(vals):
scales = list(vals[idx:idx + num_acc_n])
idx += num_acc_n
b_tile.append([(packed[ni], scales[ni]) for ni in range_constexpr(num_acc_n)])
elif const_expr(int4_bf16_single_field):
elif const_expr(not is_even_odd_split):
b_tile.append(list(vals[idx:idx + num_acc_n]))
idx += num_acc_n
else:
Expand Down Expand Up @@ -1767,6 +1815,11 @@ def compile_moe_gemm2(
"(or `rocdl.mfma_i32_16x16x32_i8`)."
)

mfma_i32_k64 = None
if is_int8 and _is_gfx950:
mfma_i32_k64 = getattr(rocdl, "mfma_i32_16x16x64_i8", None)
_use_int8_k64 = is_int8 and _is_gfx950 and (mfma_i32_k64 is not None)

mfma_f32_bf16_k16 = None
if is_bf16:
mfma_f32_bf16_k16 = getattr(rocdl, "mfma_f32_16x16x16bf16_1k", None) or getattr(
Expand Down Expand Up @@ -2279,8 +2332,22 @@ def load_b_tile(base_k):
raw_ku.append(raw)
raw_data.append(raw_ku)
return raw_data
elif const_expr(_use_int8_k64):
# gfx950 int8 K=64: merge two K32 loads into one i32x4 per ni
b_tile = []
for ku in range_constexpr(k_unroll):
packs = []
for ni in range_constexpr(num_acc_n):
ki0 = (ku * 2) + 0
ki1 = (ku * 2) + 1
b0 = load_b_pack(base_k, ki0, ni)
b1 = load_b_pack(base_k, ki1, ni)
b_merged = vector.bitcast(T.i32x4, vector.from_elements(T.i64x2, [b0, b1]))
packs.append(b_merged)
b_tile.append(packs)
return b_tile
else:
# fp8/int8/bf16/fp16: original code path
# fp8/int8(gfx942)/bf16/fp16: original code path
b_tile = []
for ku in range_constexpr(k_unroll):
packs0 = []
Expand All @@ -2294,7 +2361,7 @@ def load_b_tile(base_k):
packs1.append(b1)
b_tile.append((packs0, packs1))
return b_tile

# ---- Pipeline helpers: store X tile to LDS with ping-pong base ----
def store_x_tile_to_lds(vec_x_in_parts, lds_base, x_load_bytes_v):
for i in range_constexpr(num_x_loads):
Expand Down Expand Up @@ -2498,6 +2565,31 @@ def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val):
if const_expr(_pending_acc is not None):
p_idx, p_tmp, p_sc = _pending_acc
acc_list[p_idx] = _acc_scaled_f32(acc_list[p_idx], p_tmp, p_sc)
elif const_expr(_use_int8_k64):
# gfx950 int8 K=64: single MFMA instruction per K=64 step
def mfma_k64_one(acc_in, a_full, b_full):
return mfma_i32_k64(T.i32x4, [a_full, b_full, acc_in, 0, 0, 0])

for ku in range_constexpr(k_unroll):
b_packs = b_tile_in[ku]
ki64 = arith.index(ku * 64)
col_base = col_offset_base_bytes + ki64

for mi in range_constexpr(m_repeat):
mi_val = arith.index(mi * 16)
curr_row_a_lds = row_a_lds + mi_val

if const_expr((a0_prefetch is not None) and (ku == 0) and (mi == 0)):
a0, a1 = a0_prefetch
else:
a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base)
a_full = vector.bitcast(T.i32x4, vector.from_elements(T.i64x2, [a0, a1]))

for ni in range_constexpr(num_acc_n):
acc_idx = mi * num_acc_n + ni
acc_list[acc_idx] = mfma_k64_one(
acc_list[acc_idx], a_full, b_packs[ni]
)
else:
for ku in range_constexpr(k_unroll):
b_packs0, b_packs1 = b_tile_in[ku]
Expand Down Expand Up @@ -2645,7 +2737,8 @@ def hot_loop_scheduler():
# B-tile data layout per k_unroll entry (3 variants):
# See gemm1 _flatten_b_tile for full layout documentation.
int4_bf16_single_field = is_int4_bf16 and not is_int4_bf16_groupwise
_fields_per_ku = 1 if int4_bf16_single_field else 2
is_even_odd_split = not int4_bf16_single_field and not _use_int8_k64
_fields_per_ku = 2 if is_even_odd_split else 1
_vals_per_b_tile = k_unroll * _fields_per_ku * num_acc_n
_n_acc = m_repeat * num_acc_n
_p_b = _n_acc
Expand All @@ -2658,7 +2751,7 @@ def _flatten_b_tile(b_tile):
if const_expr(is_int4_bf16_groupwise):
flat.extend(t[0] for t in ku_entry)
flat.extend(t[1] for t in ku_entry)
elif const_expr(int4_bf16_single_field):
elif const_expr(not is_even_odd_split):
flat.extend(ku_entry)
else:
flat.extend(ku_entry[0])
Expand All @@ -2675,7 +2768,7 @@ def _unflatten_b_tile(vals):
scales = list(vals[idx:idx + num_acc_n])
idx += num_acc_n
b_tile.append([(packed[ni], scales[ni]) for ni in range_constexpr(num_acc_n)])
elif const_expr(int4_bf16_single_field):
elif const_expr(not is_even_odd_split):
b_tile.append(list(vals[idx:idx + num_acc_n]))
idx += num_acc_n
else:
Expand Down
9 changes: 9 additions & 0 deletions python/flydsl/expr/rocdl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None)
_ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8
_ods_mfma_i32_16x16x32_i8 = mfma_i32_16x16x32_i8
_ods_mfma_i32_16x16x64_i8 = globals().get("mfma_i32_16x16x64_i8", None)
_ods_mfma_f32_16x16x32_f16 = globals().get("mfma_f32_16x16x32_f16", None)
_ods_mfma_f32_16x16x32_bf16 = globals().get("mfma_f32_16x16x32_bf16", None)
_ods_mfma_scale_f32_16x16x128_f8f6f4 = (
Expand Down Expand Up @@ -114,6 +115,14 @@ def mfma_i32_16x16x32_i8(result_type, operands, *, loc=None, ip=None):
return _ods_mfma_i32_16x16x32_i8(result_type, a, b, c, cbsz, abid, blgp, loc=loc, ip=ip).result


@traced_op
def mfma_i32_16x16x64_i8(result_type, operands, *, loc=None, ip=None):
if _ods_mfma_i32_16x16x64_i8 is None:
raise AttributeError("ROCDL op not found: mfma_i32_16x16x64_i8 (gfx950+)")
a, b, c, cbsz, abid, blgp = _split_mfma_operands(operands, loc=loc)
return _ods_mfma_i32_16x16x64_i8(result_type, a, b, c, cbsz, abid, blgp, loc=loc, ip=ip).result


@traced_op
def mfma_f32_16x16x32_f16(result_type, operands, *, loc=None, ip=None):
if _ods_mfma_f32_16x16x32_f16 is None:
Expand Down
Loading