Skip to content

Migrate VPD to JAX: train + analyze in one framework; retire torch to oracle#560

Open
ocg-goodfire wants to merge 557 commits into
mainfrom
feature/jax
Open

Migrate VPD to JAX: train + analyze in one framework; retire torch to oracle#560
ocg-goodfire wants to merge 557 commits into
mainfrom
feature/jax

Conversation

@ocg-goodfire

@ocg-goodfire ocg-goodfire commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Top-line

Migrates VPD fully to JAX — train and analyze in one framework — and retires the torch stack to a git-tagged oracle. Squash-merges to main once, at the end. Net: ~430 commits, +24k / −50k LOC (mostly deletion).

Key decisions

  1. One framework: JAX. Training and all analysis run in JAX. Torch is the battle-tested oracle, preserved at tags torch-oracle / torch-oracle-npool; JAX conforms to it (SPEC.md is the normative contract, numeric seams default to matching torch, goldens prove it).
  2. Consumers read JAX runs natively (orbax + DecomposedModel via open_jax_run) — harvest, clustering, autointerp, slow/offline eval, app. The JAX→torch export bridge is dead.
  3. App is a read-only viewer. Attribution graphs, circuit-opt / editing, PGD intervention are dropped (recoverable from git). App backend imports zero torch.
  4. Generic model interface. DecomposedModel — ordered sites + pure fns clean_output / site_inputs / masked_output / weight_deltas / masked_site_outputs — generic over input/output/recon-loss with [B,T,d] as the fixed waist.
  5. Recon unified as plan × mask-source strategy (make_plan + chunking helpers); loss is KL on final logits. Hidden-acts recon is a separate eval diagnostic over the masked_site_outputs seam (amends SPEC S31), not a recon-grid variant.
  6. Strict bar: no # type: ignore / Any / cast without sign-off; make check-jax gated in pre-commit; fail-fast, types-first.

Status — functionally complete; green

JAX trainer; all consumer ports; read-only app; dropped-feature deletion; DecomposedLM → DecomposedModel rename; hidden-acts eval port; llama8b loader; type-debt → 0 + pre-commit gate; a code-review + fix pass (#854/#855/#856) and a first dead-trainer deletion (#857). Suites green: ~415 lab + ~166 jax at 1 and 4 devices; make type / check-jax clean; torch↔JAX per-term equivalence + stacked-parity trajectory goldens pass bit-unmodified; validated end-to-end on SimpleMLP pile run p-761bc061.

Remaining before main-merge (each gated — see commit history / memory)

  • Retire torch offline-eval (offline_eval.py / pd-offline-eval / jsp-export; rewire _submit_offline_evaljsp-slow-eval). Gated on parity-validating JAX slow-eval vs torch on a real llama8b run — no current-format llama8b run is loadable, so this needs a fresh run.
  • Bulk param_decomp/ core deletion (metrics/ tree, train_step, …). Gated on the above (the live offline-eval path still imports them). Bridge/capstone surface stays.
  • Harvest accumulator → numpy (last torch in harvest). Gated on the async_eval (in-loop autointerp) decision.
  • Capstone: torch→jax run adapter so old torch runs load — the deliberate final step.
  • Scope call: TMS / ResidMLP / vendored / pretrain still torch (separate domains; likely not this PR).
  • Surfaced: PGD scope-vocabulary convergence (deferred — stored-run compat).

Reading guide

param_decomp_jax/jax_single_pool/SPEC.mdlm.pyrecon.pytrain.py. TRANSITION.md = the settled plan; LOSS_PARITY_DESIGN.md = recon unification. VPD paper rides separately as #562.

🤖 Generated with Claude Code

ocg-goodfire and others added 30 commits June 10, 2026 18:44
Repo convention keeps the WANDB key in .env; the offline runner never goes
through init_pd_run / get_wandb_entity (the reference yaml pins the entity),
so nothing else loads it.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The streaming fineweb loader leaves non-daemon parquet/arrow reader threads
that block interpreter shutdown indefinitely (repro: iterate the eval loader
on CPU and return) — both validation jobs sat RUNNING on a B200 for 25+ min
after printing results. os._exit(0) once stdout + wandb are flushed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…rt -> pd-offline-eval -> cleanup) per committed checkpoint
…e per-consumer all-to-alls

The CI head's out_w is ΣC-sharded so its output is born C-sharded; GSPMD then
resharded it separately for every consumer (each chunk forward + PPGD + imp-min,
fwd and bwd) — at 36 sites those ~1.2 GB all-to-all buffers dominated jit_step's
temp arena (109.46 GiB OOM at 32 GPU, job 50542; XLA dump memprobe_mc_50581).
batch_sharded_ci reshards lower+upper once at the producer. Pure sharding
constraint — values exact, trajectories unchanged. AOT probe (L20-31 C=8192 bl1
remat-on, 8 GPU): temp 67.0 -> 49.7 GiB.
…memory table + recommendation

Harvested from the stranded remat agent's completed jobs (50467/50468 A/B,
50475-50478 multi-chunk probes) plus this session's post-pin re-probe.
…ig route

Converter relaxations, each asserted or printed: raw-HF Llama target spec
(transformers.LlamaForCausalLM == vendored weights, export-bridge verified),
'model.'-prefixed site patterns, fp32 weights_dtype accepted with a loud
bf16-frozen divergence note; wandb entity honestly nullable (upstream
'entity: null' = API-key default).

configs/torch/llama8b_l18_C30k_200k_1pool.yaml = torch run p-19645bf7's snapshot
with four header-documented edits: pretokenized data @ seq 2048 (upstream streamed
raw text @ 512), ci attn max_len 2048, PPGD scope broadcast_across_batch (upstream
per_batch_per_position — not implemented here, deliberate), save_every 5000.
…uncher (C30k 16-GPU arena exceeded the 75% default pool)
…nlocks 32-GPU mesh vs gcd-capped 16)

AOT probe: bl4 remat-off arena 70.3 GiB — fits the 32-GPU launch comfortably.
…s deliberately uncompensated)

AOT probes (8 GPU): bl8 remat-off 110.5 GiB (the launch shape, 64 GPU), bl4 off
70.3, bl12 off 149.3 (32-GPU B=384 ruled out — no margin for a multi-day run).
load_run_dir_config rebuilds the config from the pinned copies (wrapper as
config.yaml + torch yaml as torch_config.yaml), so torch-wrapper runs export and
offline-eval like native ones. The wrapper's launch-relative path field is ignored
in favor of the pinned torch yaml.
…4x partial bump for the 4x batch; header edit 7)
…dispatch on runtime.topology)

The C49k JAX run's reference yaml is single-pool; everything this module touches
is shared between the schemas (and PDConfig is run_eval_pass's native type).
…e launcher

CUDA_ERROR_STREAM_CAPTURE_INVALIDATED killed 4 jobs tonight across disjoint
allocations (50453/50525/50676/50743) — systematic in the capture path, not bad
nodes. Torch never exercises CUDA graphs, which is why the torch trainer needed no
such resilience. ~1-3% launch-overhead cost.
…ot just the batch shell

--signal=B:TERM@300 only signals the batch script; the trainer's SIGTERM-save
handler never fired and preemption (job 50818) hard-killed the ranks after the
grace window — losing everything since the last save_every checkpoint.
jsp-export writes vendored-layout keys regardless of the run yaml's target spec,
so eval must build VendoredLlama (same weights) with model.-prefix patterns
stripped; frozen forced bf16 = what the JAX run actually trains with (its
documented fp32-yaml divergence), and the more faithful eval reference. Unblocks
the C49k offline-eval loop (ckpt-5000 attempt died on the raw-HF dispatch).
…val args (C49k needs --micro-batch-size 4: mb8 OOMs at 157.6+7.8 GiB on a 178 GiB B200)
Per Oli: resilient-launch loops are an anti-pattern this stack shouldn't need. The
one recurring crash class is fixed at the source (command buffers off, ~0% cost);
the loop demonstrably contributed nothing across last night's incidents — it
burned all 5 attempts on each deterministic OOM, could not escape a suspect node
(retries reuse the allocation), and every actual recovery was a diagnosed
resubmission. Plain --requeue + the now-working SIGTERM save = the torch model.
…one eval job per checkpoint save

Replaces the polling sidecar chain (which self-multiplied and hid failures): rank 0
fires sbatch offline_eval_once.sbatch right after each save, fire-and-forget (a
failed submission must not kill a multi-day run — the one place graceful handling
beats fail-fast; it prints loudly instead). Per-run serialization via
--job-name=jsp-oeval-<run> + --dependency=singleton; the run-dir marker dedups
across requeues; optional <run_dir>/offline_eval_args carries run-specific flags
(C49k: --micro-batch-size 4).
…yout, torch postprocess contract

Wrapper gains run_id (p-<8hex>, the torch generate_run_id convention; authored
into the wrapper so resume derives the same identity and the byte-compare pins
it): run dir = out_dir/<run_id>, wandb id = run_id with run_name as the display
name. The pinned torch yaml is additionally written as experiment_config.yaml
(SavedLMRun's contract name), and the push-triggered eval job materializes
model_<step>.pth (safetensors -> torch.save, pruned to newest 2) — together
making a JAX run dir a first-class citizen of harvest/app/postprocess with zero
special-casing. wandb config gains a jax_runtime section (actual device count
etc.) since the upstream yaml's runtime.dp describes the torch run, not this one
(it read 'dp: 32' while running on 64 GPUs). run_id is optional ONLY for the
pre-scheme live C49k wrapper — collapse after it migrates.
…rity fixtures

gen_stacked_fixtures.py runs against the PRE-restructure feature/jax-single-pool-pd
code (stacked DecompVU, LayerRange) and pins, on the tiny L3-5 MLP target: target
weights, per-site V/U, CI-fn/source init leaves, clean/masked/site-input/weight-delta
forwards (S2/S3/S4/N2), and a 2-step training trajectory (metrics + final V/U +
adversary sources, S13-S15). The generated npz was produced from the live base-branch
checkout's venv; the script is excluded from this distribution's basedpyright (it
deliberately imports the old API and only runs on the base branch).

The follow-up site-generality restructure must reproduce these to SPEC D4
reassociation tolerance (clean logits bit-identical).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…e/up/down, per-site C

Generalizes the Llama target from contiguous-layer MLP-only to any per-layer matrix
site, each with its own C. The trainer (train.py) was already site-generic; this
rebuilds the target + plumbing around per-site state:

* llama8b.py — DecompVU becomes {site: (V (d_in,C_s), U (C_s,d_out))} fp32 masters
  (same init scale rule: V~N(0,d_in^-0.5), U~N(0,C^-0.5)); Target is a uniform
  SuffixLayer list from first_decomposed_layer(sites) to the LM head — layers with no
  sites run the plain frozen block, and any site absent from the spec computes the
  exact frozen x@W path (S2: masked_logits(live=()) stays bit-identical to
  clean_logits, S3). q/k/v sites are decomposed BEFORE RoPE/SDPA via FrozenAttn.core
  (head-count reshapes asserted); o applies to the attention output. Clean CI inputs
  (S4): q/k/v <- post-LN1 residual, o <- pre-o_proj attn output, gate/up <- post-LN2
  residual, down <- silu(gate)*up. weight_deltas stays fp32 W - V@U per site (N2).
  Canonical site order = layer-ascending x KIND_ORDER (computation order); site names
  keep the torch-module-path convention layers.{i}.{self_attn,mlp}.{kind}_proj.

* llama8b_sharding.py — per-site placement: V shards C on axis 1 (P(None,'dp')),
  U on axis 0 (P('dp',None)); per-site C % mesh-size asserted; the
  jit-with-out_shardings init pattern and batch_sharded_ci are untouched.

* config.py / torch_config.py / run.py / run_state.py / export.py — TargetConfig
  carries canonical (site, C) pairs; the native {first_layer,last_layer,C} YAML maps
  to the MLP family (unchanged behavior); the torch route accepts any
  (model.)layers.N.{self_attn,mlp}.{kind}_proj target with per-site C (contiguity /
  uniform-C / all-three-kinds asserts dropped; identity-target and other-pattern
  refusals kept). Export reads V/U per site and renames any decomposed matrix's
  frozen weight to .target_weight — attention sites included (the torch vendored
  ComponentLlamaAttention componentizes q/k/v/o, so attention-site export is fully
  supported; round-trip-verified against the real LMComponentModel via the new
  l18_attn fixture case, heterogeneous C, key parity exact).

Equivalence with the stacked implementation (tests/stacked_parity, fixtures pinned
from feature/jax-single-pool-pd): clean/masked/site-input/weight-delta forwards
BIT-identical; 2-step train trajectory rel <= ~2.5e-6 (D4; sole divergence is
clip-global-norm leaf-order reassociation). Cross-framework loss-term equivalence
(tests/equivalence) unchanged and green.

invariance_check gains an abs tolerance floor (1e-7) alongside rel 1e-4: the per-leaf
grad-norm diagnostics include ~1e-4-magnitude norms whose cross-shard reduction
cancellation can graze rel 1e-4 (observed: one ci-fn wv leaf, abs err ~1e-8, step-0
only, non-growing) while every loss term sits at ~1e-6. Verified at 4 and 8 sim
devices (worst loss-term rel 2.5e-6 / 7.1e-6).

Checkpoints are NOT cross-compatible with stacked-era runs (components pytree layout
+ V/U init RNG derivation changed).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…EC §3 site prose

SPEC.md needed no semantic change: §1/§3 already define a site as any selected weight
matrix. Amended the one §3 orientation sentence that described the *implementation* as
contiguous-MLP-only (now: any per-layer q/k/v/o/gate/up/down matrix, per-site C; q/k/v
decomposed before RoPE/SDPA) and added the attention-site examples to the §4.1
site_inputs comment (S4). No invariant changed.

AUDIT.md records: stacked->per-site parity evidence, checkpoint non-compatibility
across the restructure, and the verify_export_torch production-numerics pass becoming
measure-only (the documented GELU/eps divergence is amplified on the tiny attention
fixture; the asserted jax-matched pass is the mapping proof).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…orch PGDReconLoss parity)

AdversaryConfig = SourceAdamConfig | FreshPGDConfig selects the recon adversary
variant in make_train_step: PPGD keeps its persistent sources + Adam moments +
fused (n_warmup+1)-th ascent (SPEC S13/S14, unchanged); FreshPGD samples per-site
sources every step (unique_per_datapoint (B,T,C+1) or shared_across_batch
(1,1,C+1)), runs n_steps of step_size*sign(grad) with clamp [0,1], and carries NO
cross-step state (TrainState.sources stays empty; the post-backward source
cotangent is unused). Metrics key 'pgd' -> wandb train/loss/PGDReconLoss.
ExperimentConfig.ppgd renamed to .adversary (native yaml key 'ppgd:' unchanged);
the torch-config route accepts PGDReconLoss as the adversary slot. Matches Dan's
p-a005ed60 config (n_steps=1, step_size=1, init random, unique_per_datapoint).
…vs replicated vs bridged) + remaining gaps in dependency order
…LM implementation

The t-9d2b8f02 pile model (4L, d768, 6 heads, GELU MLP, plain rotate-half
RoPE, tied head) as a decomposition target. llama_simple_mlp.py mirrors
llama8b.py's seams: per-site maybe-masked matmuls for all six site kinds
(h.{i}.attn.{q,k,v,o}_proj before-RoPE/after, h.{i}.mlp.c_fc before the
GELU, down_proj after), frozen residual-start suffix + embedding/blocks
prefix, fp32 weight deltas (N2), live=() bit-equal to clean (S2/S3).

Registration: TargetConfig | LlamaSimpleMLPTargetConfig union in config.py
(native target block dispatches on pretrain_run_path); torch_config accepts
kind: pretrained LlamaSimpleMLP specs with h.* wildcard decomposition
patterns expanded over the checkpoint's n_layer; run.py builds the target
by kind and threads prefix_residual_fn; export/offline-eval guarded
llama8b-only.

Weights load from the torch pretrain cache via a one-off .pt -> safetensors
conversion (tools/convert_llama_simple_mlp_checkpoint.py, torch venv; tied
lm_head stored once as wte.weight).

Torch-fixture equivalence (tests/simple_mlp_equivalence/, fp32): tiny
random model (GQA repeat=2) max abs logits diff ~2e-7; real t-9d2b8f02
weights ~5e-5 — pins RoPE construction, tanh-GELU, GQA grouping, rms eps.
Tiny-config site tests mirror test_llama8b.py (masked identity, live=()
bitwise, mixed attn+mlp train step with heterogeneous C: S13/S15/S9/N1).
Suite green at 1 and 8 sim devices; basedpyright + ruff clean.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
EvalConfig.l0_groups carries the fnmatch site patterns; make_eval_step resolves
members at build time (unmatched group refuses) and emits the torch group keys
(l0/<thr>_<group> = SUM of member-site L0s). Unblocks Dan's pile config whose
eval block groups per-layer + total.
Scope names now spell the stored source shape in tensor order: single_source->c,
broadcast_across_batch->sc, repeat_across_batch->nsc, per_batch_per_position->bsc;
PGD MaskScope -> {c, bc, bsc}. Classes renamed to match (CScope/SCScope/NSCScope/
BSCScope). Code + tests + reference yamls swept together; 514 tests green.
Stored experiment_config.yaml files predate the shape-spelled scope names;
BeforeValidator aliases exactly the literals that occur in stored data
(broadcast_across_batch->sc, per_batch_per_position->bsc; mask scopes
shared_across_batch->c, unique_per_datapoint->bsc). Delete once stored runs
are migrated.
# Conflicts:
#	param_decomp/metrics/base.py
#	param_decomp/metrics/importance_minimality.py
#	param_decomp/optimize.py
#	param_decomp/tests/metrics/test_importance_minimality_loss.py
#	param_decomp_lab/experiments/lm/run.py
#	param_decomp_lab/run_sink.py
#	param_decomp_lab/tests/test_resumption.py
ocg-goodfire and others added 30 commits June 25, 2026 13:32
…emory lever for batch)

Add an enablable `remat_ci_fn` runtime flag (mirrors `remat_recon_forwards`) that wraps
the CI-fn forward in `eqx.filter_checkpoint`, recomputing it in the backward instead of
storing its activations. The CI-fn forward (4-block transformer × n_chunks) is the one
hot-path component with NO checkpointing today, and its activations scale with batch —
so this is the main activation-memory lever for training at larger batch on big targets.

Pure memory/compute trade, no algorithm effect: verified `remat_ci_fn` True vs False give
identical loss (bit-level: max Δloss 1.9e-6, max param-update Δ 1.9e-6 — float-reassoc
noise from the recompute) on the tiny-llama step.

Threaded: configs.py (RuntimeConfig field, default False) → run.py engine param + wandb
record → train.py make_train_step (`apply_ci_fn` = `eqx.filter_checkpoint(_apply_ci_fn)`
when set) → the three composition roots (`built.runtime.remat_ci_fn`). Test/experiment
callers pass `remat_ci_fn=False` explicitly (the arg is required, like remat_recon_forwards).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…L-port

Integrates the two upstream commits:
- 23881ea sources/sources_opt_state -> single `adversaries: dict[str,PersistentAdversary]`
- 8dfb861 smooth-L0 (Geman–McClure) imp-min penalty (annealed_imp_min_param / imp_min_terms)

Resolution: the residual-first-class removal STAYS (this branch's core change). train.py
collided because both refactors rewrote the recon/adversary core and upstream still used
`residual` — resolved by taking upstream's train.py (adversaries + smooth-L0 structure)
and re-applying this branch's two changes on top: (1) residual->batch / embed-internal
(`leading` derived from a tap, not the opaque token batch), (2) `remat_ci_fn` (the CI-fn
activation-checkpoint lever). The 4 residual-start experiment/tool files this branch
deleted (llama8b_real, mem_probe, migrate_c49k_checkpoint, verify_c49k_migration) stay
deleted — they targeted the old residual API and are unreferenced.

Validated: make type 0/0/0; core suite 259 passed / 3 skipped / 11 xfailed (the xfails are
the pre-existing residual-fed goldens), incl. stacked-parity, equivalence goldens, the
adversary/PPGD tests, and upstream's new smooth-L0 test.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…h collectives (#889)

The cluster sets NCCL_DEBUG=INFO / NCCL_DEBUG_SUBSYS=ALL in the node environment, which
logs every NCCL collective — the slurm logs hit tens of GB per run (~100% NCCL lines).
slurm.py already intends NCCL_DEBUG=WARN, but pd-lm's _RANK_ENV never exported it, so the
inherited INFO/ALL won. Export WARN in the rank env to override it.

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Ports #543 to the JAX trainer. The rolled imp-min `lp + beta·entropy` becomes two
independently-coefficiented terms computed from one shared `(c+eps)^p` pass:

  imp  = Σ_c f_c                       (imp.coeff)
  freq = Σ_c f_c · log2(1 + a'·f_c)     (freq.coeff), a' = reference_token_count

The frequency normalizer is now EXPLICIT (`reference_token_count`) instead of the
implicit global `B·T`, so the penalty's curvature is invariant to batch size at a
fixed firing rate — batch and frequency-penalty strength become independently
tunable. Setting `a' = B·T` reproduces the old behavior exactly; coefficients
transfer as `freq.coeff = old imp.coeff · beta`.

Tied form: nested `frequency: (coeff, reference_token_count) | None` on BOTH imp
configs (`ImportanceMinimalityLoss` + `SmoothL0`), reusing the shared per-component
sums in one pass — not a separate flat `LossTerm` (which would refetch the sums and
break the self-contained-term model). `beta` is removed.

SPEC S7/S8 amended + new S8'. Migrations: all in-repo YAMLs + test/experiment
configs (`freq.coeff = imp.coeff·beta`, `reference_token_count = global B·T`;
`beta:0` → no freq block). Equivalence goldens preserved WITHOUT regen
(`a' = n_positions` reproduces the old `entropy` bit-for-bit).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01JDynbo3BeS7AEwu8iyje6F
… render

test_in_loop_slow_tier_fires_on_cadence_without_stalling timed a window that,
because submit() serializes renders via join() to cap one in-flight, was
dominated by ~3 back-to-back matplotlib renders (4s each) rather than the
dispatch cost it claims to bound. The test drives slow evals with zero
train-step gap, so the one-in-flight join blocks instead of being the no-op it
is in the real loop (slow_every=3000 train steps separate two slow evals, so a
render always finishes off-thread first). Under CI's -n4 oversubscription each
render stretched ~3.4x and the 3-render window blew past the 30s budget (40.5s);
in isolation and on a faster box it passes.

Measure only what the loop pays — the collective accumulate (~0.3s incl. compile)
plus the submit dispatch — by joining the renderer BETWEEN submits, outside the
timed window, mirroring the real train-step gap. Budget unchanged at 30s; the
measured quantity is now dispatch (sub-second), immune to render contention.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01CixstgLX5XaRm8CwAEReXk
… why

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01CixstgLX5XaRm8CwAEReXk
Checkpoint the masked-forward scan BODY (scan(checkpoint(block))) so the backward
recomputes one layer at a time and stores only the residual carry, instead of stacking
all 32 layers' activations ([32,*,14336]) — the dominant step-memory term. Threaded as a
keyword-only `remat` arg through the DecomposedModel.masked_output protocol (scan models
remat per-layer; toys whole-forward); removes the wrong-granularity whole-forward
jax.checkpoint from train.py. Numerically transparent (faith 3.746e-4 vs 3.754e-4
whole-forward baseline; HLO-verified one-layer recompute; 87 tests pass). Reverts the
temporary seq-truncation data.py hack.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…nsient

ChunkwiseTransformerCIFn ran its per-chunk transformers under eqx.filter_vmap, which
unrolls + hoists every chunk's FSDP weight all-gather into the flat entry computation —
all n_chunks gathers (each ~ΣC/tp) live at once. At full scale that transient is the
dominant tp-dependent term (~15 GB/dev @TP8, ~62 @Tp2) and the reason low-tp configs OOM
at ANY batch (empirically tp2 OOMs ~89 GiB independent of batch). Replace with lax.scan
over the (homogeneous, already-asserted) n_chunks axis so the chunk iteration lowers as a
loop: one chunk's gather live at a time, then freed. ENTRY gather 4.88->0 MiB in the HLO.
Same math up to fp32 reassociation (different matmul layout; ~1e-6, fine at bf16); 217
tests pass. The unlock for low-tp.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Without it the ~31B CI-fn forward materialises into the step module -> ~80-min
compile + near-OOM (jobs 130423/130424). The schema defaults it false, so a
config that omits it silently gets the heavy path.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Persistent-PGD source + Adam m/v storage dtype is now a config knob
(PersistentPGDReconLossConfig.source_dtype), default float32 (SPEC N1, oracle parity
preserved as a no-op cast). bfloat16 opt-in halves the source + m + v footprint (the
~41 GiB f32 PPGD term -> ~21 GiB saved at full-32L scale). Threads dtype through
init_sources_sharded -> init_persistent_sources; Adam state inherits via zeros_like;
update casts grad in / update out around the [0,1] projection. NOT SPEC N1 on the bf16
path; bf16 second-moment v underflow is an untested stability risk -> exploratory branch.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
# Conflicts:
#	param_decomp_lab/experiments/lm/launch.py
…; trim run banner

The remat_ci_fn config knob was plumbed end-to-end (config -> engine -> train) but
the make_train_step call site hardcoded remat_ci_fn=False, so the flag was inert and
the run banner logged a value the step ignored — the ~31B CI fn was never checkpointed
regardless of config (the ~80-min compile + near-OOM we chased). Thread the real value.

Also trim the run-start banner: log per-kind site counts (sites=224 [q_proj×32, ...])
instead of dumping all 224 site names inline.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…nch docstring

- configs/llama8b_full32L: remove no-op fields stripped by back-compat validators
  (n_mask_samples, sampling, autocast_bf16, device) — they only misled readers.
- data.py: restore strict seq-width assert 'in (seq_len, seq_len+1)' (drop the >= TEMP
  HACK added for the seq-64 gbsweep). The +1 is the real label-token convention
  (fineweb artifact is 512-wide, pile is 513=512+label; both truncate to seq_len).
- launch.py docstring: 'one srun task per node claiming all 8 GPUs' (was '8 tasks/node').

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Each train log line leads with wall-clock elapsed<eta (e.g. 2:31<6:50) from
the recent step rate, and drops the per-param grad norms that dwarfed the
metrics — the full breakdown still rides to wandb + jsonl. No progress bar
(no in-place redraw in a SLURM .out anyway), so no tqdm dependency.


Claude-Session: https://claude.ai/code/session_01XBgqY1rYzgQbbLdgfQZsbD

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ot fp32 (#905)

* fix(slow-eval): run the CI-fn readout in training precision (bf16), not fp32

The slow/plot eval tier deliberately read the CI fn out in fp32, but the CI
transformer's attention routes to cuDNN flash, which rejects fp32 — so the slow
tier crashed at the first slow eval on GPU:
    NotImplementedError: Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got float32
(hit by every run, including bf16-only ones — unrelated to fp8 work).

Rather than route fp32 to the XLA attention impl (#885), run the CI-fn readout in
training precision (bf16) — matching train.py / eval.py and the hidden-acts +
attn-pattern slow tiers, which already do. This is the more faithful readout (the
deployed model runs bf16) and keeps cuDNN flash (faster, no (B,H,T,T) materialize).
Reductions/returns stay fp32 so the host accumulation is byte-unchanged.

Also incorporates #885's precision-independent multihost fix: the per-position
histogram sample keeps the dp-sharded batch axis, so `np.asarray` on >1 process
spans non-addressable devices — gather it with process_allgather(tiled=True).

Supersedes #885's ci_fn.py change (no fp32 attention path is needed once the slow
tier is bf16); incorporates its slow_eval.py gather fix.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL

* style(slow-eval): trim the bf16-readout and gather comments to the load-bearing why

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…64 OOM) (#898)

* fix(load_run): replicate HF prefix without cross-host allgather (dp>=64 OOM)

jax.device_put(host_array, P()) runs multihost_utils.assert_equal ->
process_allgather(tiled=True), tiling the ~1GB embedding to ~process_count GB
(OOM at dp=64). Build the replicated global array from each host's local copy
via make_array_from_callback instead.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012PUbea772bseCPww4uv3JE

* docs(sharding): trim place_via_shardings docstring

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…throughput (#903)

Squash of the full32L HSDP perf line. Buffer donation (donate=all-except-first) + V/U reconstruction hoist take full Llama-8B VPD to b128/4-seq-per-GPU on 32 GPUs at ~3.6x the b32 throughput, numerics bit-identical. Includes env-gated memory/timing instrumentation, the b128 production config, and doc cleanup (PERF_NOTES distilled to Lore; completed migration docs removed). Full per-commit history on branch perf/gather-coalesce-unroll-k; canon in lore 2026-06-29--full32l-hsdp-donation-canonical-scaling-memory.
…review (#900)

* fix(jax): correctness + multi-host robustness fixes from feature/jax review

A review of the JAX trainer against the torch oracle (torch-oracle tag) and
SPEC.md found the training math faithful; these address the issues it surfaced.

- fix(eval): identity_ci_error counted off-diagonal errors only inside the
  min(shape) block, silently ignoring active components in the trailing columns
  of overcomplete decompositions (C > rows — the normal toy/LM case). Count over
  the full matrix minus the block diagonal, matching torch
  IdentityCIPattern.distance_from. Fixed in all three copies (slow_eval.py +
  tms/resid_mlp model.py).

- fix(run): SIGTERM was read from a per-process flag at collective decision
  points (faith-warmup exit, eval entry, orbax save), so a per-task SIGTERM with
  no cross-rank simultaneity could diverge ranks onto different branches and hang
  a collective. Reconcile the flag across ranks (OR-reduce) once per step into a
  local used at every gate; drop the now-redundant per-rank checks inside the LM
  eval_fn (the pass is admitted only after consensus and runs to completion).

- fix(config): wire faithfulness_warmup_weight_decay (was hardcoded 0.0,
  silently ignoring the field) and drop the redundant ==0.0 canonical assert;
  add a fail-fast assert that the persistent-PGD source LR is fn_type=="constant"
  (source_lr ignores decay, so a decaying schedule was silently flattened).

- fix(eval): log the LR the step actually applied (now_step-1), not next step's;
  permute the lower/upper CI heatmaps each by their own-derived permutation
  (torch parity) instead of sharing the upper-derived one.

Verified: make type clean; core tests pass at 1 and 4 simulated devices
(incl. the device-count-invariance suite); slow_eval tests pass under the Agg
backend (validating the heatmap change); toy/schedule/config/launch tests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01UKuMWZoyShCYVd3Q937KfA

* refactor(jax): trim redundant comments, drop dead sigterm_received

The SIGTERM consensus rationale was stated four times; keep it once in the
_sigterm_consensus docstring and drop the two call-site comments that restated
it. Tighten the docstring and the LR-logging comment. Remove sigterm_received()
— now unused after the LM eval_fn stopped reading the per-rank flag.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01UKuMWZoyShCYVd3Q937KfA

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
* test(stacked-parity): portable fp32 tolerance, not bit-exact across CPUs

test_clean_output_bit_identical used jnp.array_equal and the sibling forward
pins used rtol=1e-5/atol=1e-6. Those encode bit-exactness against the
fixture-generating host, but ubuntu-latest is a heterogeneous runner pool:
float32 matmul reduction order differs by ~1 ULP across CPU microarchitectures,
so the exact check and the near-zero elements under the tight atol flaked
intermittently (failed 3 of the last 4 CI runs; passed only when a run happened
to land on a matching CPU). The forwards are CI-fn-independent and the math is
unchanged — this is pure reassociation noise, not a regression.

Fold clean_output into the tolerance-based pins (rename -> test_clean_output_matches)
and raise the shared forward tolerance to rtol=1e-4/atol=1e-5. Observed
cross-arch noise is <=3.1e-6 abs (measured on arm64 vs the fixture host), so the
new bound keeps ~3-4x headroom while staying essentially exact for fp32.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01CixstgLX5XaRm8CwAEReXk

* test(llama8b): rebuild vu/ci_fn per state in fresh-pgd test (step donates)

make_train_step now donates the state (donate="all-except-first", #903), so
the shared module-level vu/ci_fn were deleted after the first run_step and the
second crashed on a deleted buffer. Build them fresh inside make_state with the
same deterministic keys — bit-identical init, independent donatable buffers.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2

* docs(stacked-parity): don't attribute the ~1e-4 magnitude to SPEC D4

D4 is the reassociation concept; the specific tolerance is ours. Reword so the
docstring cites D4 for the idea without misquoting it on the number.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…h_env (#906)

The SLURM rank env (XLA flags, NCCL/host-memory knobs, PD_* profiling toggles)
was a hardcoded `_RANK_ENV` block in launch.py, so a run's config.yaml didn't
capture what env/XLA flags it ran with, and A/B-ing a flag meant editing a shared
launcher file.

Add `RuntimeConfig.launch_env: LaunchEnv` (param_decomp.configs) as the single
source of truth — the formerly-hardcoded values become its defaults. `LaunchEnv`
carries typed knobs (xla_flags, xla_python_client_{mem_fraction,allocator},
xla_pjrt_gpu_host_memory_limit_gb, nccl_debug, malloc_arena_max), a free-form
`env` escape hatch (merged last), and a typed `ProfileConfig` that renders the
PD_* profiling toggles — so a profile run is a config, not an env hack.
`launch.py::_render_rank_env` renders `LaunchEnv.as_env()` into the exported bash
block; LD_LIBRARY_PATH stays launcher-computed (machine-specific). The old
`pd-lm --allocator` flag is now `launch_env.xla_python_client_allocator`.

Defaults render byte-identically to the old block (test pins this). Inline
(dp is None) path inherits the caller's env, unchanged.


Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
remat=True keeps nothing_saveable (production default, no behavior change); remat=False now uses dots_saveable so retained-activation forwards stop recomputing the matmuls. Gather is JIT-and-freed either way.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Passes gpu_max_activity_api_events via ProfileOptions.advanced_configuration so a full train step fits (the default ~1M-event cap truncates the step mid-forward).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Trimmed from full32L to layer 18 only; per-matrix C doubled; single recon chunk (sites_per_chunk 7, coeff 0.5).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
… in-block

The stochastic recon forwards each held a pre-built [n_layer,B,S,C] mask stack;
collapse them to ONE shared CI stack recomputed inside the checkpointed block
(faithful by checkpoint determinism). The CI envelope and the compute weights are
each forward-evaluated ONCE per step (single eqx.filter_vjp / jax.vjp) rather than
once detached for the adversary ascent + once inside the main backward.

Memory (full32L b128 dp32, MEMPROF 142183 -> MEMRECOMP 142698):
  temp arena 113.79 -> 98.88 GiB (-15), runtime peak 163.30 -> 148.38 GiB (-15),
  headroom under the 164.08 GiB limit 0.78 -> 15.70 GiB.

Interface: stack_ci + masked_output_stochastic on the DecomposedModel protocol;
scan targets recompute in-block, toys delegate to run_stochastic_masked_output.
SPEC unchanged (pure recompute restructuring; the stochastic draw need not match
the prior goldens -- only fwd==bwd faithfulness matters).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
The "lazy DDP" experiment replicated the bf16 compute weights fully (optimizer
masters still ÷N) to make the per-layer FSDP all-gather vanish — trading a full
model-scale V/U resident for no gather. It does not scale (holds a full model
copy; we are comms-bound precisely because the model is large), so retire it:
the ProfileConfig.replicate_weights toggle + as_env mapping, the train.py grad
reduce-scatter branch, and the llama8b _reconstruct_compute_weights branch. The
default ÷fsdp compute layout is now unconditional.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…gs → compiler_options

Replace the debuggy PD_* env-var surface with proper config the trainer reads directly:

- Profiling toggles (mem_profile / time_steps / trace / async_test / leaf_bench /
  no_checkpoint / profile_max_events) are read straight off `ProfileConfig` in run.py —
  no more config -> as_env() -> os.environ round-trip in the same process. Drop
  ProfileConfig.as_env().
- The compute-experiment toggles become RuntimeConfig fields threaded as args, not env:
  `scan_unroll` (native lax.scan(unroll=k)), `gather_fp8`, `ascend_replicate`.
- XLA compiler flags move off `launch_env.xla_flags` (XLA_FLAGS env) onto
  `RuntimeConfig.compiler_options`, passed NATIVELY to every jit via a typed
  `jit_util.filter_jit` helper. Unlike XLA_FLAGS env, compiler_options are in the
  compile-cache key, so a flag A/B actually recompiles (kills the stale-cache confound).
  Default = the tuned MaxText set, so a config with no overrides comes out tuned, not naked.
  `LaunchEnv.as_env()` now renders only the genuine pre-init env (client mem/allocator,
  NCCL, glibc).
- Remove the ci_broadcast experiment (vmap-all-chunks CI fn) and the PersistentPGD
  `start_frac` field (+ its config fixtures).

make check clean; 87 tests pass.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Resolve per-layer liveness at trace (live_set is static) and run the masked
forward as [frozen prefix] -> [live block] -> [frozen suffix] sub-scans, instead
of one scan over all layers with a per-site lax.cond. The cond was a packing +
scheduling barrier: removing it lets XLA pack the per-layer V/U all-gathers far
more aggressively and prefetch across the live block. Only the live block carries
V/U; frozen segments reuse clean_output's body and gather no V/U.

full32L 4-chunk (sites_per_chunk=56) step-2 wall: 8.088 -> 5.698s (-29.6%);
defining all-gather ops 467 -> 210. Numerically bit-identical (SPEC S2/S3 -- a
frozen segment is the old frozen branch op-for-op; equivalence goldens unchanged).

Assumes layer-aligned (sites_per_chunk % n_decomposed_kinds == 0) + contiguous
chunks, asserted loudly. Drops partial-layer (per_site / layerwise) recon support;
the one partial-layer unit test became a whole-layer ablation.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
`pd-lm` runs via `fire.Fire`, which parses `--tags a,b,c` into a tuple (but keeps
a value containing a hyphen as a string). `tags.split(",")` then crashed with
"'tuple' object has no attribute 'split'". Normalize tags whether Fire hands back
a str, a tuple, or None.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
The engine already applies grad clip to whichever optimizer carries a
grad_clip_norm (`_adamw_with_clip` builds `chain(clip_by_global_norm_with_eps,
adamw)` for both), using the torch-parity clip. Only the canonical-config assert
gated it off for the CI-fn optimizer. Drop that assert so CI-fn grad clip is an
optional knob; the components clip stays required (part of the method).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants