Skip to content

Commit b3d6df2

Browse files
committed
[qwen3_5_moe][ci] Track export GPU peak memory and gate it in CI
## Summary Add a GPU memory regression guard so that the Qwen3.5 MoE export keeps fitting on consumer-grade 24 GB GPUs (RTX 4090 / 3090 / A5000 …). ## What this diff does 1. `examples/models/qwen3_5_moe/export.py` - Reset CUDA peak memory stats at the start of the CUDA backend setup. - At the end of `main()`, when running with `--backend cuda`, print a stable, machine-parseable marker line: `EXPORT_GPU_PEAK_MEMORY_MB: <peak_in_MB>` This makes the actual peak GPU memory consumed by the entire load + quantize + lower pipeline visible to both humans and CI. 2. `.ci/scripts/export_model_artifact.sh` (qwen3_5_moe path) - Tee the export output to a temp log. - Grep the `EXPORT_GPU_PEAK_MEMORY_MB` marker and compare against `EXPORT_GPU_PEAK_MB_LIMIT` (default 20480 MB = 20 GB; overridable via env var). - Fail the job with an explanatory error if the budget is exceeded, so any future regression that reintroduces the ~18 GB unnecessary GPU clone (or comparable leak) is caught at PR time rather than silently breaking 24 GB-class GPUs. ## Notes - Current measured peak with the CUDA backend memory fixes (see prior commit on this branch) is ~18 GB, leaving ~2 GB headroom under the 20 GB limit. Without those fixes the peak shoots to ~37 GB and CI will fail loudly. - The threshold is intentionally tighter than the 24 GB physical cap to leave room for measurement noise and small allocator overhead. ## Test Plan - Manual: ran `python -m executorch.examples.models.qwen3_5_moe.export --prequantized <hqq-int4-bundle> --backend cuda` and confirmed the marker line is printed at the end with a sensible value (~18 GB). - Manual: simulated CI gate logic locally with the marker line and confirmed both the success path and the failure path (forced threshold below the actual peak) behave as expected.
1 parent e3751bc commit b3d6df2

2 files changed

Lines changed: 41 additions & 1 deletion

File tree

.ci/scripts/export_model_artifact.sh

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,38 @@ if [ "$MODEL_NAME" = "qwen3_5_moe" ]; then
415415

416416
# Export to .pte/.ptd (short cache dir avoids objcopy symbol length issues)
417417
echo "::group::Export"
418+
EXPORT_LOG=$(mktemp)
418419
TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \
419420
python -m executorch.examples.models.qwen3_5_moe.export \
420421
--prequantized "$LOCAL_MODEL_DIR" \
421-
--output-dir "${OUTPUT_DIR}"
422+
--output-dir "${OUTPUT_DIR}" 2>&1 | tee "$EXPORT_LOG"
423+
EXPORT_RC=${PIPESTATUS[0]}
422424
echo "::endgroup::"
423425

426+
if [ "$EXPORT_RC" -ne 0 ]; then
427+
echo "ERROR: Qwen3.5 MoE export failed (exit $EXPORT_RC)"
428+
rm -f "$EXPORT_LOG"
429+
exit "$EXPORT_RC"
430+
fi
431+
432+
# Gate peak GPU memory so we keep the export viable on consumer GPUs
433+
# (e.g. RTX 4090 with 24 GB). The export script prints a machine-
434+
# parseable marker line "EXPORT_GPU_PEAK_MEMORY_MB: <float>".
435+
EXPORT_GPU_PEAK_MB_LIMIT="${EXPORT_GPU_PEAK_MB_LIMIT:-20480}"
436+
PEAK_LINE=$(grep -E '^EXPORT_GPU_PEAK_MEMORY_MB:' "$EXPORT_LOG" | tail -1)
437+
rm -f "$EXPORT_LOG"
438+
if [ -z "$PEAK_LINE" ]; then
439+
echo "ERROR: export did not emit EXPORT_GPU_PEAK_MEMORY_MB marker; cannot enforce GPU memory budget"
440+
exit 1
441+
fi
442+
PEAK_MB=$(echo "$PEAK_LINE" | awk '{print $2}')
443+
echo "Export GPU peak memory: ${PEAK_MB} MB (limit ${EXPORT_GPU_PEAK_MB_LIMIT} MB)"
444+
if awk -v p="$PEAK_MB" -v l="$EXPORT_GPU_PEAK_MB_LIMIT" 'BEGIN{exit !(p>l)}'; then
445+
echo "ERROR: export exceeded GPU memory budget (${PEAK_MB} MB > ${EXPORT_GPU_PEAK_MB_LIMIT} MB)"
446+
echo " — this would prevent the model from being exported on a 24 GB consumer GPU."
447+
exit 1
448+
fi
449+
424450
test -f "${OUTPUT_DIR}/model.pte"
425451
test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd"
426452
ls -al "${OUTPUT_DIR}"

examples/models/qwen3_5_moe/export.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,13 @@ def main(): # noqa: C901
967967
# Register FLA Triton kernel (CUDA only)
968968
import executorch.backends.cuda.triton.kernels # noqa: F401
969969

970+
# Reset peak GPU memory stats so we can report the actual peak
971+
# consumed during the export pipeline (load + quantize + lowering)
972+
# at the very end. This is also gated by CI to make sure low-VRAM
973+
# GPUs (e.g. RTX 4090, 24 GB) can still complete the export.
974+
if torch.cuda.is_available():
975+
torch.cuda.reset_peak_memory_stats(0)
976+
970977
if args.backend == "mlx":
971978
if args.prequantized:
972979
parser.error("--prequantized is not supported with --backend mlx")
@@ -989,6 +996,13 @@ def main(): # noqa: C901
989996

990997
export_and_lower(model, config, args)
991998

999+
# Report peak GPU memory consumed during the export so CI / users can
1000+
# gate this against a known budget (e.g. 24 GB consumer GPUs).
1001+
if args.backend == "cuda" and torch.cuda.is_available():
1002+
peak_mb = torch.cuda.max_memory_allocated(0) / (1024 * 1024)
1003+
# Stable, machine-parseable marker for CI grep.
1004+
print(f"EXPORT_GPU_PEAK_MEMORY_MB: {peak_mb:.2f}")
1005+
9921006

9931007
if __name__ == "__main__":
9941008
main()

0 commit comments

Comments
 (0)