Migrate VPD to JAX: train + analyze in one framework; retire torch to oracle#560
Open
ocg-goodfire wants to merge 557 commits into
Open
Migrate VPD to JAX: train + analyze in one framework; retire torch to oracle#560ocg-goodfire wants to merge 557 commits into
ocg-goodfire wants to merge 557 commits into
Conversation
Repo convention keeps the WANDB key in .env; the offline runner never goes through init_pd_run / get_wandb_entity (the reference yaml pins the entity), so nothing else loads it. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The streaming fineweb loader leaves non-daemon parquet/arrow reader threads that block interpreter shutdown indefinitely (repro: iterate the eval loader on CPU and return) — both validation jobs sat RUNNING on a B200 for 25+ min after printing results. os._exit(0) once stdout + wandb are flushed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…rt -> pd-offline-eval -> cleanup) per committed checkpoint
…e per-consumer all-to-alls The CI head's out_w is ΣC-sharded so its output is born C-sharded; GSPMD then resharded it separately for every consumer (each chunk forward + PPGD + imp-min, fwd and bwd) — at 36 sites those ~1.2 GB all-to-all buffers dominated jit_step's temp arena (109.46 GiB OOM at 32 GPU, job 50542; XLA dump memprobe_mc_50581). batch_sharded_ci reshards lower+upper once at the producer. Pure sharding constraint — values exact, trajectories unchanged. AOT probe (L20-31 C=8192 bl1 remat-on, 8 GPU): temp 67.0 -> 49.7 GiB.
…memory table + recommendation Harvested from the stranded remat agent's completed jobs (50467/50468 A/B, 50475-50478 multi-chunk probes) plus this session's post-pin re-probe.
…ig route Converter relaxations, each asserted or printed: raw-HF Llama target spec (transformers.LlamaForCausalLM == vendored weights, export-bridge verified), 'model.'-prefixed site patterns, fp32 weights_dtype accepted with a loud bf16-frozen divergence note; wandb entity honestly nullable (upstream 'entity: null' = API-key default). configs/torch/llama8b_l18_C30k_200k_1pool.yaml = torch run p-19645bf7's snapshot with four header-documented edits: pretokenized data @ seq 2048 (upstream streamed raw text @ 512), ci attn max_len 2048, PPGD scope broadcast_across_batch (upstream per_batch_per_position — not implemented here, deliberate), save_every 5000.
…uncher (C30k 16-GPU arena exceeded the 75% default pool)
…nlocks 32-GPU mesh vs gcd-capped 16) AOT probe: bl4 remat-off arena 70.3 GiB — fits the 32-GPU launch comfortably.
…s deliberately uncompensated) AOT probes (8 GPU): bl8 remat-off 110.5 GiB (the launch shape, 64 GPU), bl4 off 70.3, bl12 off 149.3 (32-GPU B=384 ruled out — no margin for a multi-day run).
load_run_dir_config rebuilds the config from the pinned copies (wrapper as config.yaml + torch yaml as torch_config.yaml), so torch-wrapper runs export and offline-eval like native ones. The wrapper's launch-relative path field is ignored in favor of the pinned torch yaml.
…4x partial bump for the 4x batch; header edit 7)
…dispatch on runtime.topology) The C49k JAX run's reference yaml is single-pool; everything this module touches is shared between the schemas (and PDConfig is run_eval_pass's native type).
…e launcher CUDA_ERROR_STREAM_CAPTURE_INVALIDATED killed 4 jobs tonight across disjoint allocations (50453/50525/50676/50743) — systematic in the capture path, not bad nodes. Torch never exercises CUDA graphs, which is why the torch trainer needed no such resilience. ~1-3% launch-overhead cost.
…ot just the batch shell --signal=B:TERM@300 only signals the batch script; the trainer's SIGTERM-save handler never fired and preemption (job 50818) hard-killed the ranks after the grace window — losing everything since the last save_every checkpoint.
jsp-export writes vendored-layout keys regardless of the run yaml's target spec, so eval must build VendoredLlama (same weights) with model.-prefix patterns stripped; frozen forced bf16 = what the JAX run actually trains with (its documented fp32-yaml divergence), and the more faithful eval reference. Unblocks the C49k offline-eval loop (ckpt-5000 attempt died on the raw-HF dispatch).
…val args (C49k needs --micro-batch-size 4: mb8 OOMs at 157.6+7.8 GiB on a 178 GiB B200)
Per Oli: resilient-launch loops are an anti-pattern this stack shouldn't need. The one recurring crash class is fixed at the source (command buffers off, ~0% cost); the loop demonstrably contributed nothing across last night's incidents — it burned all 5 attempts on each deterministic OOM, could not escape a suspect node (retries reuse the allocation), and every actual recovery was a diagnosed resubmission. Plain --requeue + the now-working SIGTERM save = the torch model.
…one eval job per checkpoint save Replaces the polling sidecar chain (which self-multiplied and hid failures): rank 0 fires sbatch offline_eval_once.sbatch right after each save, fire-and-forget (a failed submission must not kill a multi-day run — the one place graceful handling beats fail-fast; it prints loudly instead). Per-run serialization via --job-name=jsp-oeval-<run> + --dependency=singleton; the run-dir marker dedups across requeues; optional <run_dir>/offline_eval_args carries run-specific flags (C49k: --micro-batch-size 4).
…yout, torch postprocess contract Wrapper gains run_id (p-<8hex>, the torch generate_run_id convention; authored into the wrapper so resume derives the same identity and the byte-compare pins it): run dir = out_dir/<run_id>, wandb id = run_id with run_name as the display name. The pinned torch yaml is additionally written as experiment_config.yaml (SavedLMRun's contract name), and the push-triggered eval job materializes model_<step>.pth (safetensors -> torch.save, pruned to newest 2) — together making a JAX run dir a first-class citizen of harvest/app/postprocess with zero special-casing. wandb config gains a jax_runtime section (actual device count etc.) since the upstream yaml's runtime.dp describes the torch run, not this one (it read 'dp: 32' while running on 64 GPUs). run_id is optional ONLY for the pre-scheme live C49k wrapper — collapse after it migrates.
…rity fixtures gen_stacked_fixtures.py runs against the PRE-restructure feature/jax-single-pool-pd code (stacked DecompVU, LayerRange) and pins, on the tiny L3-5 MLP target: target weights, per-site V/U, CI-fn/source init leaves, clean/masked/site-input/weight-delta forwards (S2/S3/S4/N2), and a 2-step training trajectory (metrics + final V/U + adversary sources, S13-S15). The generated npz was produced from the live base-branch checkout's venv; the script is excluded from this distribution's basedpyright (it deliberately imports the old API and only runs on the base branch). The follow-up site-generality restructure must reproduce these to SPEC D4 reassociation tolerance (clean logits bit-identical). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…e/up/down, per-site C
Generalizes the Llama target from contiguous-layer MLP-only to any per-layer matrix
site, each with its own C. The trainer (train.py) was already site-generic; this
rebuilds the target + plumbing around per-site state:
* llama8b.py — DecompVU becomes {site: (V (d_in,C_s), U (C_s,d_out))} fp32 masters
(same init scale rule: V~N(0,d_in^-0.5), U~N(0,C^-0.5)); Target is a uniform
SuffixLayer list from first_decomposed_layer(sites) to the LM head — layers with no
sites run the plain frozen block, and any site absent from the spec computes the
exact frozen x@W path (S2: masked_logits(live=()) stays bit-identical to
clean_logits, S3). q/k/v sites are decomposed BEFORE RoPE/SDPA via FrozenAttn.core
(head-count reshapes asserted); o applies to the attention output. Clean CI inputs
(S4): q/k/v <- post-LN1 residual, o <- pre-o_proj attn output, gate/up <- post-LN2
residual, down <- silu(gate)*up. weight_deltas stays fp32 W - V@U per site (N2).
Canonical site order = layer-ascending x KIND_ORDER (computation order); site names
keep the torch-module-path convention layers.{i}.{self_attn,mlp}.{kind}_proj.
* llama8b_sharding.py — per-site placement: V shards C on axis 1 (P(None,'dp')),
U on axis 0 (P('dp',None)); per-site C % mesh-size asserted; the
jit-with-out_shardings init pattern and batch_sharded_ci are untouched.
* config.py / torch_config.py / run.py / run_state.py / export.py — TargetConfig
carries canonical (site, C) pairs; the native {first_layer,last_layer,C} YAML maps
to the MLP family (unchanged behavior); the torch route accepts any
(model.)layers.N.{self_attn,mlp}.{kind}_proj target with per-site C (contiguity /
uniform-C / all-three-kinds asserts dropped; identity-target and other-pattern
refusals kept). Export reads V/U per site and renames any decomposed matrix's
frozen weight to .target_weight — attention sites included (the torch vendored
ComponentLlamaAttention componentizes q/k/v/o, so attention-site export is fully
supported; round-trip-verified against the real LMComponentModel via the new
l18_attn fixture case, heterogeneous C, key parity exact).
Equivalence with the stacked implementation (tests/stacked_parity, fixtures pinned
from feature/jax-single-pool-pd): clean/masked/site-input/weight-delta forwards
BIT-identical; 2-step train trajectory rel <= ~2.5e-6 (D4; sole divergence is
clip-global-norm leaf-order reassociation). Cross-framework loss-term equivalence
(tests/equivalence) unchanged and green.
invariance_check gains an abs tolerance floor (1e-7) alongside rel 1e-4: the per-leaf
grad-norm diagnostics include ~1e-4-magnitude norms whose cross-shard reduction
cancellation can graze rel 1e-4 (observed: one ci-fn wv leaf, abs err ~1e-8, step-0
only, non-growing) while every loss term sits at ~1e-6. Verified at 4 and 8 sim
devices (worst loss-term rel 2.5e-6 / 7.1e-6).
Checkpoints are NOT cross-compatible with stacked-era runs (components pytree layout
+ V/U init RNG derivation changed).
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…EC §3 site prose SPEC.md needed no semantic change: §1/§3 already define a site as any selected weight matrix. Amended the one §3 orientation sentence that described the *implementation* as contiguous-MLP-only (now: any per-layer q/k/v/o/gate/up/down matrix, per-site C; q/k/v decomposed before RoPE/SDPA) and added the attention-site examples to the §4.1 site_inputs comment (S4). No invariant changed. AUDIT.md records: stacked->per-site parity evidence, checkpoint non-compatibility across the restructure, and the verify_export_torch production-numerics pass becoming measure-only (the documented GELU/eps divergence is amplified on the tiny attention fixture; the asserted jax-matched pass is the mapping proof). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…orch PGDReconLoss parity) AdversaryConfig = SourceAdamConfig | FreshPGDConfig selects the recon adversary variant in make_train_step: PPGD keeps its persistent sources + Adam moments + fused (n_warmup+1)-th ascent (SPEC S13/S14, unchanged); FreshPGD samples per-site sources every step (unique_per_datapoint (B,T,C+1) or shared_across_batch (1,1,C+1)), runs n_steps of step_size*sign(grad) with clamp [0,1], and carries NO cross-step state (TrainState.sources stays empty; the post-backward source cotangent is unused). Metrics key 'pgd' -> wandb train/loss/PGDReconLoss. ExperimentConfig.ppgd renamed to .adversary (native yaml key 'ppgd:' unchanged); the torch-config route accepts PGDReconLoss as the adversary slot. Matches Dan's p-a005ed60 config (n_steps=1, step_size=1, init random, unique_per_datapoint).
…vs replicated vs bridged) + remaining gaps in dependency order
…LM implementation
The t-9d2b8f02 pile model (4L, d768, 6 heads, GELU MLP, plain rotate-half
RoPE, tied head) as a decomposition target. llama_simple_mlp.py mirrors
llama8b.py's seams: per-site maybe-masked matmuls for all six site kinds
(h.{i}.attn.{q,k,v,o}_proj before-RoPE/after, h.{i}.mlp.c_fc before the
GELU, down_proj after), frozen residual-start suffix + embedding/blocks
prefix, fp32 weight deltas (N2), live=() bit-equal to clean (S2/S3).
Registration: TargetConfig | LlamaSimpleMLPTargetConfig union in config.py
(native target block dispatches on pretrain_run_path); torch_config accepts
kind: pretrained LlamaSimpleMLP specs with h.* wildcard decomposition
patterns expanded over the checkpoint's n_layer; run.py builds the target
by kind and threads prefix_residual_fn; export/offline-eval guarded
llama8b-only.
Weights load from the torch pretrain cache via a one-off .pt -> safetensors
conversion (tools/convert_llama_simple_mlp_checkpoint.py, torch venv; tied
lm_head stored once as wte.weight).
Torch-fixture equivalence (tests/simple_mlp_equivalence/, fp32): tiny
random model (GQA repeat=2) max abs logits diff ~2e-7; real t-9d2b8f02
weights ~5e-5 — pins RoPE construction, tanh-GELU, GQA grouping, rms eps.
Tiny-config site tests mirror test_llama8b.py (masked identity, live=()
bitwise, mixed attn+mlp train step with heterogeneous C: S13/S15/S9/N1).
Suite green at 1 and 8 sim devices; basedpyright + ruff clean.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
EvalConfig.l0_groups carries the fnmatch site patterns; make_eval_step resolves members at build time (unmatched group refuses) and emits the torch group keys (l0/<thr>_<group> = SUM of member-site L0s). Unblocks Dan's pile config whose eval block groups per-layer + total.
… feature/jax-site-generality
Scope names now spell the stored source shape in tensor order: single_source->c,
broadcast_across_batch->sc, repeat_across_batch->nsc, per_batch_per_position->bsc;
PGD MaskScope -> {c, bc, bsc}. Classes renamed to match (CScope/SCScope/NSCScope/
BSCScope). Code + tests + reference yamls swept together; 514 tests green.
Stored experiment_config.yaml files predate the shape-spelled scope names; BeforeValidator aliases exactly the literals that occur in stored data (broadcast_across_batch->sc, per_batch_per_position->bsc; mask scopes shared_across_batch->c, unique_per_datapoint->bsc). Delete once stored runs are migrated.
# Conflicts: # param_decomp/metrics/base.py # param_decomp/metrics/importance_minimality.py # param_decomp/optimize.py # param_decomp/tests/metrics/test_importance_minimality_loss.py # param_decomp_lab/experiments/lm/run.py # param_decomp_lab/run_sink.py # param_decomp_lab/tests/test_resumption.py
…emory lever for batch) Add an enablable `remat_ci_fn` runtime flag (mirrors `remat_recon_forwards`) that wraps the CI-fn forward in `eqx.filter_checkpoint`, recomputing it in the backward instead of storing its activations. The CI-fn forward (4-block transformer × n_chunks) is the one hot-path component with NO checkpointing today, and its activations scale with batch — so this is the main activation-memory lever for training at larger batch on big targets. Pure memory/compute trade, no algorithm effect: verified `remat_ci_fn` True vs False give identical loss (bit-level: max Δloss 1.9e-6, max param-update Δ 1.9e-6 — float-reassoc noise from the recompute) on the tiny-llama step. Threaded: configs.py (RuntimeConfig field, default False) → run.py engine param + wandb record → train.py make_train_step (`apply_ci_fn` = `eqx.filter_checkpoint(_apply_ci_fn)` when set) → the three composition roots (`built.runtime.remat_ci_fn`). Test/experiment callers pass `remat_ci_fn=False` explicitly (the arg is required, like remat_recon_forwards). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…L-port Integrates the two upstream commits: - 23881ea sources/sources_opt_state -> single `adversaries: dict[str,PersistentAdversary]` - 8dfb861 smooth-L0 (Geman–McClure) imp-min penalty (annealed_imp_min_param / imp_min_terms) Resolution: the residual-first-class removal STAYS (this branch's core change). train.py collided because both refactors rewrote the recon/adversary core and upstream still used `residual` — resolved by taking upstream's train.py (adversaries + smooth-L0 structure) and re-applying this branch's two changes on top: (1) residual->batch / embed-internal (`leading` derived from a tap, not the opaque token batch), (2) `remat_ci_fn` (the CI-fn activation-checkpoint lever). The 4 residual-start experiment/tool files this branch deleted (llama8b_real, mem_probe, migrate_c49k_checkpoint, verify_c49k_migration) stay deleted — they targeted the old residual API and are unreferenced. Validated: make type 0/0/0; core suite 259 passed / 3 skipped / 11 xfailed (the xfails are the pre-existing residual-fed goldens), incl. stacked-parity, equivalence goldens, the adversary/PPGD tests, and upstream's new smooth-L0 test. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…h collectives (#889) The cluster sets NCCL_DEBUG=INFO / NCCL_DEBUG_SUBSYS=ALL in the node environment, which logs every NCCL collective — the slurm logs hit tens of GB per run (~100% NCCL lines). slurm.py already intends NCCL_DEBUG=WARN, but pd-lm's _RANK_ENV never exported it, so the inherited INFO/ALL won. Export WARN in the rank env to override it. Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Ports #543 to the JAX trainer. The rolled imp-min `lp + beta·entropy` becomes two independently-coefficiented terms computed from one shared `(c+eps)^p` pass: imp = Σ_c f_c (imp.coeff) freq = Σ_c f_c · log2(1 + a'·f_c) (freq.coeff), a' = reference_token_count The frequency normalizer is now EXPLICIT (`reference_token_count`) instead of the implicit global `B·T`, so the penalty's curvature is invariant to batch size at a fixed firing rate — batch and frequency-penalty strength become independently tunable. Setting `a' = B·T` reproduces the old behavior exactly; coefficients transfer as `freq.coeff = old imp.coeff · beta`. Tied form: nested `frequency: (coeff, reference_token_count) | None` on BOTH imp configs (`ImportanceMinimalityLoss` + `SmoothL0`), reusing the shared per-component sums in one pass — not a separate flat `LossTerm` (which would refetch the sums and break the self-contained-term model). `beta` is removed. SPEC S7/S8 amended + new S8'. Migrations: all in-repo YAMLs + test/experiment configs (`freq.coeff = imp.coeff·beta`, `reference_token_count = global B·T`; `beta:0` → no freq block). Equivalence goldens preserved WITHOUT regen (`a' = n_positions` reproduces the old `entropy` bit-for-bit). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01JDynbo3BeS7AEwu8iyje6F
… render test_in_loop_slow_tier_fires_on_cadence_without_stalling timed a window that, because submit() serializes renders via join() to cap one in-flight, was dominated by ~3 back-to-back matplotlib renders (4s each) rather than the dispatch cost it claims to bound. The test drives slow evals with zero train-step gap, so the one-in-flight join blocks instead of being the no-op it is in the real loop (slow_every=3000 train steps separate two slow evals, so a render always finishes off-thread first). Under CI's -n4 oversubscription each render stretched ~3.4x and the 3-render window blew past the 30s budget (40.5s); in isolation and on a faster box it passes. Measure only what the loop pays — the collective accumulate (~0.3s incl. compile) plus the submit dispatch — by joining the renderer BETWEEN submits, outside the timed window, mirroring the real train-step gap. Budget unchanged at 30s; the measured quantity is now dispatch (sub-second), immune to render contention. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01CixstgLX5XaRm8CwAEReXk
… why Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01CixstgLX5XaRm8CwAEReXk
Checkpoint the masked-forward scan BODY (scan(checkpoint(block))) so the backward recomputes one layer at a time and stores only the residual carry, instead of stacking all 32 layers' activations ([32,*,14336]) — the dominant step-memory term. Threaded as a keyword-only `remat` arg through the DecomposedModel.masked_output protocol (scan models remat per-layer; toys whole-forward); removes the wrong-granularity whole-forward jax.checkpoint from train.py. Numerically transparent (faith 3.746e-4 vs 3.754e-4 whole-forward baseline; HLO-verified one-layer recompute; 87 tests pass). Reverts the temporary seq-truncation data.py hack. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…nsient ChunkwiseTransformerCIFn ran its per-chunk transformers under eqx.filter_vmap, which unrolls + hoists every chunk's FSDP weight all-gather into the flat entry computation — all n_chunks gathers (each ~ΣC/tp) live at once. At full scale that transient is the dominant tp-dependent term (~15 GB/dev @TP8, ~62 @Tp2) and the reason low-tp configs OOM at ANY batch (empirically tp2 OOMs ~89 GiB independent of batch). Replace with lax.scan over the (homogeneous, already-asserted) n_chunks axis so the chunk iteration lowers as a loop: one chunk's gather live at a time, then freed. ENTRY gather 4.88->0 MiB in the HLO. Same math up to fp32 reassociation (different matmul layout; ~1e-6, fine at bf16); 217 tests pass. The unlock for low-tp. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Without it the ~31B CI-fn forward materialises into the step module -> ~80-min compile + near-OOM (jobs 130423/130424). The schema defaults it false, so a config that omits it silently gets the heavy path. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Persistent-PGD source + Adam m/v storage dtype is now a config knob (PersistentPGDReconLossConfig.source_dtype), default float32 (SPEC N1, oracle parity preserved as a no-op cast). bfloat16 opt-in halves the source + m + v footprint (the ~41 GiB f32 PPGD term -> ~21 GiB saved at full-32L scale). Threads dtype through init_sources_sharded -> init_persistent_sources; Adam state inherits via zeros_like; update casts grad in / update out around the [0,1] projection. NOT SPEC N1 on the bf16 path; bf16 second-moment v underflow is an untested stability risk -> exploratory branch. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
# Conflicts: # param_decomp_lab/experiments/lm/launch.py
…; trim run banner The remat_ci_fn config knob was plumbed end-to-end (config -> engine -> train) but the make_train_step call site hardcoded remat_ci_fn=False, so the flag was inert and the run banner logged a value the step ignored — the ~31B CI fn was never checkpointed regardless of config (the ~80-min compile + near-OOM we chased). Thread the real value. Also trim the run-start banner: log per-kind site counts (sites=224 [q_proj×32, ...]) instead of dumping all 224 site names inline. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…nch docstring - configs/llama8b_full32L: remove no-op fields stripped by back-compat validators (n_mask_samples, sampling, autocast_bf16, device) — they only misled readers. - data.py: restore strict seq-width assert 'in (seq_len, seq_len+1)' (drop the >= TEMP HACK added for the seq-64 gbsweep). The +1 is the real label-token convention (fineweb artifact is 512-wide, pile is 513=512+label; both truncate to seq_len). - launch.py docstring: 'one srun task per node claiming all 8 GPUs' (was '8 tasks/node'). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Each train log line leads with wall-clock elapsed<eta (e.g. 2:31<6:50) from the recent step rate, and drops the per-param grad norms that dwarfed the metrics — the full breakdown still rides to wandb + jsonl. No progress bar (no in-place redraw in a SLURM .out anyway), so no tqdm dependency. Claude-Session: https://claude.ai/code/session_01XBgqY1rYzgQbbLdgfQZsbD Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ot fp32 (#905) * fix(slow-eval): run the CI-fn readout in training precision (bf16), not fp32 The slow/plot eval tier deliberately read the CI fn out in fp32, but the CI transformer's attention routes to cuDNN flash, which rejects fp32 — so the slow tier crashed at the first slow eval on GPU: NotImplementedError: Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got float32 (hit by every run, including bf16-only ones — unrelated to fp8 work). Rather than route fp32 to the XLA attention impl (#885), run the CI-fn readout in training precision (bf16) — matching train.py / eval.py and the hidden-acts + attn-pattern slow tiers, which already do. This is the more faithful readout (the deployed model runs bf16) and keeps cuDNN flash (faster, no (B,H,T,T) materialize). Reductions/returns stay fp32 so the host accumulation is byte-unchanged. Also incorporates #885's precision-independent multihost fix: the per-position histogram sample keeps the dp-sharded batch axis, so `np.asarray` on >1 process spans non-addressable devices — gather it with process_allgather(tiled=True). Supersedes #885's ci_fn.py change (no fp32 attention path is needed once the slow tier is bf16); incorporates its slow_eval.py gather fix. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL * style(slow-eval): trim the bf16-readout and gather comments to the load-bearing why Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…64 OOM) (#898) * fix(load_run): replicate HF prefix without cross-host allgather (dp>=64 OOM) jax.device_put(host_array, P()) runs multihost_utils.assert_equal -> process_allgather(tiled=True), tiling the ~1GB embedding to ~process_count GB (OOM at dp=64). Build the replicated global array from each host's local copy via make_array_from_callback instead. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012PUbea772bseCPww4uv3JE * docs(sharding): trim place_via_shardings docstring Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2 --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…throughput (#903) Squash of the full32L HSDP perf line. Buffer donation (donate=all-except-first) + V/U reconstruction hoist take full Llama-8B VPD to b128/4-seq-per-GPU on 32 GPUs at ~3.6x the b32 throughput, numerics bit-identical. Includes env-gated memory/timing instrumentation, the b128 production config, and doc cleanup (PERF_NOTES distilled to Lore; completed migration docs removed). Full per-commit history on branch perf/gather-coalesce-unroll-k; canon in lore 2026-06-29--full32l-hsdp-donation-canonical-scaling-memory.
…review (#900) * fix(jax): correctness + multi-host robustness fixes from feature/jax review A review of the JAX trainer against the torch oracle (torch-oracle tag) and SPEC.md found the training math faithful; these address the issues it surfaced. - fix(eval): identity_ci_error counted off-diagonal errors only inside the min(shape) block, silently ignoring active components in the trailing columns of overcomplete decompositions (C > rows — the normal toy/LM case). Count over the full matrix minus the block diagonal, matching torch IdentityCIPattern.distance_from. Fixed in all three copies (slow_eval.py + tms/resid_mlp model.py). - fix(run): SIGTERM was read from a per-process flag at collective decision points (faith-warmup exit, eval entry, orbax save), so a per-task SIGTERM with no cross-rank simultaneity could diverge ranks onto different branches and hang a collective. Reconcile the flag across ranks (OR-reduce) once per step into a local used at every gate; drop the now-redundant per-rank checks inside the LM eval_fn (the pass is admitted only after consensus and runs to completion). - fix(config): wire faithfulness_warmup_weight_decay (was hardcoded 0.0, silently ignoring the field) and drop the redundant ==0.0 canonical assert; add a fail-fast assert that the persistent-PGD source LR is fn_type=="constant" (source_lr ignores decay, so a decaying schedule was silently flattened). - fix(eval): log the LR the step actually applied (now_step-1), not next step's; permute the lower/upper CI heatmaps each by their own-derived permutation (torch parity) instead of sharing the upper-derived one. Verified: make type clean; core tests pass at 1 and 4 simulated devices (incl. the device-count-invariance suite); slow_eval tests pass under the Agg backend (validating the heatmap change); toy/schedule/config/launch tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01UKuMWZoyShCYVd3Q937KfA * refactor(jax): trim redundant comments, drop dead sigterm_received The SIGTERM consensus rationale was stated four times; keep it once in the _sigterm_consensus docstring and drop the two call-site comments that restated it. Tighten the docstring and the LR-logging comment. Remove sigterm_received() — now unused after the LM eval_fn stopped reading the per-rank flag. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01UKuMWZoyShCYVd3Q937KfA --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
* test(stacked-parity): portable fp32 tolerance, not bit-exact across CPUs test_clean_output_bit_identical used jnp.array_equal and the sibling forward pins used rtol=1e-5/atol=1e-6. Those encode bit-exactness against the fixture-generating host, but ubuntu-latest is a heterogeneous runner pool: float32 matmul reduction order differs by ~1 ULP across CPU microarchitectures, so the exact check and the near-zero elements under the tight atol flaked intermittently (failed 3 of the last 4 CI runs; passed only when a run happened to land on a matching CPU). The forwards are CI-fn-independent and the math is unchanged — this is pure reassociation noise, not a regression. Fold clean_output into the tolerance-based pins (rename -> test_clean_output_matches) and raise the shared forward tolerance to rtol=1e-4/atol=1e-5. Observed cross-arch noise is <=3.1e-6 abs (measured on arm64 vs the fixture host), so the new bound keeps ~3-4x headroom while staying essentially exact for fp32. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01CixstgLX5XaRm8CwAEReXk * test(llama8b): rebuild vu/ci_fn per state in fresh-pgd test (step donates) make_train_step now donates the state (donate="all-except-first", #903), so the shared module-level vu/ci_fn were deleted after the first run_step and the second crashed on a deleted buffer. Build them fresh inside make_state with the same deterministic keys — bit-identical init, independent donatable buffers. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2 * docs(stacked-parity): don't attribute the ~1e-4 magnitude to SPEC D4 D4 is the reassociation concept; the specific tolerance is ours. Reword so the docstring cites D4 for the idea without misquoting it on the number. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2 --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…h_env (#906) The SLURM rank env (XLA flags, NCCL/host-memory knobs, PD_* profiling toggles) was a hardcoded `_RANK_ENV` block in launch.py, so a run's config.yaml didn't capture what env/XLA flags it ran with, and A/B-ing a flag meant editing a shared launcher file. Add `RuntimeConfig.launch_env: LaunchEnv` (param_decomp.configs) as the single source of truth — the formerly-hardcoded values become its defaults. `LaunchEnv` carries typed knobs (xla_flags, xla_python_client_{mem_fraction,allocator}, xla_pjrt_gpu_host_memory_limit_gb, nccl_debug, malloc_arena_max), a free-form `env` escape hatch (merged last), and a typed `ProfileConfig` that renders the PD_* profiling toggles — so a profile run is a config, not an env hack. `launch.py::_render_rank_env` renders `LaunchEnv.as_env()` into the exported bash block; LD_LIBRARY_PATH stays launcher-computed (machine-specific). The old `pd-lm --allocator` flag is now `launch_env.xla_python_client_allocator`. Defaults render byte-identically to the old block (test pins this). Inline (dp is None) path inherits the caller's env, unchanged. Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
remat=True keeps nothing_saveable (production default, no behavior change); remat=False now uses dots_saveable so retained-activation forwards stop recomputing the matmuls. Gather is JIT-and-freed either way. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Passes gpu_max_activity_api_events via ProfileOptions.advanced_configuration so a full train step fits (the default ~1M-event cap truncates the step mid-forward). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Trimmed from full32L to layer 18 only; per-matrix C doubled; single recon chunk (sites_per_chunk 7, coeff 0.5). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
… in-block The stochastic recon forwards each held a pre-built [n_layer,B,S,C] mask stack; collapse them to ONE shared CI stack recomputed inside the checkpointed block (faithful by checkpoint determinism). The CI envelope and the compute weights are each forward-evaluated ONCE per step (single eqx.filter_vjp / jax.vjp) rather than once detached for the adversary ascent + once inside the main backward. Memory (full32L b128 dp32, MEMPROF 142183 -> MEMRECOMP 142698): temp arena 113.79 -> 98.88 GiB (-15), runtime peak 163.30 -> 148.38 GiB (-15), headroom under the 164.08 GiB limit 0.78 -> 15.70 GiB. Interface: stack_ci + masked_output_stochastic on the DecomposedModel protocol; scan targets recompute in-block, toys delegate to run_stochastic_masked_output. SPEC unchanged (pure recompute restructuring; the stochastic draw need not match the prior goldens -- only fwd==bwd faithfulness matters). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
The "lazy DDP" experiment replicated the bf16 compute weights fully (optimizer masters still ÷N) to make the per-layer FSDP all-gather vanish — trading a full model-scale V/U resident for no gather. It does not scale (holds a full model copy; we are comms-bound precisely because the model is large), so retire it: the ProfileConfig.replicate_weights toggle + as_env mapping, the train.py grad reduce-scatter branch, and the llama8b _reconstruct_compute_weights branch. The default ÷fsdp compute layout is now unconditional. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
…gs → compiler_options Replace the debuggy PD_* env-var surface with proper config the trainer reads directly: - Profiling toggles (mem_profile / time_steps / trace / async_test / leaf_bench / no_checkpoint / profile_max_events) are read straight off `ProfileConfig` in run.py — no more config -> as_env() -> os.environ round-trip in the same process. Drop ProfileConfig.as_env(). - The compute-experiment toggles become RuntimeConfig fields threaded as args, not env: `scan_unroll` (native lax.scan(unroll=k)), `gather_fp8`, `ascend_replicate`. - XLA compiler flags move off `launch_env.xla_flags` (XLA_FLAGS env) onto `RuntimeConfig.compiler_options`, passed NATIVELY to every jit via a typed `jit_util.filter_jit` helper. Unlike XLA_FLAGS env, compiler_options are in the compile-cache key, so a flag A/B actually recompiles (kills the stale-cache confound). Default = the tuned MaxText set, so a config with no overrides comes out tuned, not naked. `LaunchEnv.as_env()` now renders only the genuine pre-init env (client mem/allocator, NCCL, glibc). - Remove the ci_broadcast experiment (vmap-all-chunks CI fn) and the PersistentPGD `start_frac` field (+ its config fixtures). make check clean; 87 tests pass. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
Resolve per-layer liveness at trace (live_set is static) and run the masked forward as [frozen prefix] -> [live block] -> [frozen suffix] sub-scans, instead of one scan over all layers with a per-site lax.cond. The cond was a packing + scheduling barrier: removing it lets XLA pack the per-layer V/U all-gathers far more aggressively and prefetch across the live block. Only the live block carries V/U; frozen segments reuse clean_output's body and gather no V/U. full32L 4-chunk (sites_per_chunk=56) step-2 wall: 8.088 -> 5.698s (-29.6%); defining all-gather ops 467 -> 210. Numerically bit-identical (SPEC S2/S3 -- a frozen segment is the old frozen branch op-for-op; equivalence goldens unchanged). Assumes layer-aligned (sites_per_chunk % n_decomposed_kinds == 0) + contiguous chunks, asserted loudly. Drops partial-layer (per_site / layerwise) recon support; the one partial-layer unit test became a whole-layer ablation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
`pd-lm` runs via `fire.Fire`, which parses `--tags a,b,c` into a tuple (but keeps
a value containing a hyphen as a string). `tags.split(",")` then crashed with
"'tuple' object has no attribute 'split'". Normalize tags whether Fire hands back
a str, a tuple, or None.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
The engine already applies grad clip to whichever optimizer carries a grad_clip_norm (`_adamw_with_clip` builds `chain(clip_by_global_norm_with_eps, adamw)` for both), using the torch-parity clip. Only the canonical-config assert gated it off for the CI-fn optimizer. Drop that assert so CI-fn grad clip is an optional knob; the components clip stays required (part of the method). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01QvFotbQNtDNsgXJQZuzghR
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Top-line
Migrates VPD fully to JAX — train and analyze in one framework — and retires the torch stack to a git-tagged oracle. Squash-merges to
mainonce, at the end. Net: ~430 commits, +24k / −50k LOC (mostly deletion).Key decisions
torch-oracle/torch-oracle-npool; JAX conforms to it (SPEC.mdis the normative contract, numeric seams default to matching torch, goldens prove it).DecomposedModelviaopen_jax_run) — harvest, clustering, autointerp, slow/offline eval, app. The JAX→torch export bridge is dead.DecomposedModel— orderedsites+ pure fnsclean_output/site_inputs/masked_output/weight_deltas/masked_site_outputs— generic over input/output/recon-loss with[B,T,d]as the fixed waist.plan × mask-source strategy(make_plan+ chunking helpers); loss is KL on final logits. Hidden-acts recon is a separate eval diagnostic over themasked_site_outputsseam (amends SPEC S31), not a recon-grid variant.# type: ignore/Any/castwithout sign-off;make check-jaxgated in pre-commit; fail-fast, types-first.Status — functionally complete; green
JAX trainer; all consumer ports; read-only app; dropped-feature deletion;
DecomposedLM → DecomposedModelrename; hidden-acts eval port; llama8b loader; type-debt → 0 + pre-commit gate; a code-review + fix pass (#854/#855/#856) and a first dead-trainer deletion (#857). Suites green: ~415 lab + ~166 jax at 1 and 4 devices;make type/check-jaxclean; torch↔JAX per-term equivalence + stacked-parity trajectory goldens pass bit-unmodified; validated end-to-end on SimpleMLP pile runp-761bc061.Remaining before main-merge (each gated — see commit history / memory)
offline_eval.py/pd-offline-eval/jsp-export; rewire_submit_offline_eval→jsp-slow-eval). Gated on parity-validating JAX slow-eval vs torch on a real llama8b run — no current-format llama8b run is loadable, so this needs a fresh run.param_decomp/core deletion (metrics/tree,train_step, …). Gated on the above (the live offline-eval path still imports them). Bridge/capstone surface stays.async_eval(in-loop autointerp) decision.Reading guide
param_decomp_jax/jax_single_pool/SPEC.md→lm.py→recon.py→train.py.TRANSITION.md= the settled plan;LOSS_PARITY_DESIGN.md= recon unification. VPD paper rides separately as #562.🤖 Generated with Claude Code