Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions param_decomp/CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ never silently diverge. Cite IDs (`S14`, `N1`, …) in commit messages and revie
## Architecture in one breath

`lm.py` defines `DecomposedModel` — ordered `sites` + `leading_axes` + five pure fns
(`clean_output`, `site_inputs`, `masked_output`, `weight_deltas`) plus a pluggable
(`clean_output`, `read_activations`, `masked_output`, `masked_site_outputs`,
`weight_deltas`) plus a pluggable
`recon_loss_fn` (default `kl_per_position`), flat site-name-keyed dicts at the boundary,
frozen pytree always a runtime arg (never a jit closure constant — an 8B target becomes a
multi-GB HLO constant). The activation waist is GENERIC `[*leading, d]` (masks/CI
Expand Down Expand Up @@ -82,13 +83,16 @@ renders it off the small on-host V/U + the probe CI as permutation source (cheap
gather), sharing `slow_eval.render_uv_figure` / `plot_uv_matrices` with the LM path.

**The toys (TMS, ResidMLP) live in the lab, not the core.** The core trainer carries ZERO
toy-specific code (CI-fn arches are the one allowed exception — see `ci_fn_mlp.py`). The
toy-specific code — the toy *targets* (`DecomposedModel`s, pretrain, identity-CI eval) are
all lab-side. CI-fn *architectures* are NOT toy-specific code: core owns every CI-fn arch
regardless of which experiments use it. The positionless MLPs and the sequence transformer
are peers in `ci_fn.py` (differing by domain, not status), not a toy carve-out. The
generic engine is `run.py::run_decomposition_training(pd, cadence, run, raw_cfg, lm, frozen,
ci_fn, data, remat_recon_forwards, sample_batch, eval_fn, eval_every, perf_tokens_per_step,
mesh)` — the ONE train loop every target runs through (init/restore/finetune/faith-warmup
via `_init_or_restore_state`, the recon-grid step factory, orbax checkpointing, schedules,
SIGTERM-save). It reads the pydantic `PDConfig` / `Cadence` (`param_decomp.configs`)
DIRECTLY — optimizers / loss metrics / faith warmup / seed / steps / sampling — so there is
DIRECTLY — optimizers / loss metrics / faith warmup / seed / steps — so there is
NO flattened mirror DC (the old `config.ExperimentConfig`); the run identity rides in
`config.RunInstance`, and the lab-built objects (`ci_fn` arch, `data`, the decomposed target)
pass alongside. A target injects exactly three seams: the data source
Expand All @@ -102,21 +106,22 @@ only `TargetConfig` / `LlamaSimpleMLPTargetConfig`). `BuiltRun.target` is typed
`config.TargetSites` protocol (just `.sites`), `BuiltRun.data` is `DataConfig | None` (None
for a toy run). The shared schema validation + run-identity / CI-fn-arch helpers are public
lab-side for the toys to reuse: `experiments.config.assert_canonical_algorithm_config` /
`run_instance` / `layerwise_mlp_ci_arch` / `toy_ci_arch`.
`run_instance` / `ci_arch`.

The TMS + ResidMLP targets now live under `param_decomp_lab/experiments/{tms,resid_mlp}/`
(`model.py` = the JAX `DecomposedModel` + frozen target + in-process pretrain + identity-CI
eval; `run.py` = the `pd-tms` / `pd-resid-mlp` CPU CLI that builds the `ExperimentConfig`
from the canonical schema and calls `run_decomposition_training`). They are positionless
(`leading_axes=()`) and use the layerwise per-site MLP CI fn. `ci_fn_mlp.py` (the second
CI-fn arch, the allowed exception) stays in the core: `LayerwiseMLPCIFn` (`fn_type=mlp`,
`expects_axes=()`, one independent MLP per site mapping `site_input [B,d_in] -> [B,C]`) plus
the new `GlobalMLPCIFn` (`fn_type=global_shared_mlp`, one shared MLP over all sites jointly,
concat/split in canonical site order). `run_state.init_train_state` dispatches CI-fn
construction on `cfg.ci_fn` (`CIArch` transformer / `MLPCIArch` layerwise / `GlobalMLPCIArch`
global) and uses replicated (not C-sharded) V/U + CI for the tiny toys; the core
`config.CIFnArch` admits all three and the lab `experiments.config.toy_ci_arch` builds the
layerwise / global arch from the toy ci_config (validated end-to-end on CPU via
(`leading_axes=()`) and use the MLP CI fns. All CI-fn architectures live together in
`ci_fn.py`: `LayerwiseMLPCIFn` (`expects_axes=()`, one independent MLP per site mapping
`site_input [B,d_in] -> [B,C]`), `GlobalMLPCIFn` (`expects_axes=()`, one shared MLP over all
sites jointly, concat/split in canonical site order), and the LM `ChunkwiseTransformerCIFn`
(`expects_axes=("sequence",)`, per-chunk transformers reading residual taps, stacked +
`filter_vmap`). `run_state.init_train_state` dispatches CI-fn construction on `cfg.ci_fn`
(`MLPCIArch` / `GlobalMLPCIArch` / `ChunkwiseTransformerCIArch`) and uses replicated (not
C-sharded) V/U + CI for the tiny toys; the core `config.CIFnArch` admits all three and the
lab `experiments.config.ci_arch` builds the layerwise / global arch from the toy
ci_config (validated end-to-end on CPU via
`pd-resid-mlp`). Harvest / slow-eval / export over the toys are NOT wired
(`experiments.lm.load_run.build_target` / `run_metadata` are LM-only).

Expand Down
2 changes: 1 addition & 1 deletion param_decomp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ python -m param_decomp.experiments.llama8b_real --real_weights --first_layer 20
## Design

- **Generic over vendored LMs.** The trainer sees only the `DecomposedModel` fn-table
(`lm.py`): ordered `sites`, `clean_output`, `site_inputs`, `masked_output`,
(`lm.py`): ordered `sites`, `clean_output`, `read_activations`, `masked_output`,
`masked_site_outputs` (the hidden-acts eval seam, SPEC S31), `weight_deltas` — all
pure, all taking the frozen pytree as a *runtime arg* (a frozen
8B target closed over as a jit constant bakes multi-GB weights into the HLO). Adding
Expand Down
91 changes: 64 additions & 27 deletions param_decomp/SPEC.md

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions param_decomp/TRANSITION.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ dead.
Promote what already exists — pure-functional, state-injected. There is no torch
counterpart to keep in sync.

- **`DecomposedModel`** (`lm.py`): ordered `sites` + `clean_output` / `site_inputs` /
`masked_output` / `weight_deltas` over `(frozen, vu)` pytrees.
- **`CIFn`** (`ci_fn.py`): `site_inputs -> (ci_lower, ci_upper)`. (SRP: model and CI are
- **`DecomposedModel`** (`lm.py`): ordered `sites` + `clean_output` / `read_activations` /
`masked_output` / `masked_site_outputs` / `weight_deltas` over `(frozen, vu)` pytrees.
- **`CIFn`** (`ci_fn.py`): `read_activations -> (ci_lower, ci_upper)`. (SRP: model and CI are
separate concerns; a run is the pair + frozen target.)
- **Dispatch lives only at the edges:** one `load(run) -> Decomposition` and one
`export(state) -> orbax`. Nothing else knows architecture/layout.
Expand Down
7 changes: 3 additions & 4 deletions param_decomp/adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
"""

from dataclasses import dataclass
from typing import Literal

import jax
import jax.numpy as jnp
from jax import random
from jaxtyping import Array, Float, PRNGKeyArray

from param_decomp.configs import AdamPGDConfig
from param_decomp.configs import AdamPGDConfig, MaskScopeLiteral, PGDInitStrategy
from param_decomp.lm import SiteSpec


Expand Down Expand Up @@ -54,8 +53,8 @@ def init_persistent_sources(

def init_fresh_pgd_sources(
sites: tuple[SiteSpec, ...],
init: Literal["random", "ones", "zeroes"],
scope: Literal["c", "bc", "bsc"],
init: PGDInitStrategy,
scope: MaskScopeLiteral,
leading: tuple[int, ...],
key: PRNGKeyArray,
) -> dict[str, Array]:
Expand Down
25 changes: 8 additions & 17 deletions param_decomp/attn_patterns_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@
from jax import random
from jaxtyping import Array, Float, PRNGKeyArray

from param_decomp.configs import SamplingType
from param_decomp.llama8b import FrozenAttn, Target
from param_decomp.llama_simple_mlp import SimpleMLPTarget
from param_decomp.lm import DecomposedModel
from param_decomp.lm import DecomposedModel, all_false_routes
from param_decomp.train import COMPUTE_DT, cast_floating
from vendored_jax.llama import apply_rope, repeat_kv, rope_cos_sin

Expand Down Expand Up @@ -123,10 +122,6 @@ class LayerKLReduction:
n_distributions: int


def _all_false_routes(site_names: tuple[str, ...], leading: tuple[int, ...]) -> dict[str, Array]:
return {s: jnp.zeros(leading, bool) for s in site_names}


def _pattern_kl(target_pattern: Array, masked_pattern: Array) -> Array:
"""`Σ target · (log target − log masked.clamp(1e-12))` in fp32 (torch
`F.kl_div(masked.clamp(1e-12).log(), target, reduction="sum")`)."""
Expand Down Expand Up @@ -164,7 +159,7 @@ def _clean_patterns(
frozen, components_bf16, residual,
{s: jnp.ones_like(ci_lower[s]) for s in site_names},
{s: jnp.zeros(leading, COMPUTE_DT) for s in site_names},
_all_false_routes(site_names, leading), site_names, False,
all_false_routes(site_names, leading), site_names, False,
) # fmt: skip
return {q: pattern_fn(clean_outputs[q], clean_outputs[k]) for q, k in layer_pairs}

Expand Down Expand Up @@ -206,10 +201,10 @@ def step(
residual: Float[Array, "*leading d"],
_key: PRNGKeyArray,
) -> tuple[dict[str, Array], dict[str, int]]:
site_inputs = lm.site_inputs(frozen, residual)
taps = lm.read_activations(frozen, residual, ci_fn.input_names)
components_bf16 = cast_floating(components, COMPUTE_DT)
ci_fn_bf16 = cast_floating(ci_fn, COMPUTE_DT)
ci_lower = ci_fn_bf16(site_inputs).lower
ci_lower = ci_fn_bf16(taps).lower

target_patterns = _clean_patterns(
lm, pattern_fn, layer_pairs, frozen, components_bf16, residual, ci_lower
Expand All @@ -227,7 +222,7 @@ def step(


def make_stochastic_attn_patterns_step(
lm: DecomposedModel, pattern_fn: AttnPatternFn, n_mask_samples: int, sampling: SamplingType
lm: DecomposedModel, pattern_fn: AttnPatternFn, n_mask_samples: int
) -> AttnPatternsStep:
"""Stochastic-mask attn-patterns step: `n_mask_samples` draws of `mask = ci + (1−ci)·s`
(with weight deltas), per-draw per-layer pattern KL summed. RNG via per-draw / per-site
Expand All @@ -245,10 +240,10 @@ def step(
residual: Float[Array, "*leading d"],
key: PRNGKeyArray,
) -> tuple[dict[str, Array], dict[str, int]]:
site_inputs = lm.site_inputs(frozen, residual)
taps = lm.read_activations(frozen, residual, ci_fn.input_names)
components_bf16 = cast_floating(components, COMPUTE_DT)
ci_fn_bf16 = cast_floating(ci_fn, COMPUTE_DT)
ci_lower = ci_fn_bf16(site_inputs).lower
ci_lower = ci_fn_bf16(taps).lower

target_patterns = _clean_patterns(
lm, pattern_fn, layer_pairs, frozen, components_bf16, residual, ci_lower
Expand All @@ -263,11 +258,7 @@ def step(
for site_idx, site in enumerate(site_names):
ci_site = ci_lower[site]
source_key = random.fold_in(mask_key, site_idx)
match sampling:
case "continuous":
source = random.uniform(source_key, ci_site.shape, COMPUTE_DT)
case _:
source = random.bernoulli(source_key, 0.5, ci_site.shape).astype(COMPUTE_DT)
source = random.uniform(source_key, ci_site.shape, COMPUTE_DT)
masks[site] = ci_site + (1.0 - ci_site) * source
delta_masks[site] = random.uniform(
random.fold_in(delta_key, site_idx), leading, COMPUTE_DT
Expand Down
Loading
Loading