diff --git a/param_decomp/metrics/persistent_pgd_recon.py b/param_decomp/metrics/persistent_pgd_recon.py index 498dee22f..4f4b0311f 100644 --- a/param_decomp/metrics/persistent_pgd_recon.py +++ b/param_decomp/metrics/persistent_pgd_recon.py @@ -30,7 +30,6 @@ PersistentPGDSourceScope, PersistentPGDState, PGDOptimizerConfig, - PPGDSources, RepeatAcrossBatchScope, get_ppgd_mask_infos, scope_needs_replica_sync, @@ -129,7 +128,7 @@ class _PersistentPGDReconBase[ def __init__(self, cfg: TConfig) -> None: super().__init__(cfg) self.state: PersistentPGDState | None = None - self._pending_source_grads: PPGDSources | None = None + self._source_step_active = False # Stash from `load_state_dict` if called before the first `update()` — # `PersistentPGDState` needs batch_dims, which we only learn from a live ctx. self._pending_resume_state: dict[str, Any] | None = None @@ -250,18 +249,18 @@ def compute(self) -> MetricResult: @override def before_backward(self, live_loss: Tensor | None) -> None: - if live_loss is None or self.state is None: - return - grads = self.state.get_grads(live_loss, retain_graph=True) - self._pending_source_grads = self.state.reduce_source_grads(grads) + # Sources are leaves in total_loss; the main backward fills source.grad — no extra pass. + self._source_step_active = live_loss is not None and self.state is not None @override def after_backward(self) -> None: - if self._pending_source_grads is None: + if not self._source_step_active: return assert self.state is not None - self.state.step(self._pending_source_grads) - self._pending_source_grads = None + assert self.cfg.coeff is not None + grads = self.state.reduce_source_grads(self.state.collect_source_grads(self.cfg.coeff)) + self.state.step(grads) + self._source_step_active = False @override def state_dict(self) -> dict[str, Any]: diff --git a/param_decomp/metrics/persistent_pgd_state.py b/param_decomp/metrics/persistent_pgd_state.py index d15a9ad4c..efb47f7ee 100644 --- a/param_decomp/metrics/persistent_pgd_state.py +++ b/param_decomp/metrics/persistent_pgd_state.py @@ -293,6 +293,16 @@ def get_grads(self, loss: Float[Tensor, ""], retain_graph: bool = True) -> PPGDS grads = torch.autograd.grad(loss, list(self.sources.values()), retain_graph=retain_graph) return dict(zip(self.sources.keys(), grads, strict=True)) + def collect_source_grads(self, loss_coeff: float) -> PPGDSources: + """Raw ``∂loss/∂source`` read from ``.grad`` (the main backward filled it, scaled by the + loss's ``loss_coeff``); clears ``.grad``. Reduce/step as for ``get_grads`` output.""" + grads: PPGDSources = {} + for name, source in self.sources.items(): + assert source.grad is not None, f"source {name!r}: no grad after backward" + grads[name] = source.grad / loss_coeff + source.grad = None + return grads + def reduce_source_grads(self, grads: PPGDSources) -> PPGDSources: """AVG-reduce per-rank source grads over the replica-sync group, else a no-op. diff --git a/param_decomp/tests/test_spd_losses.py b/param_decomp/tests/test_spd_losses.py index 6708d7ba2..2e24c5389 100644 --- a/param_decomp/tests/test_spd_losses.py +++ b/param_decomp/tests/test_spd_losses.py @@ -824,6 +824,40 @@ def test_scope_needs_replica_sync_classification(self: object) -> None: assert scope_needs_replica_sync(BroadcastAcrossBatchScope()) is True assert scope_needs_replica_sync(RepeatAcrossBatchScope(n_sources=2)) is True + def test_collect_source_grads_matches_get_grads(self: object) -> None: + # Fused backward: source.grad/coeff from the main backward == the separate get_grads pass. + fc_weight = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) + model = _make_seq_component_model(weight=fc_weight) + batch = torch.tensor([[[1.0, 2.0], [0.5, 1.5]]], dtype=torch.float32) + target_out = torch.tensor([[[0.3, 0.7], [0.9, 0.1]]], dtype=torch.float32) + ci = {"fc": torch.tensor([[[0.5], [0.5]]], dtype=torch.float32)} + + cfg = PersistentPGDReconLossConfig( + optimizer=SignPGDConfig(lr_schedule=ScheduleConfig(start_val=0.1)), + scope=SingleSourceScope(), + ) + state = _ppgd_state_from_cfg( + cfg, + module_to_c=model.module_to_c, + batch_dims=batch.shape[:2], + device="cpu", + use_delta_component=False, + reconstruction_loss=recon_loss_mse, + ) + + sum_loss, n = state.compute_recon_sum_and_n( + model=model, batch=batch, target_out=target_out, ci=ci, weight_deltas=None + ) + loss = sum_loss / n + separate = state.get_grads(loss, retain_graph=True) + assert any(g.abs().sum() > 0 for g in separate.values()), "grads all zero — vacuous test" + + coeff = 0.5 + (coeff * loss).backward() + fused = state.collect_source_grads(coeff) + for name in state.sources: + assert torch.allclose(fused[name], separate[name], atol=1e-6) + def test_masks_persist_across_calls(self: object) -> None: """Test that masks persist and accumulate updates across calls.""" fc_weight = torch.tensor([[2.0, 0.0], [0.0, 2.0]], dtype=torch.float32)