Skip to content

Add FMoE run_config mismatch diagnostics#3111

Open
zihaomu wants to merge 1 commit into
ROCm:mainfrom
zihaomu:pr/aiter-fmoe-run-config-mismatch-diagnostics
Open

Add FMoE run_config mismatch diagnostics#3111
zihaomu wants to merge 1 commit into
ROCm:mainfrom
zihaomu:pr/aiter-fmoe-run-config-mismatch-diagnostics

Conversation

@zihaomu
Copy link
Copy Markdown
Member

@zihaomu zihaomu commented May 11, 2026

Motivation

Add compact diagnostics for FMoE run_config compare failures, including mismatch counts, absolute error stats, norms, and top mismatched values.

Technical Details

  • Add tensor_compare_diagnostics() for compact failure summaries.
  • Report output/reference shape, dtype, and numel.
  • Sample up to 4096 evenly distributed elements instead of scanning the full tensor.
  • Report sampled mismatch count, sampled max absolute error, sampled mean absolute error, and a few mismatch
    examples.
  • Reuse the existing checkAllclose() err ratio instead of recomputing a full-tensor comparison.
  • Only run diagnostics on failure paths, including all-zero output and mismatch status.
  • Wrap diagnostics in try/except so diagnostic failures do not crash the tuner.

Befor this patch:

When your met FMoE tuner error, you get like this:

mismatch:err_ratio=0.9970(>0.5)

With this patch:

  err_ratio=...
  mismatch=...
  max_abs=...
  mean_abs=...
  out_norm=...
  ref_norm=...
  top=[idx:out=...,ref=...,abs=...]

Test Result

Submission Checklist

@zihaomu zihaomu requested review from a team and Copilot May 11, 2026 02:22
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3111 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds compact, failure-only diagnostics to the FMoE run_config validation path so that tuner “mismatch” and “all-zero output” failures include actionable tensor comparison summaries rather than only an err_ratio.

Changes:

  • Introduces tensor_compare_diagnostics() to summarize shape/dtype/numel plus sampled mismatch/error statistics and a few example mismatches.
  • Enhances run_config() failure statuses (“all zeros” and “mismatch”) to append the new diagnostics output.
  • Keeps the existing checkAllclose()-computed error ratio as the primary pass/fail metric.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +134 to +158
parts = [
f"shape=out{tuple(res.shape)},ref{tuple(ref.shape)}",
f"dtype=out={res.dtype},ref={ref.dtype}",
f"numel=out={res_numel},ref={ref_numel}",
]

if sample_count:
sample_idx = torch.arange(
sample_count, device=ref_flat.device, dtype=torch.long
)
sample_idx = sample_idx * total // sample_count
ref_f = ref_flat[sample_idx].float()
res_f = res_flat[sample_idx].float()
delta = (res_f - ref_f).abs()
close = torch.isclose(res_f, ref_f, rtol=rtol, atol=atol)
mismatch = ~close

sample_mismatch = int(mismatch.sum().item())
parts.extend(
[
f"sampled={sample_count}/{total}",
f"sample_mismatch={sample_mismatch}/{sample_count}",
f"sample_max_abs={float(delta.max().item()):.6g}",
f"sample_mean_abs={float(delta.mean().item()):.6g}",
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants