4848_GPU_INDEX_COL_RE = re .compile (r"^(index|gpu|gpu_id|gpu_index|card|device)$" , re .IGNORECASE )
4949_NUMBER_RE = re .compile (r"-?\d+(?:\.\d+)?" )
5050
51+ # Matches perf_samples_<role>_w<worker_idx>_<host>.csv as written by srt-slurm's
52+ # perfmon (SemiAnalysisAI/srt-slurm:feat/inferencex-perfmon). Hostnames can
53+ # contain underscores and digits, so the role and idx are anchored before the
54+ # host portion. Old-format filenames (perf_samples_<host>.csv) don't match
55+ # and fall through to the unlabeled cluster-wide path.
56+ _FILENAME_ROLE_RE = re .compile (r"^perf_samples_(?P<role>[a-z]+)_w(?P<idx>\d+)_(?P<host>.+)$" )
57+
58+
59+ def _parse_role_from_filename (path : Path ) -> tuple [str | None , int | None ]:
60+ """Return (role, worker_idx) parsed from the CSV stem, or (None, None).
61+
62+ Role is one of "prefill", "decode", "agg", "frontend" depending on what
63+ srt-slurm's _start_perf_monitor labels the node with. Unlabeled filenames
64+ (old format) return (None, None) so callers can treat them as cluster-wide
65+ contributions without per-worker attribution.
66+ """
67+ m = _FILENAME_ROLE_RE .match (path .stem )
68+ if not m :
69+ return None , None
70+ return m .group ("role" ), int (m .group ("idx" ))
71+
5172
5273def _parse_timestamp (value : str ) -> float | None :
5374 """Best-effort timestamp parse to Unix epoch seconds (local wall clock).
@@ -209,6 +230,70 @@ def aggregate_power(
209230 return mean (per_sample_mean_per_gpu ), num_gpus
210231
211232
233+ def aggregate_power_per_worker (
234+ csv_path : Path | Iterable [Path ],
235+ start_unix : float ,
236+ end_unix : float ,
237+ ) -> dict | None :
238+ """Aggregate measured power both cluster-wide and per worker.
239+
240+ Returns a dict with:
241+
242+ - cluster_avg_power_w: same number as aggregate_power's first tuple element
243+ - cluster_num_gpus: same as aggregate_power's second tuple element
244+ - workers: list of {role, worker_idx, num_gpus, avg_power_w}
245+ dicts — one per (role, worker_idx) group derived
246+ from CSV filenames. Empty list when no filenames
247+ match the labeled format (single-node single-CSV
248+ input, or older perfmon writing unlabeled paths).
249+
250+ Worker grouping is by filename: each path's role + worker_idx are parsed
251+ from ``perf_samples_<role>_w<idx>_<host>.csv``. Multiple CSVs sharing the
252+ same (role, worker_idx) — e.g. a multi-node TP=16 worker spanning 4 nodes —
253+ aggregate together as one worker. Unlabeled paths are silently dropped
254+ from the per-worker output but still contribute to the cluster-wide
255+ average via the underlying aggregate_power call.
256+
257+ Returns None when the cluster-wide aggregation returns None.
258+ """
259+ paths = [csv_path ] if isinstance (csv_path , Path ) else list (csv_path )
260+ if not paths or end_unix <= start_unix :
261+ return None
262+
263+ cluster = aggregate_power (paths , start_unix , end_unix )
264+ if cluster is None :
265+ return None
266+ cluster_avg , cluster_n = cluster
267+
268+ # Group paths by (role, worker_idx); silently skip files whose names
269+ # don't match the labeled format.
270+ groups : dict [tuple [str , int ], list [Path ]] = {}
271+ for p in paths :
272+ role , idx = _parse_role_from_filename (p )
273+ if role is None or idx is None :
274+ continue
275+ groups .setdefault ((role , idx ), []).append (p )
276+
277+ workers : list [dict ] = []
278+ for (role , idx ), group_paths in sorted (groups .items ()):
279+ result = aggregate_power (group_paths , start_unix , end_unix )
280+ if result is None :
281+ continue
282+ avg , n = result
283+ workers .append ({
284+ "role" : role ,
285+ "worker_idx" : idx ,
286+ "num_gpus" : n ,
287+ "avg_power_w" : round (avg , 3 ),
288+ })
289+
290+ return {
291+ "cluster_avg_power_w" : cluster_avg ,
292+ "cluster_num_gpus" : cluster_n ,
293+ "workers" : workers ,
294+ }
295+
296+
212297def _load_bench_window (
213298 bench_result_path : Path ,
214299) -> tuple [float , float , float , int , int ] | None :
@@ -285,12 +370,21 @@ def patch_agg_result(
285370 avg_power_w : float ,
286371 joules_per_output_token : float ,
287372 joules_per_total_token : float ,
373+ extras : dict | None = None ,
288374) -> None :
289- """Read the agg JSON, add the three power keys, and write it back atomically."""
375+ """Read the agg JSON, add the three base power keys + any extras, write back atomically.
376+
377+ ``extras`` is merged after the base three keys. Used for per-worker
378+ breakdowns (``workers``) and role-split energy metrics
379+ (``joules_per_input_token``, ``joules_per_output_token_decode``,
380+ ``prefill_avg_power_w``, ``decode_avg_power_w``) on disagg runs.
381+ """
290382 data = json .loads (agg_path .read_text (encoding = "utf-8" ))
291383 data ["avg_power_w" ] = round (avg_power_w , 3 )
292384 data ["joules_per_output_token" ] = round (joules_per_output_token , 6 )
293385 data ["joules_per_total_token" ] = round (joules_per_total_token , 6 )
386+ if extras :
387+ data .update (extras )
294388 tmp_path = agg_path .with_suffix (agg_path .suffix + ".tmp" )
295389 tmp_path .write_text (json .dumps (data , indent = 2 ), encoding = "utf-8" )
296390 tmp_path .replace (agg_path )
@@ -307,7 +401,7 @@ def run(csv_path: Path | Iterable[Path], bench_result: Path, agg_result: Path) -
307401 start , end , duration , total_output , total_input = window
308402
309403 paths = [csv_path ] if isinstance (csv_path , Path ) else list (csv_path )
310- result = aggregate_power (paths , start , end )
404+ result = aggregate_power_per_worker (paths , start , end )
311405 if result is None :
312406 label = str (paths [0 ]) if len (paths ) == 1 else f"{ len (paths )} CSVs"
313407 print (
@@ -316,18 +410,48 @@ def run(csv_path: Path | Iterable[Path], bench_result: Path, agg_result: Path) -
316410 file = sys .stderr ,
317411 )
318412 return 0
319- avg_power_w , num_gpus = result
413+ avg_power_w = result ["cluster_avg_power_w" ]
414+ num_gpus = result ["cluster_num_gpus" ]
415+ workers = result ["workers" ]
320416
321- # Joules consumed by the system during the bench window, divided by either
322- # output tokens (for generation-cost metrics) or all tokens (for whole-
323- # workload efficiency).
417+ # Cluster-wide energy and per-token metrics (existing behavior, unchanged).
324418 total_system_energy_j = avg_power_w * num_gpus * duration
325419 joules_per_output_token = total_system_energy_j / total_output
326420 total_tokens = total_output + total_input
327421 joules_per_total_token = (
328422 total_system_energy_j / total_tokens if total_tokens > 0 else joules_per_output_token
329423 )
330424
425+ # Per-role breakdown — only emitted when filenames had role labels (i.e.
426+ # srt-slurm's perfmon was the source). Role-split energy is only meaningful
427+ # for disagg runs (both prefill and decode workers present); aggregated
428+ # runs and frontend-only nodes don't contribute to the role-split fields.
429+ extras : dict = {}
430+ if workers :
431+ extras ["workers" ] = workers
432+ prefill = [w for w in workers if w ["role" ] == "prefill" ]
433+ decode = [w for w in workers if w ["role" ] == "decode" ]
434+ if prefill and decode :
435+ prefill_gpus = sum (w ["num_gpus" ] for w in prefill )
436+ decode_gpus = sum (w ["num_gpus" ] for w in decode )
437+ prefill_energy_j = sum (
438+ w ["avg_power_w" ] * w ["num_gpus" ] * duration for w in prefill
439+ )
440+ decode_energy_j = sum (
441+ w ["avg_power_w" ] * w ["num_gpus" ] * duration for w in decode
442+ )
443+ extras ["prefill_avg_power_w" ] = round (
444+ sum (w ["avg_power_w" ] * w ["num_gpus" ] for w in prefill ) / prefill_gpus , 3
445+ )
446+ extras ["decode_avg_power_w" ] = round (
447+ sum (w ["avg_power_w" ] * w ["num_gpus" ] for w in decode ) / decode_gpus , 3
448+ )
449+ if total_input > 0 :
450+ extras ["joules_per_input_token" ] = round (prefill_energy_j / total_input , 6 )
451+ extras ["joules_per_output_token_decode" ] = round (
452+ decode_energy_j / total_output , 6
453+ )
454+
331455 if not agg_result .is_file ():
332456 print (
333457 f"[aggregate_power] Agg result { agg_result } missing — cannot patch" ,
@@ -337,18 +461,27 @@ def run(csv_path: Path | Iterable[Path], bench_result: Path, agg_result: Path) -
337461
338462 try :
339463 patch_agg_result (
340- agg_result , avg_power_w , joules_per_output_token , joules_per_total_token
464+ agg_result ,
465+ avg_power_w ,
466+ joules_per_output_token ,
467+ joules_per_total_token ,
468+ extras = extras ,
341469 )
342470 except (OSError , json .JSONDecodeError ) as exc :
343471 print (f"[aggregate_power] Failed to patch { agg_result } : { exc } " , file = sys .stderr )
344472 return 0
345473
474+ role_summary = (
475+ f" prefill={ extras ['prefill_avg_power_w' ]:.0f} W decode={ extras ['decode_avg_power_w' ]:.0f} W"
476+ if "prefill_avg_power_w" in extras and "decode_avg_power_w" in extras
477+ else ""
478+ )
346479 print (
347- f"[aggregate_power] avg_power_w={ avg_power_w :.2f} (per GPU, n={ num_gpus } ) "
480+ f"[aggregate_power] avg_power_w={ avg_power_w :.2f} (per GPU, n={ num_gpus } ){ role_summary } "
348481 f"joules_per_output_token={ joules_per_output_token :.4f} "
349482 f"joules_per_total_token={ joules_per_total_token :.4f} "
350483 f"duration={ duration :.1f} s output_tokens={ total_output } input_tokens={ total_input } "
351- f"-> { agg_result } "
484+ f"workers= { len ( workers ) } -> { agg_result } "
352485 )
353486 return 0
354487
0 commit comments