Skip to content

Commit a0dcb76

Browse files
ocg-goodfireclaude
andcommitted
refactor(3-pool): SUM-grad convention for replicated-param gradient assembly
Replace the per-instance "pre-scale to survive a downstream reduce" patches (4 recurring bugs, latest PR #545's PPGD x n_ci) with a single convention: every data-parallel gradient reduction is SUM, and each producer emits a partial sum normalized only by the honest global count, carrying NO pool-size transport factor. SUM(partials) = total, so no producer needs any pool's size. Changes: - portals: all_reduce_ci_fn_grads and all_reduce_grads_in_block flip AVG -> SUM. - step_ppgd: V/U and CI collapse to one scale; the #545 x n_ci is deleted. - step_layerwise: stoch denom drops /n_ci (one scale now serves both CI leaves and V/U). faith + broadcast-PPGD V/U are "contribute once" (block leader only) so the SUM lands them exactly once instead of n_per_block x. - step_ci: imp-min uses the detached-global-residual trick instead of the autograd-aware all_reduce, so its backward is a local partial (no reliance on the old AVG cancelling an n_ci factor). - grad-clip n_replicas unchanged (counts distinct params for the global norm, independent of the reduce op). The replica count does not vanish for replicated contributions — it relocates from a numeric scale factor into a structural placement (contribute-once) or a graph decision (detach the global term). See SUM_GRAD_CONVENTION.md for the honest verdict on whether this is simpler than per-destination scale-splitting. Validated by a new distributed grad check at a non-square topology (n_ci=4 != n_per_block=2, 2 blocks, n_ppgd=2) with ALL loss terms enabled: fully-reduced CI-fn and V/U grads match a single-process full-batch reference (mean dist/ref = 1.000000, worst rel err 4e-07). Sensitivity confirmed: AVG on the CI reduce -> 1/n_ci on CI grads; faith on all ranks -> n_per_block x on V/U. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent db54c3e commit a0dcb76

9 files changed

Lines changed: 916 additions & 128 deletions

param_decomp_lab/tests/test_three_pool_grad_check_distributed.py

Lines changed: 577 additions & 0 deletions
Large diffs are not rendered by default.

param_decomp_lab/tests/test_three_pool_grad_scaling.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
"""Regression test for PPGD -> CI gradient scaling under a multi-rank CI pool.
1+
"""Regression test for PPGD gradient scaling under the SUM-grad convention.
22
3-
The CI pool ends each step with an AVG all-reduce over its ``n_ci`` ranks
4-
(``all_reduce_ci_fn_grads``). PPGD's CI grad is injected per-position on a single
5-
CI rank, so that AVG divides it by ``n_ci``; ``_scale_grads`` must pre-multiply
6-
the CI grad by ``n_ci`` to compensate (the LW stoch path does the same via its
7-
``/ n_ci`` denom). V/U never hits that AVG, so it keeps the plain scale.
3+
Under the SUM-grad convention (``three_pool/SUM_GRAD_CONVENTION.md``) every
4+
data-parallel gradient reduction is SUM, so a producer's grad is a partial sum
5+
normalized only by the honest global count — it carries NO pool-size transport
6+
factor. For PPGD this means V/U and CI now share ONE scale
7+
``coeff_ppgd / n_examples_global``: the old ``* n_ci`` on the CI grad (which
8+
compensated for the CI-pool AVG-reduce, PR #545) is gone, because the CI-pool
9+
reduce is now a SUM.
810
9-
At ``n_ci=1`` the factor is a no-op — which is why the bug was invisible to the
10-
8-GPU (``n_ci=1``) configs and only bit production (``n_ci`` = 16 / 24).
11+
Sources stay per-rank-local (``1 / n_examples_local``).
1112
"""
1213

1314
from types import SimpleNamespace
@@ -35,18 +36,34 @@ def _ones() -> dict[str, torch.Tensor]:
3536
return {"site": torch.ones(2, 2)}
3637

3738

38-
def test_ppgd_ci_grad_carries_extra_n_ci_factor() -> None:
39+
def test_ppgd_vu_and_ci_share_one_scale() -> None:
3940
n_ci, n_ppgd, n_examples_local, coeff = 4, 2, 8, 3.0
4041
raw = RawGrads(v=_ones(), u=_ones(), ci=_ones(), sources=_ones())
4142

4243
_scale_grads(
4344
raw, n_examples_local, _fake_ctx(n_ci=n_ci, n_ppgd=n_ppgd), _fake_cfg(coeff_ppgd=coeff)
4445
)
4546

46-
vu_scale = coeff / (n_examples_local * n_ppgd)
47-
torch.testing.assert_close(raw.v["site"], torch.full((2, 2), vu_scale))
48-
torch.testing.assert_close(raw.u["site"], torch.full((2, 2), vu_scale))
49-
# CI must carry the extra * n_ci to survive the CI-pool AVG all-reduce.
50-
torch.testing.assert_close(raw.ci["site"], torch.full((2, 2), vu_scale * n_ci))
47+
# V/U and CI are both partial sums under the SUM convention: one scale.
48+
shared_scale = coeff / (n_examples_local * n_ppgd)
49+
torch.testing.assert_close(raw.v["site"], torch.full((2, 2), shared_scale))
50+
torch.testing.assert_close(raw.u["site"], torch.full((2, 2), shared_scale))
51+
torch.testing.assert_close(raw.ci["site"], torch.full((2, 2), shared_scale))
5152
# Sources: 1 / n_examples_local only — no coeff, no 1/n_ppgd, no n_ci.
5253
torch.testing.assert_close(raw.sources["site"], torch.full((2, 2), 1.0 / n_examples_local))
54+
55+
56+
def test_ppgd_ci_scale_is_independent_of_n_ci() -> None:
57+
"""The defining property of the SUM convention: the CI grad scale does not
58+
depend on ``n_ci`` (the old patch multiplied it by ``n_ci``)."""
59+
n_ppgd, n_examples_local, coeff = 2, 8, 3.0
60+
scales: list[float] = []
61+
for n_ci in (1, 4, 16):
62+
raw = RawGrads(v=_ones(), u=_ones(), ci=_ones(), sources=_ones())
63+
_scale_grads(
64+
raw, n_examples_local, _fake_ctx(n_ci=n_ci, n_ppgd=n_ppgd), _fake_cfg(coeff_ppgd=coeff)
65+
)
66+
scales.append(raw.ci["site"][0, 0].item())
67+
assert scales[0] == scales[1] == scales[2], (
68+
f"CI scale must be independent of n_ci under SUM convention; got {scales}"
69+
)

param_decomp_lab/tests/test_three_pool_routing_plan.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def test_denom_matches_single_pool_normalization(plan: RoutingPlan) -> None:
124124
coeff_stoch=1.0,
125125
n_est=n_forwards,
126126
n_per_block=1,
127-
n_ci=1,
128127
strategy=strategy,
129128
bf16_autocast_enabled=False,
130129
)

param_decomp_lab/three_pool/CLAUDE.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@ docstring in `optimize.py` for the data-handling contract.
2020
| `step_{ci,layerwise,ppgd}.py` | per-pool step functions |
2121
| `routing_plan.py` | `RoutingPlan` (`PerSitePlan` \| `SubsetRoutingPlan`) — how each LW block turns its owned sites into a list of recon forwards |
2222
| `eval_step.py` | 3-pool eval pass (PPGD pool runs metrics; others barrier through) |
23+
| `SUM_GRAD_CONVENTION.md` | the gradient-assembly scaling convention (proposal) |
24+
25+
## Gradient-assembly scaling: the SUM convention
26+
27+
See `SUM_GRAD_CONVENTION.md` for the full derivation. Summary: every
28+
data-parallel gradient reduction is **SUM** (`all_reduce_ci_fn_grads`,
29+
`all_reduce_grads_in_block`, and PPGD's V/U reduce). Each producer emits a
30+
*partial sum* normalized only by the honest GLOBAL count — NO `n_ci` /
31+
`n_per_block` transport factor. `SUM(partials) = total`, so no producer needs a
32+
pool's size. The REPLICATED contributions are handled structurally rather than by
33+
a replica-count divide: faith + broadcast-PPGD V/U **contribute once** (emitted
34+
on the block leader only), and imp-min uses the **detached-global-residual** trick
35+
(`S = local + (all_reduce_sum(local.detach()) - local.detach())`) so its backward
36+
is a local partial. The grad-clip `n_replicas` is unchanged — it counts distinct
37+
params for the global norm, independent of the grad-reduce op. Validated by
38+
`tests/test_three_pool_grad_check_distributed.py` (non-square, all loss terms).
2339

2440
## Checkpoint save: partials on the loop, consolidation off it
2541

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# SUM-grad convention (proposal)
2+
3+
A structural redesign of the 3-pool gradient-assembly scaling, replacing the
4+
per-instance "pre-scale to survive a downstream reduction" patches (4 recurring
5+
bugs, latest = PR #545's PPGD `×n_ci`) with a single convention.
6+
7+
## The bug class
8+
9+
The CI-fn weights and the V/U weights are each REPLICATED across ranks. Their
10+
gradients are assembled from multiple producers (stoch, faith, imp-min, ppgd)
11+
and reduced across ranks. Every recurring bug had the same shape: a producer
12+
pre-scaled its gradient by a *pool-size factor* (`n_ci`, `n_per_block`) so its
13+
contribution would survive a downstream AVG-reduce it couldn't locally see. That
14+
factor leaks pool-size knowledge into the gradient VALUES, and a single
15+
differentiated scalar that feeds two destinations with different reductions
16+
(stoch → CI leaves ÷n_ci AND V/U ÷n_per_block) is guaranteed wrong on a
17+
non-square topology.
18+
19+
## The convention
20+
21+
**Every gradient crossing a cross-rank reduction is a partial SUM, normalized
22+
only by the honest GLOBAL count (global examples/positions × sites), carrying NO
23+
pool-size transport factor. All data-parallel gradient reductions are SUM.**
24+
25+
Partial sums compose: `SUM(partials) = total`. So no producer needs to know any
26+
pool's size; the only normalization is the honest global count, which is locally
27+
derivable (`P_global = n_positions_local × n_per_block`, `n_examples_global =
28+
n_examples_local × n_ppgd`). The conversion factor that turns a local count into
29+
a global count is NOT a transport factor — it is part of computing the honest
30+
denominator, and it disappears entirely on a square topology only by coincidence.
31+
32+
### Consequences
33+
34+
1. **The grad all-reduce is SUM.** After an *all*-reduce every rank holds the
35+
identical value either way; under SUM that value is the TOTAL, which equals
36+
the single-pool gradient *because each producer already divided by the global
37+
count*. The optimizer steps on it directly.
38+
2. **`cross_pool_clip_grad_norm(n_replicas)` is UNCHANGED.** Subtle: this divide
39+
is about counting DISTINCT parameters for the global norm, not about the grad
40+
reduce. After the in-pool *all*-reduce (SUM or AVG), every replica holds the
41+
IDENTICAL grad; the pool-wide sq-SUM therefore counts each block's params
42+
`n_per_block` times either way, so the `n_replicas` dedup stays. (The grad
43+
VALUE differs — SUM gives the single-pool total, AVG gave total/n_per_block —
44+
but the replica COUNT being summed is the same.)
45+
3. **stoch's one scale feeds both destinations.** CI leaves (→ CI pool, SUM) and
46+
V/U (→ LW block, SUM) both want the same partial-sum scale
47+
`coeff_stoch / (P_global × n_sites_total)`. The double-duty bug is structurally
48+
impossible now: there is only one correct scale and it serves both.
49+
4. **PPGD's `×n_ci` DIES.** V/U and CI both want `coeff_ppgd / n_examples_global`;
50+
the CI path no longer needs the extra `×n_ci` to survive an AVG. The two
51+
collapse to one scale — the shape the V/U path (which never had a bug) always
52+
had.
53+
54+
## The wrinkle: replicated contributions
55+
56+
The convention is clean for genuine DP partials (disjoint batch slices). It does
57+
NOT, by itself, handle REPLICATED contributions — gradients that are IDENTICAL on
58+
every rank in the reduction group because they were computed from replicated
59+
inputs rather than a disjoint data slice:
60+
61+
- **faith V/U** (`_faithfulness_loss`): computed from the replicated V/U weights →
62+
identical on every block rank → under SUM, `n_per_block×` too big.
63+
- **broadcast PPGD V/U**: sum-reduced within PPGD then broadcast to all block
64+
ranks → identical on every block rank → same `n_per_block×` problem.
65+
- **imp-min CI**: the autograd-aware `dist_fn.all_reduce(SUM)` backward
66+
SUM-reduces the *replicated* upstream gradient across the CI pool, leaving each
67+
rank with `n_ci×` its true partial. Under the old AVG this was exactly the
68+
factor that made it correct; under SUM it is `n_ci×` too big.
69+
70+
Three ways to handle each:
71+
72+
(a) **Divide the replicated contribution by the replica count before the SUM.**
73+
Rejected: this REINTRODUCES the pool-size factor into a producer — exactly
74+
what the convention abolishes. It only relocates the factor.
75+
(b) **Contribute once.** Compute the replicated contribution on a single rank
76+
(the block leader) so there is no replica to undo. Chosen for faith and
77+
broadcast PPGD V/U.
78+
(c) **Detached-global-residual.** Make the forward value global but the backward
79+
flow only through the local contribution:
80+
`S = local + (all_reduce_sum(local.detach()) - local.detach())`.
81+
Forward `S = global_sum`; backward `∂S/∂local = 1`, no cross-rank term, so
82+
each rank gets its TRUE partial which SUM-composes. Chosen for imp-min
83+
(its loss genuinely needs the global sum inside the `log2`, so option (b)
84+
doesn't apply — it isn't a replica, it's a global reduction).
85+
86+
### faith / broadcast PPGD → contribute once (option b)
87+
88+
- **faith**: run the faith backward on the **block leader only**. The leader's
89+
`.grad` then carries the full single-pool faith grad once; non-leaders
90+
contribute zero faith. After the block SUM every rank holds it exactly once.
91+
Faith is already divided by `numel_global`, so the leader's value is already
92+
the single-pool grad — no further scaling.
93+
- **broadcast PPGD V/U**: skip the in-block broadcast; the block **leader** adds
94+
the received PPGD grad to its `.grad`, non-leaders add nothing. After the block
95+
SUM every rank holds it once.
96+
97+
These two changes mean the block all-reduce SUM now combines ONLY:
98+
`leader_faith + leader_ppgd + Σ_ranks stoch_partial_r` = the single-pool total.
99+
100+
### imp-min → detached-global-residual (option c)
101+
102+
`_importance_minimality_loss` replaces the autograd-aware `dist_fn.all_reduce`
103+
with the detached-global-residual on `per_component_sums`. Forward identical
104+
(global sum inside `finalize_imp_min`'s `log2` and mean), backward flows only
105+
through this rank's local CI values → a true partial → SUM-composes under the CI
106+
pool's SUM all-reduce. The `×n_ci` knowledge leaves the imp path entirely.
107+
108+
## The honest verdict
109+
110+
Does the SUM convention ELIMINATE pool-size knowledge from producers?
111+
112+
- **From the data-parallel producers: YES.** stoch, ppgd V/U, ppgd CI all lose
113+
every `n_ci` / `n_per_block` *transport* factor. The #545 `×n_ci` is deleted.
114+
stoch's two destinations collapse to one scale. The remaining `n_per_block` in
115+
stoch's denom is not a transport factor — it is the `local→global` position
116+
count conversion, which any honest global normalization needs.
117+
- **From the replicated contributions: NO — but it RELOCATES the count to a
118+
structurally honest place.** faith and broadcast-PPGD no longer *scale* by
119+
`n_per_block`; instead they *contribute once* (a topology fact: "this grad is
120+
replicated, emit it on one rank"). imp-min no longer *relies on AVG to cancel*
121+
`n_ci`; instead it *states* "my backward is a local partial" via the residual
122+
trick. The replica count does not appear as a numeric factor in any producer's
123+
gradient value — it appears as a *placement* decision (which rank emits) or a
124+
*graph* decision (detach the cross-rank term).
125+
126+
Net: the convention is **not a free win** — replicated contributions still need
127+
the system to know they are replicated. But it converts an error-prone numeric
128+
coupling ("multiply by the size of a pool you can't see, to survive a reduce that
129+
happens elsewhere") into a local, inspectable structural statement ("this is a
130+
partial; emit it once" / "this is replicated; detach the global term"). That is a
131+
genuine simplification for the DP majority and a clearer, harder-to-get-wrong
132+
encoding for the replicated minority — not a lateral move.

param_decomp_lab/three_pool/portals.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -439,44 +439,39 @@ def send(self, role: PPGDRole, v_grads: dict[str, Tensor], u_grads: dict[str, Te
439439
def recv(
440440
self, role: LWRole, v_templates: dict[str, Tensor], u_templates: dict[str, Tensor]
441441
) -> tuple[dict[str, Tensor], dict[str, Tensor]]:
442-
"""Block leader recvs g_VU for owned sites from PPGD leader, then
443-
in-block broadcasts so all replicas see the same grad."""
442+
"""Block leader recvs g_VU for owned sites from PPGD leader; non-leaders
443+
get nothing.
444+
445+
Contribute-once (see ``SUM_GRAD_CONVENTION.md``): PPGD's grad is identical
446+
across block replicas, so under the block SUM-reduce it must land on
447+
exactly ONE rank. We add it to the leader's ``.grad`` only and skip the
448+
old in-block broadcast — the SUM then distributes it to every replica
449+
exactly once. Non-leaders return empty dicts and add nothing.
450+
"""
451+
if not role.is_block_leader:
452+
return {}, {}
453+
454+
my_sites = role.owned_sites
455+
packed_numel = sum(v_templates[s].numel() + u_templates[s].numel() for s in my_sites)
456+
sample = v_templates[my_sites[0]]
457+
packed = torch.empty(packed_numel, dtype=WIRE_DTYPE, device=sample.device)
458+
ppgd_leader = self.world.ppgd_ranks[0]
459+
with time_nccl_op("GradVuFromPPGD.recv:recv"):
460+
dist.recv(packed, src=ppgd_leader, group=self.world.cross_pool_p2p_group)
444461
v_grads: dict[str, Tensor] = {}
445462
u_grads: dict[str, Tensor] = {}
446-
447-
if role.is_block_leader:
448-
my_sites = role.owned_sites
449-
packed_numel = sum(v_templates[s].numel() + u_templates[s].numel() for s in my_sites)
450-
sample = v_templates[my_sites[0]]
451-
packed = torch.empty(packed_numel, dtype=WIRE_DTYPE, device=sample.device)
452-
ppgd_leader = self.world.ppgd_ranks[0]
453-
with time_nccl_op("GradVuFromPPGD.recv:recv"):
454-
dist.recv(packed, src=ppgd_leader, group=self.world.cross_pool_p2p_group)
455-
offset = 0
456-
for s in my_sites:
457-
v_n = v_templates[s].numel()
458-
u_n = u_templates[s].numel()
459-
v_grads[s] = (
460-
packed[offset : offset + v_n].view_as(v_templates[s]).to(v_templates[s].dtype)
461-
)
462-
offset += v_n
463-
u_grads[s] = (
464-
packed[offset : offset + u_n].view_as(u_templates[s]).to(u_templates[s].dtype)
465-
)
466-
offset += u_n
467-
else:
468-
for s in role.owned_sites:
469-
v_grads[s] = torch.empty_like(v_templates[s])
470-
u_grads[s] = torch.empty_like(u_templates[s])
471-
472-
block_group = self.world.block_group_groups[role.block_idx]
473-
block_leader_rank = self.world.layerwise_block_groups[role.block_idx].leader
474-
with time_nccl_op("GradVuFromPPGD.recv:in_block_bcast"):
475-
for s in role.owned_sites:
476-
v_grads[s] = v_grads[s].contiguous()
477-
u_grads[s] = u_grads[s].contiguous()
478-
dist.broadcast(v_grads[s], src=block_leader_rank, group=block_group)
479-
dist.broadcast(u_grads[s], src=block_leader_rank, group=block_group)
463+
offset = 0
464+
for s in my_sites:
465+
v_n = v_templates[s].numel()
466+
u_n = u_templates[s].numel()
467+
v_grads[s] = (
468+
packed[offset : offset + v_n].view_as(v_templates[s]).to(v_templates[s].dtype)
469+
)
470+
offset += v_n
471+
u_grads[s] = (
472+
packed[offset : offset + u_n].view_as(u_templates[s]).to(u_templates[s].dtype)
473+
)
474+
offset += u_n
480475
return v_grads, u_grads
481476

482477

@@ -616,12 +611,18 @@ def _bucketed_all_reduce(
616611

617612

618613
def all_reduce_ci_fn_grads(world: World, params: Iterable[nn.Parameter]) -> None:
619-
"""CI in-pool AVG-reduce on CI fn grads (standard DDP). No-op for 1-rank pool."""
614+
"""CI in-pool SUM-reduce on CI fn grads. No-op for 1-rank pool.
615+
616+
SUM, not AVG (see ``SUM_GRAD_CONVENTION.md``): each producer's CI grad is a
617+
partial sum already normalized by the honest global count, so the cross-rank
618+
SUM reassembles the single-pool total directly. No producer pre-scales by
619+
``n_ci`` to survive this reduce.
620+
"""
620621
if dist.get_world_size(world.ci_pool_group) <= 1:
621622
return
622623
_bucketed_all_reduce(
623624
(p.grad for p in params if p.grad is not None),
624-
dist.ReduceOp.AVG,
625+
dist.ReduceOp.SUM,
625626
world.ci_pool_group,
626627
"all_reduce_ci_fn_grads",
627628
)
@@ -635,8 +636,16 @@ def sum_reduce_ppgd_grads(world: World, grads: Iterable[Tensor]) -> None:
635636

636637

637638
def all_reduce_grads_in_block(world: World, role: LWRole, params: Iterable[nn.Parameter]) -> None:
638-
"""LW in-block DDP AVG-reduce over V/U + faithfulness grads (async buckets,
639-
wait + copy back). No-op when the block group is 1-rank or there are no grads."""
639+
"""LW in-block SUM-reduce over V/U grads (async buckets, wait + copy back).
640+
No-op when the block group is 1-rank or there are no grads.
641+
642+
SUM, not AVG (see ``SUM_GRAD_CONVENTION.md``): the per-rank stoch grad is a
643+
partial sum over a disjoint position slice, normalized by the honest global
644+
count, so the cross-rank SUM reassembles the single-pool total. The
645+
REPLICATED contributions (faith, broadcast PPGD grad) are emitted on the
646+
block leader ONLY — contribute-once — so they survive the SUM exactly once
647+
without any ``n_per_block`` pre-scaling.
648+
"""
640649
block_group = world.block_group_groups[role.block_idx]
641650
if dist.get_world_size(block_group) <= 1:
642651
return
@@ -652,7 +661,7 @@ def all_reduce_grads_in_block(world: World, role: LWRole, params: Iterable[nn.Pa
652661
with time_nccl_op("all_reduce_grads_in_block"):
653662
for bucket in buckets.values():
654663
flat = _flatten_dense_tensors(bucket)
655-
w = dist.all_reduce(flat, op=dist.ReduceOp.AVG, group=block_group, async_op=True)
664+
w = dist.all_reduce(flat, op=dist.ReduceOp.SUM, group=block_group, async_op=True)
656665
assert w is not None
657666
states.append((bucket, flat, w))
658667
for bucket, flat, w in states:

0 commit comments

Comments
 (0)