Skip to content

Commit d04634e

Browse files
committed
feat(power): per-worker prefill/decode power attribution + role-split joules
Layers per-worker breakdown on top of the cluster-wide multinode aggregation in the parent PR #1574. New agg JSON fields (additive — all existing keys preserved bit-for-bit for backward compat): workers: [{role, worker_idx, num_gpus, avg_power_w}, ...] role ∈ "prefill" / "decode" / "agg" / "frontend". Each (role, idx) aggregates across all CSVs for that worker — a multi-node TP=16 decode worker on 4 nodes produces one workers entry with num_gpus=16. prefill_avg_power_w, decode_avg_power_w (disagg only) Weighted per-GPU averages within each role. joules_per_input_token = prefill_energy / total_input_tokens joules_per_output_token_decode = decode_energy / total_output_tokens Disagg-only role-split metrics. Existing joules_per_output_token and joules_per_total_token keep their cluster-wide semantics so the chart won't shift on existing data. Worker → CSV mapping is by filename: srt-slurm's perfmon (companion change on SemiAnalysisAI/srt-slurm c4c86dc) writes `perf_samples_<role>_w<worker_idx>_<host>.csv`. Unlabeled filenames (old single-CSV format) silently emit empty workers list and skip the role split — cluster-wide metrics unchanged in that case. 77/77 tests pass (68 existing + 9 new — per-worker grouping, multi-node worker aggregation, mixed labeled/unlabeled inputs, disagg E2E with role split, agg E2E omitting disagg-only fields, bit-for-bit backward compat for old-format callers).
1 parent 6da2f1b commit d04634e

2 files changed

Lines changed: 339 additions & 9 deletions

File tree

utils/aggregate_power.py

Lines changed: 142 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@
4848
_GPU_INDEX_COL_RE = re.compile(r"^(index|gpu|gpu_id|gpu_index|card|device)$", re.IGNORECASE)
4949
_NUMBER_RE = re.compile(r"-?\d+(?:\.\d+)?")
5050

51+
# Matches perf_samples_<role>_w<worker_idx>_<host>.csv as written by srt-slurm's
52+
# perfmon (SemiAnalysisAI/srt-slurm:feat/inferencex-perfmon). Hostnames can
53+
# contain underscores and digits, so the role and idx are anchored before the
54+
# host portion. Old-format filenames (perf_samples_<host>.csv) don't match
55+
# and fall through to the unlabeled cluster-wide path.
56+
_FILENAME_ROLE_RE = re.compile(r"^perf_samples_(?P<role>[a-z]+)_w(?P<idx>\d+)_(?P<host>.+)$")
57+
58+
59+
def _parse_role_from_filename(path: Path) -> tuple[str | None, int | None]:
60+
"""Return (role, worker_idx) parsed from the CSV stem, or (None, None).
61+
62+
Role is one of "prefill", "decode", "agg", "frontend" depending on what
63+
srt-slurm's _start_perf_monitor labels the node with. Unlabeled filenames
64+
(old format) return (None, None) so callers can treat them as cluster-wide
65+
contributions without per-worker attribution.
66+
"""
67+
m = _FILENAME_ROLE_RE.match(path.stem)
68+
if not m:
69+
return None, None
70+
return m.group("role"), int(m.group("idx"))
71+
5172

5273
def _parse_timestamp(value: str) -> float | None:
5374
"""Best-effort timestamp parse to Unix epoch seconds (local wall clock).
@@ -209,6 +230,70 @@ def aggregate_power(
209230
return mean(per_sample_mean_per_gpu), num_gpus
210231

211232

233+
def aggregate_power_per_worker(
234+
csv_path: Path | Iterable[Path],
235+
start_unix: float,
236+
end_unix: float,
237+
) -> dict | None:
238+
"""Aggregate measured power both cluster-wide and per worker.
239+
240+
Returns a dict with:
241+
242+
- cluster_avg_power_w: same number as aggregate_power's first tuple element
243+
- cluster_num_gpus: same as aggregate_power's second tuple element
244+
- workers: list of {role, worker_idx, num_gpus, avg_power_w}
245+
dicts — one per (role, worker_idx) group derived
246+
from CSV filenames. Empty list when no filenames
247+
match the labeled format (single-node single-CSV
248+
input, or older perfmon writing unlabeled paths).
249+
250+
Worker grouping is by filename: each path's role + worker_idx are parsed
251+
from ``perf_samples_<role>_w<idx>_<host>.csv``. Multiple CSVs sharing the
252+
same (role, worker_idx) — e.g. a multi-node TP=16 worker spanning 4 nodes —
253+
aggregate together as one worker. Unlabeled paths are silently dropped
254+
from the per-worker output but still contribute to the cluster-wide
255+
average via the underlying aggregate_power call.
256+
257+
Returns None when the cluster-wide aggregation returns None.
258+
"""
259+
paths = [csv_path] if isinstance(csv_path, Path) else list(csv_path)
260+
if not paths or end_unix <= start_unix:
261+
return None
262+
263+
cluster = aggregate_power(paths, start_unix, end_unix)
264+
if cluster is None:
265+
return None
266+
cluster_avg, cluster_n = cluster
267+
268+
# Group paths by (role, worker_idx); silently skip files whose names
269+
# don't match the labeled format.
270+
groups: dict[tuple[str, int], list[Path]] = {}
271+
for p in paths:
272+
role, idx = _parse_role_from_filename(p)
273+
if role is None or idx is None:
274+
continue
275+
groups.setdefault((role, idx), []).append(p)
276+
277+
workers: list[dict] = []
278+
for (role, idx), group_paths in sorted(groups.items()):
279+
result = aggregate_power(group_paths, start_unix, end_unix)
280+
if result is None:
281+
continue
282+
avg, n = result
283+
workers.append({
284+
"role": role,
285+
"worker_idx": idx,
286+
"num_gpus": n,
287+
"avg_power_w": round(avg, 3),
288+
})
289+
290+
return {
291+
"cluster_avg_power_w": cluster_avg,
292+
"cluster_num_gpus": cluster_n,
293+
"workers": workers,
294+
}
295+
296+
212297
def _load_bench_window(
213298
bench_result_path: Path,
214299
) -> tuple[float, float, float, int, int] | None:
@@ -285,12 +370,21 @@ def patch_agg_result(
285370
avg_power_w: float,
286371
joules_per_output_token: float,
287372
joules_per_total_token: float,
373+
extras: dict | None = None,
288374
) -> None:
289-
"""Read the agg JSON, add the three power keys, and write it back atomically."""
375+
"""Read the agg JSON, add the three base power keys + any extras, write back atomically.
376+
377+
``extras`` is merged after the base three keys. Used for per-worker
378+
breakdowns (``workers``) and role-split energy metrics
379+
(``joules_per_input_token``, ``joules_per_output_token_decode``,
380+
``prefill_avg_power_w``, ``decode_avg_power_w``) on disagg runs.
381+
"""
290382
data = json.loads(agg_path.read_text(encoding="utf-8"))
291383
data["avg_power_w"] = round(avg_power_w, 3)
292384
data["joules_per_output_token"] = round(joules_per_output_token, 6)
293385
data["joules_per_total_token"] = round(joules_per_total_token, 6)
386+
if extras:
387+
data.update(extras)
294388
tmp_path = agg_path.with_suffix(agg_path.suffix + ".tmp")
295389
tmp_path.write_text(json.dumps(data, indent=2), encoding="utf-8")
296390
tmp_path.replace(agg_path)
@@ -307,7 +401,7 @@ def run(csv_path: Path | Iterable[Path], bench_result: Path, agg_result: Path) -
307401
start, end, duration, total_output, total_input = window
308402

309403
paths = [csv_path] if isinstance(csv_path, Path) else list(csv_path)
310-
result = aggregate_power(paths, start, end)
404+
result = aggregate_power_per_worker(paths, start, end)
311405
if result is None:
312406
label = str(paths[0]) if len(paths) == 1 else f"{len(paths)} CSVs"
313407
print(
@@ -316,18 +410,48 @@ def run(csv_path: Path | Iterable[Path], bench_result: Path, agg_result: Path) -
316410
file=sys.stderr,
317411
)
318412
return 0
319-
avg_power_w, num_gpus = result
413+
avg_power_w = result["cluster_avg_power_w"]
414+
num_gpus = result["cluster_num_gpus"]
415+
workers = result["workers"]
320416

321-
# Joules consumed by the system during the bench window, divided by either
322-
# output tokens (for generation-cost metrics) or all tokens (for whole-
323-
# workload efficiency).
417+
# Cluster-wide energy and per-token metrics (existing behavior, unchanged).
324418
total_system_energy_j = avg_power_w * num_gpus * duration
325419
joules_per_output_token = total_system_energy_j / total_output
326420
total_tokens = total_output + total_input
327421
joules_per_total_token = (
328422
total_system_energy_j / total_tokens if total_tokens > 0 else joules_per_output_token
329423
)
330424

425+
# Per-role breakdown — only emitted when filenames had role labels (i.e.
426+
# srt-slurm's perfmon was the source). Role-split energy is only meaningful
427+
# for disagg runs (both prefill and decode workers present); aggregated
428+
# runs and frontend-only nodes don't contribute to the role-split fields.
429+
extras: dict = {}
430+
if workers:
431+
extras["workers"] = workers
432+
prefill = [w for w in workers if w["role"] == "prefill"]
433+
decode = [w for w in workers if w["role"] == "decode"]
434+
if prefill and decode:
435+
prefill_gpus = sum(w["num_gpus"] for w in prefill)
436+
decode_gpus = sum(w["num_gpus"] for w in decode)
437+
prefill_energy_j = sum(
438+
w["avg_power_w"] * w["num_gpus"] * duration for w in prefill
439+
)
440+
decode_energy_j = sum(
441+
w["avg_power_w"] * w["num_gpus"] * duration for w in decode
442+
)
443+
extras["prefill_avg_power_w"] = round(
444+
sum(w["avg_power_w"] * w["num_gpus"] for w in prefill) / prefill_gpus, 3
445+
)
446+
extras["decode_avg_power_w"] = round(
447+
sum(w["avg_power_w"] * w["num_gpus"] for w in decode) / decode_gpus, 3
448+
)
449+
if total_input > 0:
450+
extras["joules_per_input_token"] = round(prefill_energy_j / total_input, 6)
451+
extras["joules_per_output_token_decode"] = round(
452+
decode_energy_j / total_output, 6
453+
)
454+
331455
if not agg_result.is_file():
332456
print(
333457
f"[aggregate_power] Agg result {agg_result} missing — cannot patch",
@@ -337,18 +461,27 @@ def run(csv_path: Path | Iterable[Path], bench_result: Path, agg_result: Path) -
337461

338462
try:
339463
patch_agg_result(
340-
agg_result, avg_power_w, joules_per_output_token, joules_per_total_token
464+
agg_result,
465+
avg_power_w,
466+
joules_per_output_token,
467+
joules_per_total_token,
468+
extras=extras,
341469
)
342470
except (OSError, json.JSONDecodeError) as exc:
343471
print(f"[aggregate_power] Failed to patch {agg_result}: {exc}", file=sys.stderr)
344472
return 0
345473

474+
role_summary = (
475+
f" prefill={extras['prefill_avg_power_w']:.0f}W decode={extras['decode_avg_power_w']:.0f}W"
476+
if "prefill_avg_power_w" in extras and "decode_avg_power_w" in extras
477+
else ""
478+
)
346479
print(
347-
f"[aggregate_power] avg_power_w={avg_power_w:.2f} (per GPU, n={num_gpus}) "
480+
f"[aggregate_power] avg_power_w={avg_power_w:.2f} (per GPU, n={num_gpus}){role_summary} "
348481
f"joules_per_output_token={joules_per_output_token:.4f} "
349482
f"joules_per_total_token={joules_per_total_token:.4f} "
350483
f"duration={duration:.1f}s output_tokens={total_output} input_tokens={total_input} "
351-
f"-> {agg_result}"
484+
f"workers={len(workers)} -> {agg_result}"
352485
)
353486
return 0
354487

0 commit comments

Comments
 (0)