Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
3b9e810
throwaway: conc-64 gsm8k eval for DEP8+MTP3 to reproduce dispatch tok…
Oseltamivir Jun 3, 2026
45f69f5
trigger sweep
Oseltamivir Jun 3, 2026
9983cc0
add dispatch token clamp (>=256) and run benchmark+eval at conc-64
Oseltamivir Jun 3, 2026
139f646
revert clamp for Run B (without fix) benchmark+eval at conc-64
Oseltamivir Jun 3, 2026
906a9ae
sed-patch moriep.py to clamp dispatch tokens >= 64 (warpSize) and run…
Oseltamivir Jun 3, 2026
12dadc1
clamp MoRI dispatch tokens to warpSize floor (64) and run benchmark+e…
Oseltamivir Jun 3, 2026
04ceb08
Merge remote-tracking branch 'origin/main' into dsr1-dep8-mtp3-conc64…
Oseltamivir Jun 3, 2026
272be18
raise MoRI dispatch-buffer floor to 256 (warpSize=64 proven insuffici…
Oseltamivir Jun 3, 2026
998408b
fix MoRI dispatch corruption at the root: moriep.py overlay floors di…
Oseltamivir Jun 4, 2026
6508c77
switch MoRI dispatch-floor fix from full-file overlay to in-place pat…
Oseltamivir Jun 4, 2026
79bb67a
fix moriep dispatch-floor patcher crash when sglang.__file__ is None
Oseltamivir Jun 4, 2026
9b50d69
note namespace-package patcher fix in changelog (re-trigger sweep)
Oseltamivir Jun 4, 2026
ddaae82
Merge branch 'main' into dsr1-dep8-mtp3-conc64-eval
Oseltamivir Jun 4, 2026
9fb87c1
fix moriep patcher path: handle vendor image python/sglang/ layout
Oseltamivir Jun 4, 2026
59dd6c3
update changelog: note patcher path fix (re-trigger sweep)
Oseltamivir Jun 4, 2026
fe96b05
fix MoRI dispatch token formula: --max-running-requests is per DP rank
Oseltamivir Jun 4, 2026
1fdb89f
update changelog: root-cause formula fix (re-trigger sweep)
Oseltamivir Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 4 additions & 136 deletions .github/configs/amd-master.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1986,123 +1986,10 @@ dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp:
- isl: 8192
osl: 1024
search-space:
# MTP configurations
# 1P1D pure TP8
- spec-decoding: "mtp"
conc-list: [ 1, 2, 4, 8 ]
prefill:
num-worker: 1
tp: 8
ep: 1
dp-attn: false
additional-settings:
- "PREFILL_NODES=1"
decode:
num-worker: 1
tp: 8
ep: 1
dp-attn: false
additional-settings:
- "DECODE_NODES=1"
- "DECODE_MTP_SIZE=3"

# 1P2D TP8
Comment thread
cursor[bot] marked this conversation as resolved.
- spec-decoding: "mtp"
conc-list: [ 2, 4, 8, 16, 32 ]
prefill:
num-worker: 1
tp: 8
ep: 1
dp-attn: false
additional-settings:
- "PREFILL_NODES=1"
decode:
num-worker: 2
tp: 8
ep: 1
dp-attn: false
additional-settings:
- "DECODE_NODES=2"
- "DECODE_MTP_SIZE=3"

# 1P2D TP8
- spec-decoding: "mtp"
conc-list: [ 32, 64 ]
prefill:
num-worker: 1
tp: 8
ep: 1
dp-attn: false
additional-settings:
- "PREFILL_NODES=1"
decode:
num-worker: 2
tp: 8
ep: 1
dp-attn: false
additional-settings:
- "DECODE_NODES=2"
- "DECODE_MTP_SIZE=3"

# 1*DEP8 + 1*DEP8
- spec-decoding: "mtp"
conc-list: [ 640, 512 ]
prefill:
num-worker: 1
tp: 8
ep: 8
dp-attn: true
additional-settings:
- "PREFILL_NODES=1"
decode:
num-worker: 1
tp: 8
ep: 8
dp-attn: true
additional-settings:
- "DECODE_NODES=1"
- "DECODE_MTP_SIZE=3"

# 1*DEP8 + 1*DEP8
- spec-decoding: "mtp"
conc-list: [ 256 ]
prefill:
num-worker: 1
tp: 8
ep: 8
dp-attn: true
additional-settings:
- "PREFILL_NODES=1"
decode:
num-worker: 1
tp: 8
ep: 8
dp-attn: true
additional-settings:
- "DECODE_NODES=1"
- "DECODE_MTP_SIZE=3"


# 1*DEP8 + 1*DEP8
- spec-decoding: "mtp"
conc-list: [ 128 ]
prefill:
num-worker: 1
tp: 8
ep: 8
dp-attn: true
additional-settings:
- "PREFILL_NODES=1"
decode:
num-worker: 1
tp: 8
ep: 8
dp-attn: true
additional-settings:
- "DECODE_NODES=1"
- "DECODE_MTP_SIZE=3"

# 1*DEP8 + 1*DEP8
# THROWAWAY (not for merge): conc-64 only DEP8+MTP3 to reproduce
# SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK < 256 corruption.
# max(CONC_LIST)=64 → dispatch_tokens=64/8*4=32 → broken All2All kernel.
# 1*DEP8 + 1*DEP8, MTP3
- spec-decoding: "mtp"
conc-list: [ 64 ]
prefill:
Expand All @@ -2121,25 +2008,6 @@ dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp:
- "DECODE_NODES=1"
- "DECODE_MTP_SIZE=3"

# 2*DEP8 + 1*DEP8
- spec-decoding: "mtp"
conc-list: [ 1024, 2048, 4096 ]
prefill:
num-worker: 2
tp: 8
ep: 8
dp-attn: true
additional-settings:
- "PREFILL_NODES=2"
decode:
num-worker: 1
tp: 8
ep: 8
dp-attn: true
additional-settings:
- "DECODE_NODES=1"
- "DECODE_MTP_SIZE=1"


# DSv4-Pro FP4 on MI355X via SGLang. Uses a rocm720 mi35x image built off the
# amd/deepseek_v4 branch in sgl-project/sglang; the SHA is encoded in the
Expand Down
53 changes: 53 additions & 0 deletions benchmarks/multi_node/amd_utils/patches/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,59 @@ This is a stop-gap. The proper upstream fix is to migrate MoRI to the
plural `state_types: List[StateType]` API (full design + diff in
`scripts/sglang_disagg/docs/03-upstream-pr-proposal.md`).

## `apply_moriep_dispatch_floor.py` (in-place patch, NOT a bind-mount overlay)

This one is different from `mori_conn.py`: it is a **surgical in-place
patch script**, not a full-file bind-mount overlay. It is run inside the
container by `server_sglang.sh` (right after `env.sh`) and edits the
installed
`/sgl-workspace/sglang/.../token_dispatcher/moriep.py`
in place, injecting a single floor after the dispatch-token env read.

**Why not a bind-mount overlay (learned the hard way):** the
`lmsysorg/sglang-rocm:v0.5.12.post1-*` image ships a **downstream-patched
`moriep.py`** (class `MoriEPDispatcher`, with attrs such as
`expert_mask_gpu`) that diverges from the upstream
[v0.5.12.post1](https://github.com/sgl-project/sglang/tree/v0.5.12.post1)
tag. A full-file overlay of the upstream file (even one byte-identical to
the tag, `md5 ac626f5459...`) reverts the AMD additions and crashes the
scheduler at init: `AttributeError: 'MoriEPDispatcher' object has no
attribute 'expert_mask_gpu'`. The in-place patch touches only the
dispatch-token read and preserves all downstream code, so it is robust to
the vendor fork.

**Bug it fixes:** at low concurrency the MoRI EP dispatch path silently
corrupts output (decodes fine, acceptance length stays high, but gsm8k
drops to 0). The per-rank dispatch buffer
`num_max_dispatch_tokens_per_rank` (→ mori `max_num_inp_token_per_rank`)
is derived by the harness as `max(CONC_LIST)/TP*(MTP+1)`, which collapses
at low conc (conc-64 / TP8 / MTP3 → `64/8*4 = 32`). MoRI sizes its
receive buffer `MaxNumTokensToRecv() = worldSize * maxNumInpTokenPerRank`
(`max_total_recv_tokens` defaults to 0 → that fallback, and it is a *cap*
not a floor — `dispatch_combine.hpp:126-136`). The intra-node dispatch
kernel's per-dest atomic counter then runs past that buffer; the only
guard is `assert(destTokId < MaxNumTokensToRecv())`, compiled out under
`-DNDEBUG`, so the result is silent out-of-bounds writes
(`internode_v1.cpp` `DispatchIntraNodeBlock`).

The patch floors `num_max_dispatch_tokens_per_rank` to **256** right at
its env read — the single source of truth that feeds both
`get_ep_dispatch_configs()` (kernel selection) and the buffer-sizing
arg. It is idempotent and fail-loud-but-non-fatal (a structure miss prints
a clear marker plus the surrounding source and lets the server proceed).
Empirically validated on MI355X (conc-64 DEP8+MTP3): dispatch `32 →
gsm8k 0.00`, `64 → 0.00` (one wavefront is not enough), `256 → 0.94`.

This is a stop-gap. The proper upstream fix is in MoRI: size the receive
buffer from the routing fan-in and turn the compiled-out `assert` into a
real bounds guard (see [ROCm/mori#356](https://github.com/ROCm/mori/issues/356)).
The integration-level guard belongs in sglang's `moriep.py`
([sgl-project/sglang#27194](https://github.com/sgl-project/sglang/issues/27194)) —
this patch is exactly that guard, pending upstream merge. No
`EXTRA_DOCKER_MOUNTS` wiring is needed; the patch is applied
unconditionally by `server_sglang.sh` and no-ops when the value is
already ≥256 (e.g. prefill, which uses 8192).

## How to enable

```bash
Expand Down
164 changes: 164 additions & 0 deletions benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#!/usr/bin/env python3
"""Surgically floor the MoRI per-rank dispatch buffer to >=256 in the installed
sglang `moriep.py`, in place, inside the container.

Why in-place (not a bind-mount overlay): the lmsysorg/sglang-rocm image ships a
*downstream-patched* moriep.py (class `MoriEPDispatcher`, extra attrs such as
`expert_mask_gpu`) that diverges from the upstream v0.5.12.post1 tag. A full-file
overlay of the upstream file reverts those AMD additions and crashes the
scheduler at init (`AttributeError: ... 'expert_mask_gpu'`). So we patch the
image's own file and touch only the dispatch-token read.

The bug being fixed: at low concurrency the per-rank dispatch buffer
(num_max_dispatch_tokens_per_rank -> mori max_num_inp_token_per_rank) collapses
(conc-64/TP8/MTP3 -> 64/8*4 = 32). MoRI sizes its receive buffer
MaxNumTokensToRecv() = worldSize * maxNumInpTokenPerRank (dispatch_combine.hpp;
max_total_recv_tokens defaults to 0 -> that fallback, and it is a cap not a
floor). The intra-node dispatch kernel's per-dest atomic counter then overruns
the buffer; the only guard is assert(destTokId < MaxNumTokensToRecv()) which is
compiled out under -DNDEBUG -> silent out-of-bounds writes -> output that decodes
fine (high acceptance length) but is semantically garbage (gsm8k=0).

Empirically on MI355X (conc-64 DEP8+MTP3): dispatch 32 -> gsm8k 0.00,
64 -> 0.00 (one wavefront insufficient), 256 -> 0.94. We floor to 256.

Idempotent and fail-loud-but-non-fatal: a regex/structure miss prints a clear
marker and the surrounding source (for diagnosis) but does not abort the server.

Upstream: sgl-project/sglang#27194, ROCm/mori#356.
"""
import os
import re
import sys

FLOOR = 256
MARKER = "[InferenceX moriep dispatch floor]"
TAG = "[moriep-floor]"


def find_target():
try:
import sglang
except Exception as e: # pragma: no cover
print(f"{TAG} ERROR: could not import sglang ({e}); NOT patched")
return None

# sglang may be a namespace package (no __init__.py) where __file__ is
# None. Fall through several strategies to locate the package root.
pkg_dir = None
if getattr(sglang, "__file__", None) is not None:
pkg_dir = os.path.dirname(sglang.__file__)
elif getattr(sglang, "__path__", None):
pkg_dir = list(sglang.__path__)[0]
else:
try:
import importlib.util
spec = importlib.util.find_spec("sglang")
if spec and spec.submodule_search_locations:
pkg_dir = list(spec.submodule_search_locations)[0]
except Exception:
pass

if pkg_dir is None:
print(f"{TAG} ERROR: could not determine sglang install path "
f"(__file__={getattr(sglang, '__file__', '?')}, "
f"__path__={getattr(sglang, '__path__', '?')}); NOT patched")
return None

rel = os.path.join("srt", "layers", "moe", "token_dispatcher", "moriep.py")
candidates = [
os.path.join(pkg_dir, rel),
os.path.join(pkg_dir, "python", "sglang", rel),
]
for path in candidates:
if os.path.isfile(path):
return path

# Last resort: walk the tree (bounded to 6 levels to avoid scanning /).
for root, _dirs, files in os.walk(pkg_dir):
if root.count(os.sep) - pkg_dir.count(os.sep) > 6:
_dirs.clear()
continue
if "moriep.py" in files:
found = os.path.join(root, "moriep.py")
print(f"{TAG} found moriep.py via walk: {found}")
return found

print(f"{TAG} ERROR: moriep.py not found under {pkg_dir} "
f"(tried {candidates}); NOT patched")
return None


def main():
path = find_target()
if path is None:
return 0 # non-fatal

with open(path) as f:
src = f.read()
lines = src.splitlines(keepends=True)

# Diagnostic: always show where the dispatch-token count is read/used so the
# CI log reveals the image's actual file shape even on a clean apply.
for i, l in enumerate(lines):
if "num_max_dispatch_tokens_per_rank" in l:
print(f"{TAG}[diag] {path}:{i + 1}: {l.rstrip()}")

if MARKER in src:
print(f"{TAG} already applied; skipping")
return 0

# Find the assignment that reads the env var, regardless of class name or
# formatting: `self.num_max_dispatch_tokens_per_rank = get_int_env_var(`.
start = None
for i, l in enumerate(lines):
if re.search(
r"self\.num_max_dispatch_tokens_per_rank\s*=\s*get_int_env_var\s*\(",
l,
):
start = i
break
if start is None:
print(
f"{TAG} ERROR: dispatch-token env read not found in {path}; "
f"NOT patched (server will run UNPATCHED -> expect corruption at "
f"low conc). See [diag] lines above for the actual source shape."
)
return 0 # non-fatal: surface loudly but let the run proceed

# Walk forward to the end of the (possibly multi-line) call by balancing parens.
depth = 0
end = start
for j in range(start, len(lines)):
depth += lines[j].count("(") - lines[j].count(")")
if depth <= 0:
end = j
break

indent = re.match(r"\s*", lines[start]).group(0)
floor_block = (
f"{indent}# {MARKER} floor to {FLOOR} (warpSize/fan-in safe). MoRI recv buffer\n"
f"{indent}# is worldSize*maxNumInpTokenPerRank; values below {FLOOR} silently\n"
f"{indent}# corrupt the dispatch path (gsm8k=0). sgl#27194 / mori#356.\n"
f"{indent}self.num_max_dispatch_tokens_per_rank = max(\n"
f"{indent} self.num_max_dispatch_tokens_per_rank, {FLOOR}\n"
f"{indent})\n"
)
lines.insert(end + 1, floor_block)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unbalanced parens corrupt moriep insert

Low Severity

If parenthesis balancing never reaches depth <= 0 before EOF, end stays at start and the floor block is inserted on the line after the opening get_int_env_var(, which can splice Python inside a multi-line call and break moriep.py without reporting failure.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 59dd6c3. Configure here.


try:
with open(path, "w") as f:
f.write("".join(lines))
except OSError as e:
print(f"{TAG} ERROR: could not write {path} ({e}); NOT patched")
return 0

print(
f"{TAG} applied: floored num_max_dispatch_tokens_per_rank to >= {FLOOR} "
f"in {path} (after line {end + 1})"
)
return 0
Comment thread
cursor[bot] marked this conversation as resolved.


if __name__ == "__main__":
sys.exit(main())
Loading
Loading