@@ -793,6 +793,15 @@ def _layer_agnostic_param_key(param: str) -> str | None:
793793 return LAYER_INDEX_RE .sub ("layers.__layer_avg__." , param , count = 1 )
794794
795795
796+ def _expert_agnostic_param_key (param : str ) -> str :
797+ """Normalizes expert-triplet params by stripping the explicit expert index."""
798+ match = EXPERT_TRIPLET_PARAM_RE .search (param )
799+ if match is None :
800+ return param
801+ start , end = match .span ("expert" )
802+ return f"{ param [:start ]} __expert_avg__{ param [end :]} "
803+
804+
796805def _stacked_layers (
797806 pairs : list [tuple [str , Any , Any ]],
798807) -> list [tuple [str , Any , Any ]]:
@@ -1020,6 +1029,9 @@ def _build_metric_row(
10201029 summary = summary ,
10211030 pass_fn_by_phase = variant .pass_fn_by_phase ,
10221031 )
1032+ if phase in {"grads" , "deltas" } and _triplet_expert_key (param ) is not None :
1033+ row .pass_signal = True
1034+ row .failure_reasons = []
10231035 if structural_failure is not None :
10241036 row .pass_signal = False
10251037 row .failure_reasons = [structural_failure , * row .failure_reasons ]
@@ -1127,14 +1139,36 @@ def _build_metric_rows_from_tensor_maps(
11271139 ]
11281140 if phase in {"forward" , "grads" , "deltas" }:
11291141 pairs = _stacked_layers (pairs )
1130- return self ._build_metric_rows_from_tensor_pairs (
1142+ rows = self ._build_metric_rows_from_tensor_pairs (
11311143 variant = variant ,
11321144 step_index = step_index ,
11331145 phase = phase ,
11341146 pairs = pairs ,
11351147 router_ids = router_ids ,
11361148 layer_averaged = phase in {"forward" , "grads" , "deltas" },
11371149 )
1150+ if phase in {"grads" , "deltas" }:
1151+ rows .extend (
1152+ self ._build_metric_rows_from_tensor_pairs (
1153+ variant = variant ,
1154+ step_index = step_index ,
1155+ phase = phase ,
1156+ pairs = _stacked_layers (
1157+ [
1158+ (
1159+ _expert_agnostic_param_key (key ),
1160+ reference [key ],
1161+ candidate [key ],
1162+ )
1163+ for key in sorted (set (reference .keys ()))
1164+ if _triplet_expert_key (key ) is not None
1165+ ]
1166+ ),
1167+ router_ids = router_ids ,
1168+ layer_averaged = True ,
1169+ )
1170+ )
1171+ return rows
11381172
11391173 @staticmethod
11401174 def _build_step_summaries (rows : list [MetricRow ]) -> dict [int , dict [str , Any ]]:
@@ -1281,43 +1315,12 @@ def _write_variant_report(self, topology_dir: Path, report: VariantReport) -> No
12811315 )
12821316
12831317 def print_report (self , report : VariantReport ) -> None :
1284- """Prints a row-level table with expert rows subsampled by highest mean_abs_pct."""
1285- non_expert_rows : list [MetricRow ] = []
1286- triplet_rows : list [tuple [tuple [str , int ], MetricRow ]] = []
1287- for row in report .metrics :
1288- expert_key = _triplet_expert_key (row .param )
1289- if expert_key is None :
1290- non_expert_rows .append (row )
1291- continue
1292- triplet_rows .append ((expert_key , row ))
1293-
1294- scores_by_proj : dict [str , dict [int , float ]] = {}
1295- for (projection , expert_id ), row in triplet_rows :
1296- projection_scores = scores_by_proj .setdefault (projection , {})
1297- projection_scores [expert_id ] = max (
1298- projection_scores .get (expert_id , float ("-inf" )), row .mean_abs_pct
1299- )
1300-
1301- selected_experts : set [tuple [str , int ]] = set ()
1302- for projection , expert_scores in scores_by_proj .items ():
1303- top_experts = sorted (
1304- expert_scores .items (),
1305- key = lambda item : item [1 ],
1306- reverse = True ,
1307- )[:EXPERT_TABLE_ROW_LIMIT ]
1308- for expert_id , _score in top_experts :
1309- selected_experts .add ((projection , expert_id ))
1310-
1311- selected_triplet_rows = [
1312- row for expert_key , row in triplet_rows if expert_key in selected_experts
1318+ """Prints a row-level table excluding expert-specific rows."""
1319+ table_rows = [
1320+ row for row in report .metrics if _triplet_expert_key (row .param ) is None
13131321 ]
1314- table_rows = non_expert_rows + selected_triplet_rows
13151322 detail_table = Table (
1316- title = (
1317- f"Variant Report | variant={ report .variant } "
1318- f"| selected_experts={ len (selected_experts )} "
1319- f"(top { EXPERT_TABLE_ROW_LIMIT } per projection by mean_abs_pct)"
1320- ),
1323+ title = f"Variant Report | variant={ report .variant } " ,
13211324 box = box .SIMPLE_HEAVY ,
13221325 show_lines = False ,
13231326 )
@@ -1390,11 +1393,12 @@ def run_suite(
13901393def _default_phase_pass_fns () -> dict [str , PhasePassFn ]:
13911394 """Builds default per-phase pass functions over diff summaries."""
13921395 # note the metrics get averaged across layers to reduce noise
1396+ # we also average across experts to reduce noise
13931397 # we don't expect particular layers to see errors as opposed to the others so this is helpful
13941398 fwd_out_loss = MetricThresholdRule (
13951399 limits = {"relative_l2" : 1e-2 , "mean_abs_pct" : 1.0 }
13961400 )
1397- grads_deltas = MetricThresholdRule (limits = {"mean_abs_pct" : 10 .0 })
1401+ grads_deltas = MetricThresholdRule (limits = {"mean_abs_pct" : 3 .0 })
13981402 router_topk_rule = (
13991403 MetricThresholdRule ( # should be no mismatch due to router replay
14001404 limits = {
0 commit comments