Skip to content

Commit ad07d03

Browse files
Fix MoE shape logging to print immediately to stderr
The previous approach relied on atexit flush which doesn't fire when the server process is killed. Now logs each new unique shape immediately to stderr (visible in workflow logs) and to the debug file. Co-authored-by: functionstackx <functionstackx@users.noreply.github.com>
1 parent 3f7a42f commit ad07d03

1 file changed

Lines changed: 15 additions & 22 deletions

File tree

benchmarks/benchmark_lib.sh

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -507,18 +507,14 @@ import os as _os, sys as _sys
507507
508508
_MOE_LOG = _os.environ.get("MOE_DEBUG_LOG", "")
509509
if _MOE_LOG:
510-
import atexit as _atexit, threading as _threading
510+
import threading as _threading
511511
512512
_moe_lock = _threading.Lock()
513513
_moe_shapes = {} # key -> count
514-
_moe_logged = 0
515514
_MOE_MAX_LOG = 200 # stop collecting after this many unique shapes
516515
517516
def _log_moe_shapes(hidden, w1, w2, topk_w, topk_ids, **kw):
518-
"""Collect shape tuples from fused_moe calls."""
519-
global _moe_logged
520-
if _moe_logged >= _MOE_MAX_LOG:
521-
return
517+
"""Log shape tuples from fused_moe calls immediately to stderr + file."""
522518
key = (
523519
"hidden=" + str(tuple(hidden.shape)),
524520
"w1=" + str(tuple(w1.shape)),
@@ -529,22 +525,19 @@ if _MOE_LOG:
529525
"w1_dtype=" + str(w1.dtype),
530526
)
531527
with _moe_lock:
532-
_moe_shapes[key] = _moe_shapes.get(key, 0) + 1
533-
_moe_logged = len(_moe_shapes)
534-
535-
def _flush_moe_log():
536-
if not _moe_shapes:
537-
return
538-
try:
539-
with open(_MOE_LOG, "a") as f:
540-
f.write("=== fused_moe shape log (rank 0) ===\n")
541-
for key, cnt in sorted(_moe_shapes.items(), key=lambda x: -x[1]):
542-
f.write(f" count={cnt:>5d} {' '.join(key)}\n")
543-
f.write(f"=== total unique shapes: {len(_moe_shapes)} ===\n")
544-
except Exception as e:
545-
print(f"[MOE_DEBUG] flush error: {e}", file=_sys.stderr)
546-
547-
_atexit.register(_flush_moe_log)
528+
prev = _moe_shapes.get(key, 0)
529+
_moe_shapes[key] = prev + 1
530+
if len(_moe_shapes) > _MOE_MAX_LOG:
531+
return
532+
# Log new shapes immediately + periodic count updates
533+
if prev == 0 or prev in (10, 100, 1000):
534+
msg = f"[MOE_SHAPE] count={prev+1:>5d} {' '.join(key)}"
535+
print(msg, file=_sys.stderr, flush=True)
536+
try:
537+
with open(_MOE_LOG, "a") as f:
538+
f.write(msg + "\n")
539+
except Exception:
540+
pass
548541
549542
# ---------- Patch aiter.fused_moe.fused_moe (AMD/ROCm path) ----------
550543
def _try_patch_aiter():

0 commit comments

Comments
 (0)