diff --git a/param_decomp_lab/tests/test_three_pool_grad_scaling.py b/param_decomp_lab/tests/test_three_pool_grad_scaling.py new file mode 100644 index 000000000..ac4abd81b --- /dev/null +++ b/param_decomp_lab/tests/test_three_pool_grad_scaling.py @@ -0,0 +1,52 @@ +"""Regression test for PPGD -> CI gradient scaling under a multi-rank CI pool. + +The CI pool ends each step with an AVG all-reduce over its ``n_ci`` ranks +(``all_reduce_ci_fn_grads``). PPGD's CI grad is injected per-position on a single +CI rank, so that AVG divides it by ``n_ci``; ``_scale_grads`` must pre-multiply +the CI grad by ``n_ci`` to compensate (the LW stoch path does the same via its +``/ n_ci`` denom). V/U never hits that AVG, so it keeps the plain scale. + +At ``n_ci=1`` the factor is a no-op — which is why the bug was invisible to the +8-GPU (``n_ci=1``) configs and only bit production (``n_ci`` = 16 / 24). +""" + +from types import SimpleNamespace +from typing import cast + +import torch + +from param_decomp_lab.three_pool.context import PPGDContext +from param_decomp_lab.three_pool.runtime import _ThreePoolRuntime +from param_decomp_lab.three_pool.step_ppgd import RawGrads, _scale_grads + + +def _fake_ctx(*, n_ci: int, n_ppgd: int) -> PPGDContext: + # _scale_grads only reads ctx.world.{n_ci,n_ppgd}; cast through object. + fake = SimpleNamespace(world=SimpleNamespace(n_ci=n_ci, n_ppgd=n_ppgd)) + return cast(PPGDContext, cast(object, fake)) + + +def _fake_cfg(*, coeff_ppgd: float) -> _ThreePoolRuntime: + # _scale_grads only reads cfg.coeff_ppgd; cast through object. + return cast(_ThreePoolRuntime, cast(object, SimpleNamespace(coeff_ppgd=coeff_ppgd))) + + +def _ones() -> dict[str, torch.Tensor]: + return {"site": torch.ones(2, 2)} + + +def test_ppgd_ci_grad_carries_extra_n_ci_factor() -> None: + n_ci, n_ppgd, n_examples_local, coeff = 4, 2, 8, 3.0 + raw = RawGrads(v=_ones(), u=_ones(), ci=_ones(), sources=_ones()) + + _scale_grads( + raw, n_examples_local, _fake_ctx(n_ci=n_ci, n_ppgd=n_ppgd), _fake_cfg(coeff_ppgd=coeff) + ) + + vu_scale = coeff / (n_examples_local * n_ppgd) + torch.testing.assert_close(raw.v["site"], torch.full((2, 2), vu_scale)) + torch.testing.assert_close(raw.u["site"], torch.full((2, 2), vu_scale)) + # CI must carry the extra * n_ci to survive the CI-pool AVG all-reduce. + torch.testing.assert_close(raw.ci["site"], torch.full((2, 2), vu_scale * n_ci)) + # Sources: 1 / n_examples_local only — no coeff, no 1/n_ppgd, no n_ci. + torch.testing.assert_close(raw.sources["site"], torch.full((2, 2), 1.0 / n_examples_local)) diff --git a/param_decomp_lab/three_pool/step_ppgd.py b/param_decomp_lab/three_pool/step_ppgd.py index 4b66ef005..b7a18af52 100644 --- a/param_decomp_lab/three_pool/step_ppgd.py +++ b/param_decomp_lab/three_pool/step_ppgd.py @@ -27,8 +27,8 @@ D4. Recon loss with the refined sources (the (N+1)'th forward). D5. Backward: one ``torch.autograd.grad`` over the UNSCALED recon sum yields raw g_VU + g_CI + g_sources; each consumer's own normalization is applied - explicitly afterward (V/U + CI: coeff_ppgd / n_examples_global; sources: - 1 / n_examples_local — no coeff, no 1/n_ppgd). + explicitly afterward (V/U: coeff_ppgd / n_examples_global; CI: that times + n_ci to survive the CI-pool AVG-reduce; sources: 1 / n_examples_local). D5b. Send g_CI to CI pool (every rank, on its batch slice). Peer-to-peer point-to-point — no PPGD-internal reduce needed, so fires immediately after backward to unblock CI's recv-wait sooner. @@ -232,20 +232,27 @@ def _scale_grads( only its batch-slice's contribution, so the per-rank scale carries the 1/n_ppgd_ranks and the in-pool SUM-reduce reassembles the full-batch grad. + V/U and CI share that per-rank scale but diverge by their downstream + reduction: V/U is SUM-reduced in the PPGD pool, but the CI grad is later + AVG-reduced over n_ci ranks on the CI pool (``all_reduce_ci_fn_grads``), which + divides it by n_ci. CI therefore needs an extra * n_ci to survive that AVG — + mirroring the LW stoch path's `/ n_ci` denom. + Sources are per-rank-local adversary state optimized against THIS rank's own - recon mean (recon_sum_loss / n_examples_local) — no coeff, no 1/n_ppgd - (those are V/U-reduction artifacts). Identical to the warmup source grad. + recon mean (recon_sum_loss / n_examples_local) — no coeff, no 1/n_ppgd, no + n_ci. Identical to the warmup source grad. - Folding any one consumer's scaling into the differentiated scalar — as an - earlier version did, dividing by n_ppgd for the V/U reduce — silently - mis-scales the others (it gave the source step a spurious 1/n_ppgd). + Folding any one consumer's scaling into the differentiated scalar silently + mis-scales the others, which is why each is applied explicitly here. """ n_ppgd_ranks = ctx.world.n_ppgd - vu_and_ci_grad_scale = cfg.coeff_ppgd / (n_examples_local * n_ppgd_ranks) + n_ci = ctx.world.n_ci + vu_grad_scale = cfg.coeff_ppgd / (n_examples_local * n_ppgd_ranks) + ci_grad_scale = vu_grad_scale * n_ci source_grad_scale = 1.0 / n_examples_local - _scale_grads_in_place(raw.v, vu_and_ci_grad_scale) - _scale_grads_in_place(raw.u, vu_and_ci_grad_scale) - _scale_grads_in_place(raw.ci, vu_and_ci_grad_scale) + _scale_grads_in_place(raw.v, vu_grad_scale) + _scale_grads_in_place(raw.u, vu_grad_scale) + _scale_grads_in_place(raw.ci, ci_grad_scale) _scale_grads_in_place(raw.sources, source_grad_scale)