Skip to content

Commit 6508c77

Browse files
committed
switch MoRI dispatch-floor fix from full-file overlay to in-place patch (vendor image diverges from upstream)
The full-file moriep.py overlay crashed the scheduler at init: AttributeError: 'MoriEPDispatcher' object has no attribute 'expert_mask_gpu' RuntimeError: Rank 0 scheduler died during initialization Root cause of the failure: the lmsysorg/sglang-rocm:v0.5.12.post1 image ships a DOWNSTREAM-patched moriep.py (class MoriEPDispatcher, extra attrs like expert_mask_gpu) that diverges from the upstream v0.5.12.post1 tag. The overlay was byte-identical to the upstream tag (md5 ac626f5459...), so bind-mounting it reverted the AMD additions -> AttributeError. (The overlay DID mount and the floor DID fire -- "[MORI floor] num_max_dispatch_tokens_per_rank=32 < 256; clamping" -- so the fix value is right; only the delivery was wrong.) Fix: replace the overlay with patches/apply_moriep_dispatch_floor.py, a surgical in-place patch run by server_sglang.sh inside the container. It edits the image's own moriep.py, injecting `num_max_dispatch_tokens_per_rank = max(..., 256)` after the dispatch-token env read (line-based, balanced-paren end detection, class- agnostic, idempotent, fail-loud-but-non-fatal with a diagnostic dump of the image's actual source). This preserves all vendor downstream code. The fix value (256) is unchanged and proven (env-clamp run gsm8k 0.94). Upstream: sgl-project/sglang#27194, ROCm/mori#356.
1 parent 998408b commit 6508c77

6 files changed

Lines changed: 167 additions & 1171 deletions

File tree

benchmarks/multi_node/amd_utils/job.slurm

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,27 +79,6 @@ if [[ "${MORI_CONN_PATCH:-auto}" != "skip" ]] \
7979
echo "[job.slurm] auto-applied MoRI conn.py overlay: ${_MORI_PATCH_FILE}"
8080
fi
8181

82-
# ── MoRI dispatch-buffer corruption fix: moriep.py overlay ────────────
83-
# sglang v0.5.12.post1 silently corrupts the MoRI EP dispatch path when the
84-
# per-rank dispatch buffer (num_max_dispatch_tokens_per_rank) is small: the
85-
# receive buffer is sized worldSize*maxNumInpTokenPerRank and the only overflow
86-
# guard is an assert() compiled out in release builds, so low concurrency
87-
# (e.g. conc-64 DEP8+MTP3 -> 32 tokens) yields out-of-bounds writes and gsm8k=0.
88-
# The overlay floors num_max_dispatch_tokens_per_rank to 256 at its env read
89-
# (the single source of truth for kernel selection + buffer sizing). The base
90-
# file is byte-identical to upstream v0.5.12.post1 (md5 ac626f5459...), so the
91-
# overlay is a +22-line diff. See patches/README.md and sgl-project/sglang#27194.
92-
_MORIEP_PATCH_FILE="$DI_REPO_DIR/benchmarks/multi_node/amd_utils/patches/moriep.py"
93-
_MORIEP_PATCH_TARGET="/sgl-workspace/sglang/python/sglang/srt/layers/moe/token_dispatcher/moriep.py"
94-
if [[ "${MORIEP_PATCH:-auto}" != "skip" ]] \
95-
&& [[ -f "$_MORIEP_PATCH_FILE" ]] \
96-
&& [[ "${DOCKER_IMAGE_NAME:-}" == *"v0.5.12.post1"* ]] \
97-
&& [[ "${EXTRA_DOCKER_MOUNTS:-}" != *"$_MORIEP_PATCH_TARGET"* ]]; then
98-
EXTRA_DOCKER_MOUNTS="${EXTRA_DOCKER_MOUNTS:-} -v ${_MORIEP_PATCH_FILE}:${_MORIEP_PATCH_TARGET}:ro"
99-
export EXTRA_DOCKER_MOUNTS
100-
echo "[job.slurm] auto-applied MoRI moriep.py dispatch-floor overlay: ${_MORIEP_PATCH_FILE}"
101-
fi
102-
10382
xP="${xP:-1}"
10483
yD="${yD:-1}"
10584

benchmarks/multi_node/amd_utils/patches/README.md

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,26 @@ This is a stop-gap. The proper upstream fix is to migrate MoRI to the
6060
plural `state_types: List[StateType]` API (full design + diff in
6161
`scripts/sglang_disagg/docs/03-upstream-pr-proposal.md`).
6262

63-
## `moriep.py`
64-
65-
Overlays
66-
`/sgl-workspace/sglang/python/sglang/srt/layers/moe/token_dispatcher/moriep.py`.
67-
68-
Source: forked from `lmsysorg/sglang-rocm:v0.5.12.post1-*` (sglang
69-
[v0.5.12.post1](https://github.com/sgl-project/sglang/tree/v0.5.12.post1)).
70-
The base file is **byte-identical to the upstream tag**
71-
(`md5 ac626f5459a699f9ac953d9d8e71d861`); the overlay is a single
72-
+22-line insertion in `MoriTokenDispatcher.__init__`.
63+
## `apply_moriep_dispatch_floor.py` (in-place patch, NOT a bind-mount overlay)
64+
65+
This one is different from `mori_conn.py`: it is a **surgical in-place
66+
patch script**, not a full-file bind-mount overlay. It is run inside the
67+
container by `server_sglang.sh` (right after `env.sh`) and edits the
68+
installed
69+
`/sgl-workspace/sglang/.../token_dispatcher/moriep.py`
70+
in place, injecting a single floor after the dispatch-token env read.
71+
72+
**Why not a bind-mount overlay (learned the hard way):** the
73+
`lmsysorg/sglang-rocm:v0.5.12.post1-*` image ships a **downstream-patched
74+
`moriep.py`** (class `MoriEPDispatcher`, with attrs such as
75+
`expert_mask_gpu`) that diverges from the upstream
76+
[v0.5.12.post1](https://github.com/sgl-project/sglang/tree/v0.5.12.post1)
77+
tag. A full-file overlay of the upstream file (even one byte-identical to
78+
the tag, `md5 ac626f5459...`) reverts the AMD additions and crashes the
79+
scheduler at init: `AttributeError: 'MoriEPDispatcher' object has no
80+
attribute 'expert_mask_gpu'`. The in-place patch touches only the
81+
dispatch-token read and preserves all downstream code, so it is robust to
82+
the vendor fork.
7383

7484
**Bug it fixes:** at low concurrency the MoRI EP dispatch path silently
7585
corrupts output (decodes fine, acceptance length stays high, but gsm8k
@@ -85,19 +95,23 @@ guard is `assert(destTokId < MaxNumTokensToRecv())`, compiled out under
8595
`-DNDEBUG`, so the result is silent out-of-bounds writes
8696
(`internode_v1.cpp` `DispatchIntraNodeBlock`).
8797

88-
The overlay floors `num_max_dispatch_tokens_per_rank` to **256** right at
98+
The patch floors `num_max_dispatch_tokens_per_rank` to **256** right at
8999
its env read — the single source of truth that feeds both
90100
`get_ep_dispatch_configs()` (kernel selection) and the buffer-sizing
91-
arg. Empirically validated on MI355X (conc-64 DEP8+MTP3):
92-
dispatch `32 → gsm8k 0.00`, `64 → 0.00` (one wavefront is not enough),
93-
`256 → 0.94`.
101+
arg. It is idempotent and fail-loud-but-non-fatal (a structure miss prints
102+
a clear marker plus the surrounding source and lets the server proceed).
103+
Empirically validated on MI355X (conc-64 DEP8+MTP3): dispatch `32 →
104+
gsm8k 0.00`, `64 → 0.00` (one wavefront is not enough), `256 → 0.94`.
94105

95106
This is a stop-gap. The proper upstream fix is in MoRI: size the receive
96107
buffer from the routing fan-in and turn the compiled-out `assert` into a
97108
real bounds guard (see [ROCm/mori#356](https://github.com/ROCm/mori/issues/356)).
98109
The integration-level guard belongs in sglang's `moriep.py`
99110
([sgl-project/sglang#27194](https://github.com/sgl-project/sglang/issues/27194)) —
100-
this overlay is exactly that guard, pending upstream merge.
111+
this patch is exactly that guard, pending upstream merge. No
112+
`EXTRA_DOCKER_MOUNTS` wiring is needed; the patch is applied
113+
unconditionally by `server_sglang.sh` and no-ops when the value is
114+
already ≥256 (e.g. prefill, which uses 8192).
101115

102116
## How to enable
103117

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#!/usr/bin/env python3
2+
"""Surgically floor the MoRI per-rank dispatch buffer to >=256 in the installed
3+
sglang `moriep.py`, in place, inside the container.
4+
5+
Why in-place (not a bind-mount overlay): the lmsysorg/sglang-rocm image ships a
6+
*downstream-patched* moriep.py (class `MoriEPDispatcher`, extra attrs such as
7+
`expert_mask_gpu`) that diverges from the upstream v0.5.12.post1 tag. A full-file
8+
overlay of the upstream file reverts those AMD additions and crashes the
9+
scheduler at init (`AttributeError: ... 'expert_mask_gpu'`). So we patch the
10+
image's own file and touch only the dispatch-token read.
11+
12+
The bug being fixed: at low concurrency the per-rank dispatch buffer
13+
(num_max_dispatch_tokens_per_rank -> mori max_num_inp_token_per_rank) collapses
14+
(conc-64/TP8/MTP3 -> 64/8*4 = 32). MoRI sizes its receive buffer
15+
MaxNumTokensToRecv() = worldSize * maxNumInpTokenPerRank (dispatch_combine.hpp;
16+
max_total_recv_tokens defaults to 0 -> that fallback, and it is a cap not a
17+
floor). The intra-node dispatch kernel's per-dest atomic counter then overruns
18+
the buffer; the only guard is assert(destTokId < MaxNumTokensToRecv()) which is
19+
compiled out under -DNDEBUG -> silent out-of-bounds writes -> output that decodes
20+
fine (high acceptance length) but is semantically garbage (gsm8k=0).
21+
22+
Empirically on MI355X (conc-64 DEP8+MTP3): dispatch 32 -> gsm8k 0.00,
23+
64 -> 0.00 (one wavefront insufficient), 256 -> 0.94. We floor to 256.
24+
25+
Idempotent and fail-loud-but-non-fatal: a regex/structure miss prints a clear
26+
marker and the surrounding source (for diagnosis) but does not abort the server.
27+
28+
Upstream: sgl-project/sglang#27194, ROCm/mori#356.
29+
"""
30+
import os
31+
import re
32+
import sys
33+
34+
FLOOR = 256
35+
MARKER = "[InferenceX moriep dispatch floor]"
36+
TAG = "[moriep-floor]"
37+
38+
39+
def find_target():
40+
try:
41+
import sglang
42+
except Exception as e: # pragma: no cover
43+
print(f"{TAG} ERROR: could not import sglang ({e}); NOT patched")
44+
return None
45+
path = os.path.join(
46+
os.path.dirname(sglang.__file__),
47+
"srt", "layers", "moe", "token_dispatcher", "moriep.py",
48+
)
49+
if not os.path.isfile(path):
50+
print(f"{TAG} ERROR: moriep.py not found at {path}; NOT patched")
51+
return None
52+
return path
53+
54+
55+
def main():
56+
path = find_target()
57+
if path is None:
58+
return 0 # non-fatal
59+
60+
with open(path) as f:
61+
src = f.read()
62+
lines = src.splitlines(keepends=True)
63+
64+
# Diagnostic: always show where the dispatch-token count is read/used so the
65+
# CI log reveals the image's actual file shape even on a clean apply.
66+
for i, l in enumerate(lines):
67+
if "num_max_dispatch_tokens_per_rank" in l:
68+
print(f"{TAG}[diag] {path}:{i + 1}: {l.rstrip()}")
69+
70+
if MARKER in src:
71+
print(f"{TAG} already applied; skipping")
72+
return 0
73+
74+
# Find the assignment that reads the env var, regardless of class name or
75+
# formatting: `self.num_max_dispatch_tokens_per_rank = get_int_env_var(`.
76+
start = None
77+
for i, l in enumerate(lines):
78+
if re.search(
79+
r"self\.num_max_dispatch_tokens_per_rank\s*=\s*get_int_env_var\s*\(",
80+
l,
81+
):
82+
start = i
83+
break
84+
if start is None:
85+
print(
86+
f"{TAG} ERROR: dispatch-token env read not found in {path}; "
87+
f"NOT patched (server will run UNPATCHED -> expect corruption at "
88+
f"low conc). See [diag] lines above for the actual source shape."
89+
)
90+
return 0 # non-fatal: surface loudly but let the run proceed
91+
92+
# Walk forward to the end of the (possibly multi-line) call by balancing parens.
93+
depth = 0
94+
end = start
95+
for j in range(start, len(lines)):
96+
depth += lines[j].count("(") - lines[j].count(")")
97+
if depth <= 0:
98+
end = j
99+
break
100+
101+
indent = re.match(r"\s*", lines[start]).group(0)
102+
floor_block = (
103+
f"{indent}# {MARKER} floor to {FLOOR} (warpSize/fan-in safe). MoRI recv buffer\n"
104+
f"{indent}# is worldSize*maxNumInpTokenPerRank; values below {FLOOR} silently\n"
105+
f"{indent}# corrupt the dispatch path (gsm8k=0). sgl#27194 / mori#356.\n"
106+
f"{indent}self.num_max_dispatch_tokens_per_rank = max(\n"
107+
f"{indent} self.num_max_dispatch_tokens_per_rank, {FLOOR}\n"
108+
f"{indent})\n"
109+
)
110+
lines.insert(end + 1, floor_block)
111+
112+
try:
113+
with open(path, "w") as f:
114+
f.write("".join(lines))
115+
except OSError as e:
116+
print(f"{TAG} ERROR: could not write {path} ({e}); NOT patched")
117+
return 0
118+
119+
print(
120+
f"{TAG} applied: floored num_max_dispatch_tokens_per_rank to >= {FLOOR} "
121+
f"in {path} (after line {end + 1})"
122+
)
123+
return 0
124+
125+
126+
if __name__ == "__main__":
127+
sys.exit(main())

0 commit comments

Comments
 (0)