Skip to content

Gfx1250 moe merge#3122

Open
XingerZhu wants to merge 7 commits intomainfrom
gfx1250_moe_merge
Open

Gfx1250 moe merge#3122
XingerZhu wants to merge 7 commits intomainfrom
gfx1250_moe_merge

Conversation

@XingerZhu
Copy link
Copy Markdown

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

XingerZhu and others added 6 commits May 9, 2026 09:41
…0 workaround

Pure infrastructure plumbing required for the gfx1250 MoE path; no
per-arch code yet. Splits cleanly from the kernel/dispatcher commits
that follow.

* aiter/ops/flydsl/utils.py: harden ``is_flydsl_available`` to actually
  load ``flydsl._mlir._mlir_libs._mlirDialectsFly`` (PEP-420 namespace
  pkg defeats ``find_spec``); cache via ``lru_cache``.
* aiter/ops/quant.py: add Triton ``_per_1x32_fp8_e8m0_quant_kernel`` +
  ``_per_1x32_f8_e8m0_quant_triton`` host wrapper (pure addition, does
  not touch upstream ``per_1x32_i4_quant`` / ``rotate_*`` ops).
* aiter/test_common.py: add ``_bool_all_safe`` and route
  ``checkAllclose`` through it -- gfx1250 (PyTorch 2.10 + ROCm 7.2)
  hangs in the bool ``Tensor.all()`` reduction kernel.
* aiter/jit/core.py: (1) ``_match_type`` so the dispatcher accepts
  ``torch.Tensor`` for an ``aiter_tensor_t`` annotation and vice versa;
  (2) ``_develop_module_ok`` -- only run torch->aiter_tensor_t
  conversion + ``_set_current_hip_stream`` when the loaded module
  actually exposes that pybind symbol (legacy ``module_quant`` etc. are
  tagged ``develop=True`` but use classic ``torch.Tensor`` signatures).
  Upstream ``pd.concat`` empty-frame guard left untouched.
* aiter/jit/optCompilerConfig.json: add ``-D__Float4_e2m1fn_x2=1`` to
  ``module_quant`` hipcc flags (pin the macro across torch versions).

Co-authored-by: Cursor <cursoragent@cursor.com>
Vendored from FlyDSL. None of these files are imported on non-gfx1250
paths (the dispatcher in commit 4 gates on ``get_gfx() == "gfx1250"``
and ``is_flydsl_available()``), so this is a pure addition.

* gemm_common_gfx1250.py: shared launch / TDM helpers used by all
  gfx1250 GEMM-family kernels (block index, descriptor builders, LDS
  layout, ``_finalize_alloc_and_launch_2d`` w/ waves_per_eu + cluster
  + 3D grid plumbing).
* moe_gemm_2stage_common_gfx1250.py: MoE-specific common (sorted-tid
  LDS fast path, optional bias, SwiGLU emit, split-K epilogue,
  TDM hoist).
* moe_gemm_2stage_mxscale_gfx1250.py: per_1x32 MXFP8/MXFP4 MXScale
  kernel (the main a8w4 / fp8 SwiGLU MoE GEMM on gfx1250).
* moe_gemm_2stage_wmma_gfx1250.py: WMMA-based bf16/fp8 path.
* pipeline_utils.py: compile-time tail-plan helper for N-stage
  multi-buffer software pipelining + LDS reuse fence calculation.

Co-authored-by: Cursor <cursoragent@cursor.com>
Pure additive change. Adds the shape-alignment / pad-cache utilities the
gfx1250 stage1/stage2 wrappers (added in commit 4) need to legalize odd
model_dim / inter_dim like GPT-OSS 2880 against the FlyDSL mxscale
kernels' tile_k=128 + tile_n divisibility constraints.

Touches only:
* ``import os`` at the top.
* ``_MXSCALE_FORMAT_PACK`` -- (pack_a, pack_b, weight_is_preshuffled)
  per in_dtype.
* ``_MXSCALE_PAD_CACHE`` + ``_mxscale_pad_cache_{key,get,put}`` --
  data_ptr-keyed FIFO cache (default 512MB budget, env-tunable via
  ``AITER_GFX1250_PAD_CACHE_MAX_BYTES`` / disable via
  ``AITER_GFX1250_DISABLE_PAD_CACHE=1``) so repeated fused_moe calls
  with the same weight tensor don't redo the ~100MB memcpy.
* ``_mxscale_align_up``, ``_mxscale_pick_tile_n`` -- mirrors FlyDSL's
  own ``bench_resolve_tiles`` heuristic.
* ``_mxscale_zero_pad_last`` -- pads on the last dim, going through a
  uint8 view for sub-byte / e8m0 dtypes that ``F.pad`` doesn't
  implement natively.
* ``_mxscale_pad_weight_k`` -- preshuffle-aware K-pad that appends
  whole 16x16 zero tiles instead of bytes-inside-shuffle-groups.

Explicitly preserved upstream code (NOT touched in this commit):
* int4_bf16 dispatch (``get_flydsl_stage{1,2}_kernels_int4_bf16``,
  the ``elif a_dtype=="bf16" and b_dtype=="int4"`` branches).
* ``_get_compiled_silu_fused`` extended signature
  (``act`` / ``enable_bias``) and the ``_get_compiled_swiglu`` cache.
* split-K + bias post-processing path in ``flydsl_moe_stage1``
  (``topk_ids``, ``swiglu_and_mul_bias``, ``silu_and_mul_bias`` etc.).
* ``_run_compiled`` (only its docstring differs upstream).

Co-authored-by: Cursor <cursoragent@cursor.com>
Top-level
- Add `_USE_OPUS_MOE_SORTING` (`AITER_USE_OPUS_MOE_SORTING`) without
  removing upstream `_USE_CK_MOE_SORTING` /
  `_USE_GENERIC_SWIGLU_MXFP4_LAYOUT` / `_SWIGLU_MXFP4_BF16_BOUND`. The
  gfx1250 torch fallback is reachable via `AITER_USE_CK_MOE_SORTING=1`
  (caller asks for non-opus, gfx1250 short-circuit then routes to
  `_moe_sorting_torch_gfx1250`).
- `_moe_sorting_torch_gfx1250(...)`: pure-torch, vectorised moe_sorting
  fallback that matches `moe_sorting_torch_native` semantics:
  `(topk_idx<<24)|token_idx` packing, per-expert `block_size` padding,
  local-expert-id tile fill, and `num_valid_ids = [padded_slots, num_local_tokens]`.
  Needed because both `aiter.moe_sorting_opus_fwd` (HSA queue completion
  signal never raised -> infinite wait) and `aiter.moe_sorting_fwd`
  (CK-tile, NULL kernel pointer on gfx1250) are broken on gfx1250.

Dispatcher
- `_moe_sorting_impl` short-circuits to the torch fallback when
  `get_gfx() == "gfx1250" and not use_opus`.
- `moe_sorting()` combines both env vars when computing `use_opus`
  (gfx1250 + `AITER_USE_CK_MOE_SORTING=1` -> torch fallback path).
- `q_dtype_a` decision table prepends two gfx1250 branches:
  fp8 weight -> fp8 activation; SwiGLU -> fp8 above
  `bf16_fp8_bound`, fp4x2 below. Sits before the existing
  `_USE_GENERIC_SWIGLU_MXFP4_LAYOUT` switch so the GPT-OSS path on
  gfx950/gfx942 is unaffected.
- `KERNEL_DICT["gfx1250"] = {}` placeholder (never populated; keeps the
  per-gfx 1-stage lookup explicit so future gfx1250 1-stage entries land
  in an obvious spot).

Entry points
- `_gfx1250_data_format(q_dtype_a, q_dtype_w, q_type, dtype)`: maps
  aiter quant params to FlyDSL kernel format string
  (`fp4`/`fp8`/`a8w4`/`bf16`/`fp16`).
- `_ensure_flydsl_kernels_path()`: inserts
  `aiter/aiter/ops/flydsl/` into `sys.path` so the vendored
  `from kernels.pipeline_utils import ...` style imports resolve.
- `_gfx1250_moe_stage1` / `_gfx1250_moe_stage2`: host-side wrappers
  around the FlyDSL JIT kernels. Handle K/N alignment (zero-pad K to
  tile_k via `_mxscale_*` helpers, pick a tile_n that divides both N
  dims), zero-init the `out` buffer (FlyDSL kernel only writes slots
  listed in `sorted_token_ids`; padding slots leak as inf/nan into
  stage2), view E8M0 scales as raw uint8 (FlyDSL DLPack does not yet
  map dtype 14 to MLIR), forward bias when stage1 SwiGLU is on, and
  surface `AITER_GFX1250_*` debug knobs.

`get_2stage_cfgs`
- Top-of-function gfx1250 direct dispatch: when
  `is_flydsl_available()` and `_gfx1250_data_format(...)` is non-None,
  return a `MOEMetadata` wrapping `_gfx1250_moe_stage{1,2}` with
  default `block_m=32`, `tile_n=tile_k=128` for mxscale and 64 for
  WMMA, bypassing the tuned-CSV path entirely.
- gfx1250 safety net inside the regular path: if a tuned cfg's
  `kernelName1`/`kernelName2` starts with `moe_ck2stages` (CK is not
  built for gfx1250 -> NULL kernel pointer segfault), drop the cfg and
  fall back to default heuristics.

`fused_moe_2stages`
- Stage1 fp8/per_1x32 quant branch: add a `get_gfx() != "gfx1250"`
  guard so gfx1250 falls through to the dedicated elif which mirrors
  the FlyDSL UT `_per_1x32_fp8_quant` exactly (e4m3fn bytes via
  `fp4_utils._f32_to_floatx_unpacked`, scale max=240, no NaN
  sentinel poisoning the K-sum).
- Stage2 quant: matching gfx1250 elif keeping scale_x in source-token
  layout (FlyDSL stage2 gathers per-token scale via `sorted_token_ids`
  internally, so a pre-sorted tile layout would address the wrong
  bytes).
- Upstream `topk_ids` / `topk_weights` forwarding into split-K bias
  paths is preserved.

Notes
- `isShuffled` keeps the upstream `getattr(w1, ...) or getattr(w2, ...)`
  formulation; gfx1250 does not consult it (routed via dedicated entry
  point above), so there is no need to weaken the CK fallback signal.

Co-authored-by: Cursor <cursoragent@cursor.com>
…nobs

Helpers (top of file)
- _gfx1250_fp8_round_trip_bf16(x): per-1x32 mxfp8 quant+dequant
  (e4m3fn bytes, e8m0 scale, dtype_max=240).  Mirrors FlyDSL UT
  ``_per_1x32_fp8_quant`` / ``_dequant_blockscale_fp8`` so the bf16
  reference path sees the same precision loss the kernel does;
  otherwise checkAllclose drifts by ~0.5 per K=3072 element on a8w4
  despite the kernel being numerically correct.
- _preshuffle_b_16x16_dtype_safe(fp4u, b, rows, cols): wrapper around
  FlyDSL's ``preshuffle_b_16x16`` that view-casts packed sub-byte
  dtypes (``torch.float4_e2m1fn_x2``) through ``uint8`` for the
  ``permute(...).contiguous()`` step, since rocm/torch 2.10 has no
  ``copy_`` impl for those dtypes.  No-op for >=1-byte dtypes.
- _gfx1250_a8w4_default_kpad(model_dim, inter_dim): K-adaptive
  default for ``-hip`` so GPT-OSS-class K~3000 still lands inside
  the FlyDSL verdict thresholds.  Static (192,128) covers ~6% of
  K=2880 -> verdict fails by ~25% mismatch / 0.61 logits_diff;
  ~K/4 (rounded to 128) / inter/4 (rounded to 64) gives ~75%
  effective K with margin past both verdict thresholds.

test_fmoe
- per_1x32 skip relaxed: gfx1250 joins gfx950 (test now runs on
  both arches for per_1x32 quant types).
- _input_scale=0.2 for gfx1250 fp8/fp4 + per_1x32 (mirrors UT
  init_scale=0.2): default unit-stddev randn() saturates the
  K-sum into bf16 range and the ref disagrees by 100x; UT scales
  inputs down for exactly this reason.
- w2 also scaled by 1/sqrt(inter_dim) when input is scaled.
- Quant section: keep upstream i4x2 (a16wi4) and a16w4 paths;
  insert gfx1250 fp8 round-trip into the per_1x32 a16w4/a8w4 a1
  branch so the bf16 reference matches kernel arithmetic.
- Bias section: keep upstream ``i4x2 -> bias=None`` branch; for
  gfx1250 fp8/fp4 + per_1x32, disable bias when actType != Swiglu
  to match the FlyDSL stage1/stage2 epilogue's actual bias support.
- Pre-shuffle: insert a top-of-chain ``_gfx1250_flydsl_eligible``
  branch (fp4 -> raw weights+raw scales; fp8/a8w4 -> preshuffle_b_16x16
  + raw scales).  Upstream a16w4 (shuffle_weight_a16w4) and a16wi4
  (pack_int8_to_packed_int4 + shuffle_scale_for_int4) branches stay.

Accuracy verdict
- ``--no-perftest``: early-return None before run_perftest so the
  outer driver drops the case from the result table.  Use case is
  exercising the gfx1250 front-end on a host whose Triton
  activation-quant kernel doesn't yet codegen for gfx1250.
- ck_atol/rtol per FlyDSL precision class: a8w4 -> 0.5/0.5,
  fp4 -> 0.25/0.5, fp8 -> 0.25/0.25.  Default 1e-2 for everything
  else (upstream behaviour).
- Verdict: FlyDSL paths use UT-style "mismatch_ratio < 5% OR
  logits_diff < threshold" PASS rule (FlyDSL test_common.verify_output).
  Per-precision logits_thr: a8w4=0.5, fp4=0.25, fp8=0.05.  PASS sets
  err=0 so the markdown summary surfaces a green row; FAIL still
  asserts under strict_accuracy.  Non-FlyDSL paths keep upstream
  strict-error logic.
- Return dict gains ``err`` alongside upstream ``us`` /
  ``logits_diff`` so the verdict block can normalise mismatch ratio.

CLI
- ``-hip`` default flipped from ``[(192,128)]`` to ``None`` so the
  K-adaptive ``_gfx1250_a8w4_default_kpad`` kicks in (still
  overridable: ``-hip 0,0`` etc.).
- ``--no-perftest`` (above).
- ``--w_fp4_kernel`` / ``--wfp4`` alias: surfaces
  ``AITER_GFX1250_W_FP4_KERNEL=1`` env var so kernel selection can
  opt into the FlyDSL fp4-weight kernel.  Currently advisory.
- ``--num_buffers`` (1..4): pipeline buffer depth knob, mirrors
  FlyDSL UT.  Sets ``AITER_GFX1250_NUM_BUFFERS`` unconditionally so
  reruns from a parent shell don't inherit a stale value.
- ``AITER_GFX1250_DEBUG=1`` gates two debug prints (which path the
  test took, what the preshuffled w1/w2 dtype/contiguity ended up).

Notes
- Upstream features kept intact: int4 (a8w4), per_128x128 quant,
  a16w4 (shuffle_weight_a16w4), a16wi4 (per_1x32 + i4x2),
  ``_runtime_swiglu_mxfp4_q_dtype_a`` CSV row filter,
  ``_PER1X32_BF16_I4`` legacy iteration, ``shuffle_scale_for_int4``
  + ``pack_int8_to_packed_int4`` imports.
- ``import math`` (stdlib) replaces aiter_new's ``import math as _math``
  inline pattern.
- ``return {...}`` now carries ``err`` AND ``logits_diff`` (strict
  superset of upstream: upstream consumers checked
  ``logits_diff`` only; FlyDSL gfx1250 verdict uses ``err``).

Co-authored-by: Cursor <cursoragent@cursor.com>
Default ``testGraph`` for ``perftest()`` and ``run_perftest()`` is
flipped from ``False`` to a sentinel ``None``, which resolves to:
  * True  on gfx1250
  * False everywhere else (preserves upstream behaviour)

Rationale: on gfx1250 the FlyDSL kernel wrappers are
Python -> ctypes -> hipModuleLaunchKernel per call.  At the MoE-GEMM
batch sizes the test sweep covers, the per-iter Python overhead is on
the same order as kernel time, so eager-launch ``us2`` does not
reflect what actually ships through compiled stacks (vLLM / sgl-lang)
that wrap the call in a HIP graph.  The graph path gives a materially
more useful number for tuning + accuracy verdict.

Callers that pass ``testGraph=True/False`` explicitly are unaffected
(no change in behaviour); only the implicit default flips and only on
gfx1250.

Note on dead code cleanup
-------------------------
The plan called for deleting an explicit dead block in
``fused_moe.py`` (``if get_gfx() == "gfx1250":`` nested inside an
outer ``and get_gfx() != "gfx1250":`` guard, present in the
aiter_new branch).  That block was already filtered out during commit 4
(the gfx1250 branch lives in a dedicated elif below the guard, so the
inner block had no reachable execution path and was dropped at copy
time).  Verified ``rg "if get_gfx\(\) == .gfx1250."`` returns only the
two legitimate top-level guards (``_moe_sorting_impl`` short-circuit
and ``get_2stage_cfgs`` direct dispatch).

Co-authored-by: Cursor <cursoragent@cursor.com>
@XingerZhu XingerZhu requested review from a team and Copilot May 11, 2026 07:40
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3122 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR integrates and wires up gfx1250-specific 2-stage MoE support by routing gfx1250 execution through FlyDSL kernel wrappers (including WMMA and mxscale variants), adding shape-alignment/padding utilities, and updating tests/bench infra to accommodate gfx1250 numerical and runtime characteristics.

Changes:

  • Add gfx1250 FlyDSL dispatch path in aiter.fused_moe (including gfx1250 quantization, sorting fallback behavior, and padding/alignment support).
  • Introduce/extend gfx1250 FlyDSL kernel implementations and shared utilities (pipeline helpers, WMMA kernels, common gfx1250 MoE helpers).
  • Update MoE 2-stage tests and perf/test utilities to better reflect gfx1250 numerical behavior and runtime overheads (graph default, tolerance/verdict logic, preshuffle helpers).

Reviewed changes

Copilot reviewed 12 out of 13 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
op_tests/test_moe_2stage.py Updates MoE 2-stage test sweep and reference/tolerance logic for gfx1250/FlyDSL paths; adds helper utilities and new CLI knobs.
aiter/test_common.py Makes testGraph default arch-aware and adjusts allclose boolean reduction to avoid gfx1250 hang.
aiter/ops/quant.py Adds a Triton per-1x32 fp8 e8m0 quantization kernel used by gfx1250 paths.
aiter/ops/flydsl/utils.py Strengthens FlyDSL availability detection by requiring loadable compiled extensions.
aiter/ops/flydsl/moe_kernels.py Adds mxscale tile selection + K-padding utilities and a VRAM-bounded pad cache to support gfx1250 mxscale constraints.
aiter/ops/flydsl/kernels/pipeline_utils.py Adds vendored pipeline utilities required by gfx1250 multi-buffer pipeline mode.
aiter/ops/flydsl/kernels/moe_gemm_2stage_wmma_gfx1250.py Introduces gfx1250 WMMA fp16/bf16 2-stage MoE kernels and compile entry points.
aiter/ops/flydsl/kernels/moe_gemm_2stage_common_gfx1250.py Adds shared gfx1250 MoE utilities (tiling, epilogues, mxscale loaders/pipeline planning, bias/activation helpers).
aiter/ops/flydsl/kernels/gemm_common_gfx1250.py Adds shared gfx1250 GEMM utilities (LDS helpers, pipeline fences, epilogue stores).
aiter/jit/optCompilerConfig.json Adds a gfx1250-related HIP compile define alongside existing FP8 enablement.
aiter/jit/core.py Improves pybind arg type-checking (torch.Tensor vs aiter_tensor_t) and guards develop-mode stream-setting logic.
aiter/fused_moe.py Adds gfx1250-specialized moe_sorting fallback and gfx1250 FlyDSL stage1/stage2 routing + quantization/padding/bias handling.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/quant.py
Comment on lines +42 to +66
DTYPE_MAX: tl.constexpr = 2.0**DTYPE_MAX_POW2

offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask_n = offs_n < N

x_offs = pid_m * stride_x_m + offs_n * stride_x_n
x = tl.load(x_ptr + x_offs, mask=mask_n, other=0.0).to(tl.float32)

amax = tl.max(tl.abs(x))
raw = amax / DTYPE_MAX

raw_i32 = raw.to(tl.int32, bitcast=True)
exponent = ((raw_i32 >> 23) & 0xFF).to(tl.uint8)
round_bit = (raw_i32 & 0x400000) > 0
sticky = ((raw_i32 & 0x200000) > 0) | ((raw_i32 & 0x1FFFFF) > 0) | (exponent > 0)
exponent = tl.where(round_bit & sticky, exponent + 1, exponent)

scale_f32 = (exponent.to(tl.int32) << 23).to(tl.float32, bitcast=True)
y = x / scale_f32
y = tl.clamp(y, -DTYPE_MAX, DTYPE_MAX)

y_offs = pid_m * stride_y_m + offs_n * stride_y_n
tl.store(y_ptr + y_offs, y.to(y_ptr.type.element_ty), mask=mask_n)

tl.store(scale_ptr + pid_m * stride_s_m + pid_n * stride_s_n, exponent)
Comment on lines +51 to +56
e8m0 scale, dtype_max=240) and immediately dequantise back to bf16.

This matches the activation precision the FlyDSL gfx1250 a8w4 / fp8
MoE GEMM kernel sees, so torch references that consume bf16 inputs
can be compared against the kernel's quantised output without the
K-sum noise overwhelming checkAllclose's atol=1e-2.
Comment thread aiter/fused_moe.py
Comment on lines +47 to +71
"""Pure-torch moe_sorting fallback for gfx1250 (vectorised, no host syncs).

Both ``aiter.moe_sorting_opus_fwd`` (precompiled opus kernel) and
``aiter.moe_sorting_fwd`` (CK-tile) are broken on gfx1250 today:

* The opus kernel submits to the HSA queue but never raises the
completion signal, so any subsequent op that synchronises against
that stream busy-waits inside ``rocr::core::InterruptSignal::WaitRelaxed``
forever.
* The CK-tile path calls into Composable Kernel which simply isn't
compiled for gfx1250 (NULL kernel pointer -> segfault).

The semantics match the reference ``moe_sorting_torch_native`` in
``aiter/op_tests/test_moe_sorting.py``:
* tokens belonging to expert ``e`` are placed contiguously inside
``sorted_ids``, padded to a multiple of ``block_size``;
* ``sorted_ids[i] = (topk_idx << 24) | token_idx`` for valid slots,
``init_val = (topk << 24) | M`` for padding slots;
* ``sorted_expert_ids[b] = local_expert_idx`` (i.e. expert id with
masked-out experts removed) for the b-th tile;
* ``num_valid_ids = [num_valid_sorted_slots, num_valid_input_tokens]``.

Implementation is fully vectorised on the GPU -- O(num_experts) python-side
sync, but those are at most a few cheap shape lookups, not 256 ``.item()``
+ ``torch.where`` calls per forward.
Comment thread aiter/fused_moe.py
Comment on lines +194 to +201
# gfx1250: prefer the prebuilt opus moe-sorting kernel when it works,
# because the pure-torch fallback (_moe_sorting_torch_gfx1250) computes
# padded slot counts that disagree with the FlyDSL stage2 kernel layout
# (it over-counts blocks by ~1% which leaves stage2 atomic_add writing
# to garbage rows -> output stays at zero). Only fall back to torch if
# the user explicitly disables opus (AITER_USE_OPUS_MOE_SORTING=0 +
# AITER_USE_CK_MOE_SORTING=1) or the opus kernel is not loadable on
# this build.
moe_gemm_2stage_common_gfx1250.py
- ``_Stage1GateUpPackedWrapper._get_packed_operands`` cache key:
  switch from ``(arg_w.data_ptr(), arg_scale_w.data_ptr() if hasattr ...)``
  to ``(id(arg_w), id(arg_scale_w))``. ``data_ptr()`` is undefined for
  Python-side wrapper objects (and for sub-byte packed views in some
  torch builds), and the previous ``hasattr`` fallback degenerated to
  ``id()`` for those branches anyway.  Use ``id()`` for both legs so the
  key is uniformly defined and matches the lifetime semantics the cache
  already relies on (the wrapper holds the operands alive).

moe_gemm_2stage_mxscale_gfx1250.py
- ``_issue_all_loads`` (stage1 + stage2 mxscale codegen): reorder the B
  TDM load and scalar (A scale / E8M0) loads based on ``is_fp4``.
  * fp4: scalar loads first, then B TDM. fp4 weight tile is half the
    bytes of a8w4/fp8 so the B-TDM issue isn't the long pole; issuing
    the scalar loads first lets the load fence move forward and the
    next-iter B descriptor get bumped sooner, giving the WMMA more
    overlap room.
  * fp8 / a8w4: keep the historical order (B TDM first, then scalars)
    -- the larger B tile dominates and benefits from being launched as
    early as possible.
  Two identical edits, one per stage.

moe_gemm_2stage_wmma_gfx1250.py
- ``_compile_stage2_wmma_kernel_impl``: add a ``_keep_const_expr_ref =
  const_expr`` no-op in the kernel body.  ASTRewriter strips
  ``const_expr(...)`` from ``if`` tests as part of its constant-folding
  pass; if every other reference to ``const_expr`` lives only inside an
  ``if const_expr(...):`` test, the rewrite leaves zero references in
  the rewritten body, ``co_freevars`` shrinks by one, and CPython
  rejects ``f.__code__ = new_f_code_o`` because the original
  ``__closure__`` length no longer matches the new code object's
  freevars count.  Keeping one explicit reference (annotated F841 to
  silence linters) preserves the freevars slot.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants