Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions param_decomp/metrics/persistent_pgd_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
PersistentPGDSourceScope,
PersistentPGDState,
PGDOptimizerConfig,
PPGDSources,
RepeatAcrossBatchScope,
get_ppgd_mask_infos,
scope_needs_replica_sync,
Expand Down Expand Up @@ -129,7 +128,7 @@ class _PersistentPGDReconBase[
def __init__(self, cfg: TConfig) -> None:
super().__init__(cfg)
self.state: PersistentPGDState | None = None
self._pending_source_grads: PPGDSources | None = None
self._source_step_active = False
# Stash from `load_state_dict` if called before the first `update()` —
# `PersistentPGDState` needs batch_dims, which we only learn from a live ctx.
self._pending_resume_state: dict[str, Any] | None = None
Expand Down Expand Up @@ -250,18 +249,18 @@ def compute(self) -> MetricResult:

@override
def before_backward(self, live_loss: Tensor | None) -> None:
if live_loss is None or self.state is None:
return
grads = self.state.get_grads(live_loss, retain_graph=True)
self._pending_source_grads = self.state.reduce_source_grads(grads)
# Sources are leaves in total_loss; the main backward fills source.grad — no extra pass.
self._source_step_active = live_loss is not None and self.state is not None

@override
def after_backward(self) -> None:
if self._pending_source_grads is None:
if not self._source_step_active:
return
assert self.state is not None
self.state.step(self._pending_source_grads)
self._pending_source_grads = None
assert self.cfg.coeff is not None
grads = self.state.reduce_source_grads(self.state.collect_source_grads(self.cfg.coeff))
self.state.step(grads)
self._source_step_active = False

@override
def state_dict(self) -> dict[str, Any]:
Expand Down
10 changes: 10 additions & 0 deletions param_decomp/metrics/persistent_pgd_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ def get_grads(self, loss: Float[Tensor, ""], retain_graph: bool = True) -> PPGDS
grads = torch.autograd.grad(loss, list(self.sources.values()), retain_graph=retain_graph)
return dict(zip(self.sources.keys(), grads, strict=True))

def collect_source_grads(self, loss_coeff: float) -> PPGDSources:
"""Raw ``∂loss/∂source`` read from ``.grad`` (the main backward filled it, scaled by the
loss's ``loss_coeff``); clears ``.grad``. Reduce/step as for ``get_grads`` output."""
grads: PPGDSources = {}
for name, source in self.sources.items():
assert source.grad is not None, f"source {name!r}: no grad after backward"
grads[name] = source.grad / loss_coeff
source.grad = None
return grads

def reduce_source_grads(self, grads: PPGDSources) -> PPGDSources:
"""AVG-reduce per-rank source grads over the replica-sync group, else a no-op.

Expand Down
34 changes: 34 additions & 0 deletions param_decomp/tests/test_spd_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,40 @@ def test_scope_needs_replica_sync_classification(self: object) -> None:
assert scope_needs_replica_sync(BroadcastAcrossBatchScope()) is True
assert scope_needs_replica_sync(RepeatAcrossBatchScope(n_sources=2)) is True

def test_collect_source_grads_matches_get_grads(self: object) -> None:
# Fused backward: source.grad/coeff from the main backward == the separate get_grads pass.
fc_weight = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
model = _make_seq_component_model(weight=fc_weight)
batch = torch.tensor([[[1.0, 2.0], [0.5, 1.5]]], dtype=torch.float32)
target_out = torch.tensor([[[0.3, 0.7], [0.9, 0.1]]], dtype=torch.float32)
ci = {"fc": torch.tensor([[[0.5], [0.5]]], dtype=torch.float32)}

cfg = PersistentPGDReconLossConfig(
optimizer=SignPGDConfig(lr_schedule=ScheduleConfig(start_val=0.1)),
scope=SingleSourceScope(),
)
state = _ppgd_state_from_cfg(
cfg,
module_to_c=model.module_to_c,
batch_dims=batch.shape[:2],
device="cpu",
use_delta_component=False,
reconstruction_loss=recon_loss_mse,
)

sum_loss, n = state.compute_recon_sum_and_n(
model=model, batch=batch, target_out=target_out, ci=ci, weight_deltas=None
)
loss = sum_loss / n
separate = state.get_grads(loss, retain_graph=True)
assert any(g.abs().sum() > 0 for g in separate.values()), "grads all zero — vacuous test"

coeff = 0.5
(coeff * loss).backward()
fused = state.collect_source_grads(coeff)
for name in state.sources:
assert torch.allclose(fused[name], separate[name], atol=1e-6)

def test_masks_persist_across_calls(self: object) -> None:
"""Test that masks persist and accumulate updates across calls."""
fc_weight = torch.tensor([[2.0, 0.0], [0.0, 2.0]], dtype=torch.float32)
Expand Down
Loading