Skip to content

refactor(ci): chunkwise CI fn — protocol, read_activations seam, scale-driven sharding#883

Merged
ocg-goodfire merged 6 commits into
feature/jaxfrom
feature/chunkwise-ci-fn
Jun 23, 2026
Merged

refactor(ci): chunkwise CI fn — protocol, read_activations seam, scale-driven sharding#883
ocg-goodfire merged 6 commits into
feature/jaxfrom
feature/chunkwise-ci-fn

Conversation

@ocg-goodfire

Copy link
Copy Markdown
Collaborator

Description

Replaces the single global-concat transformer CI fn with a CIFn protocol (dict[InputTap, Array] -> CI) and a chunkwise-transformer implementation, and untangles the concerns that were welded into the old arch.

  • Protocol + bundle. CIFn: input keyspace (opaque tap keys) independent of output keyspace (decomposition sites); output sites must partition the model's sites. CI = {logits, lower, upper}, with the two squashings (S5/S6) centralized in CI.from_logits (was copy-pasted per impl).
  • Model seam. DecomposedModel.site_inputsread_activations(frozen, resid, wanted) — a general clean-path accessor keyed by opaque taps. The target is the sole key interpreter: the LM serves resid.{i} residual taps (chunkwise input) and per-site matrix inputs (harvest); toys serve per-site inputs. Deletes three copies of the per-site frozen-forward math.
  • Chunkwise transformer. Sites partition into chunks; each chunk reads one or more taps (RMS-normed per tap → concatenated → input_dim), runs an independent pre-norm bidirectional-RoPE transformer, emits per-site CI. Per-chunk transformers are stacked along an n_chunks axis and run under one eqx.filter_vmap. input_dim is a generic linear width computed lab-side — no residual-dim/transformer concept leaks into core. layerwise/global are degenerate chunkings.
  • One module, peers. All CI-fn arches (chunkwise transformer + the positionless MLPs, formerly ci_fn_mlp.py) live in ci_fn.py; CIFnArch + build_ci_fn unify construction.
  • Sharding decoupled from arch. Scale-driven init_decomp_vu_placed / init_ci_fn_placed (shardable = mesh>1 and all C%n==0); the match ci_fn_arch in run_state is gone. A non-tileable CI matrix replicating is logged, not silent.
  • Flat config. Single-type CI-config union (drops the nested mode × fn_type unions + the legacy null-drop shim).
  • Lab chunk generator. blocks_per_chunk consecutive blocks → one residual tap in, block sites out.
  • SPEC. §4.6 / S4 / S5 / S27 updated to the chunkwise realization (same read_activations → CI contract).

Related Issue

N/A

Motivation and Context

The old CI fn was a single transformer over every site's input concatenated — expensive, with sharding welded to arch identity (MLP→replicated / transformer→sharded, an incidental correlation) and a nested mode × fn_type config. The chunkwise design is the "no-regrets" generalization: it interpolates between layerwise (one chunk per site) and global (one chunk), reads the residual stream entering each chunk rather than every per-site activation, and makes the input/output keyspaces independent (enabling future read-from-elsewhere variants). Along the way the genuinely-orthogonal axes are separated: architecture (in ci_fn.py), construction (build_ci_fn), and placement (scale-driven), so adding a CI fn no longer touches a dispatch match.

How Has This Been Tested?

  • make type: 0 errors (after make install-dev syncs the worktree venv).
  • Core tests: 117 passed / 7 skipped / 1 xfailed at 1 device; sharding-sensitive subset 14 passed at 4 sim-devices (exercises the new C-sharded path).
  • Equivalence suite: 12 passed, unchanged — confirms the CI-fn change is correctly isolated from the recon / faithfulness / source-term goldens.
  • Toys (TMS / ResidMLP): 45 passed, end-to-end through read_activations + the training step.
  • A real core bug surfaced only by running (a lazy-split StopIteration in per-chunk init, invisible to the type checker) and is fixed; per-site read_activations outputs are proven bit-identical to the deleted re-threads via the committed out::site_input stacked-parity goldens.

Does this PR introduce a breaking change?

Yes.

  • CI-fn config schema changed to the flat type form (type: chunkwise_transformer / layerwise_mlp / global_mlp); old mode/fn_type YAMLs must be migrated (in-repo configs are migrated here).
  • CI-fn numerics changed by design — the CI fn is now JAX-native and no longer bit-faithful to the torch oracle's global-concat CI fn. The rest of the spec stays oracle-grounded.
  • Checkpoints with the old CI fn are not loadable; the stacked-parity CI-trajectory golden is xfail pending a torch-oracle regen.

🤖 Generated with Claude Code

ocg-goodfire and others added 6 commits June 22, 2026 14:53
…ale-driven sharding

Replace the single global-concat transformer CI fn with a `CIFn` protocol and a
chunkwise-transformer implementation, separating the concerns that were tangled in the
old arch.

- `CIFn` protocol: `dict[InputTap, Array] -> CI`. The input keyspace (opaque tap keys)
  is independent of the output keyspace (decomposition sites); output sites must
  partition the model's sites. `CI = {logits, lower, upper}` with the two squashings
  (SPEC S5/S6) centralized in `CI.from_logits` (was copy-pasted per impl).
- Model seam: `DecomposedModel.site_inputs` -> `read_activations(frozen, resid, wanted)`,
  a general clean-path accessor keyed by opaque taps. The target is the sole key
  interpreter: LM serves `resid.{i}` residual taps (chunkwise input) AND per-site matrix
  inputs (harvest); toys serve per-site inputs. Deletes three re-implementations of the
  per-site frozen-forward math.
- `ChunkwiseTransformerCIFn`: sites partition into chunks; each chunk reads one or more
  taps (RMS-normed per tap, concatenated -> `input_dim`), runs an independent pre-norm
  bidirectional-RoPE transformer, emits per-site CI. Per-chunk transformers are stacked
  along an `n_chunks` axis and run under one `eqx.filter_vmap`. `input_dim` is a generic
  linear-input width computed lab-side — no residual-dim/transformer concept in core.
- All CI-fn architectures (chunkwise transformer + the positionless per-site / global
  MLPs, formerly `ci_fn_mlp.py`) live in `ci_fn.py` as peers; `CIFnArch` + `build_ci_fn`
  unify construction.
- Sharding decoupled from arch: scale-driven `init_decomp_vu_placed` / `init_ci_fn_placed`
  (`shardable = mesh>1 and all C%n==0`); the `match ci_fn_arch` in run_state is gone.
  Replication of a non-tileable CI matrix is logged, not silent.
- Flat single-`type` CI-config union (drops the nested `mode x fn_type` unions + the
  legacy null-drop shim).
- Lab chunk generator: `blocks_per_chunk` consecutive blocks -> one residual tap in,
  block sites out; `input_dim` resolved from each target's residual width.
- SPEC §4.6 / S4 / S5 / S27 updated to the chunkwise realization (same
  `read_activations -> CI` contract; the CI-fn arch is now JAX-native, no longer
  torch-oracle bit-faithful — the rest of the spec stays oracle-grounded).

Tests green at 1 and 4 sim-devices; equivalence suite unchanged; type-clean. The one
xfail (stacked-parity CI trajectory) needs a torch-oracle golden regen (CI numerics
changed by design).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
…lias

Review follow-ups on the chunkwise CI fn (PR #883):

- read_activations (`clean_activations`, both LM targets): collapse the
  capture-then-re-advance double-compute into one return-threaded sweep. Each block's
  internals are computed once (h1 / attn_y / post_attn / mlp_in / down_in), the
  requested resid taps + per-site inputs are captured, and the residual is advanced
  from those same intermediates — no second forward. Bit-identical to the prior values
  (stacked-parity `out::site_input` goldens unchanged); removes the ~2x block compute
  on layers that have a requested site key. This is the JAX return-threading answer to
  "torch would just use a forward hook" — no mutable state, capture by returning.

- train.py: inline `AnyCIFn` to `CIFn`. It was a vestigial `AnyCIFn = CIFn` alias left
  over from the protocol collapse (formerly a union of the concrete CI-fn classes); now
  pointless indirection.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
…omponents.py

The decomposition representation lived in `llama8b.py`, so every target — including the
positionless TMS/ResidMLP toys — imported `DecompVU` / `init_decomp_vu` from a file named
`llama8b`, and `_site_out` (the SPEC §4.1 decomposed linear `((x@V)*m)@U + (x@Δ)*d`) was
byte-identically copy-pasted across `llama8b`, `llama_simple_mlp`, and both toys (a fourth
importer pulled the private `_site_out`).

Move the three domain-neutral primitives into a new `param_decomp/components.py`
(`_site_out` -> public `site_out`, single definition), and repoint all 24 importers. No
target imports the decomposition representation from `llama8b` anymore (AST-verified), and
the §4.1 primitive is one copy.

Verbatim move — values unchanged: `make type` clean, 107 tests pass (the stacked-parity
`out::site_input` goldens confirm bit-identity).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
…i_arch

`toy_ci_arch` + `layerwise_mlp_ci_arch` (experiments/config.py) + `_ci_arch`
(lm/config.py) were three spellings of the same `CiConfig -> CIFnArch` step. Replace with
one `ci_arch(ci_config, resolve_chunkwise)` that matches on the config type: MLP / global
resolve trivially (list->tuple); the chunkwise branch calls the injected
`resolve_chunkwise` closure. The chunkwise resolver (`_resolve_chunkwise_ci_arch`, formerly
`_ci_arch`) stays in lm/config.py because it needs LM-target internals — passing it as a
closure avoids inverting the experiments/config <- lm/config layering. Toys pass
`resolve_chunkwise=None` (asserts loudly if a positionless toy ever requests chunkwise).

Added a note in ci_fn.py that `MLPCIArch`/`GlobalMLPCIArch` ≈ their pydantic configs by
design (uniform `CIFnArch` union for `build_ci_fn`), so the duplication reads as deliberate.

make type clean; toy tests 42 passed; chunkwise branch build-checked.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
… DRY, sampling→continuous

Low-risk cleanups from the core design review:
- dead code: CIBlock.head_dim, BatchLocation.shard_idx, log.set_format, sharding.replicate, the install_sigterm_flag private twin
- types: make_eval_step takes EvalPGDConfig (not a bare tuple); drop RuntimeConfig.device/autocast_bf16 (torch-isms the JAX core never reads) + a before-validator strip for stored configs
- DRY: batch_shard_leading + all_false_routes hoisted to shared homes; PGD init/scope Literals reused; strategy_has_delta inlined; _log_wandb_safe extracted
- sampling: removed the SamplingType config axis (only "continuous" is supported) + a before-validator that strips/asserts continuous on stored configs
- docs: site_inputs->read_activations across docstrings/docs; clean_output=lambda -> clean_suffix_logits; stale/narrativized comments dropped

make type clean; equivalence + stacked-parity goldens green; no numeric changes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
The 5 param_decomp/configs/*.yaml LM configs still carried the old
mode/fn_type global_shared_transformer ci_config — the chunkwise schema swap
migrated the toy YAMLs but missed these, leaving test_config + test_finetune_resume
red (14 tests). Migrate each to type: chunkwise_transformer (blocks_per_chunk=1
per-layer default; d_model/n_blocks/n_heads/mlp_hidden carried over; the old attn
max_len/rope_base have no field in the flat config and are dropped).

All 5 convert cleanly, incl. the model.-prefixed C49k config (handled by the existing
model. strip in _site_cs — not stale). The 9-layer config's chunk size is left at
blocks_per_chunk=1; the intended granularity (1/3/9) is a modeling call.

Full param_decomp/tests/ suite green (222 passed, 7 skipped, 1 xfailed); make type clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
@ocg-goodfire ocg-goodfire force-pushed the feature/chunkwise-ci-fn branch from 4994726 to 13b834f Compare June 22, 2026 16:52
@ocg-goodfire ocg-goodfire merged commit 13b834f into feature/jax Jun 23, 2026
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant