Skip to content

Commit d3cf6af

Browse files
SCAO AuthorsCopilot
andcommitted
feat: add CI/CD DDP step + multi-GPU support with sync_preconditioner()
CI/CD (.github/workflows/ci.yml): - Add DDP test step (gloo/CPU, 2 processes) after main pytest run Multi-GPU / DDP (scao/optimizer.py, scao/preconditioner.py): - Add _broadcast_precond() function: broadcasts all preconditioner state (eigenfactors U_l/S_l/U_r/S_r, EMA accumulators, int8 scale factors, step counter, adaptive rank k) from rank 0 to all other ranks via torch.distributed.broadcast; handles Kronecker / block-diagonal / diagonal modes and rank mismatch after checkpoint loading - Add SCAO.sync_preconditioner(process_group=None): broadcasts exp_avg, exp_avg_sq, step, and preconditioner tensors for every parameter; emits RuntimeWarning if dist is not initialised; no-op during single-GPU training - Add DDP section to module docstring: recommended async_precond=False, checkpoint-resume pattern, torchrun usage DDP tests (scao/tests/test_ddp.py): - test_ddp_converges: 2-process gloo/CPU, quadratic loss, verify loss decreases over 30 steps on both ranks - test_sync_preconditioner: inject zeroed U_l on rank 1, call sync, verify both ranks have identical U_l norm via dist.all_gather Multi-GPU benchmark (scripts/bench_ddp.py): - torchrun-compatible script: NCCL (GPU) or gloo (CPU) backend auto-selected - AdamW vs SCAO vs SCAO+int8 comparison with DDP-wrapped GPT model - Per-GPU batch size, world_size-scaled throughput reporting - Saves results_ddp_<scale>.csv and _curves.csv (rank 0 only) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 81b60b3 commit d3cf6af

5 files changed

Lines changed: 710 additions & 3 deletions

File tree

.github/workflows/ci.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ jobs:
3737
run: mypy scao/ --ignore-missing-imports --exclude 'scao/(benchmarks|tests|cuda)' || true
3838

3939
- name: Run tests
40-
run: pytest scao/tests/ -v --tb=short --cov=scao --cov-report=xml --ignore=scao/tests/test_profiling.py
40+
run: pytest scao/tests/ -v --tb=short --cov=scao --cov-report=xml --ignore=scao/tests/test_profiling.py --ignore=scao/tests/test_ddp.py
41+
42+
- name: DDP tests (gloo / CPU, 2 processes)
43+
run: pytest scao/tests/test_ddp.py -v --tb=short
4144

4245
- name: Upload coverage
4346
uses: codecov/codecov-action@v4

scao/optimizer.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,26 @@
3636
(``update_curvature()``) is decorated with ``@torch.compiler.disable`` because
3737
it runs infrequently and contains non-traceable Python control flow.
3838
39+
DistributedDataParallel (DDP)
40+
-----------------------------
41+
SCAO works out-of-the-box with ``torch.nn.parallel.DistributedDataParallel``.
42+
DDP all-reduces gradients automatically before ``optimizer.step()`` is called,
43+
so all ranks see identical gradients and optimizer state stays synchronised
44+
without any extra steps during normal training.
45+
46+
Recommended DDP configuration::
47+
48+
model = torch.nn.parallel.DistributedDataParallel(model)
49+
optimizer = SCAO(model.parameters(), lr=1e-3, async_precond=False)
50+
51+
* Set ``async_precond=False`` to avoid CUDA stream conflicts with NCCL
52+
all-reduce operations on the same device.
53+
54+
* After loading a checkpoint on rank 0, broadcast state to all ranks::
55+
56+
optimizer.load_state_dict(torch.load("ckpt.pt", map_location="cpu"))
57+
optimizer.sync_preconditioner() # broadcast from rank 0 → all ranks
58+
3959
Usage
4060
-----
4161
from scao import SCAO
@@ -65,7 +85,7 @@
6585
from torch import Tensor
6686
from torch.optim import Optimizer
6787

68-
from .preconditioner import SparsePreconditioner
88+
from .preconditioner import SparsePreconditioner, _broadcast_precond
6989

7090

7191
class SCAO(Optimizer):
@@ -443,9 +463,64 @@ def synchronize_precond(self) -> None:
443463
self._precond_stream.synchronize()
444464

445465
# ------------------------------------------------------------------
446-
# Callback registration
466+
# Distributed: sync preconditioner state across ranks
447467
# ------------------------------------------------------------------
448468

469+
def sync_preconditioner(
470+
self,
471+
process_group: "torch.distributed.ProcessGroup | None" = None,
472+
) -> None:
473+
"""
474+
Broadcast all optimizer state from rank 0 to every other rank.
475+
476+
Call this after loading a checkpoint on rank 0 before resuming
477+
distributed training, or any time you suspect optimizer state may
478+
have diverged across ranks (e.g. after a rank restart).
479+
480+
During normal DDP training you do **not** need to call this — DDP
481+
all-reduces gradients before ``step()`` so all ranks receive identical
482+
updates and state stays in sync automatically.
483+
484+
Args:
485+
process_group: the process group to use for collective operations.
486+
Defaults to the global default group.
487+
488+
Example::
489+
490+
# After loading checkpoint on rank 0:
491+
if dist.get_rank() == 0:
492+
optimizer.load_state_dict(torch.load("ckpt.pt"))
493+
optimizer.sync_preconditioner()
494+
"""
495+
import torch.distributed as dist
496+
497+
if not dist.is_available() or not dist.is_initialized():
498+
warnings.warn(
499+
"sync_preconditioner() called but torch.distributed is not initialised. "
500+
"Call torch.distributed.init_process_group() first.",
501+
RuntimeWarning,
502+
stacklevel=2,
503+
)
504+
return
505+
506+
for state in self.state.values():
507+
# Sync first- and second-moment tensors.
508+
for key in ("exp_avg", "exp_avg_sq"):
509+
if key in state:
510+
dist.broadcast(state[key], src=0, group=process_group)
511+
512+
# Sync per-step counter.
513+
if "step" in state:
514+
step_t = torch.tensor([state["step"]], dtype=torch.int64)
515+
dist.broadcast(step_t, src=0, group=process_group)
516+
state["step"] = int(step_t.item())
517+
518+
# Sync preconditioner tensors (eigenfactors, EMA accumulators).
519+
precond: SparsePreconditioner | None = state.get("preconditioner")
520+
if precond is not None:
521+
_broadcast_precond(precond, process_group)
522+
523+
449524
def add_callback(self, callback) -> None:
450525
"""
451526
Register a monitoring callback.

scao/preconditioner.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,3 +552,75 @@ def load_state_dict(self, state: dict) -> None:
552552
self.S_r = state["S_r"].to(device=self.device, dtype=_PRECOND_DTYPE)
553553
else:
554554
self.diag_ema.copy_(state["diag_ema"])
555+
556+
557+
def _broadcast_precond(
558+
precond: "SparsePreconditioner",
559+
process_group: "torch.distributed.ProcessGroup | None" = None,
560+
) -> None:
561+
"""
562+
Broadcast all preconditioner state tensors from rank 0 to all ranks.
563+
564+
Handles all three preconditioner modes (Kronecker, block-diagonal, diagonal)
565+
and both EMA storage formats (float32 and int8). Also syncs the step counter
566+
and adaptive rank ``k`` so that subsequent updates remain numerically identical
567+
across all ranks.
568+
569+
Args:
570+
precond: the SparsePreconditioner instance to synchronise.
571+
process_group: optional process group (default: the global default group).
572+
573+
Notes:
574+
This function is called by ``SCAO.sync_preconditioner()``. It is not
575+
intended to be called directly unless you manage the distributed state
576+
yourself.
577+
"""
578+
import torch.distributed as dist
579+
580+
# Sync step counter from rank 0.
581+
step_t = torch.tensor([precond.precond_step], dtype=torch.int64, device=precond.device)
582+
dist.broadcast(step_t, src=0, group=process_group)
583+
precond.precond_step = int(step_t.item())
584+
585+
if precond.use_block_diagonal:
586+
for blk in precond._blocks:
587+
_broadcast_precond(blk, process_group)
588+
return
589+
590+
if precond.use_kronecker:
591+
# Sync the adaptive rank k; non-rank-0 processes must resize tensors if
592+
# the checkpoint was saved at a different rank than their current state.
593+
k_t = torch.tensor([precond.k], dtype=torch.int64, device=precond.device)
594+
dist.broadcast(k_t, src=0, group=process_group)
595+
k_new = int(k_t.item())
596+
597+
if k_new != precond.k:
598+
precond.k = k_new
599+
precond.U_l = torch.empty(precond.m, k_new, dtype=_PRECOND_DTYPE, device=precond.device)
600+
precond.S_l = torch.empty(k_new, dtype=_PRECOND_DTYPE, device=precond.device)
601+
precond.U_r = torch.empty(precond.n, k_new, dtype=_PRECOND_DTYPE, device=precond.device)
602+
precond.S_r = torch.empty(k_new, dtype=_PRECOND_DTYPE, device=precond.device)
603+
604+
# Broadcast EMA accumulators.
605+
if precond.use_int8_ema:
606+
dist.broadcast(precond.L_ema_q, src=0, group=process_group)
607+
dist.broadcast(precond.R_ema_q, src=0, group=process_group)
608+
# Scale factors are Python floats; wrap as tensors for broadcast.
609+
for attr in ("L_ema_scale", "R_ema_scale"):
610+
t = torch.tensor([getattr(precond, attr)], device=precond.device)
611+
dist.broadcast(t, src=0, group=process_group)
612+
setattr(precond, attr, float(t.item()))
613+
else:
614+
dist.broadcast(precond.L_ema, src=0, group=process_group)
615+
dist.broadcast(precond.R_ema, src=0, group=process_group)
616+
617+
# Broadcast eigenfactors (in-place: tensors already have the right shape).
618+
dist.broadcast(precond.U_l, src=0, group=process_group)
619+
dist.broadcast(precond.S_l, src=0, group=process_group)
620+
dist.broadcast(precond.U_r, src=0, group=process_group)
621+
dist.broadcast(precond.S_r, src=0, group=process_group)
622+
623+
else:
624+
# Diagonal fallback
625+
dist.broadcast(precond.diag_ema, src=0, group=process_group)
626+

0 commit comments

Comments
 (0)