From 0e1b5205ffd4e127ffcf2a04eeb6433b725b43f3 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 19 Jun 2026 14:28:38 +0100 Subject: [PATCH] feat(jax): smooth-L0 (Geman-McClure) importance-minimality penalty MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a second importance-minimality penalty alongside the L_p baseline. Both share one structure (per-site Σψ(c) sparsity + β·mean·log2(1+sum) frequency term); they differ ONLY in the per-value penalty ψ and its annealed scalar: L_p : ψ(c) = (c+eps)^p p anneals 2.0 -> 0.4 (singular at 0) smooth-L0 : ψ(c) = c^2 / (c^2 + γ^2) γ anneals 1.0 -> 0.1 (flat at 0) smooth-L0 has ψ'(0)=0 and |ψ'|<=0.65/γ everywhere, so there's no origin cliff (no eps floor, no aggressive grad clip); it parks inert components at tiny-but-nonzero CI rather than exact 0 (report L0 at a CI>0.1 cutoff). - config: SmoothL0ImportanceMinimalityLossConfig + AnyImportanceMinimalityLossConfig union; added to the AnyLossMetricConfig YAML discriminated union. - losses.py: factor shared _imp_min_terms(ci, ψ) and _linear_anneal; add smooth_l0_importance_minimality_terms / annealed_gamma; type-dispatched annealed_imp_min_param / imp_min_terms. importance_minimality_terms keeps its signature (equivalence + global-reduction goldens unchanged). - recon.py / train.py: the imp-min slot accepts either penalty (exactly one); step dispatches on config type. train/p_imp logs whichever param is live. - SPEC.md: new invariant S9'. - test_smooth_l0_imp_min.py: ψ(γ)=1/2, ψ'(0)=0, 0.65/γ bound at c=γ/√3, redescent, per-site math, anneal + dispatch. - example config: smooth-L0 sibling (only the imp-min block swapped). make check / make check-jax clean; 198 JAX tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) Claude-Session: https://claude.ai/code/session_014k5TGHpEUTNnDeTxG2wbQe --- param_decomp_config/losses.py | 34 +++ param_decomp_config/pd.py | 2 + param_decomp_jax/jax_single_pool/SPEC.md | 8 +- ...r_chunkwise_faith10x_ppgdbsc_smoothl0.yaml | 195 ++++++++++++++++++ param_decomp_jax/jax_single_pool/losses.py | 112 ++++++++-- param_decomp_jax/jax_single_pool/recon.py | 14 +- .../jax_single_pool/tests/test_config.py | 6 +- .../tests/test_smooth_l0_imp_min.py | 87 ++++++++ param_decomp_jax/jax_single_pool/train.py | 12 +- 9 files changed, 446 insertions(+), 24 deletions(-) create mode 100644 param_decomp_jax/jax_single_pool/configs/llama8b_l18-23_6layer_chunkwise_faith10x_ppgdbsc_smoothl0.yaml create mode 100644 param_decomp_jax/jax_single_pool/tests/test_smooth_l0_imp_min.py diff --git a/param_decomp_config/losses.py b/param_decomp_config/losses.py index 3c389e0f3..03da9d3df 100644 --- a/param_decomp_config/losses.py +++ b/param_decomp_config/losses.py @@ -60,6 +60,40 @@ class ImportanceMinimalityLossConfig(LossMetricConfig): eps: NonNegativeFloat = 1e-12 +class SmoothL0ImportanceMinimalityLossConfig(LossMetricConfig): + """Geman–McClure smooth-L0 importance-minimality penalty on upper-leaky CI values. + + Per-value penalty `phi_gamma(c) = c^2 / (c^2 + gamma^2)` (a smooth approximation to + the active-component count `1[c>0]`, exact only as `gamma -> 0`), summed over + components and fed through the same per-site `lp + beta * mean * log2(1 + sum)` + structure as `ImportanceMinimalityLoss`. Differs from the `L_p` penalty only in the + per-value shape: `phi'(0) = 0` and `|phi'| <= 0.65/gamma` everywhere, so there is no + singularity at the origin (no `eps` floor, no aggressive grad clip) — the gradient is + localized on the threshold band `c ~ gamma/sqrt(3)` and redescends for clearly-on + components. + + `gamma` is the initial scale; it is linearly annealed toward `gamma_anneal_final_gamma` + between `gamma_anneal_start_frac` and `gamma_anneal_end_frac` of training. Annealing + `gamma` down sharpens the count (a typical `c >> gamma` then reads as "1"). A constant + schedule is `gamma_anneal_final_gamma == gamma`. + """ + + type: Literal["SmoothL0ImportanceMinimalityLoss"] = "SmoothL0ImportanceMinimalityLoss" + gamma: PositiveFloat + beta: NonNegativeFloat + gamma_anneal_start_frac: Probability = 1.0 + gamma_anneal_final_gamma: PositiveFloat | None = None + gamma_anneal_end_frac: Probability = 1.0 + + +# The two importance-minimality penalties share the `coeff`/`beta` surface and the +# `lp + beta * entropy` aggregation; they differ only in the per-value penalty shape and +# its annealed parameter (`p` vs `gamma`). The trainer's imp-min slot accepts either. +AnyImportanceMinimalityLossConfig = ( + ImportanceMinimalityLossConfig | SmoothL0ImportanceMinimalityLossConfig +) + + class CIMaskedReconLossConfig(LossMetricConfig): type: Literal["CIMaskedReconLoss"] = "CIMaskedReconLoss" diff --git a/param_decomp_config/pd.py b/param_decomp_config/pd.py index 1b385b996..330a58495 100644 --- a/param_decomp_config/pd.py +++ b/param_decomp_config/pd.py @@ -31,6 +31,7 @@ PGDReconLayerwiseLossConfig, PGDReconLossConfig, PGDReconSubsetLossConfig, + SmoothL0ImportanceMinimalityLossConfig, StochasticHiddenActsReconLossConfig, StochasticReconLayerwiseLossConfig, StochasticReconLossConfig, @@ -64,6 +65,7 @@ class OptimizerConfig(BaseConfig): | PGDReconLayerwiseLossConfig | PGDReconLossConfig | PGDReconSubsetLossConfig + | SmoothL0ImportanceMinimalityLossConfig | StochasticHiddenActsReconLossConfig | StochasticReconLayerwiseLossConfig | StochasticReconLossConfig diff --git a/param_decomp_jax/jax_single_pool/SPEC.md b/param_decomp_jax/jax_single_pool/SPEC.md index 3f2248dc7..13fa44a40 100644 --- a/param_decomp_jax/jax_single_pool/SPEC.md +++ b/param_decomp_jax/jax_single_pool/SPEC.md @@ -135,9 +135,12 @@ def kl_per_position(masked_output, clean_output) = def faithfulness_loss(components) = ( Σ_s ‖W_s − V_s@U_s‖_F² ) / ( Σ_s numel(W_s) ) (S17) -def importance_minimality_loss(ci_upper, pnorm): # per-site grouping (S7) - for s: per_component_sums[c] = Σ_{b,t} (ci_upper_s[b,t,c] + eps) ** pnorm (S8,S9) +def importance_minimality_loss(ci_upper, psi): # per-site grouping (S7) + for s: per_component_sums[c] = Σ_{b,t} psi(ci_upper_s[b,t,c]) (S8,S9) return Σ_s Σ_c (per_component_sums[c]/(B·T)) · (1 + beta · log2(1 + per_component_sums[c])) + # psi is the per-value penalty; the two penalties differ ONLY in psi (S9'): + # L_p : psi(c) = (c + eps) ** pnorm pnorm anneals 2.0 → 0.4 + # smooth-L0 : psi(c) = c**2 / (c**2 + gamma**2) gamma anneals 1.0 → 0.1 # RECON_PLAN: a static list of entries (live_sites, SAMPLE_ROUTING); each entry's sampler # returns a statically-sized FAMILY of routing draws, each draw = one forward (§6). @@ -238,6 +241,7 @@ on the clean path. Refs: `ci_fn.py` GELU line + `CI_FN_RMS_EPS`; torch | S7 | Imp-min groups per site: the `log2(1+sum)` consumes one site's per-component sum. Merging sites/layers into one group is incorrect (convexity). | | S8 | The per-component sums are over the **global batch**, accumulated before the `log2`. (Per-shard results combined after the log are incorrect — Jensen; see D2.) | | S9 | `pnorm(step)` anneals linearly `2.0 → 0.4` over the configured frac window; `eps` sits inside the power. **JAX narrowing:** annealing is REQUIRED — `annealed_pnorm` asserts `cfg.p_anneal_final_p is not None` (`losses.py:53`) and `train.py:122` asserts it too. Torch supports a constant-p config (`importance_minimality.py:16-37` returns `initial_p` when no annealing window). Constant-p in JAX is expressed by setting `p_anneal_final_p == pnorm` (a flat schedule); any other torch constant-p config is REFUSED (fail-fast assert), never silently approximated. | +| S9′ | The imp-min slot accepts EITHER per-value penalty (one `ImportanceMinimalityLoss` *or* `SmoothL0ImportanceMinimalityLoss` in `loss_metrics`, never both — `build_recon_terms` asserts `imp_min is None` on the second). smooth-L0 (Geman–McClure) is `psi(c)=c²/(c²+γ²)`: `psi'(0)=0` (flat at the origin, NO `L_p` cliff, so no `eps` floor / no aggressive grad clip), `psi(γ)=½`, `|psi'|≤0.65/γ` with the peak at `c=γ/√3`, redescending for `c≫γ`. `γ` anneals linearly `gamma → gamma_anneal_final_gamma` over `[gamma_anneal_start_frac, gamma_anneal_end_frac]` (typically `1.0 → 0.1`); like S9 annealing is REQUIRED (`annealed_gamma` asserts `gamma_anneal_final_gamma is not None`; constant-γ is `final == gamma`). `γ>0` (denominator; pydantic `PositiveFloat`). Everything downstream of `psi` (per-site grouping S7, global-sum-before-log2 S8, `lp + beta·entropy`) is SHARED — `_imp_min_terms(ci_upper, psi)` is the one impl, dispatched by `imp_min_terms` / `annealed_imp_min_param` on the config type. The logged `train/p_imp` carries whichever parameter is live (`p` or `γ`). | | S10′ | The recon objective is a static tuple of coefficiented loss TERMS (one per configured recon loss metric, in config order). Each term is a static plan of `(live_sites, SAMPLE_ROUTING, MASK_SOURCE)` entries; the term's loss = mean over ALL its forwards (every draw of every entry) of `kl_per_position`; the total adds `coeff · term` per term. Plan structures (live-sets, sampler identities, family sizes, strategy kinds) are fixed across steps. The §4 pseudocode shows the production two-term instantiation (`stochastic_recon_loss` + `adversarial_recon_loss`). Recon KL direction is pinned by S25; the mean-over-forwards ≡ accumulator identity by S26. | | S11 | `uniform_k_routing`, per position: `k ~ U{1..|live_sites|}` then a uniform `k`-subset of the live sites routes True; non-live sites are not live at all. Routing draws are fresh per step, sampled inside the step. | | S12′ | An adversarial term's loss forward consumes its sources as LEAVES (no ascent-graph history); gradient flows to components and (through `ci_lower`) to the CI fn — and, for persistent sources, to the leaves themselves (S14′). The PRODUCTION adversarial term masks ALL sites and routes everywhere; subset-routed adversarial terms route per their plan. | diff --git a/param_decomp_jax/jax_single_pool/configs/llama8b_l18-23_6layer_chunkwise_faith10x_ppgdbsc_smoothl0.yaml b/param_decomp_jax/jax_single_pool/configs/llama8b_l18-23_6layer_chunkwise_faith10x_ppgdbsc_smoothl0.yaml new file mode 100644 index 000000000..3565e7b48 --- /dev/null +++ b/param_decomp_jax/jax_single_pool/configs/llama8b_l18-23_6layer_chunkwise_faith10x_ppgdbsc_smoothl0.yaml @@ -0,0 +1,195 @@ +# From-scratch 6-layer R&D decomposition of Llama-3.1-8B MLP layers 18-23 (18 targets, +# C=49152, seq 512, B=128, 40k steps, chunkwise recon 2 chunks of 9 sites). Same config +# as wandb run p-d69f728c (jax-l18-23-6L-seq512-b128-40k) with three deliberate edits: +# 1. FaithfulnessLoss coeff 1e5 -> 1e6 (faithfulness up 10x). +# 2. PersistentPGDReconLoss scope sc (broadcast_across_batch) -> bsc (independent +# source per batch element + position; skips cross-rank source sync). +# 3. PersistentPGDReconLoss Adam beta1 0.5 -> 0.01. +# p-d69f728c is the 6-layer (18-23) sibling of llama8b_l18-26_9layer_chunkwise.yaml: +# identical except 18 targets instead of 27 and ChunkwiseSubsetReconLoss coeff 1.0 +# (= base 0.5 x 2 chunks) instead of 1.5 (x 3 chunks), undoing chunk-mean dilution so +# per-site recon pressure stays invariant to chunk count. remat ON. +# +# smooth-L0 sibling of ..._ppgdbsc.yaml: the ONLY change is the importance-minimality +# penalty — L_p (pnorm 2.0->0.4) swapped for Geman-McClure smooth-L0 (gamma 1.0->0.1). +# Same coeff (5e-6) and beta (0.2). smooth-L0 has a flat gradient at c=0 (no cliff, no +# eps floor) so it parks inert components at tiny-but-nonzero CI rather than exact 0 — +# report L0 at a CI>0.1 cutoff to compare against the L_p baseline. +run_name: jax-l18-23-6L-seq512-b128-40k-faith10x-ppgdbsc-smoothl0-g1-0.1 +out_dir: /mnt/data/artifacts/mechanisms/param-decomp/jax_runs +cadence: + keep_last_n_checkpoints: 2 + save_every: 5000 + train_log_every: 200 +data: + buffer_size: 1000 + column_name: input_ids + data_files: /mnt/data/artifacts/mechanisms/param-decomp/datasets/fineweb_llama_tok_512/*.parquet + dataset_name: parquet + eval_split: train + is_tokenized: true + max_seq_len: 512 + revision: null + shuffle_each_epoch: true + streaming: false + tokenizer_name: meta-llama/Llama-3.1-8B + train_split: train +eval: + batch_size: 128 + every: 1000 + metrics: + - n_batches_accum: 1 + type: CIHistograms + - ci_alive_threshold: 0.0 + type: ComponentActivationDensity + - ci_alive_threshold: 0.0 + groups: null + type: CI_L0 + - rounding_threshold: 0.0 + type: CEandKLLosses + - type: CIMeanPerComponent + - coeff: null + type: StochasticHiddenActsReconLoss + - type: CIHiddenActsReconLoss + - coeff: null + init: random + mask_scope: shared_across_batch + n_steps: 20 + step_size: 0.1 + type: PGDReconLoss + n_steps: 1 + slow_every: 10000 + slow_on_first_step: true +pd: + batch_size: 128 + ci_config: + fn_type: global_shared_transformer + hidden_dims: null + mode: global + simple_transformer_ci_cfg: + attn_config: + max_len: 512 + n_heads: 64 + rope_base: 10000.0 + d_model: 4096 + mlp_hidden_dim: + - 16384 + n_blocks: 4 + ci_fn_optimizer: + betas: + - 0.9 + - 0.999 + grad_clip_norm: null + lr_schedule: + final_val_frac: 0.1 + fn_type: cosine + start_val: 5.0e-05 + warmup_pct: 0.0 + weight_decay: 0.0 + components_optimizer: + betas: + - 0.9 + - 0.999 + grad_clip_norm: 0.01 + lr_schedule: + final_val_frac: 0.1 + fn_type: cosine + start_val: 1.5e-04 + warmup_pct: 0.0 + weight_decay: 0.0 + decomposition_targets: + - C: 49152 + module_pattern: model.layers.18.mlp.gate_proj + - C: 49152 + module_pattern: model.layers.18.mlp.up_proj + - C: 49152 + module_pattern: model.layers.18.mlp.down_proj + - C: 49152 + module_pattern: model.layers.19.mlp.gate_proj + - C: 49152 + module_pattern: model.layers.19.mlp.up_proj + - C: 49152 + module_pattern: model.layers.19.mlp.down_proj + - C: 49152 + module_pattern: model.layers.20.mlp.gate_proj + - C: 49152 + module_pattern: model.layers.20.mlp.up_proj + - C: 49152 + module_pattern: model.layers.20.mlp.down_proj + - C: 49152 + module_pattern: model.layers.21.mlp.gate_proj + - C: 49152 + module_pattern: model.layers.21.mlp.up_proj + - C: 49152 + module_pattern: model.layers.21.mlp.down_proj + - C: 49152 + module_pattern: model.layers.22.mlp.gate_proj + - C: 49152 + module_pattern: model.layers.22.mlp.up_proj + - C: 49152 + module_pattern: model.layers.22.mlp.down_proj + - C: 49152 + module_pattern: model.layers.23.mlp.gate_proj + - C: 49152 + module_pattern: model.layers.23.mlp.up_proj + - C: 49152 + module_pattern: model.layers.23.mlp.down_proj + faithfulness_warmup_lr: 0.001 + faithfulness_warmup_steps: 400 + faithfulness_warmup_weight_decay: 0.0 + identity_decomposition_targets: null + loss_metrics: + - beta: 0.2 + coeff: 5.0e-06 + gamma: 1.0 + gamma_anneal_start_frac: 0.0 + gamma_anneal_final_gamma: 0.1 + gamma_anneal_end_frac: 1.0 + type: SmoothL0ImportanceMinimalityLoss + - coeff: 1.0 + n_samples: 1 + routing: + type: uniform_k_subset + sites_per_chunk: 9 + type: ChunkwiseSubsetReconLoss + - coeff: 0.5 + n_samples: 1 + n_warmup_steps: 2 + optimizer: + beta1: 0.01 + beta2: 0.99 + eps: 1.0e-08 + lr_schedule: + final_val_frac: 1.0 + fn_type: constant + start_val: 0.01 + warmup_pct: 0.025 + type: adam + scope: + type: bsc + start_frac: 0.0 + type: PersistentPGDReconLoss + - coeff: 1000000.0 + type: FaithfulnessLoss + n_mask_samples: 1 + sampling: continuous + seed: 0 + sigmoid_type: leaky_hard + steps: 40000 + tied_weights: null + use_delta_component: true +runtime: + remat_recon_forwards: true + autocast_bf16: true + device: cuda:0 + dp: 64 +target: + output_extract: logits + weights_dtype: bfloat16 + spec: + kind: hf + model_class: transformers.LlamaForCausalLM + model_name: meta-llama/Llama-3.1-8B +wandb: + entity: null + project: param-decomp-llama diff --git a/param_decomp_jax/jax_single_pool/losses.py b/param_decomp_jax/jax_single_pool/losses.py index 169b90556..c488ceb83 100644 --- a/param_decomp_jax/jax_single_pool/losses.py +++ b/param_decomp_jax/jax_single_pool/losses.py @@ -1,13 +1,18 @@ """The pure loss terms (SPEC §2) and their schedules — fp32 reductions, no state.""" import math +from collections.abc import Callable import jax import jax.numpy as jnp from beartype import beartype from jaxtyping import Array, Float, jaxtyped -from param_decomp_config.losses import ImportanceMinimalityLossConfig +from param_decomp_config.losses import ( + AnyImportanceMinimalityLossConfig, + ImportanceMinimalityLossConfig, + SmoothL0ImportanceMinimalityLossConfig, +) @jaxtyped(typechecker=beartype) @@ -37,14 +42,16 @@ def faithfulness_loss(weight_deltas: dict[str, Float[Array, "_ _"]]) -> Float[Ar return numerator / denominator -@jaxtyped(typechecker=beartype) -def importance_minimality_terms( - ci_upper: dict[str, Float[Array, "*leading _"]], pnorm: Float[Array, ""], eps: float +def _imp_min_terms( + ci_upper: dict[str, Float[Array, "*leading _"]], + per_value_penalty: Callable[[Float[Array, "*leading _"]], Float[Array, "*leading _"]], ) -> tuple[Float[Array, ""], Float[Array, ""]]: - """`(lp, entropy)` with per-site grouping and the global-batch sum inside the log2 - (SPEC S7/S8); the loss is `lp + beta * entropy`. Torch's `_no_beta` diagnostic (`lp` - alone) is emitted only on the eval path (`Metric.compute`), never from the train - step, so this trainer does not log a `train/loss/*_no_beta` key. + """`(lp, entropy)` for any per-value penalty `psi`, with per-site grouping and the + global-batch sum inside the log2 (SPEC S7/S8); the loss is `lp + beta * entropy`. + + `lp` is the mean-over-positions sparsity term `Σ_j mean_i psi(c_ij)`; `entropy` is the + frequency-minimality term `Σ_j mean_i psi · log2(1 + Σ_i psi)`. The two imp-min + penalties (`L_p`, smooth-L0) differ ONLY in `psi`. Under GSPMD the `*leading` axes are the global batch, so `jnp.sum` IS the exact global per-component sum — XLA reduces across shards inside the graph.""" @@ -54,20 +61,101 @@ def importance_minimality_terms( ci = ci.astype(jnp.float32) # (*leading, C) leading_axes = tuple(range(ci.ndim - 1)) n_positions = math.prod(ci.shape[:-1]) - per_component_sums = jnp.sum((ci + eps) ** pnorm, axis=leading_axes) # (C,) + per_component_sums = jnp.sum(per_value_penalty(ci), axis=leading_axes) # (C,) per_component_means = per_component_sums / n_positions lp = lp + jnp.sum(per_component_means) entropy = entropy + jnp.sum(per_component_means * jnp.log2(1.0 + per_component_sums)) return lp, entropy +@jaxtyped(typechecker=beartype) +def importance_minimality_terms( + ci_upper: dict[str, Float[Array, "*leading _"]], pnorm: Float[Array, ""], eps: float +) -> tuple[Float[Array, ""], Float[Array, ""]]: + """`L_p` imp-min terms: per-value penalty `(c + eps)^pnorm`. Singular at `c=0` for + `pnorm < 1` (the `eps` floor caps the gradient there). Torch's `_no_beta` diagnostic + (`lp` alone) is emitted only on the eval path (`Metric.compute`), never from the train + step, so this trainer does not log a `train/loss/*_no_beta` key.""" + return _imp_min_terms(ci_upper, lambda ci: (ci + eps) ** pnorm) + + +@jaxtyped(typechecker=beartype) +def smooth_l0_importance_minimality_terms( + ci_upper: dict[str, Float[Array, "*leading _"]], gamma: Float[Array, ""] +) -> tuple[Float[Array, ""], Float[Array, ""]]: + """Geman–McClure smooth-L0 imp-min terms: per-value penalty `c^2 / (c^2 + gamma^2)`. + Flat at the origin (`phi'(0)=0`) and bounded (`|phi'| <= 0.65/gamma`) — no singularity, + no `eps` floor. Approaches the true `L_0` count as `gamma -> 0`.""" + gamma_sq = gamma * gamma + return _imp_min_terms(ci_upper, lambda ci: ci**2 / (ci**2 + gamma_sq)) + + +def _linear_anneal( + step_f32: Array, + total_steps: int, + initial: float, + final: float, + start_frac: float, + end_frac: float, +) -> Array: + span = max(end_frac - start_frac, 1e-9) + progress = jnp.clip((step_f32 / total_steps - start_frac) / span, 0.0, 1.0) + return jnp.asarray(initial + (final - initial) * progress) + + def annealed_pnorm(step_f32: Array, total_steps: int, cfg: ImportanceMinimalityLossConfig) -> Array: """`p` anneals linearly `pnorm → p_anneal_final_p` over `[p_anneal_start_frac, p_anneal_end_frac]` of training (SPEC S9).""" assert cfg.p_anneal_final_p is not None - span = max(cfg.p_anneal_end_frac - cfg.p_anneal_start_frac, 1e-9) - progress = jnp.clip((step_f32 / total_steps - cfg.p_anneal_start_frac) / span, 0.0, 1.0) - return jnp.asarray(cfg.pnorm + (cfg.p_anneal_final_p - cfg.pnorm) * progress) + return _linear_anneal( + step_f32, + total_steps, + cfg.pnorm, + cfg.p_anneal_final_p, + cfg.p_anneal_start_frac, + cfg.p_anneal_end_frac, + ) + + +def annealed_gamma( + step_f32: Array, total_steps: int, cfg: SmoothL0ImportanceMinimalityLossConfig +) -> Array: + """`gamma` anneals linearly `gamma → gamma_anneal_final_gamma` over + `[gamma_anneal_start_frac, gamma_anneal_end_frac]` of training (SPEC S9').""" + assert cfg.gamma_anneal_final_gamma is not None + return _linear_anneal( + step_f32, + total_steps, + cfg.gamma, + cfg.gamma_anneal_final_gamma, + cfg.gamma_anneal_start_frac, + cfg.gamma_anneal_end_frac, + ) + + +def annealed_imp_min_param( + step_f32: Array, total_steps: int, cfg: AnyImportanceMinimalityLossConfig +) -> Array: + """The annealed per-value-penalty parameter at this step (`p` for `L_p`, `gamma` for + smooth-L0). Pure function of the step, so it's hoisted out of the loss `grad`.""" + match cfg: + case ImportanceMinimalityLossConfig(): + return annealed_pnorm(step_f32, total_steps, cfg) + case SmoothL0ImportanceMinimalityLossConfig(): + return annealed_gamma(step_f32, total_steps, cfg) + + +def imp_min_terms( + ci_upper: dict[str, Float[Array, "*leading _"]], + cfg: AnyImportanceMinimalityLossConfig, + annealed_param: Array, +) -> tuple[Float[Array, ""], Float[Array, ""]]: + """Dispatch `(lp, entropy)` on the imp-min penalty kind, given its annealed parameter.""" + match cfg: + case ImportanceMinimalityLossConfig(): + return importance_minimality_terms(ci_upper, annealed_param, cfg.eps) + case SmoothL0ImportanceMinimalityLossConfig(): + return smooth_l0_importance_minimality_terms(ci_upper, annealed_param) def warmup_then_constant_lr( diff --git a/param_decomp_jax/jax_single_pool/recon.py b/param_decomp_jax/jax_single_pool/recon.py index 9d6444116..ffb0a52c8 100644 --- a/param_decomp_jax/jax_single_pool/recon.py +++ b/param_decomp_jax/jax_single_pool/recon.py @@ -22,6 +22,7 @@ from jax_single_pool.lm import chunk_sites from param_decomp_config.losses import ( AdamPGDConfig, + AnyImportanceMinimalityLossConfig, BSCScope, ChunkwiseSubsetReconLossConfig, CIMaskedReconLayerwiseLossConfig, @@ -34,6 +35,7 @@ PGDReconLossConfig, PGDReconSubsetLossConfig, SCScope, + SmoothL0ImportanceMinimalityLossConfig, StochasticReconLayerwiseLossConfig, StochasticReconLossConfig, StochasticReconSubsetLossConfig, @@ -281,7 +283,7 @@ class LossSpec: (state_key -> shared config; SPEC S23: each key feeds exactly one term).""" faith_coeff: float - imp_min: ImportanceMinimalityLossConfig + imp_min: AnyImportanceMinimalityLossConfig recon_terms: ReconLossTerms persistent: dict[str, PersistentPGDReconLossConfig] @@ -308,7 +310,7 @@ def build_recon_terms( Term ORDER follows the config list (recon terms only) — per-term RNG keys are derived from that order, so it is semantically load-bearing (SPEC R1).""" faith_coeff: float | None = None - imp_min: ImportanceMinimalityLossConfig | None = None + imp_min: AnyImportanceMinimalityLossConfig | None = None terms: list[ReconLossTerm] = [] persistent: dict[str, PersistentPGDReconLossConfig] = {} @@ -327,6 +329,10 @@ def assert_unique_instance_key(cfg: AnyLossMetricConfig) -> str: assert imp_min is None assert cfg.p_anneal_final_p is not None imp_min = cfg + case SmoothL0ImportanceMinimalityLossConfig(): + assert imp_min is None + assert cfg.gamma_anneal_final_gamma is not None + imp_min = cfg case UnmaskedReconLossConfig() | CIMaskedReconLossConfig(): value = 1.0 if isinstance(cfg, UnmaskedReconLossConfig) else 0.0 plan = make_plan( @@ -402,7 +408,9 @@ def assert_unique_instance_key(cfg: AnyLossMetricConfig) -> str: raise AssertionError(f"unsupported training loss {cfg.type!r}") assert faith_coeff is not None and imp_min is not None, ( - f"need FaithfulnessLoss + ImportanceMinimalityLoss, got {[m.type for m in loss_metrics]}" + "need FaithfulnessLoss + an importance-minimality loss " + f"(ImportanceMinimalityLoss | SmoothL0ImportanceMinimalityLoss), " + f"got {[m.type for m in loss_metrics]}" ) assert terms, "no recon loss terms configured" for term in terms: diff --git a/param_decomp_jax/jax_single_pool/tests/test_config.py b/param_decomp_jax/jax_single_pool/tests/test_config.py index bc4aa18de..a22bdd6ac 100644 --- a/param_decomp_jax/jax_single_pool/tests/test_config.py +++ b/param_decomp_jax/jax_single_pool/tests/test_config.py @@ -21,7 +21,10 @@ from jax_single_pool.lm import SiteC from jax_single_pool.recon import build_recon_terms from param_decomp_config.lm import LMExperimentConfig -from param_decomp_config.losses import PersistentPGDReconLossConfig +from param_decomp_config.losses import ( + ImportanceMinimalityLossConfig, + PersistentPGDReconLossConfig, +) CONFIGS = Path(__file__).parent.parent / "configs" RUN_ID = "p-0123abcd" @@ -55,6 +58,7 @@ def test_b128_config_converts(tmp_path: Path): converted.loss_metrics, tuple(sc.name for sc in converted.target.sites), converted.n_mask_samples, converted.sampling, ) # fmt: skip + assert isinstance(spec.imp_min, ImportanceMinimalityLossConfig) assert spec.faith_coeff == 1e5 and spec.imp_min.pnorm == 2.0 (ppgd,) = spec.persistent.values() assert isinstance(ppgd, PersistentPGDReconLossConfig) diff --git a/param_decomp_jax/jax_single_pool/tests/test_smooth_l0_imp_min.py b/param_decomp_jax/jax_single_pool/tests/test_smooth_l0_imp_min.py new file mode 100644 index 000000000..b9057bd50 --- /dev/null +++ b/param_decomp_jax/jax_single_pool/tests/test_smooth_l0_imp_min.py @@ -0,0 +1,87 @@ +"""smooth-L0 (Geman–McClure) importance-minimality penalty (SPEC S7/S8/S9'). + +The penalty shares the per-site `lp + beta * mean * log2(1 + sum)` structure with the +`L_p` penalty and differs ONLY in the per-value shape `phi_gamma(c) = c^2/(c^2+gamma^2)`. +These checks pin the properties that motivate it over `L_p` (see the smooth-L0 report): +flat at the origin (`phi'(0)=0`, no singularity to clip), bounded gradient +(`|phi'| <= 0.65/gamma`), redescent for clearly-on components, and the half-saturation +crossover `phi(gamma) = 1/2`. +""" + +import jax +import jax.numpy as jnp + +from jax_single_pool.losses import ( + annealed_gamma, + annealed_imp_min_param, + imp_min_terms, + smooth_l0_importance_minimality_terms, +) +from param_decomp_config.losses import SmoothL0ImportanceMinimalityLossConfig + + +def _phi(c: jax.Array, gamma: float) -> jax.Array: + return c**2 / (c**2 + gamma**2) + + +def test_phi_shape_invariants(): + for gamma in (1.0, 0.1): + assert float(_phi(jnp.array(0.0), gamma)) == 0.0 # off -> exactly 0 + assert abs(float(_phi(jnp.array(gamma), gamma)) - 0.5) < 1e-6 # half-saturation + assert float(_phi(jnp.array(10.0 * gamma), gamma)) > 0.99 # clearly-on -> ~1 + + +def test_phi_gradient_flat_at_origin_and_bounded(): + """phi'(0) = 0 (no L_p cliff) and the peak |phi'| ~ 0.65/gamma sits at c = gamma/sqrt(3).""" + for gamma in (1.0, 0.1): + dphi = jax.grad(lambda c, g=gamma: _phi(c, g)) + assert float(dphi(jnp.array(0.0))) == 0.0 + cs = jnp.linspace(0.0, 5.0 * gamma, 4096) + grads = jnp.abs(jax.vmap(dphi)(cs)) + peak = float(grads.max()) + assert peak <= 0.65 / gamma + 1e-3 + c_peak = float(cs[jnp.argmax(grads)]) + assert abs(c_peak - gamma / jnp.sqrt(3.0)) < 0.02 * gamma + # redescent: gradient at a clearly-on point is far below the peak. + assert float(dphi(jnp.array(5.0 * gamma))) < 0.2 * peak + + +def test_terms_match_manual_per_site_structure(): + ci = { + "a": jnp.array([[0.0, 0.5, 1.0], [0.2, 0.0, 0.9]]), + "b": jnp.array([[0.3], [0.7]]), + } + gamma = 0.1 + lp, entropy = smooth_l0_importance_minimality_terms(ci, jnp.asarray(gamma)) + + exp_lp = jnp.zeros(()) + exp_entropy = jnp.zeros(()) + for v in ci.values(): + sums = _phi(v, gamma).sum(axis=0) + means = sums / v.shape[0] + exp_lp = exp_lp + means.sum() + exp_entropy = exp_entropy + (means * jnp.log2(1.0 + sums)).sum() + assert jnp.allclose(lp, exp_lp) + assert jnp.allclose(entropy, exp_entropy) + + +def test_anneal_and_dispatch(): + cfg = SmoothL0ImportanceMinimalityLossConfig( + coeff=2e-4, + gamma=1.0, + beta=0.5, + gamma_anneal_start_frac=0.0, + gamma_anneal_final_gamma=0.1, + gamma_anneal_end_frac=1.0, + ) + total = 100 + assert abs(float(annealed_gamma(jnp.asarray(0.0), total, cfg)) - 1.0) < 1e-6 + assert abs(float(annealed_gamma(jnp.asarray(50.0), total, cfg)) - 0.55) < 1e-6 + assert abs(float(annealed_gamma(jnp.asarray(total), total, cfg)) - 0.1) < 1e-6 + + ci = {"a": jnp.array([[0.0, 0.5, 1.0], [0.2, 0.0, 0.9]])} + param = annealed_imp_min_param(jnp.asarray(float(total)), total, cfg) + via_dispatch = imp_min_terms(ci, cfg, param) + direct = smooth_l0_importance_minimality_terms(ci, param) + assert jnp.allclose(via_dispatch[0], direct[0]) + assert jnp.allclose(via_dispatch[1], direct[1]) diff --git a/param_decomp_jax/jax_single_pool/train.py b/param_decomp_jax/jax_single_pool/train.py index 068e1deae..f588ad27c 100644 --- a/param_decomp_jax/jax_single_pool/train.py +++ b/param_decomp_jax/jax_single_pool/train.py @@ -38,9 +38,9 @@ from jax_single_pool.ci_fn_mlp import LayerwiseMLPCIFn from jax_single_pool.lm import DecomposedModel from jax_single_pool.losses import ( - annealed_pnorm, + annealed_imp_min_param, faithfulness_loss, - importance_minimality_terms, + imp_min_terms, warmup_then_constant_lr, ) from jax_single_pool.recon import ( @@ -136,7 +136,7 @@ def make_train_step( recon_terms = loss_spec.recon_terms imp_min = loss_spec.imp_min faith_coeff = loss_spec.faith_coeff - assert imp_min.coeff is not None and imp_min.p_anneal_final_p is not None + assert imp_min.coeff is not None imp_coeff = imp_min.coeff term_coeff_by_state_key = { entry.sources.state_key: term.coeff @@ -267,7 +267,7 @@ def step( state: TrainState, frozen: Any, residual: Float[Array, "*leading d"], key: PRNGKeyArray ) -> tuple[TrainState, dict[str, Array]]: step_f32 = state.step.astype(jnp.float32) - pnorm = annealed_pnorm(step_f32, total_steps, imp_min) + imp_min_param = annealed_imp_min_param(step_f32, total_steps, imp_min) leading = residual.shape[:-1] residual = batch_sharded(residual) @@ -413,7 +413,7 @@ def loss_fn( ci_fn_bf16 = cast_floating(ci_fn, COMPUTE_DT) ci = batch_sharded_ci(ci_fn_bf16(site_inputs)) faith_loss = faithfulness_loss(lm.weight_deltas(frozen, components)) - imp_lp, imp_entropy = importance_minimality_terms(ci.upper, pnorm, imp_min.eps) + imp_lp, imp_entropy = imp_min_terms(ci.upper, imp_min, imp_min_param) imp_loss = imp_lp + imp_min.beta * imp_entropy term_losses: list[Array] = [] @@ -532,7 +532,7 @@ def loss_fn( "total": total_loss, "faith": faith_loss, "imp": imp_loss, - "p_imp": pnorm, + "p_imp": imp_min_param, **{f"loss/{t.name}": v for t, v in zip(recon_terms, term_losses, strict=True)}, **grad_norm_metrics, }