Skip to content

Commit 7ae613c

Browse files
ChangLiu0709cliu1004@amd.comchunfangamdcursoragentfunctionstackx
authored
[AMD] Add Qwen3.5 FP8 MI355X SGLang disaggregated benchmark (#1570)
* Add Qwen3.5 FP8 MI355X SGLang disaggregated benchmark (PR-1). Introduce CI config, model server flags, multinode launch script, and amd_utils plumbing (sudo auto-detect, optional disagg decode TP/DP flags) for qwen3.5-fp8-mi355x-sglang-disagg smoke sweeps on MI355X. Co-authored-by: chunfangamd <chun.fang@amd.com> Co-authored-by: ChangLiu0709 <cliu1004@amd.com> * Add perf-changelog entry for qwen3.5-fp8-mi355x-sglang-disagg. Required when adding new amd-master.yaml benchmark configs (PR #1570). Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: chunfangamd <chun.fang@amd.com> Co-authored-by: ChangLiu0709 <cliu1004@amd.com> * Fix misleading 8k1k comment in qwen3.5-fp8-mi355x-sglang-disagg config. The search-space uses 1P+1D TP8/EP1 with dp-attn, not TP2/EP2 from the aggregated qwen3.5-fp8-mi355x-sglang recipe. Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: chunfangamd <chun.fang@amd.com> Co-authored-by: ChangLiu0709 <cliu1004@amd.com> * qwen3.5-fp8-mi355x-sglang-disagg: 8k1k row uses dp-attn=false With --enable-dp-attention + --moe-a2a-backend mori, sglang auto-promotes moe_ep_size=tp_size=8, but is_deepep_class_backend() excludes MoRI, so num_shared_slots stays at the global value (1) and the (num_experts - num_shared_slots) % moe_ep_size assertion in fused_moe_triton/layer.py fires for Qwen3.5 (512 routed + 1 shared). Track upstream sglang for a fix; flip back to dp-attn=true once MoRI is added to is_deepep_class_backend() or shared-slot accounting is reconciled. Co-authored-by: chunfangamd <chun.fang@amd.com> Co-authored-by: ChangLiu0709 <cliu1004@amd.com> * fix: add FRAMEWORK to check_env_vars in qwen3.5 sglang-disagg script Matches sister sglang-disagg scripts (dsr1_fp8, dsr1_fp4) and the GLM-5 disagg launch script. submit.sh requires FRAMEWORK; surfacing the missing-var failure at the top of the launch script gives a cleaner error than letting it fail deep inside submit.sh. Addresses Cursor Bugbot review comment on PR #1570. Co-authored-by: chunfangamd <chun.fang@amd.com> Co-authored-by: ChangLiu0709 <cliu1004@amd.com> --------- Co-authored-by: cliu1004@amd.com <cliu1004@amd.com@mia1-p01-g18.mia.tensorwave.lan> Co-authored-by: chunfangamd <chun.fang@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: functionstackx <47992694+functionstackx@users.noreply.github.com>
1 parent bfad6df commit 7ae613c

5 files changed

Lines changed: 195 additions & 1 deletion

File tree

.github/configs/amd-master.yaml

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,70 @@ qwen3.5-fp8-mi355x-atom-mtp:
367367
- { tp: 4, ep: 1, conc-start: 4, conc-end: 256, spec-decoding: mtp }
368368
- { tp: 8, ep: 1, conc-start: 4, conc-end: 256, spec-decoding: mtp }
369369

370+
qwen3.5-fp8-mi355x-sglang-disagg:
371+
image: lmsysorg/sglang-rocm:v0.5.11-rocm700-mi35x-20260511
372+
model: Qwen/Qwen3.5-397B-A17B-FP8
373+
model-prefix: qwen3.5
374+
runner: mi355x-disagg
375+
precision: fp8
376+
framework: sglang-disagg
377+
multinode: true
378+
disagg: true
379+
scenarios:
380+
fixed-seq-len:
381+
- isl: 1024
382+
osl: 1024
383+
search-space:
384+
# Matches qwen3.5-fp8-mi355x-sglang TP8/EP1 low-concurrency sweep
385+
- spec-decoding: "none"
386+
conc-list: [ 8, 16, 32, 64, 128, 256, 512 ]
387+
prefill:
388+
num-worker: 1
389+
tp: 8
390+
ep: 1
391+
dp-attn: false
392+
additional-settings:
393+
- "PREFILL_NODES=1"
394+
decode:
395+
num-worker: 1
396+
tp: 8
397+
ep: 1
398+
dp-attn: false
399+
additional-settings:
400+
- "DECODE_NODES=1"
401+
- "DECODE_MTP_SIZE=0"
402+
403+
- isl: 8192
404+
osl: 1024
405+
search-space:
406+
# 1P+1D TP8/EP1 low-concurrency sweep.
407+
# dp-attn intentionally false (matches the 1k1k row): with
408+
# --enable-dp-attention + --moe-a2a-backend mori, sglang auto-promotes
409+
# moe_ep_size=tp_size=8, but is_deepep_class_backend() excludes MoRI,
410+
# so num_shared_slots stays at the global value (1) and the
411+
# (num_experts - num_shared_slots) % moe_ep_size assertion in
412+
# fused_moe_triton/layer.py fires for Qwen3.5 (512 routed + 1 shared).
413+
# Track upstream sglang for a fix; flip back to dp-attn=true once
414+
# MoRI is added to is_deepep_class_backend() or shared-slot
415+
# accounting is reconciled.
416+
- spec-decoding: "none"
417+
conc-list: [ 8, 16, 32, 64, 128, 256, 512 ]
418+
prefill:
419+
num-worker: 1
420+
tp: 8
421+
ep: 1
422+
dp-attn: false
423+
additional-settings:
424+
- "PREFILL_NODES=1"
425+
decode:
426+
num-worker: 1
427+
tp: 8
428+
ep: 1
429+
dp-attn: false
430+
additional-settings:
431+
- "DECODE_NODES=1"
432+
- "DECODE_MTP_SIZE=0"
433+
370434
qwen3.5-fp4-mi355x-sglang:
371435
image: lmsysorg/sglang:v0.5.12-rocm720-mi35x
372436
model: amd/Qwen3.5-397B-A17B-MXFP4

benchmarks/multi_node/amd_utils/models.yaml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,37 @@ DeepSeek-R1-0528:
161161
chunked_prefill_size: 262144
162162
cuda_graph_bs_range: "1-128"
163163

164+
Qwen3.5-397B-A17B-FP8:
165+
base_flags: "--decode-log-interval 1000 --log-level warning --watchdog-timeout 3600 --load-balance-method round_robin --kv-cache-dtype fp8_e4m3 --attention-backend aiter --disaggregation-transfer-backend mori --moe-dense-tp-size 1"
166+
mtp_flags: ""
167+
dp_flags: "--moe-a2a-backend mori --enable-dp-attention --enable-dp-lm-head"
168+
prefill:
169+
mem_fraction_static: 0.8
170+
disable_radix_cache: true
171+
dp:
172+
max_running_requests: 24
173+
chunked_prefill_size: "MORI_MAX_DISPATCH_TOKENS_PREFILL * PREFILL_TP_SIZE"
174+
cuda_graph_bs: "1 2 3"
175+
no_dp:
176+
max_running_requests: 128
177+
chunked_prefill_size: 262144
178+
cuda_graph_bs_range: "1-128"
179+
decode:
180+
mem_fraction_static: 0.85
181+
prefill_round_robin_balance: true
182+
dp:
183+
max_running_requests: 4096
184+
chunked_prefill_size: "MORI_MAX_DISPATCH_TOKENS_DECODE * DECODE_TP_SIZE"
185+
cuda_graph_bs_range: "1-160"
186+
ep_only:
187+
max_running_requests: 256
188+
chunked_prefill_size: 262144
189+
cuda_graph_bs_range: "1-256"
190+
no_dp:
191+
max_running_requests: 128
192+
chunked_prefill_size: 262144
193+
cuda_graph_bs_range: "1-128"
194+
164195
DeepSeek-R1-0528-MXFP4-Preview:
165196
base_flags: "--decode-log-interval 1000 --log-level warning --watchdog-timeout 3600 --ep-dispatch-algorithm fake --load-balance-method round_robin --kv-cache-dtype fp8_e4m3 --attention-backend aiter --disaggregation-transfer-backend mori"
166197
mtp_flags: "--speculative-algorithm NEXTN --speculative-eagle-topk 1"
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#!/usr/bin/env bash
2+
3+
source "$(dirname "$0")/../benchmark_lib.sh"
4+
5+
check_env_vars \
6+
CONC_LIST \
7+
ISL \
8+
OSL \
9+
IMAGE \
10+
SPEC_DECODING \
11+
MODEL_PATH \
12+
PREFILL_NUM_WORKERS \
13+
PREFILL_TP \
14+
PREFILL_EP \
15+
PREFILL_DP_ATTN \
16+
DECODE_NUM_WORKERS \
17+
DECODE_TP \
18+
DECODE_EP \
19+
DECODE_DP_ATTN \
20+
PREFILL_NODES \
21+
DECODE_NODES \
22+
RANDOM_RANGE_RATIO \
23+
FRAMEWORK
24+
25+
if [[ -n "$SLURM_JOB_ID" ]]; then
26+
echo "JOB $SLURM_JOB_ID running on $SLURMD_NODENAME"
27+
fi
28+
29+
set -x
30+
31+
# Use upstreamed multi_node scripts (no external clone needed)
32+
cd "$GITHUB_WORKSPACE/benchmarks/multi_node/amd_utils" || exit 1
33+
34+
# Set up SGL launch script-specific environment variables
35+
export TIME_LIMIT="08:00:00"
36+
export MODEL_PATH=$MODEL_PATH
37+
export MODEL_NAME=$MODEL_NAME
38+
export CONTAINER_IMAGE=$IMAGE
39+
40+
if [[ "${PREFILL_EP:-1}" -eq 1 ]]; then
41+
export PREFILL_ENABLE_EP=false
42+
else
43+
export PREFILL_ENABLE_EP=true
44+
fi
45+
46+
if [[ "$PREFILL_DP_ATTN" == "true" ]]; then
47+
export PREFILL_ENABLE_DP=true
48+
else
49+
export PREFILL_ENABLE_DP=false
50+
fi
51+
52+
if [[ "${DECODE_EP:-1}" -eq 1 ]]; then
53+
export DECODE_ENABLE_EP=false
54+
else
55+
export DECODE_ENABLE_EP=true
56+
fi
57+
58+
if [[ "$DECODE_DP_ATTN" == "true" ]]; then
59+
export DECODE_ENABLE_DP=true
60+
else
61+
export DECODE_ENABLE_DP=false
62+
fi
63+
64+
# Launch jobs based on ISL/OSL
65+
# Replace ' ' in CONC_LIST with 'x' such that the concurrency list is represented
66+
# by a list of numbers delimited by 'x'. This is because of how the underlying launch script
67+
# expects the concurrencies.
68+
JOB_ID=$(bash ./submit.sh $PREFILL_NODES \
69+
$PREFILL_NUM_WORKERS \
70+
$DECODE_NODES \
71+
$DECODE_NUM_WORKERS \
72+
$ISL $OSL "${CONC_LIST// /x}" inf \
73+
${PREFILL_ENABLE_EP} ${PREFILL_ENABLE_DP} \
74+
${DECODE_ENABLE_EP} ${DECODE_ENABLE_DP} \
75+
${PREFILL_TP} ${DECODE_TP} \
76+
${RANDOM_RANGE_RATIO} \
77+
${NODE_LIST:-})
78+
79+
if [[ $? -ne 0 ]]; then
80+
echo "Failed to submit job" >&2
81+
exit 1
82+
fi
83+
84+
echo "$JOB_ID"

perf-changelog.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3171,3 +3171,13 @@
31713171
description:
31723172
- "Validates measured-power aggregation pipeline (PR #1558) on both NVIDIA (H200) and AMD (MI355X) hardware — different SMI tools (nvidia-smi vs amd-smi), different CSV schemas (power.draw [W] vs socket_power), same aggregator. No config change. Entry intentionally kept past merge so run-sweep produces canonical agg JSONs with avg_power_w + joules_per_output_token on main for both vendors, seeding the dashboard's day-zero data."
31733173
pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1558
3174+
3175+
- config-keys:
3176+
- qwen3.5-fp8-mi355x-sglang-disagg
3177+
description:
3178+
- "Add Qwen3.5-397B-A17B-FP8 MI355X SGLang disaggregated prefill-decode benchmark"
3179+
- "Image: lmsysorg/sglang-rocm:v0.5.11-rocm700-mi35x-20260511"
3180+
- "1P+1D TP8/EP1 smoke sweep for 1k1k and 8k1k (conc 8-512); MoRI transfer backend"
3181+
- "Add models.yaml server flags and multinode launch script qwen3.5_fp8_mi355x_sglang-disagg.sh"
3182+
- "8k1k row uses dp-attn=false (matches 1k1k): with --enable-dp-attention + --moe-a2a-backend mori, sglang auto-promotes moe_ep_size=tp_size=8, but is_deepep_class_backend() excludes MoRI, so num_shared_slots stays at the global value (1) and the (num_experts - num_shared_slots) % moe_ep_size assertion in fused_moe_triton/layer.py fires for Qwen3.5 (512 routed + 1 shared). Track upstream sglang; flip back to dp-attn=true once MoRI is added to is_deepep_class_backend() or shared-slot accounting is reconciled."
3183+
pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1570

runners/launch_mi355x-amds.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ if [[ "$IS_MULTINODE" == "true" ]]; then
5454
# Ensure root-owned files are cleaned up even on early exit to prevent
5555
# EACCES errors when the next GH Actions job checks out on this runner.
5656
# Always preserve slurm logs as CI artifacts for debugging.
57+
# KEEP_LOGS=1 disables the trap entirely (local-debug knob).
5758
cleanup_and_save_logs() {
5859
if [[ -n "${GITHUB_ACTIONS:-}" && -n "${JOB_ID:-}" ]]; then
5960
local art_dir="$GITHUB_WORKSPACE/benchmark_artifacts"
@@ -69,7 +70,11 @@ if [[ "$IS_MULTINODE" == "true" ]]; then
6970
fi
7071
sudo rm -rf "$BENCHMARK_LOGS_DIR" 2>/dev/null || true
7172
}
72-
trap cleanup_and_save_logs EXIT
73+
if [[ "${KEEP_LOGS:-0}" == "1" ]]; then
74+
trap '' EXIT
75+
else
76+
trap cleanup_and_save_logs EXIT
77+
fi
7378

7479
SCRIPT_NAME="${EXP_NAME%%_*}_${PRECISION}_mi355x_${FRAMEWORK}.sh"
7580
if [[ "$FRAMEWORK" == "sglang-disagg" ]] || [[ "$FRAMEWORK" == "vllm-disagg" ]]; then

0 commit comments

Comments
 (0)