Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions param_decomp_config/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,40 @@ class ImportanceMinimalityLossConfig(LossMetricConfig):
eps: NonNegativeFloat = 1e-12


class SmoothL0ImportanceMinimalityLossConfig(LossMetricConfig):
"""Geman–McClure smooth-L0 importance-minimality penalty on upper-leaky CI values.

Per-value penalty `phi_gamma(c) = c^2 / (c^2 + gamma^2)` (a smooth approximation to
the active-component count `1[c>0]`, exact only as `gamma -> 0`), summed over
components and fed through the same per-site `lp + beta * mean * log2(1 + sum)`
structure as `ImportanceMinimalityLoss`. Differs from the `L_p` penalty only in the
per-value shape: `phi'(0) = 0` and `|phi'| <= 0.65/gamma` everywhere, so there is no
singularity at the origin (no `eps` floor, no aggressive grad clip) — the gradient is
localized on the threshold band `c ~ gamma/sqrt(3)` and redescends for clearly-on
components.

`gamma` is the initial scale; it is linearly annealed toward `gamma_anneal_final_gamma`
between `gamma_anneal_start_frac` and `gamma_anneal_end_frac` of training. Annealing
`gamma` down sharpens the count (a typical `c >> gamma` then reads as "1"). A constant
schedule is `gamma_anneal_final_gamma == gamma`.
"""

type: Literal["SmoothL0ImportanceMinimalityLoss"] = "SmoothL0ImportanceMinimalityLoss"
gamma: PositiveFloat
beta: NonNegativeFloat
gamma_anneal_start_frac: Probability = 1.0
gamma_anneal_final_gamma: PositiveFloat | None = None
gamma_anneal_end_frac: Probability = 1.0


# The two importance-minimality penalties share the `coeff`/`beta` surface and the
# `lp + beta * entropy` aggregation; they differ only in the per-value penalty shape and
# its annealed parameter (`p` vs `gamma`). The trainer's imp-min slot accepts either.
AnyImportanceMinimalityLossConfig = (
ImportanceMinimalityLossConfig | SmoothL0ImportanceMinimalityLossConfig
)


class CIMaskedReconLossConfig(LossMetricConfig):
type: Literal["CIMaskedReconLoss"] = "CIMaskedReconLoss"

Expand Down
2 changes: 2 additions & 0 deletions param_decomp_config/pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PGDReconLayerwiseLossConfig,
PGDReconLossConfig,
PGDReconSubsetLossConfig,
SmoothL0ImportanceMinimalityLossConfig,
StochasticHiddenActsReconLossConfig,
StochasticReconLayerwiseLossConfig,
StochasticReconLossConfig,
Expand Down Expand Up @@ -64,6 +65,7 @@ class OptimizerConfig(BaseConfig):
| PGDReconLayerwiseLossConfig
| PGDReconLossConfig
| PGDReconSubsetLossConfig
| SmoothL0ImportanceMinimalityLossConfig
| StochasticHiddenActsReconLossConfig
| StochasticReconLayerwiseLossConfig
| StochasticReconLossConfig
Expand Down
8 changes: 6 additions & 2 deletions param_decomp_jax/jax_single_pool/SPEC.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,12 @@ def kl_per_position(masked_output, clean_output) =

def faithfulness_loss(components) = ( Σ_s ‖W_s − V_s@U_s‖_F² ) / ( Σ_s numel(W_s) ) (S17)

def importance_minimality_loss(ci_upper, pnorm): # per-site grouping (S7)
for s: per_component_sums[c] = Σ_{b,t} (ci_upper_s[b,t,c] + eps) ** pnorm (S8,S9)
def importance_minimality_loss(ci_upper, psi): # per-site grouping (S7)
for s: per_component_sums[c] = Σ_{b,t} psi(ci_upper_s[b,t,c]) (S8,S9)
return Σ_s Σ_c (per_component_sums[c]/(B·T)) · (1 + beta · log2(1 + per_component_sums[c]))
# psi is the per-value penalty; the two penalties differ ONLY in psi (S9'):
# L_p : psi(c) = (c + eps) ** pnorm pnorm anneals 2.0 → 0.4
# smooth-L0 : psi(c) = c**2 / (c**2 + gamma**2) gamma anneals 1.0 → 0.1

# RECON_PLAN: a static list of entries (live_sites, SAMPLE_ROUTING); each entry's sampler
# returns a statically-sized FAMILY of routing draws, each draw = one forward (§6).
Expand Down Expand Up @@ -238,6 +241,7 @@ on the clean path. Refs: `ci_fn.py` GELU line + `CI_FN_RMS_EPS`; torch
| S7 | Imp-min groups per site: the `log2(1+sum)` consumes one site's per-component sum. Merging sites/layers into one group is incorrect (convexity). |
| S8 | The per-component sums are over the **global batch**, accumulated before the `log2`. (Per-shard results combined after the log are incorrect — Jensen; see D2.) |
| S9 | `pnorm(step)` anneals linearly `2.0 → 0.4` over the configured frac window; `eps` sits inside the power. **JAX narrowing:** annealing is REQUIRED — `annealed_pnorm` asserts `cfg.p_anneal_final_p is not None` (`losses.py:53`) and `train.py:122` asserts it too. Torch supports a constant-p config (`importance_minimality.py:16-37` returns `initial_p` when no annealing window). Constant-p in JAX is expressed by setting `p_anneal_final_p == pnorm` (a flat schedule); any other torch constant-p config is REFUSED (fail-fast assert), never silently approximated. |
| S9′ | The imp-min slot accepts EITHER per-value penalty (one `ImportanceMinimalityLoss` *or* `SmoothL0ImportanceMinimalityLoss` in `loss_metrics`, never both — `build_recon_terms` asserts `imp_min is None` on the second). smooth-L0 (Geman–McClure) is `psi(c)=c²/(c²+γ²)`: `psi'(0)=0` (flat at the origin, NO `L_p` cliff, so no `eps` floor / no aggressive grad clip), `psi(γ)=½`, `|psi'|≤0.65/γ` with the peak at `c=γ/√3`, redescending for `c≫γ`. `γ` anneals linearly `gamma → gamma_anneal_final_gamma` over `[gamma_anneal_start_frac, gamma_anneal_end_frac]` (typically `1.0 → 0.1`); like S9 annealing is REQUIRED (`annealed_gamma` asserts `gamma_anneal_final_gamma is not None`; constant-γ is `final == gamma`). `γ>0` (denominator; pydantic `PositiveFloat`). Everything downstream of `psi` (per-site grouping S7, global-sum-before-log2 S8, `lp + beta·entropy`) is SHARED — `_imp_min_terms(ci_upper, psi)` is the one impl, dispatched by `imp_min_terms` / `annealed_imp_min_param` on the config type. The logged `train/p_imp` carries whichever parameter is live (`p` or `γ`). |
| S10′ | The recon objective is a static tuple of coefficiented loss TERMS (one per configured recon loss metric, in config order). Each term is a static plan of `(live_sites, SAMPLE_ROUTING, MASK_SOURCE)` entries; the term's loss = mean over ALL its forwards (every draw of every entry) of `kl_per_position`; the total adds `coeff · term` per term. Plan structures (live-sets, sampler identities, family sizes, strategy kinds) are fixed across steps. The §4 pseudocode shows the production two-term instantiation (`stochastic_recon_loss` + `adversarial_recon_loss`). Recon KL direction is pinned by S25; the mean-over-forwards ≡ accumulator identity by S26. |
| S11 | `uniform_k_routing`, per position: `k ~ U{1..|live_sites|}` then a uniform `k`-subset of the live sites routes True; non-live sites are not live at all. Routing draws are fresh per step, sampled inside the step. |
| S12′ | An adversarial term's loss forward consumes its sources as LEAVES (no ascent-graph history); gradient flows to components and (through `ci_lower`) to the CI fn — and, for persistent sources, to the leaves themselves (S14′). The PRODUCTION adversarial term masks ALL sites and routes everywhere; subset-routed adversarial terms route per their plan. |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# From-scratch 6-layer R&D decomposition of Llama-3.1-8B MLP layers 18-23 (18 targets,
# C=49152, seq 512, B=128, 40k steps, chunkwise recon 2 chunks of 9 sites). Same config
# as wandb run p-d69f728c (jax-l18-23-6L-seq512-b128-40k) with three deliberate edits:
# 1. FaithfulnessLoss coeff 1e5 -> 1e6 (faithfulness up 10x).
# 2. PersistentPGDReconLoss scope sc (broadcast_across_batch) -> bsc (independent
# source per batch element + position; skips cross-rank source sync).
# 3. PersistentPGDReconLoss Adam beta1 0.5 -> 0.01.
# p-d69f728c is the 6-layer (18-23) sibling of llama8b_l18-26_9layer_chunkwise.yaml:
# identical except 18 targets instead of 27 and ChunkwiseSubsetReconLoss coeff 1.0
# (= base 0.5 x 2 chunks) instead of 1.5 (x 3 chunks), undoing chunk-mean dilution so
# per-site recon pressure stays invariant to chunk count. remat ON.
#
# smooth-L0 sibling of ..._ppgdbsc.yaml: the ONLY change is the importance-minimality
# penalty — L_p (pnorm 2.0->0.4) swapped for Geman-McClure smooth-L0 (gamma 1.0->0.1).
# Same coeff (5e-6) and beta (0.2). smooth-L0 has a flat gradient at c=0 (no cliff, no
# eps floor) so it parks inert components at tiny-but-nonzero CI rather than exact 0 —
# report L0 at a CI>0.1 cutoff to compare against the L_p baseline.
run_name: jax-l18-23-6L-seq512-b128-40k-faith10x-ppgdbsc-smoothl0-g1-0.1
out_dir: /mnt/data/artifacts/mechanisms/param-decomp/jax_runs
cadence:
keep_last_n_checkpoints: 2
save_every: 5000
train_log_every: 200
data:
buffer_size: 1000
column_name: input_ids
data_files: /mnt/data/artifacts/mechanisms/param-decomp/datasets/fineweb_llama_tok_512/*.parquet
dataset_name: parquet
eval_split: train
is_tokenized: true
max_seq_len: 512
revision: null
shuffle_each_epoch: true
streaming: false
tokenizer_name: meta-llama/Llama-3.1-8B
train_split: train
eval:
batch_size: 128
every: 1000
metrics:
- n_batches_accum: 1
type: CIHistograms
- ci_alive_threshold: 0.0
type: ComponentActivationDensity
- ci_alive_threshold: 0.0
groups: null
type: CI_L0
- rounding_threshold: 0.0
type: CEandKLLosses
- type: CIMeanPerComponent
- coeff: null
type: StochasticHiddenActsReconLoss
- type: CIHiddenActsReconLoss
- coeff: null
init: random
mask_scope: shared_across_batch
n_steps: 20
step_size: 0.1
type: PGDReconLoss
n_steps: 1
slow_every: 10000
slow_on_first_step: true
pd:
batch_size: 128
ci_config:
fn_type: global_shared_transformer
hidden_dims: null
mode: global
simple_transformer_ci_cfg:
attn_config:
max_len: 512
n_heads: 64
rope_base: 10000.0
d_model: 4096
mlp_hidden_dim:
- 16384
n_blocks: 4
ci_fn_optimizer:
betas:
- 0.9
- 0.999
grad_clip_norm: null
lr_schedule:
final_val_frac: 0.1
fn_type: cosine
start_val: 5.0e-05
warmup_pct: 0.0
weight_decay: 0.0
components_optimizer:
betas:
- 0.9
- 0.999
grad_clip_norm: 0.01
lr_schedule:
final_val_frac: 0.1
fn_type: cosine
start_val: 1.5e-04
warmup_pct: 0.0
weight_decay: 0.0
decomposition_targets:
- C: 49152
module_pattern: model.layers.18.mlp.gate_proj
- C: 49152
module_pattern: model.layers.18.mlp.up_proj
- C: 49152
module_pattern: model.layers.18.mlp.down_proj
- C: 49152
module_pattern: model.layers.19.mlp.gate_proj
- C: 49152
module_pattern: model.layers.19.mlp.up_proj
- C: 49152
module_pattern: model.layers.19.mlp.down_proj
- C: 49152
module_pattern: model.layers.20.mlp.gate_proj
- C: 49152
module_pattern: model.layers.20.mlp.up_proj
- C: 49152
module_pattern: model.layers.20.mlp.down_proj
- C: 49152
module_pattern: model.layers.21.mlp.gate_proj
- C: 49152
module_pattern: model.layers.21.mlp.up_proj
- C: 49152
module_pattern: model.layers.21.mlp.down_proj
- C: 49152
module_pattern: model.layers.22.mlp.gate_proj
- C: 49152
module_pattern: model.layers.22.mlp.up_proj
- C: 49152
module_pattern: model.layers.22.mlp.down_proj
- C: 49152
module_pattern: model.layers.23.mlp.gate_proj
- C: 49152
module_pattern: model.layers.23.mlp.up_proj
- C: 49152
module_pattern: model.layers.23.mlp.down_proj
faithfulness_warmup_lr: 0.001
faithfulness_warmup_steps: 400
faithfulness_warmup_weight_decay: 0.0
identity_decomposition_targets: null
loss_metrics:
- beta: 0.2
coeff: 5.0e-06
gamma: 1.0
gamma_anneal_start_frac: 0.0
gamma_anneal_final_gamma: 0.1
gamma_anneal_end_frac: 1.0
type: SmoothL0ImportanceMinimalityLoss
- coeff: 1.0
n_samples: 1
routing:
type: uniform_k_subset
sites_per_chunk: 9
type: ChunkwiseSubsetReconLoss
- coeff: 0.5
n_samples: 1
n_warmup_steps: 2
optimizer:
beta1: 0.01
beta2: 0.99
eps: 1.0e-08
lr_schedule:
final_val_frac: 1.0
fn_type: constant
start_val: 0.01
warmup_pct: 0.025
type: adam
scope:
type: bsc
start_frac: 0.0
type: PersistentPGDReconLoss
- coeff: 1000000.0
type: FaithfulnessLoss
n_mask_samples: 1
sampling: continuous
seed: 0
sigmoid_type: leaky_hard
steps: 40000
tied_weights: null
use_delta_component: true
runtime:
remat_recon_forwards: true
autocast_bf16: true
device: cuda:0
dp: 64
target:
output_extract: logits
weights_dtype: bfloat16
spec:
kind: hf
model_class: transformers.LlamaForCausalLM
model_name: meta-llama/Llama-3.1-8B
wandb:
entity: null
project: param-decomp-llama
Loading
Loading