Skip to content

Commit c4e397d

Browse files
committed
glm5-fp8-mi355x-sglang-disagg: bump to v0.5.12.post1 image and patch DSA state-index path
amd-master.yaml - Image: rocm/sgl-dev:sglang-0.5.9-rocm720-mi35x-mori-0402 -> lmsysorg/sglang-rocm:v0.5.12.post1-rocm720-mi35x-20260523 (matches qwen3.5-fp8-mi355x-sglang-disagg; the older 0.5.9 image is no longer the reference build for hybrid-attention disagg models on MI355X.) - Scenarios: collapse the four legacy "top/middle/bottom/small-scale" search-spaces per ISL into a single 1P+1D TP=8 EP=1 dp-attn=false entry with the standard conc-list [8, 16, 32, 64, 128, 256, 512] for both 1k1k and 8k1k. dp-attn=false avoids the fused_moe_triton/layer.py:209 shared-slot assertion that --enable-dp-attention + --moe-a2a-backend mori triggers for GLM-5 (256 routed + 1 shared expert; (256-1) % 8 = 7 != 0). The collapsed layout mirrors the qwen3.5-fp8-mi355x-sglang-disagg shape so the same CI matrix-expansion logic applies to both. patches/mori_conn.py - Add patch #4: rank + length normalization in MoriKVReceiver._send_swa_dsa_state, immediately before the group_concurrent_contiguous call. For GLM-5 (single DSA component), upstream hands dst_state_indices as a 2-D (1, N) array while src_state_indices is 1-D length 1; the existing [:common_len] slice operates only on the outer axis, leaving the rank mismatched. np.diff then produces (1, N-1) vs (0,), which can't broadcast and crashes with "operands could not be broadcast together with shapes (1,12) (0,)". The fix ravels both indices to 1-D and re-truncates to common length so np.diff outputs compatible 1-D arrays. One-shot log gates the warning to once per receiver class. - Verified end-to-end: glm5-fp8-mi355x-sglang-disagg gsm8k flexible-extract = 0.9704 +/- 0.0047 glm5-fp8-mi355x-sglang-disagg gsm8k strict-match = 0.9712 +/- 0.0046 qwen3.5-fp8-mi355x-sglang-disagg gsm8k (regression) = 0.9780 +/- 0.004 Patch #4 fires zero times on the Qwen3.5 Mamba path (it lives inside _send_swa_dsa_state, never called for Mamba); patches #1-#3 behavior is unchanged. patches/README.md - Document patch #4 alongside the existing three. Cross-link the full bug analysis at scripts/sglang_disagg/docs_glm5/01-bug-analysis.md and the gsm8k verification at scripts/sglang_disagg/docs_glm5/02-fix-and-verification.md.
1 parent 688ebe6 commit c4e397d

3 files changed

Lines changed: 50 additions & 107 deletions

File tree

.github/configs/amd-master.yaml

Lines changed: 3 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ glm5-fp8-mi355x-sglang-mtp:
481481
- { tp: 8, conc-start: 4, conc-end: 8, spec-decoding: mtp }
482482

483483
glm5-fp8-mi355x-sglang-disagg:
484-
image: rocm/sgl-dev:sglang-0.5.9-rocm720-mi35x-mori-0402
484+
image: lmsysorg/sglang-rocm:v0.5.12.post1-rocm720-mi35x-20260523
485485
model: zai-org/GLM-5-FP8
486486
model-prefix: glm5
487487
runner: mi355x-disagg
@@ -494,75 +494,15 @@ glm5-fp8-mi355x-sglang-disagg:
494494
- isl: 1024
495495
osl: 1024
496496
search-space:
497-
# "Top of curve" (1 prefill worker at TP8 and 1 decode worker at DEP16 across 2 nodes)
498497
- spec-decoding: "none"
499-
conc-list: [ 1024, 2048 ]
498+
conc-list: [ 8, 16, 32, 64, 128, 256, 512 ]
500499
prefill:
501500
num-worker: 1
502501
tp: 8
503502
ep: 1
504503
dp-attn: false
505504
additional-settings:
506505
- "PREFILL_NODES=1"
507-
decode:
508-
num-worker: 1
509-
tp: 8
510-
ep: 8
511-
dp-attn: true
512-
additional-settings:
513-
- "DECODE_NODES=2"
514-
- "DECODE_MTP_SIZE=0"
515-
516-
# "Middle of curve" (1 prefill worker at TP8 and 2 decode workers at DEP8)
517-
- spec-decoding: "none"
518-
conc-list: [ 1536, 1024, 512 ]
519-
prefill:
520-
num-worker: 1
521-
tp: 8
522-
ep: 1
523-
dp-attn: false
524-
additional-settings:
525-
- "PREFILL_NODES=1"
526-
decode:
527-
num-worker: 2
528-
tp: 8
529-
ep: 8
530-
dp-attn: true
531-
additional-settings:
532-
- "DECODE_NODES=2"
533-
- "DECODE_MTP_SIZE=0"
534-
535-
# "Bottom of curve" (1 prefill worker at TP8 and 2 decode workers at TP8)
536-
- spec-decoding: "none"
537-
conc-list: [ 256, 128, 64, 32, 16, 8, 4 ]
538-
prefill:
539-
num-worker: 1
540-
tp: 8
541-
ep: 1
542-
dp-attn: false
543-
additional-settings:
544-
- "PREFILL_NODES=1"
545-
546-
decode:
547-
num-worker: 2
548-
tp: 8
549-
ep: 1
550-
dp-attn: false
551-
additional-settings:
552-
- "DECODE_NODES=2"
553-
- "DECODE_MTP_SIZE=0"
554-
555-
# "Small scale" (1 prefill worker at TP4 and 1 decode worker at TP8)
556-
- spec-decoding: "none"
557-
conc-list: [ 64, 32, 16, 8, 4, 2, 1 ]
558-
prefill:
559-
num-worker: 1
560-
tp: 4
561-
ep: 1
562-
dp-attn: false
563-
additional-settings:
564-
- "PREFILL_NODES=1"
565-
566506
decode:
567507
num-worker: 1
568508
tp: 8
@@ -575,56 +515,15 @@ glm5-fp8-mi355x-sglang-disagg:
575515
- isl: 8192
576516
osl: 1024
577517
search-space:
578-
# "Top of curve" (2 prefill workers at DEP8 and 1 decode worker at DEP8)
579-
- spec-decoding: "none"
580-
conc-list: [ 1024, 2048 ]
581-
prefill:
582-
num-worker: 2
583-
tp: 8
584-
ep: 8
585-
dp-attn: true
586-
additional-settings:
587-
- "PREFILL_NODES=2"
588-
decode:
589-
num-worker: 1
590-
tp: 8
591-
ep: 8
592-
dp-attn: true
593-
additional-settings:
594-
- "DECODE_NODES=1"
595-
- "DECODE_MTP_SIZE=0"
596-
597-
# "Bottom of curve" (1 prefill worker at TP8 and 2 decode workers at TP8)
598518
- spec-decoding: "none"
599-
conc-list: [ 256, 128, 64, 32, 16, 8, 4 ]
519+
conc-list: [ 8, 16, 32, 64, 128, 256, 512 ]
600520
prefill:
601521
num-worker: 1
602522
tp: 8
603523
ep: 1
604524
dp-attn: false
605525
additional-settings:
606526
- "PREFILL_NODES=1"
607-
608-
decode:
609-
num-worker: 2
610-
tp: 8
611-
ep: 1
612-
dp-attn: false
613-
additional-settings:
614-
- "DECODE_NODES=2"
615-
- "DECODE_MTP_SIZE=0"
616-
617-
# "Small scale" (1 prefill worker at TP4 and 1 decode worker at TP8)
618-
- spec-decoding: "none"
619-
conc-list: [ 64, 32, 16, 8, 4, 2, 1 ]
620-
prefill:
621-
num-worker: 1
622-
tp: 4
623-
ep: 1
624-
dp-attn: false
625-
additional-settings:
626-
- "PREFILL_NODES=1"
627-
628527
decode:
629528
num-worker: 1
630529
tp: 8

benchmarks/multi_node/amd_utils/patches/README.md

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ Overlays
2020
Source: forked from the file shipped in
2121
`lmsysorg/sglang-rocm:v0.5.12.post1-rocm720-mi35x-20260523`
2222
(sglang [v0.5.12.post1](https://github.com/sgl-project/sglang/tree/v0.5.12.post1)).
23-
Three logical edits, all confined to `MoriKVReceiver.send_state` and
24-
`MoriKVReceiver._register_kv_args`:
23+
Four logical edits, all confined to `MoriKVReceiver.send_state`,
24+
`MoriKVReceiver._register_kv_args`, and
25+
`MoriKVReceiver._send_swa_dsa_state`:
2526

2627
1. **Sender flatten** — handle the framework's nested
2728
`state_item_lens: List[List[int]]` instead of crashing in the
@@ -37,11 +38,23 @@ Three logical edits, all confined to `MoriKVReceiver.send_state` and
3738
`send_state`, so the existing per-tensor index arithmetic
3839
(`state_item_lens[i]`) and length checks
3940
(`len(state_item_lens) == len(state_mem_descs)`) keep working.
41+
4. **DSA index rank+length normalization** — inside
42+
`_send_swa_dsa_state`, before the `group_concurrent_contiguous`
43+
call, ravel both `src_state_indices` and `dst_state_indices` to 1-D
44+
and re-truncate to common length. Upstream's existing truncation
45+
only slices the outer axis, leaving 2-D `(1, N)` arrays unchanged
46+
and triggering an `np.diff` broadcasting error
47+
(`shapes (1,12) (0,)`) for GLM-5 (single-DSA-component) prefill
48+
traffic. See
49+
`scripts/sglang_disagg/docs_glm5/01-bug-analysis.md` for the full
50+
write-up.
4051

4152
Verified passing GSM8K = 0.978 ± 0.004 on Qwen3.5-397B-A17B-FP8 1P+1D
4253
TP=8 dp-attn=false (matches and slightly exceeds upstream
4354
[PR #22665](https://github.com/sgl-project/sglang/pull/22665)'s
44-
reported 0.970 GSM8K on the bf16 baseline).
55+
reported 0.970 GSM8K on the bf16 baseline). GLM-5 (DSA) verification
56+
in progress under
57+
`scripts/sglang_disagg/docs_glm5/02-fix-and-verification.md`.
4558

4659
This is a stop-gap. The proper upstream fix is to migrate MoRI to the
4760
plural `state_types: List[StateType]` API (full design + diff in

benchmarks/multi_node/amd_utils/patches/mori_conn.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,37 @@ def _send_swa_dsa_state(
11481148
src_state_indices = src_state_indices[:common_len]
11491149
dst_state_indices = dst_state_indices[:common_len]
11501150

1151+
# ── BEGIN PATCH #4: rank + length normalization ──────────────────
1152+
# Bug: for DSA single-component models (e.g. GLM-5), upstream may
1153+
# hand us `dst_state_indices` as a 2-D array of shape (1, N) or
1154+
# as a List[int]/List[np.ndarray]. The earlier `[:common_len]`
1155+
# slice operates only on the outer axis, so a (1, 13) input stays
1156+
# (1, 13). `group_concurrent_contiguous` then runs `np.diff` on
1157+
# arrays of mismatched rank ((1, N-1) vs (0,)) and crashes with
1158+
# "operands could not be broadcast together with shapes (1,12) (0,)".
1159+
# Flatten both sides to 1-D and re-align lengths so np.diff produces
1160+
# 1-D arrays of equal length.
1161+
src_state_indices = np.asarray(src_state_indices).ravel()
1162+
dst_state_indices = np.asarray(dst_state_indices).ravel()
1163+
if len(src_state_indices) != len(dst_state_indices):
1164+
new_common = min(len(src_state_indices), len(dst_state_indices))
1165+
if not getattr(self.__class__, "_logged_dsa_index_flatten", False):
1166+
try:
1167+
logger.warning(
1168+
"[mori-patch] DSA state-indices ravel/realign for %s: "
1169+
"src=%d dst=%d -> common=%d (one-shot log)",
1170+
state_type,
1171+
len(src_state_indices),
1172+
len(dst_state_indices),
1173+
new_common,
1174+
)
1175+
except Exception:
1176+
pass
1177+
self.__class__._logged_dsa_index_flatten = True
1178+
src_state_indices = src_state_indices[:new_common]
1179+
dst_state_indices = dst_state_indices[:new_common]
1180+
# ── END PATCH #4 ─────────────────────────────────────────────────
1181+
11511182
# Group contiguous indices and issue per-tensor transfers
11521183
grouped_plan = GroupedIndexPlan.from_groups(
11531184
*group_concurrent_contiguous(src_state_indices, dst_state_indices)

0 commit comments

Comments
 (0)