Skip to content

Commit 1d3f38d

Browse files
Add allreduce tensor shape logging for profiling
- Add `benchmarks/patches/inject_ar_shape_logging.py`: patches SGLang's parallel_state.py and custom_all_reduce.py inside the container to log tensor .shape, .dtype, and byte size on rank 0 for every allreduce call - Modify `dsr1_fp4_mi355x.sh` to run the injection when AR_SHAPE_LOGGING=1 - Add `ar-shape-logging` input to profile.yml workflow Closes #1005 Co-authored-by: functionstackx <functionstackx@users.noreply.github.com>
1 parent 0ee5ff3 commit 1d3f38d

5 files changed

Lines changed: 322 additions & 0 deletions

File tree

.github/workflows/profile.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ on:
2222
required: false
2323
type: boolean
2424
default: false
25+
ar-shape-logging:
26+
description: "Enable allreduce tensor shape logging (AR_SHAPE_LOGGING)"
27+
required: false
28+
type: boolean
29+
default: false
2530
ref:
2631
description: "Ref (branch/sha) to checkout"
2732
required: false
@@ -117,6 +122,7 @@ jobs:
117122
DISAGG: ${{ matrix.config.disagg }}
118123
MOE_DEBUG: '0'
119124
MOE_DEBUG_LOG: ${{ (inputs.moe-debug) && '/workspace/moe_debug.tp0.log' || '' }}
125+
AR_SHAPE_LOGGING: ${{ (inputs.ar-shape-logging) && '1' || '0' }}
120126
steps:
121127
- name: Resource cleanup
122128
run: |
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
Monkey-patch SGLang's GroupCoordinator.all_reduce to log tensor shapes
3+
entering the custom allreduce kernel (cross_device_reduce_2stage).
4+
5+
Usage: Set PYTHONPATH to include the directory containing sitecustomize.py
6+
which imports this module, OR call patch() directly before launching SGLang.
7+
8+
Logs are written to /workspace/allreduce_shapes.log (one line per call on rank 0).
9+
After the run, the log can be post-processed to get unique shapes and counts.
10+
"""
11+
12+
import atexit
13+
import collections
14+
import os
15+
16+
_shape_counts = collections.Counter()
17+
_log_file = None
18+
_original_all_reduce = None
19+
_original_all_reduce_out_place = None
20+
_patched = False
21+
# Limit per-call logging to avoid flooding stdout; summary is printed at exit.
22+
_MAX_LOG_LINES = 200
23+
_log_line_count = 0
24+
25+
26+
def _get_rank():
27+
try:
28+
import torch.distributed as dist
29+
if dist.is_initialized():
30+
return dist.get_rank()
31+
except Exception:
32+
pass
33+
return 0
34+
35+
36+
def _patched_all_reduce_out_place(self, input_, outplace_all_reduce_method):
37+
"""Wrapper around _all_reduce_out_place that logs shapes for custom AR calls."""
38+
global _log_line_count
39+
rank = _get_rank()
40+
if rank == 0:
41+
shape_key = (tuple(input_.shape), str(input_.dtype), outplace_all_reduce_method)
42+
_shape_counts[shape_key] += 1
43+
if _log_line_count < _MAX_LOG_LINES:
44+
print(
45+
f"[AR_SHAPE] method={outplace_all_reduce_method} "
46+
f"shape={list(input_.shape)} dtype={input_.dtype} "
47+
f"numel={input_.numel()} bytes={input_.numel() * input_.element_size()}",
48+
flush=True,
49+
)
50+
_log_line_count += 1
51+
return _original_all_reduce_out_place(self, input_, outplace_all_reduce_method)
52+
53+
54+
def _patched_all_reduce(self, input_):
55+
"""Wrapper around all_reduce that logs shapes for ALL allreduce calls (including in-place/deterministic)."""
56+
global _log_line_count
57+
rank = _get_rank()
58+
if rank == 0 and _log_line_count < _MAX_LOG_LINES:
59+
shape_key = (tuple(input_.shape), str(input_.dtype), "all")
60+
_shape_counts[shape_key] += 1
61+
if _log_line_count < _MAX_LOG_LINES:
62+
print(
63+
f"[AR_SHAPE_ENTRY] shape={list(input_.shape)} dtype={input_.dtype} "
64+
f"numel={input_.numel()} bytes={input_.numel() * input_.element_size()}",
65+
flush=True,
66+
)
67+
_log_line_count += 1
68+
return _original_all_reduce(self, input_)
69+
70+
71+
def _print_summary():
72+
"""Print aggregated shape summary at process exit."""
73+
rank = _get_rank()
74+
if rank != 0 or not _shape_counts:
75+
return
76+
77+
log_path = os.environ.get("AR_SHAPE_LOG", "/workspace/allreduce_shapes.log")
78+
lines = []
79+
lines.append("\n" + "=" * 80)
80+
lines.append("[AR_SHAPE_SUMMARY] AllReduce tensor shapes (rank 0):")
81+
lines.append(f"{'Count':>8} {'Method':<12} {'Shape':<30} {'Dtype':<16} {'Bytes':<12}")
82+
lines.append("-" * 80)
83+
84+
for (shape, dtype, method), count in _shape_counts.most_common():
85+
import torch
86+
# Compute element size from dtype string
87+
elem_size = 2 # default bf16
88+
if "float32" in dtype:
89+
elem_size = 4
90+
elif "float16" in dtype or "bfloat16" in dtype:
91+
elem_size = 2
92+
elif "float8" in dtype:
93+
elem_size = 1
94+
numel = 1
95+
for s in shape:
96+
numel *= s
97+
nbytes = numel * elem_size
98+
lines.append(f"{count:>8} {method:<12} {str(list(shape)):<30} {dtype:<16} {nbytes:<12}")
99+
100+
lines.append("=" * 80)
101+
summary = "\n".join(lines)
102+
print(summary, flush=True)
103+
104+
try:
105+
with open(log_path, "w") as f:
106+
f.write(summary + "\n")
107+
print(f"[AR_SHAPE] Summary written to {log_path}", flush=True)
108+
except Exception as e:
109+
print(f"[AR_SHAPE] Failed to write log: {e}", flush=True)
110+
111+
112+
def patch():
113+
"""Apply the monkey-patch to GroupCoordinator."""
114+
global _original_all_reduce, _original_all_reduce_out_place, _patched
115+
if _patched:
116+
return
117+
118+
try:
119+
from sglang.srt.distributed.parallel_state import GroupCoordinator
120+
except ImportError:
121+
print("[AR_SHAPE] Could not import GroupCoordinator, skipping patch", flush=True)
122+
return
123+
124+
_original_all_reduce = GroupCoordinator.all_reduce
125+
_original_all_reduce_out_place = GroupCoordinator._all_reduce_out_place
126+
127+
GroupCoordinator.all_reduce = _patched_all_reduce
128+
GroupCoordinator._all_reduce_out_place = _patched_all_reduce_out_place
129+
_patched = True
130+
131+
atexit.register(_print_summary)
132+
print("[AR_SHAPE] Monkey-patch installed: logging allreduce tensor shapes on rank 0", flush=True)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Inject allreduce shape logging into SGLang's parallel_state.py at runtime.
4+
5+
Patches GroupCoordinator._all_reduce_out_place to print tensor shapes on rank 0.
6+
This patches the actual source file inside the container so that all worker
7+
processes (forked by SGLang) pick up the change.
8+
9+
Usage: python3 inject_ar_shape_logging.py
10+
"""
11+
import importlib
12+
import os
13+
import re
14+
import sys
15+
import textwrap
16+
17+
18+
def find_and_patch(module_path: str, target_method: str, log_tag: str) -> bool:
19+
"""Find a Python module file and inject shape logging into a method."""
20+
try:
21+
mod = importlib.import_module(module_path)
22+
filepath = mod.__file__
23+
except (ImportError, AttributeError) as e:
24+
print(f"[AR_SHAPE] Could not import {module_path}: {e}")
25+
return False
26+
27+
if not filepath or not os.path.exists(filepath):
28+
print(f"[AR_SHAPE] File not found for {module_path}")
29+
return False
30+
31+
with open(filepath, "r") as f:
32+
src = f.read()
33+
34+
# Look for the method definition
35+
# Match: "def <method_name>(self, <args>):"
36+
pattern = rf"( def {re.escape(target_method)}\(self[^)]*\)[^:]*:.*\n)"
37+
match = re.search(pattern, src)
38+
if not match:
39+
print(f"[AR_SHAPE] Could not find {target_method} in {filepath}")
40+
return False
41+
42+
# Check if already patched
43+
if "[AR_SHAPE_LOG]" in src:
44+
print(f"[AR_SHAPE] Already patched: {filepath}")
45+
return True
46+
47+
# Find the first argument name after self (the tensor)
48+
sig_match = re.search(
49+
rf"def {re.escape(target_method)}\(self,\s*(\w+)", src
50+
)
51+
tensor_name = sig_match.group(1) if sig_match else "input_"
52+
53+
# Build the logging code to insert after the method def line
54+
log_code = textwrap.dedent(f"""\
55+
# [AR_SHAPE_LOG] Injected shape logging
56+
try:
57+
import torch.distributed as _dist
58+
if not _dist.is_initialized() or _dist.get_rank() == 0:
59+
_s = list({tensor_name}.shape)
60+
_b = {tensor_name}.numel() * {tensor_name}.element_size()
61+
print(f"[AR_SHAPE] {log_tag} shape={{_s}} dtype={{{tensor_name}.dtype}} bytes={{_b}}", flush=True)
62+
except Exception:
63+
pass
64+
""")
65+
66+
# Indent to match method body (8 spaces)
67+
indented_log = textwrap.indent(log_code, " ")
68+
69+
# Insert after the method definition line
70+
end_of_def = match.end()
71+
new_src = src[:end_of_def] + indented_log + src[end_of_def:]
72+
73+
with open(filepath, "w") as f:
74+
f.write(new_src)
75+
print(f"[AR_SHAPE] Patched {target_method} in {filepath}")
76+
return True
77+
78+
79+
def patch_parallel_state():
80+
"""Patch GroupCoordinator._all_reduce_out_place in parallel_state.py."""
81+
return find_and_patch(
82+
"sglang.srt.distributed.parallel_state",
83+
"_all_reduce_out_place",
84+
"out_place",
85+
)
86+
87+
88+
def patch_sglang_custom_ar():
89+
"""Patch CustomAllreduce.all_reduce_unreg in sglang's custom_all_reduce.py."""
90+
return find_and_patch(
91+
"sglang.srt.distributed.device_communicators.custom_all_reduce",
92+
"all_reduce_unreg",
93+
"sglang_unreg",
94+
)
95+
96+
97+
def patch_aiter_custom_ar():
98+
"""Patch CustomAllreduce.all_reduce_unreg in aiter's custom_all_reduce.py."""
99+
return find_and_patch(
100+
"aiter.dist.device_communicators.custom_all_reduce",
101+
"all_reduce_unreg",
102+
"aiter_unreg",
103+
)
104+
105+
106+
def patch_top_level_all_reduce():
107+
"""Patch GroupCoordinator.all_reduce — the single entry point for all allreduce calls."""
108+
return find_and_patch(
109+
"sglang.srt.distributed.parallel_state",
110+
"all_reduce",
111+
"entry",
112+
)
113+
114+
115+
if __name__ == "__main__":
116+
print("[AR_SHAPE] Starting allreduce shape logging injection...")
117+
118+
# Patch the top-level entry point (catches ALL allreduce calls)
119+
patch_top_level_all_reduce()
120+
121+
# Patch the out-of-place path (catches custom AR method selection)
122+
patch_parallel_state()
123+
124+
# Patch the low-level unreg call in both sglang and aiter
125+
patch_sglang_custom_ar()
126+
patch_aiter_custom_ar()
127+
128+
print("[AR_SHAPE] Done. Shape logs will appear as [AR_SHAPE] lines in server output.")
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Auto-patch SGLang allreduce shape logging via import hook.
2+
3+
When AR_SHAPE_LOGGING=1, installs a meta-path finder that waits for
4+
sglang.srt.distributed.parallel_state to be imported, then applies the
5+
monkey-patch to log tensor shapes entering the custom allreduce kernel.
6+
"""
7+
import importlib
8+
import os
9+
import sys
10+
11+
12+
if os.environ.get("AR_SHAPE_LOGGING") == "1":
13+
14+
class _AllReducePatchFinder:
15+
"""Meta-path finder that triggers patching after parallel_state is imported."""
16+
_target = "sglang.srt.distributed.parallel_state"
17+
_done = False
18+
19+
def find_module(self, fullname, path=None):
20+
if not self._done and fullname == self._target:
21+
return self
22+
return None
23+
24+
def load_module(self, fullname):
25+
# Remove ourselves so we don't recurse
26+
self._done = True
27+
if self in sys.meta_path:
28+
sys.meta_path.remove(self)
29+
30+
# Let the real import happen
31+
if fullname in sys.modules:
32+
mod = sys.modules[fullname]
33+
else:
34+
mod = importlib.import_module(fullname)
35+
36+
# Now apply the patch
37+
try:
38+
_patch_dir = os.path.dirname(os.path.abspath(__file__))
39+
if _patch_dir not in sys.path:
40+
sys.path.insert(0, _patch_dir)
41+
import allreduce_shape_logger
42+
allreduce_shape_logger.patch()
43+
except Exception as e:
44+
print(f"[AR_SHAPE] Deferred patch failed: {e}", flush=True)
45+
46+
return mod
47+
48+
sys.meta_path.insert(0, _AllReducePatchFinder())
49+
print("[AR_SHAPE] Import hook installed, will patch after parallel_state loads", flush=True)

benchmarks/single_node/dsr1_fp4_mi355x.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ hf download "$MODEL"
2020
export SGLANG_USE_AITER=1
2121
export ROCM_QUICK_REDUCE_QUANTIZATION=INT4
2222

23+
# Optionally inject allreduce shape logging (set AR_SHAPE_LOGGING=1 to enable)
24+
if [[ "${AR_SHAPE_LOGGING:-}" == "1" ]]; then
25+
echo "[AR_SHAPE] Injecting allreduce shape logger..."
26+
python3 /workspace/benchmarks/patches/inject_ar_shape_logging.py
27+
echo "[AR_SHAPE] Injection complete"
28+
fi
29+
2330
PREFILL_SIZE=196608
2431
if [[ "$ISL" == "8192" && "$OSL" == "1024" ]]; then
2532
if [[ "$CONC" -gt "32" ]]; then

0 commit comments

Comments
 (0)