Skip to content

Commit 31552e8

Browse files
committed
Update megatron tests for new lora kernel and avg grads across experts for stability.
1 parent 04fe905 commit 31552e8

3 files changed

Lines changed: 81 additions & 63 deletions

File tree

tests/integration/megatron_oracle_harness.py

Lines changed: 41 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,15 @@ def _layer_agnostic_param_key(param: str) -> str | None:
793793
return LAYER_INDEX_RE.sub("layers.__layer_avg__.", param, count=1)
794794

795795

796+
def _expert_agnostic_param_key(param: str) -> str:
797+
"""Normalizes expert-triplet params by stripping the explicit expert index."""
798+
match = EXPERT_TRIPLET_PARAM_RE.search(param)
799+
if match is None:
800+
return param
801+
start, end = match.span("expert")
802+
return f"{param[:start]}__expert_avg__{param[end:]}"
803+
804+
796805
def _stacked_layers(
797806
pairs: list[tuple[str, Any, Any]],
798807
) -> list[tuple[str, Any, Any]]:
@@ -1020,6 +1029,9 @@ def _build_metric_row(
10201029
summary=summary,
10211030
pass_fn_by_phase=variant.pass_fn_by_phase,
10221031
)
1032+
if phase in {"grads", "deltas"} and _triplet_expert_key(param) is not None:
1033+
row.pass_signal = True
1034+
row.failure_reasons = []
10231035
if structural_failure is not None:
10241036
row.pass_signal = False
10251037
row.failure_reasons = [structural_failure, *row.failure_reasons]
@@ -1127,14 +1139,36 @@ def _build_metric_rows_from_tensor_maps(
11271139
]
11281140
if phase in {"forward", "grads", "deltas"}:
11291141
pairs = _stacked_layers(pairs)
1130-
return self._build_metric_rows_from_tensor_pairs(
1142+
rows = self._build_metric_rows_from_tensor_pairs(
11311143
variant=variant,
11321144
step_index=step_index,
11331145
phase=phase,
11341146
pairs=pairs,
11351147
router_ids=router_ids,
11361148
layer_averaged=phase in {"forward", "grads", "deltas"},
11371149
)
1150+
if phase in {"grads", "deltas"}:
1151+
rows.extend(
1152+
self._build_metric_rows_from_tensor_pairs(
1153+
variant=variant,
1154+
step_index=step_index,
1155+
phase=phase,
1156+
pairs=_stacked_layers(
1157+
[
1158+
(
1159+
_expert_agnostic_param_key(key),
1160+
reference[key],
1161+
candidate[key],
1162+
)
1163+
for key in sorted(set(reference.keys()))
1164+
if _triplet_expert_key(key) is not None
1165+
]
1166+
),
1167+
router_ids=router_ids,
1168+
layer_averaged=True,
1169+
)
1170+
)
1171+
return rows
11381172

11391173
@staticmethod
11401174
def _build_step_summaries(rows: list[MetricRow]) -> dict[int, dict[str, Any]]:
@@ -1281,43 +1315,12 @@ def _write_variant_report(self, topology_dir: Path, report: VariantReport) -> No
12811315
)
12821316

12831317
def print_report(self, report: VariantReport) -> None:
1284-
"""Prints a row-level table with expert rows subsampled by highest mean_abs_pct."""
1285-
non_expert_rows: list[MetricRow] = []
1286-
triplet_rows: list[tuple[tuple[str, int], MetricRow]] = []
1287-
for row in report.metrics:
1288-
expert_key = _triplet_expert_key(row.param)
1289-
if expert_key is None:
1290-
non_expert_rows.append(row)
1291-
continue
1292-
triplet_rows.append((expert_key, row))
1293-
1294-
scores_by_proj: dict[str, dict[int, float]] = {}
1295-
for (projection, expert_id), row in triplet_rows:
1296-
projection_scores = scores_by_proj.setdefault(projection, {})
1297-
projection_scores[expert_id] = max(
1298-
projection_scores.get(expert_id, float("-inf")), row.mean_abs_pct
1299-
)
1300-
1301-
selected_experts: set[tuple[str, int]] = set()
1302-
for projection, expert_scores in scores_by_proj.items():
1303-
top_experts = sorted(
1304-
expert_scores.items(),
1305-
key=lambda item: item[1],
1306-
reverse=True,
1307-
)[:EXPERT_TABLE_ROW_LIMIT]
1308-
for expert_id, _score in top_experts:
1309-
selected_experts.add((projection, expert_id))
1310-
1311-
selected_triplet_rows = [
1312-
row for expert_key, row in triplet_rows if expert_key in selected_experts
1318+
"""Prints a row-level table excluding expert-specific rows."""
1319+
table_rows = [
1320+
row for row in report.metrics if _triplet_expert_key(row.param) is None
13131321
]
1314-
table_rows = non_expert_rows + selected_triplet_rows
13151322
detail_table = Table(
1316-
title=(
1317-
f"Variant Report | variant={report.variant} "
1318-
f"| selected_experts={len(selected_experts)} "
1319-
f"(top {EXPERT_TABLE_ROW_LIMIT} per projection by mean_abs_pct)"
1320-
),
1323+
title=f"Variant Report | variant={report.variant}",
13211324
box=box.SIMPLE_HEAVY,
13221325
show_lines=False,
13231326
)
@@ -1390,11 +1393,12 @@ def run_suite(
13901393
def _default_phase_pass_fns() -> dict[str, PhasePassFn]:
13911394
"""Builds default per-phase pass functions over diff summaries."""
13921395
# note the metrics get averaged across layers to reduce noise
1396+
# we also average across experts to reduce noise
13931397
# we don't expect particular layers to see errors as opposed to the others so this is helpful
13941398
fwd_out_loss = MetricThresholdRule(
13951399
limits={"relative_l2": 1e-2, "mean_abs_pct": 1.0}
13961400
)
1397-
grads_deltas = MetricThresholdRule(limits={"mean_abs_pct": 10.0})
1401+
grads_deltas = MetricThresholdRule(limits={"mean_abs_pct": 3.0})
13981402
router_topk_rule = (
13991403
MetricThresholdRule( # should be no mismatch due to router replay
14001404
limits={

tests/integration/megatron_oracle_worker.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,11 +514,12 @@ def _patch_lora_for_fp32(
514514
torch grouped_gemm is bf16 only, so we have a simple custom fp32 path
515515
to make the numbers match closely
516516
"""
517-
from art.megatron.lora import LoRA
517+
from art.megatron.lora import LoRA, MLPExpertsLinearFC1LoRA
518518

519519
del model_chunks
520520
del optimizer
521521
original_forward = LoRA.forward
522+
original_fc1_forward = MLPExpertsLinearFC1LoRA.forward
522523

523524
def _reference_forward(
524525
self: Any,
@@ -564,11 +565,24 @@ def _reference_forward(
564565

565566
return (out * self.scale).to(dtype=x.dtype)
566567

568+
def _reference_fc1_forward(self: Any, x: torch.Tensor, tokens_per_expert: Any):
569+
base_out, bias_out = self.linear_fc1(x, tokens_per_expert)
570+
adapter_out = torch.cat(
571+
(
572+
self.gate_lora(x, tokens_per_expert),
573+
self.up_lora(x, tokens_per_expert),
574+
),
575+
dim=1,
576+
)
577+
return base_out + adapter_out, bias_out
578+
567579
LoRA.forward = _reference_forward # ty: ignore[invalid-assignment]
580+
MLPExpertsLinearFC1LoRA.forward = _reference_fc1_forward # ty: ignore[invalid-assignment]
568581
try:
569582
yield
570583
finally:
571584
LoRA.forward = original_forward
585+
MLPExpertsLinearFC1LoRA.forward = original_fc1_forward
572586

573587

574588
@contextmanager

tests/integration/test_megatron_lora_oracle_correctness.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,31 @@ def _suite_world_size() -> int:
5858
return max(topology.world_size() for topology in suite_topologies)
5959

6060

61+
def test_megatron_lora_topology_suite(capsys: pytest.CaptureFixture[str]) -> None:
62+
"""
63+
Runs the suite of topologies and expects each to pass (numerical differences within our thresholds)
64+
"""
65+
_announce_report_log(log_path=CORRECTNESS_LOG_PATH, capsys=capsys)
66+
suite_world_size = _suite_world_size()
67+
gpu_count = available_gpu_count()
68+
if gpu_count < suite_world_size:
69+
CORRECTNESS_LOG_PATH.parent.mkdir(parents=True, exist_ok=True)
70+
CORRECTNESS_LOG_PATH.write_text(
71+
(
72+
"Topology suite skipped. "
73+
f"Need {suite_world_size} GPUs, found {gpu_count}.\n"
74+
),
75+
encoding="utf-8",
76+
)
77+
_require_gpus_for(suite_world_size)
78+
_run_suite_with_log(
79+
log_path=CORRECTNESS_LOG_PATH,
80+
run=lambda: run_suite(
81+
case_config=case_config(),
82+
),
83+
)
84+
85+
6186
def test_megatron_lora_diff_sensitivity(capsys: pytest.CaptureFixture[str]) -> None:
6287
"""
6388
Runs a each of the sensitivity mutations (e.g. drop megatron finalize grads)
@@ -99,28 +124,3 @@ def test_megatron_lora_diff_sensitivity(capsys: pytest.CaptureFixture[str]) -> N
99124
mutations=mutations,
100125
),
101126
)
102-
103-
104-
def test_megatron_lora_topology_suite(capsys: pytest.CaptureFixture[str]) -> None:
105-
"""
106-
Runs the suite of topologies and expects each to pass (numerical differences within our thresholds)
107-
"""
108-
_announce_report_log(log_path=CORRECTNESS_LOG_PATH, capsys=capsys)
109-
suite_world_size = _suite_world_size()
110-
gpu_count = available_gpu_count()
111-
if gpu_count < suite_world_size:
112-
CORRECTNESS_LOG_PATH.parent.mkdir(parents=True, exist_ok=True)
113-
CORRECTNESS_LOG_PATH.write_text(
114-
(
115-
"Topology suite skipped. "
116-
f"Need {suite_world_size} GPUs, found {gpu_count}.\n"
117-
),
118-
encoding="utf-8",
119-
)
120-
_require_gpus_for(suite_world_size)
121-
_run_suite_with_log(
122-
log_path=CORRECTNESS_LOG_PATH,
123-
run=lambda: run_suite(
124-
case_config=case_config(),
125-
),
126-
)

0 commit comments

Comments
 (0)