diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 1ad705468..cb075b438 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1986,123 +1986,10 @@ dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp: - isl: 8192 osl: 1024 search-space: - # MTP configurations - # 1P1D pure TP8 - - spec-decoding: "mtp" - conc-list: [ 1, 2, 4, 8 ] - prefill: - num-worker: 1 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 1 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=3" - - # 1P2D TP8 - - spec-decoding: "mtp" - conc-list: [ 2, 4, 8, 16, 32 ] - prefill: - num-worker: 1 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 2 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "DECODE_NODES=2" - - "DECODE_MTP_SIZE=3" - - # 1P2D TP8 - - spec-decoding: "mtp" - conc-list: [ 32, 64 ] - prefill: - num-worker: 1 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 2 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "DECODE_NODES=2" - - "DECODE_MTP_SIZE=3" - - # 1*DEP8 + 1*DEP8 - - spec-decoding: "mtp" - conc-list: [ 640, 512 ] - prefill: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=3" - - # 1*DEP8 + 1*DEP8 - - spec-decoding: "mtp" - conc-list: [ 256 ] - prefill: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=3" - - - # 1*DEP8 + 1*DEP8 - - spec-decoding: "mtp" - conc-list: [ 128 ] - prefill: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=3" - - # 1*DEP8 + 1*DEP8 + # THROWAWAY (not for merge): conc-64 only DEP8+MTP3 to reproduce + # SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK < 256 corruption. + # max(CONC_LIST)=64 → dispatch_tokens=64/8*4=32 → broken All2All kernel. + # 1*DEP8 + 1*DEP8, MTP3 - spec-decoding: "mtp" conc-list: [ 64 ] prefill: @@ -2121,25 +2008,6 @@ dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp: - "DECODE_NODES=1" - "DECODE_MTP_SIZE=3" - # 2*DEP8 + 1*DEP8 - - spec-decoding: "mtp" - conc-list: [ 1024, 2048, 4096 ] - prefill: - num-worker: 2 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "PREFILL_NODES=2" - decode: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=1" - # DSv4-Pro FP4 on MI355X via SGLang. Uses a rocm720 mi35x image built off the # amd/deepseek_v4 branch in sgl-project/sglang; the SHA is encoded in the diff --git a/benchmarks/multi_node/amd_utils/patches/README.md b/benchmarks/multi_node/amd_utils/patches/README.md index d9b5de79d..97ab47d26 100644 --- a/benchmarks/multi_node/amd_utils/patches/README.md +++ b/benchmarks/multi_node/amd_utils/patches/README.md @@ -60,6 +60,59 @@ This is a stop-gap. The proper upstream fix is to migrate MoRI to the plural `state_types: List[StateType]` API (full design + diff in `scripts/sglang_disagg/docs/03-upstream-pr-proposal.md`). +## `apply_moriep_dispatch_floor.py` (in-place patch, NOT a bind-mount overlay) + +This one is different from `mori_conn.py`: it is a **surgical in-place +patch script**, not a full-file bind-mount overlay. It is run inside the +container by `server_sglang.sh` (right after `env.sh`) and edits the +installed +`/sgl-workspace/sglang/.../token_dispatcher/moriep.py` +in place, injecting a single floor after the dispatch-token env read. + +**Why not a bind-mount overlay (learned the hard way):** the +`lmsysorg/sglang-rocm:v0.5.12.post1-*` image ships a **downstream-patched +`moriep.py`** (class `MoriEPDispatcher`, with attrs such as +`expert_mask_gpu`) that diverges from the upstream +[v0.5.12.post1](https://github.com/sgl-project/sglang/tree/v0.5.12.post1) +tag. A full-file overlay of the upstream file (even one byte-identical to +the tag, `md5 ac626f5459...`) reverts the AMD additions and crashes the +scheduler at init: `AttributeError: 'MoriEPDispatcher' object has no +attribute 'expert_mask_gpu'`. The in-place patch touches only the +dispatch-token read and preserves all downstream code, so it is robust to +the vendor fork. + +**Bug it fixes:** at low concurrency the MoRI EP dispatch path silently +corrupts output (decodes fine, acceptance length stays high, but gsm8k +drops to 0). The per-rank dispatch buffer +`num_max_dispatch_tokens_per_rank` (→ mori `max_num_inp_token_per_rank`) +is derived by the harness as `max(CONC_LIST)/TP*(MTP+1)`, which collapses +at low conc (conc-64 / TP8 / MTP3 → `64/8*4 = 32`). MoRI sizes its +receive buffer `MaxNumTokensToRecv() = worldSize * maxNumInpTokenPerRank` +(`max_total_recv_tokens` defaults to 0 → that fallback, and it is a *cap* +not a floor — `dispatch_combine.hpp:126-136`). The intra-node dispatch +kernel's per-dest atomic counter then runs past that buffer; the only +guard is `assert(destTokId < MaxNumTokensToRecv())`, compiled out under +`-DNDEBUG`, so the result is silent out-of-bounds writes +(`internode_v1.cpp` `DispatchIntraNodeBlock`). + +The patch floors `num_max_dispatch_tokens_per_rank` to **256** right at +its env read — the single source of truth that feeds both +`get_ep_dispatch_configs()` (kernel selection) and the buffer-sizing +arg. It is idempotent and fail-loud-but-non-fatal (a structure miss prints +a clear marker plus the surrounding source and lets the server proceed). +Empirically validated on MI355X (conc-64 DEP8+MTP3): dispatch `32 → +gsm8k 0.00`, `64 → 0.00` (one wavefront is not enough), `256 → 0.94`. + +This is a stop-gap. The proper upstream fix is in MoRI: size the receive +buffer from the routing fan-in and turn the compiled-out `assert` into a +real bounds guard (see [ROCm/mori#356](https://github.com/ROCm/mori/issues/356)). +The integration-level guard belongs in sglang's `moriep.py` +([sgl-project/sglang#27194](https://github.com/sgl-project/sglang/issues/27194)) — +this patch is exactly that guard, pending upstream merge. No +`EXTRA_DOCKER_MOUNTS` wiring is needed; the patch is applied +unconditionally by `server_sglang.sh` and no-ops when the value is +already ≥256 (e.g. prefill, which uses 8192). + ## How to enable ```bash diff --git a/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py b/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py new file mode 100644 index 000000000..c2c8f5ecb --- /dev/null +++ b/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +"""Surgically floor the MoRI per-rank dispatch buffer to >=256 in the installed +sglang `moriep.py`, in place, inside the container. + +Why in-place (not a bind-mount overlay): the lmsysorg/sglang-rocm image ships a +*downstream-patched* moriep.py (class `MoriEPDispatcher`, extra attrs such as +`expert_mask_gpu`) that diverges from the upstream v0.5.12.post1 tag. A full-file +overlay of the upstream file reverts those AMD additions and crashes the +scheduler at init (`AttributeError: ... 'expert_mask_gpu'`). So we patch the +image's own file and touch only the dispatch-token read. + +The bug being fixed: at low concurrency the per-rank dispatch buffer +(num_max_dispatch_tokens_per_rank -> mori max_num_inp_token_per_rank) collapses +(conc-64/TP8/MTP3 -> 64/8*4 = 32). MoRI sizes its receive buffer +MaxNumTokensToRecv() = worldSize * maxNumInpTokenPerRank (dispatch_combine.hpp; +max_total_recv_tokens defaults to 0 -> that fallback, and it is a cap not a +floor). The intra-node dispatch kernel's per-dest atomic counter then overruns +the buffer; the only guard is assert(destTokId < MaxNumTokensToRecv()) which is +compiled out under -DNDEBUG -> silent out-of-bounds writes -> output that decodes +fine (high acceptance length) but is semantically garbage (gsm8k=0). + +Empirically on MI355X (conc-64 DEP8+MTP3): dispatch 32 -> gsm8k 0.00, +64 -> 0.00 (one wavefront insufficient), 256 -> 0.94. We floor to 256. + +Idempotent and fail-loud-but-non-fatal: a regex/structure miss prints a clear +marker and the surrounding source (for diagnosis) but does not abort the server. + +Upstream: sgl-project/sglang#27194, ROCm/mori#356. +""" +import os +import re +import sys + +FLOOR = 256 +MARKER = "[InferenceX moriep dispatch floor]" +TAG = "[moriep-floor]" + + +def find_target(): + try: + import sglang + except Exception as e: # pragma: no cover + print(f"{TAG} ERROR: could not import sglang ({e}); NOT patched") + return None + + # sglang may be a namespace package (no __init__.py) where __file__ is + # None. Fall through several strategies to locate the package root. + pkg_dir = None + if getattr(sglang, "__file__", None) is not None: + pkg_dir = os.path.dirname(sglang.__file__) + elif getattr(sglang, "__path__", None): + pkg_dir = list(sglang.__path__)[0] + else: + try: + import importlib.util + spec = importlib.util.find_spec("sglang") + if spec and spec.submodule_search_locations: + pkg_dir = list(spec.submodule_search_locations)[0] + except Exception: + pass + + if pkg_dir is None: + print(f"{TAG} ERROR: could not determine sglang install path " + f"(__file__={getattr(sglang, '__file__', '?')}, " + f"__path__={getattr(sglang, '__path__', '?')}); NOT patched") + return None + + rel = os.path.join("srt", "layers", "moe", "token_dispatcher", "moriep.py") + candidates = [ + os.path.join(pkg_dir, rel), + os.path.join(pkg_dir, "python", "sglang", rel), + ] + for path in candidates: + if os.path.isfile(path): + return path + + # Last resort: walk the tree (bounded to 6 levels to avoid scanning /). + for root, _dirs, files in os.walk(pkg_dir): + if root.count(os.sep) - pkg_dir.count(os.sep) > 6: + _dirs.clear() + continue + if "moriep.py" in files: + found = os.path.join(root, "moriep.py") + print(f"{TAG} found moriep.py via walk: {found}") + return found + + print(f"{TAG} ERROR: moriep.py not found under {pkg_dir} " + f"(tried {candidates}); NOT patched") + return None + + +def main(): + path = find_target() + if path is None: + return 0 # non-fatal + + with open(path) as f: + src = f.read() + lines = src.splitlines(keepends=True) + + # Diagnostic: always show where the dispatch-token count is read/used so the + # CI log reveals the image's actual file shape even on a clean apply. + for i, l in enumerate(lines): + if "num_max_dispatch_tokens_per_rank" in l: + print(f"{TAG}[diag] {path}:{i + 1}: {l.rstrip()}") + + if MARKER in src: + print(f"{TAG} already applied; skipping") + return 0 + + # Find the assignment that reads the env var, regardless of class name or + # formatting: `self.num_max_dispatch_tokens_per_rank = get_int_env_var(`. + start = None + for i, l in enumerate(lines): + if re.search( + r"self\.num_max_dispatch_tokens_per_rank\s*=\s*get_int_env_var\s*\(", + l, + ): + start = i + break + if start is None: + print( + f"{TAG} ERROR: dispatch-token env read not found in {path}; " + f"NOT patched (server will run UNPATCHED -> expect corruption at " + f"low conc). See [diag] lines above for the actual source shape." + ) + return 0 # non-fatal: surface loudly but let the run proceed + + # Walk forward to the end of the (possibly multi-line) call by balancing parens. + depth = 0 + end = start + for j in range(start, len(lines)): + depth += lines[j].count("(") - lines[j].count(")") + if depth <= 0: + end = j + break + + indent = re.match(r"\s*", lines[start]).group(0) + floor_block = ( + f"{indent}# {MARKER} floor to {FLOOR} (warpSize/fan-in safe). MoRI recv buffer\n" + f"{indent}# is worldSize*maxNumInpTokenPerRank; values below {FLOOR} silently\n" + f"{indent}# corrupt the dispatch path (gsm8k=0). sgl#27194 / mori#356.\n" + f"{indent}self.num_max_dispatch_tokens_per_rank = max(\n" + f"{indent} self.num_max_dispatch_tokens_per_rank, {FLOOR}\n" + f"{indent})\n" + ) + lines.insert(end + 1, floor_block) + + try: + with open(path, "w") as f: + f.write("".join(lines)) + except OSError as e: + print(f"{TAG} ERROR: could not write {path} ({e}); NOT patched") + return 0 + + print( + f"{TAG} applied: floored num_max_dispatch_tokens_per_rank to >= {FLOOR} " + f"in {path} (after line {end + 1})" + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/benchmarks/multi_node/amd_utils/server_sglang.sh b/benchmarks/multi_node/amd_utils/server_sglang.sh index c28ccab41..f0bfb9d6f 100755 --- a/benchmarks/multi_node/amd_utils/server_sglang.sh +++ b/benchmarks/multi_node/amd_utils/server_sglang.sh @@ -49,6 +49,16 @@ GPUS_PER_NODE="${GPUS_PER_NODE:-8}" source $SGLANG_WS_PATH/setup_deps.sh source $SGLANG_WS_PATH/env.sh +# Root-cause fix for low-concurrency MoRI dispatch-buffer corruption: surgically +# floor num_max_dispatch_tokens_per_rank to >=256 in the installed (vendor-patched) +# sglang moriep.py, in place, before any sglang.launch_server starts. A full-file +# overlay can't be used here because the lmsysorg image ships a downstream-patched +# moriep.py (class MoriEPDispatcher / expert_mask_gpu) that diverges from upstream. +# See patches/apply_moriep_dispatch_floor.py and patches/README.md. +echo "[server_sglang] applying MoRI dispatch-floor patch to installed sglang moriep.py" +python3 "$SGLANG_WS_PATH/patches/apply_moriep_dispatch_floor.py" \ + || echo "[server_sglang] WARN: moriep dispatch-floor patch returned non-zero" + host_ip=$(ip route get 1.1.1.1 | awk '/src/ {print $7}') host_name=$(hostname) @@ -213,7 +223,9 @@ fi if [[ "$DECODE_ENABLE_DP" == "true" ]] && [[ "$DECODE_ENABLE_EP" == "true" ]]; then decode_max_running_requests=$BENCH_MAX_CONC_VALUE decode_dp_ranks=$DECODE_TP_SIZE - MORI_MAX_DISPATCH_TOKENS_DECODE=$((BENCH_MAX_CONC_VALUE / decode_dp_ranks)) + # --max-running-requests is PER DP RANK (not global); each rank can hold + # up to BENCH_MAX_CONC_VALUE requests, so dispatch tokens = that capacity. + MORI_MAX_DISPATCH_TOKENS_DECODE=$BENCH_MAX_CONC_VALUE MORI_MOE_MAX_INPUT_TOKENS_DECODE=$((MORI_MAX_DISPATCH_TOKENS_DECODE * decode_dp_ranks * 7 / 10)) # Update derived variable SGLANG_MORI_DISPATCH_INTER_KERNEL_SWITCH_THRESHOLD=$((MORI_MAX_DISPATCH_TOKENS_DECODE * 2)) @@ -248,6 +260,12 @@ if [[ "$DECODE_MTP_SIZE" -gt 0 ]]; then MORI_MOE_MAX_INPUT_TOKENS_DECODE=$((MORI_MOE_MAX_INPUT_TOKENS_DECODE * (DECODE_MTP_SIZE + 1))) fi +# NOTE: the low-concurrency MoRI dispatch-buffer corruption (small +# SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK -> silent OOB -> gsm8k=0) is fixed +# at the root cause by the moriep.py overlay (patches/moriep.py, auto-mounted by +# job.slurm), which floors num_max_dispatch_tokens_per_rank to 256 inside sglang. +# The earlier harness-level env clamp here has been removed in favor of that. + # ============================================================================= # Cluster Topology Configuration # ============================================================================= diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 4fd5d53ec..cd2f8d1f1 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3452,6 +3452,12 @@ - "Add 1k1k/8k1k FP8 recipe set under benchmarks/multi_node/srt-slurm-recipes/vllm/minimax-m2.5-fp8/" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1647 +- config-keys: + - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp + description: + - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with an in-place patch (patches/apply_moriep_dispatch_floor.py, run by server_sglang.sh) that floors num_max_dispatch_tokens_per_rank to 256 inside the installed sglang moriep.py. NOTE a full-file overlay was tried first and crashed the scheduler (AttributeError: MoriEPDispatcher has no attribute expert_mask_gpu) because the lmsysorg image ships a downstream-patched moriep.py that diverges from the upstream v0.5.12.post1 tag; the surgical in-place patch preserves the vendor fork. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the value collapses to 32, overrunning the dispatch kernel's receive slots (only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged. Upstream: sgl-project/sglang#27194, ROCm/mori#356. Also fixes the root-cause harness formula (BENCH_MAX_CONC_VALUE/dp_ranks was wrong: --max-running-requests is per-DP-rank, not global) and patcher path resolution for vendor image layout." + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 + - config-keys: - minimaxm2.5-fp8-gb200-dynamo-vllm description: