feat(jax): smooth-L0 (Geman–McClure) importance-minimality penalty#879
Closed
danbraunai-goodfire wants to merge 1 commit into
Closed
feat(jax): smooth-L0 (Geman–McClure) importance-minimality penalty#879danbraunai-goodfire wants to merge 1 commit into
danbraunai-goodfire wants to merge 1 commit into
Conversation
Add a second importance-minimality penalty alongside the L_p baseline. Both share one structure (per-site Σψ(c) sparsity + β·mean·log2(1+sum) frequency term); they differ ONLY in the per-value penalty ψ and its annealed scalar: L_p : ψ(c) = (c+eps)^p p anneals 2.0 -> 0.4 (singular at 0) smooth-L0 : ψ(c) = c^2 / (c^2 + γ^2) γ anneals 1.0 -> 0.1 (flat at 0) smooth-L0 has ψ'(0)=0 and |ψ'|<=0.65/γ everywhere, so there's no origin cliff (no eps floor, no aggressive grad clip); it parks inert components at tiny-but-nonzero CI rather than exact 0 (report L0 at a CI>0.1 cutoff). - config: SmoothL0ImportanceMinimalityLossConfig + AnyImportanceMinimalityLossConfig union; added to the AnyLossMetricConfig YAML discriminated union. - losses.py: factor shared _imp_min_terms(ci, ψ) and _linear_anneal; add smooth_l0_importance_minimality_terms / annealed_gamma; type-dispatched annealed_imp_min_param / imp_min_terms. importance_minimality_terms keeps its signature (equivalence + global-reduction goldens unchanged). - recon.py / train.py: the imp-min slot accepts either penalty (exactly one); step dispatches on config type. train/p_imp logs whichever param is live. - SPEC.md: new invariant S9'. - test_smooth_l0_imp_min.py: ψ(γ)=1/2, ψ'(0)=0, 0.65/γ bound at c=γ/√3, redescent, per-site math, anneal + dispatch. - example config: smooth-L0 sibling (only the imp-min block swapped). make check / make check-jax clean; 198 JAX tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014k5TGHpEUTNnDeTxG2wbQe
Collaborator
Author
|
Superseded. See thread. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds a second importance-minimality penalty to the JAX trainer alongside the existing
L_pbaseline. Both penalties share one structure — a per-siteΣ_j ψ(c)sparsity term plus theβ · mean · log2(1 + sum)frequency term — and differ only in the per-value penaltyψand its annealed scalar:ψ(c)c→0⁺L_p(existing)(c + eps)^pp: 2.0 → 0.4→ ∞forp<1(ε-capped) — a cliffc² / (c² + γ²)γ: 1.0 → 0.1→ 0— flat at the originsmooth-L0 (Geman–McClure) has
ψ'(0)=0and|ψ'| ≤ 0.65/γeverywhere, so there is no singularity atc=0— noepsfloor, no aggressive grad clip. The gradient is localized on the threshold bandc ≈ γ/√3and redescends for clearly-on components. The flip side: it parks inert components at tiny-but-nonzero CI rather than driving them to exact 0, so its honest sparsity shows up at aCI>0.1cutoff (where it matches theL_pbaseline on both L0 and KL).Motivation and Context
The
L_ppenalty's apparent sparsity comes from the same origin cliff that destabilizes training (forcinggrad_clip_norm: 0.01and anepsfloor). smooth-L0 trades exact-zeros for a stable, bounded gradient. This adds it as a drop-in alternative penalty so the two can be compared head-to-head under identical recon/faithfulness settings.What changed
param_decomp_config—SmoothL0ImportanceMinimalityLossConfig(γ + linear-anneal fields;PositiveFloatsince γ is a denominator) +AnyImportanceMinimalityLossConfigunion; added to theAnyLossMetricConfigdiscriminated union.losses.py— factored the shared_imp_min_terms(ci, ψ)(per-site grouping, global-sum-before-log2,lp + β·entropy) and_linear_anneal; addedsmooth_l0_importance_minimality_terms/annealed_gamma; type-dispatchedannealed_imp_min_param/imp_min_terms.importance_minimality_termskeeps its original signature so the equivalence + global-reduction goldens don't churn.recon.py/train.py— the imp-min slot accepts either penalty (exactly one required); the step dispatches on config type.train/p_implogs whichever parameter is live (porγ).SPEC.md— new invariant S9′ documenting the penalty and the one-of-two slot rule...._ppgdbsc_smoothl0.yaml, a smooth-L0 sibling of an existing run config (only the imp-min block swapped).By design the two penalties are not entangled — when
L_pis eventually retired it's a localized delete (config class, union arm,annealed_pnorm/importance_minimality_terms, one arm in each match, the S9 invariant).How Has This Been Tested?
make checkandmake check-jaxboth clean.test_smooth_l0_imp_min.pypins the defining properties:ψ(γ)=½,ψ'(0)=0, the0.65/γgradient bound atc=γ/√3, redescent, the per-site term math, and the anneal + dispatch path.test_clean_output_bit_identicalfloat-reassociation mismatch against a committed golden — fails identically with these changes stashed).load_config→build_recon_terms, selecting the smooth-L0 penalty.Does this PR introduce a breaking change?
No. The
L_ppath is untouched (same config shape, same goldens); smooth-L0 is purely additive.🤖 Generated with Claude Code
https://claude.ai/code/session_014k5TGHpEUTNnDeTxG2wbQe