Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions param_decomp_lab/tests/test_three_pool_grad_scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Regression test for PPGD -> CI gradient scaling under a multi-rank CI pool.

The CI pool ends each step with an AVG all-reduce over its ``n_ci`` ranks
(``all_reduce_ci_fn_grads``). PPGD's CI grad is injected per-position on a single
CI rank, so that AVG divides it by ``n_ci``; ``_scale_grads`` must pre-multiply
the CI grad by ``n_ci`` to compensate (the LW stoch path does the same via its
``/ n_ci`` denom). V/U never hits that AVG, so it keeps the plain scale.

At ``n_ci=1`` the factor is a no-op — which is why the bug was invisible to the
8-GPU (``n_ci=1``) configs and only bit production (``n_ci`` = 16 / 24).
"""

from types import SimpleNamespace
from typing import cast

import torch

from param_decomp_lab.three_pool.context import PPGDContext
from param_decomp_lab.three_pool.runtime import _ThreePoolRuntime
from param_decomp_lab.three_pool.step_ppgd import RawGrads, _scale_grads


def _fake_ctx(*, n_ci: int, n_ppgd: int) -> PPGDContext:
# _scale_grads only reads ctx.world.{n_ci,n_ppgd}; cast through object.
fake = SimpleNamespace(world=SimpleNamespace(n_ci=n_ci, n_ppgd=n_ppgd))
return cast(PPGDContext, cast(object, fake))


def _fake_cfg(*, coeff_ppgd: float) -> _ThreePoolRuntime:
# _scale_grads only reads cfg.coeff_ppgd; cast through object.
return cast(_ThreePoolRuntime, cast(object, SimpleNamespace(coeff_ppgd=coeff_ppgd)))


def _ones() -> dict[str, torch.Tensor]:
return {"site": torch.ones(2, 2)}


def test_ppgd_ci_grad_carries_extra_n_ci_factor() -> None:
n_ci, n_ppgd, n_examples_local, coeff = 4, 2, 8, 3.0
raw = RawGrads(v=_ones(), u=_ones(), ci=_ones(), sources=_ones())

_scale_grads(
raw, n_examples_local, _fake_ctx(n_ci=n_ci, n_ppgd=n_ppgd), _fake_cfg(coeff_ppgd=coeff)
)

vu_scale = coeff / (n_examples_local * n_ppgd)
torch.testing.assert_close(raw.v["site"], torch.full((2, 2), vu_scale))
torch.testing.assert_close(raw.u["site"], torch.full((2, 2), vu_scale))
# CI must carry the extra * n_ci to survive the CI-pool AVG all-reduce.
torch.testing.assert_close(raw.ci["site"], torch.full((2, 2), vu_scale * n_ci))
# Sources: 1 / n_examples_local only — no coeff, no 1/n_ppgd, no n_ci.
torch.testing.assert_close(raw.sources["site"], torch.full((2, 2), 1.0 / n_examples_local))
29 changes: 18 additions & 11 deletions param_decomp_lab/three_pool/step_ppgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
D4. Recon loss with the refined sources (the (N+1)'th forward).
D5. Backward: one ``torch.autograd.grad`` over the UNSCALED recon sum yields
raw g_VU + g_CI + g_sources; each consumer's own normalization is applied
explicitly afterward (V/U + CI: coeff_ppgd / n_examples_global; sources:
1 / n_examples_local — no coeff, no 1/n_ppgd).
explicitly afterward (V/U: coeff_ppgd / n_examples_global; CI: that times
n_ci to survive the CI-pool AVG-reduce; sources: 1 / n_examples_local).
D5b. Send g_CI to CI pool (every rank, on its batch slice). Peer-to-peer
point-to-point — no PPGD-internal reduce needed, so fires immediately
after backward to unblock CI's recv-wait sooner.
Expand Down Expand Up @@ -232,20 +232,27 @@ def _scale_grads(
only its batch-slice's contribution, so the per-rank scale carries the
1/n_ppgd_ranks and the in-pool SUM-reduce reassembles the full-batch grad.

V/U and CI share that per-rank scale but diverge by their downstream
reduction: V/U is SUM-reduced in the PPGD pool, but the CI grad is later
AVG-reduced over n_ci ranks on the CI pool (``all_reduce_ci_fn_grads``), which
divides it by n_ci. CI therefore needs an extra * n_ci to survive that AVG —
mirroring the LW stoch path's `/ n_ci` denom.

Sources are per-rank-local adversary state optimized against THIS rank's own
recon mean (recon_sum_loss / n_examples_local) — no coeff, no 1/n_ppgd
(those are V/U-reduction artifacts). Identical to the warmup source grad.
recon mean (recon_sum_loss / n_examples_local) — no coeff, no 1/n_ppgd, no
n_ci. Identical to the warmup source grad.

Folding any one consumer's scaling into the differentiated scalar — as an
earlier version did, dividing by n_ppgd for the V/U reduce — silently
mis-scales the others (it gave the source step a spurious 1/n_ppgd).
Folding any one consumer's scaling into the differentiated scalar silently
mis-scales the others, which is why each is applied explicitly here.
"""
n_ppgd_ranks = ctx.world.n_ppgd
vu_and_ci_grad_scale = cfg.coeff_ppgd / (n_examples_local * n_ppgd_ranks)
n_ci = ctx.world.n_ci
vu_grad_scale = cfg.coeff_ppgd / (n_examples_local * n_ppgd_ranks)
ci_grad_scale = vu_grad_scale * n_ci
source_grad_scale = 1.0 / n_examples_local
_scale_grads_in_place(raw.v, vu_and_ci_grad_scale)
_scale_grads_in_place(raw.u, vu_and_ci_grad_scale)
_scale_grads_in_place(raw.ci, vu_and_ci_grad_scale)
_scale_grads_in_place(raw.v, vu_grad_scale)
_scale_grads_in_place(raw.u, vu_grad_scale)
_scale_grads_in_place(raw.ci, ci_grad_scale)
_scale_grads_in_place(raw.sources, source_grad_scale)


Expand Down
Loading