Skip to content

feat(jax): smooth-L0 (Geman–McClure) importance-minimality penalty#879

Closed
danbraunai-goodfire wants to merge 1 commit into
feature/jaxfrom
feature/smooth-l0-imp-min
Closed

feat(jax): smooth-L0 (Geman–McClure) importance-minimality penalty#879
danbraunai-goodfire wants to merge 1 commit into
feature/jaxfrom
feature/smooth-l0-imp-min

Conversation

@danbraunai-goodfire

Copy link
Copy Markdown
Collaborator

Description

Adds a second importance-minimality penalty to the JAX trainer alongside the existing L_p baseline. Both penalties share one structure — a per-site Σ_j ψ(c) sparsity term plus the β · mean · log2(1 + sum) frequency term — and differ only in the per-value penalty ψ and its annealed scalar:

per-value ψ(c) annealed param gradient at c→0⁺
L_p (existing) (c + eps)^p p: 2.0 → 0.4 → ∞ for p<1 (ε-capped) — a cliff
smooth-L0 (new) c² / (c² + γ²) γ: 1.0 → 0.1 → 0 — flat at the origin

smooth-L0 (Geman–McClure) has ψ'(0)=0 and |ψ'| ≤ 0.65/γ everywhere, so there is no singularity at c=0 — no eps floor, no aggressive grad clip. The gradient is localized on the threshold band c ≈ γ/√3 and redescends for clearly-on components. The flip side: it parks inert components at tiny-but-nonzero CI rather than driving them to exact 0, so its honest sparsity shows up at a CI>0.1 cutoff (where it matches the L_p baseline on both L0 and KL).

Motivation and Context

The L_p penalty's apparent sparsity comes from the same origin cliff that destabilizes training (forcing grad_clip_norm: 0.01 and an eps floor). smooth-L0 trades exact-zeros for a stable, bounded gradient. This adds it as a drop-in alternative penalty so the two can be compared head-to-head under identical recon/faithfulness settings.

What changed

  • param_decomp_configSmoothL0ImportanceMinimalityLossConfig (γ + linear-anneal fields; PositiveFloat since γ is a denominator) + AnyImportanceMinimalityLossConfig union; added to the AnyLossMetricConfig discriminated union.
  • losses.py — factored the shared _imp_min_terms(ci, ψ) (per-site grouping, global-sum-before-log2, lp + β·entropy) and _linear_anneal; added smooth_l0_importance_minimality_terms / annealed_gamma; type-dispatched annealed_imp_min_param / imp_min_terms. importance_minimality_terms keeps its original signature so the equivalence + global-reduction goldens don't churn.
  • recon.py / train.py — the imp-min slot accepts either penalty (exactly one required); the step dispatches on config type. train/p_imp logs whichever parameter is live (p or γ).
  • SPEC.md — new invariant S9′ documenting the penalty and the one-of-two slot rule.
  • Example config..._ppgdbsc_smoothl0.yaml, a smooth-L0 sibling of an existing run config (only the imp-min block swapped).

By design the two penalties are not entangled — when L_p is eventually retired it's a localized delete (config class, union arm, annealed_pnorm/importance_minimality_terms, one arm in each match, the S9 invariant).

How Has This Been Tested?

  • make check and make check-jax both clean.
  • New test_smooth_l0_imp_min.py pins the defining properties: ψ(γ)=½, ψ'(0)=0, the 0.65/γ gradient bound at c=γ/√3, redescent, the per-site term math, and the anneal + dispatch path.
  • Full JAX suite: 198 passed (one pre-existing, unrelated test_clean_output_bit_identical float-reassociation mismatch against a committed golden — fails identically with these changes stashed).
  • The example config loads through the real load_configbuild_recon_terms, selecting the smooth-L0 penalty.

Does this PR introduce a breaking change?

No. The L_p path is untouched (same config shape, same goldens); smooth-L0 is purely additive.

🤖 Generated with Claude Code

https://claude.ai/code/session_014k5TGHpEUTNnDeTxG2wbQe

Add a second importance-minimality penalty alongside the L_p baseline. Both
share one structure (per-site Σψ(c) sparsity + β·mean·log2(1+sum) frequency
term); they differ ONLY in the per-value penalty ψ and its annealed scalar:

  L_p        : ψ(c) = (c+eps)^p            p anneals 2.0 -> 0.4   (singular at 0)
  smooth-L0  : ψ(c) = c^2 / (c^2 + γ^2)    γ anneals 1.0 -> 0.1   (flat at 0)

smooth-L0 has ψ'(0)=0 and |ψ'|<=0.65/γ everywhere, so there's no origin cliff
(no eps floor, no aggressive grad clip); it parks inert components at
tiny-but-nonzero CI rather than exact 0 (report L0 at a CI>0.1 cutoff).

- config: SmoothL0ImportanceMinimalityLossConfig + AnyImportanceMinimalityLossConfig
  union; added to the AnyLossMetricConfig YAML discriminated union.
- losses.py: factor shared _imp_min_terms(ci, ψ) and _linear_anneal; add
  smooth_l0_importance_minimality_terms / annealed_gamma; type-dispatched
  annealed_imp_min_param / imp_min_terms. importance_minimality_terms keeps its
  signature (equivalence + global-reduction goldens unchanged).
- recon.py / train.py: the imp-min slot accepts either penalty (exactly one);
  step dispatches on config type. train/p_imp logs whichever param is live.
- SPEC.md: new invariant S9'.
- test_smooth_l0_imp_min.py: ψ(γ)=1/2, ψ'(0)=0, 0.65/γ bound at c=γ/√3,
  redescent, per-site math, anneal + dispatch.
- example config: smooth-L0 sibling (only the imp-min block swapped).

make check / make check-jax clean; 198 JAX tests pass.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014k5TGHpEUTNnDeTxG2wbQe
@danbraunai-goodfire

Copy link
Copy Markdown
Collaborator Author

Superseded. See thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant