refactor(ci): chunkwise CI fn — protocol, read_activations seam, scale-driven sharding#883
Merged
Merged
Conversation
…ale-driven sharding
Replace the single global-concat transformer CI fn with a `CIFn` protocol and a
chunkwise-transformer implementation, separating the concerns that were tangled in the
old arch.
- `CIFn` protocol: `dict[InputTap, Array] -> CI`. The input keyspace (opaque tap keys)
is independent of the output keyspace (decomposition sites); output sites must
partition the model's sites. `CI = {logits, lower, upper}` with the two squashings
(SPEC S5/S6) centralized in `CI.from_logits` (was copy-pasted per impl).
- Model seam: `DecomposedModel.site_inputs` -> `read_activations(frozen, resid, wanted)`,
a general clean-path accessor keyed by opaque taps. The target is the sole key
interpreter: LM serves `resid.{i}` residual taps (chunkwise input) AND per-site matrix
inputs (harvest); toys serve per-site inputs. Deletes three re-implementations of the
per-site frozen-forward math.
- `ChunkwiseTransformerCIFn`: sites partition into chunks; each chunk reads one or more
taps (RMS-normed per tap, concatenated -> `input_dim`), runs an independent pre-norm
bidirectional-RoPE transformer, emits per-site CI. Per-chunk transformers are stacked
along an `n_chunks` axis and run under one `eqx.filter_vmap`. `input_dim` is a generic
linear-input width computed lab-side — no residual-dim/transformer concept in core.
- All CI-fn architectures (chunkwise transformer + the positionless per-site / global
MLPs, formerly `ci_fn_mlp.py`) live in `ci_fn.py` as peers; `CIFnArch` + `build_ci_fn`
unify construction.
- Sharding decoupled from arch: scale-driven `init_decomp_vu_placed` / `init_ci_fn_placed`
(`shardable = mesh>1 and all C%n==0`); the `match ci_fn_arch` in run_state is gone.
Replication of a non-tileable CI matrix is logged, not silent.
- Flat single-`type` CI-config union (drops the nested `mode x fn_type` unions + the
legacy null-drop shim).
- Lab chunk generator: `blocks_per_chunk` consecutive blocks -> one residual tap in,
block sites out; `input_dim` resolved from each target's residual width.
- SPEC §4.6 / S4 / S5 / S27 updated to the chunkwise realization (same
`read_activations -> CI` contract; the CI-fn arch is now JAX-native, no longer
torch-oracle bit-faithful — the rest of the spec stays oracle-grounded).
Tests green at 1 and 4 sim-devices; equivalence suite unchanged; type-clean. The one
xfail (stacked-parity CI trajectory) needs a torch-oracle golden regen (CI numerics
changed by design).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
…lias Review follow-ups on the chunkwise CI fn (PR #883): - read_activations (`clean_activations`, both LM targets): collapse the capture-then-re-advance double-compute into one return-threaded sweep. Each block's internals are computed once (h1 / attn_y / post_attn / mlp_in / down_in), the requested resid taps + per-site inputs are captured, and the residual is advanced from those same intermediates — no second forward. Bit-identical to the prior values (stacked-parity `out::site_input` goldens unchanged); removes the ~2x block compute on layers that have a requested site key. This is the JAX return-threading answer to "torch would just use a forward hook" — no mutable state, capture by returning. - train.py: inline `AnyCIFn` to `CIFn`. It was a vestigial `AnyCIFn = CIFn` alias left over from the protocol collapse (formerly a union of the concrete CI-fn classes); now pointless indirection. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
…omponents.py The decomposition representation lived in `llama8b.py`, so every target — including the positionless TMS/ResidMLP toys — imported `DecompVU` / `init_decomp_vu` from a file named `llama8b`, and `_site_out` (the SPEC §4.1 decomposed linear `((x@V)*m)@U + (x@Δ)*d`) was byte-identically copy-pasted across `llama8b`, `llama_simple_mlp`, and both toys (a fourth importer pulled the private `_site_out`). Move the three domain-neutral primitives into a new `param_decomp/components.py` (`_site_out` -> public `site_out`, single definition), and repoint all 24 importers. No target imports the decomposition representation from `llama8b` anymore (AST-verified), and the §4.1 primitive is one copy. Verbatim move — values unchanged: `make type` clean, 107 tests pass (the stacked-parity `out::site_input` goldens confirm bit-identity). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
…i_arch `toy_ci_arch` + `layerwise_mlp_ci_arch` (experiments/config.py) + `_ci_arch` (lm/config.py) were three spellings of the same `CiConfig -> CIFnArch` step. Replace with one `ci_arch(ci_config, resolve_chunkwise)` that matches on the config type: MLP / global resolve trivially (list->tuple); the chunkwise branch calls the injected `resolve_chunkwise` closure. The chunkwise resolver (`_resolve_chunkwise_ci_arch`, formerly `_ci_arch`) stays in lm/config.py because it needs LM-target internals — passing it as a closure avoids inverting the experiments/config <- lm/config layering. Toys pass `resolve_chunkwise=None` (asserts loudly if a positionless toy ever requests chunkwise). Added a note in ci_fn.py that `MLPCIArch`/`GlobalMLPCIArch` ≈ their pydantic configs by design (uniform `CIFnArch` union for `build_ci_fn`), so the duplication reads as deliberate. make type clean; toy tests 42 passed; chunkwise branch build-checked. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
… DRY, sampling→continuous Low-risk cleanups from the core design review: - dead code: CIBlock.head_dim, BatchLocation.shard_idx, log.set_format, sharding.replicate, the install_sigterm_flag private twin - types: make_eval_step takes EvalPGDConfig (not a bare tuple); drop RuntimeConfig.device/autocast_bf16 (torch-isms the JAX core never reads) + a before-validator strip for stored configs - DRY: batch_shard_leading + all_false_routes hoisted to shared homes; PGD init/scope Literals reused; strategy_has_delta inlined; _log_wandb_safe extracted - sampling: removed the SamplingType config axis (only "continuous" is supported) + a before-validator that strips/asserts continuous on stored configs - docs: site_inputs->read_activations across docstrings/docs; clean_output=lambda -> clean_suffix_logits; stale/narrativized comments dropped make type clean; equivalence + stacked-parity goldens green; no numeric changes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
The 5 param_decomp/configs/*.yaml LM configs still carried the old mode/fn_type global_shared_transformer ci_config — the chunkwise schema swap migrated the toy YAMLs but missed these, leaving test_config + test_finetune_resume red (14 tests). Migrate each to type: chunkwise_transformer (blocks_per_chunk=1 per-layer default; d_model/n_blocks/n_heads/mlp_hidden carried over; the old attn max_len/rope_base have no field in the flat config and are dropped). All 5 convert cleanly, incl. the model.-prefixed C49k config (handled by the existing model. strip in _site_cs — not stale). The 9-layer config's chunk size is left at blocks_per_chunk=1; the intended granularity (1/3/9) is a modeling call. Full param_decomp/tests/ suite green (222 passed, 7 skipped, 1 xfailed); make type clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01LkSohqSpa5upFTjYzkuVRe
4994726 to
13b834f
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Replaces the single global-concat transformer CI fn with a
CIFnprotocol (dict[InputTap, Array] -> CI) and a chunkwise-transformer implementation, and untangles the concerns that were welded into the old arch.CIFn: input keyspace (opaque tap keys) independent of output keyspace (decomposition sites); output sites must partition the model's sites.CI = {logits, lower, upper}, with the two squashings (S5/S6) centralized inCI.from_logits(was copy-pasted per impl).DecomposedModel.site_inputs→read_activations(frozen, resid, wanted)— a general clean-path accessor keyed by opaque taps. The target is the sole key interpreter: the LM servesresid.{i}residual taps (chunkwise input) and per-site matrix inputs (harvest); toys serve per-site inputs. Deletes three copies of the per-site frozen-forward math.input_dim), runs an independent pre-norm bidirectional-RoPE transformer, emits per-site CI. Per-chunk transformers are stacked along ann_chunksaxis and run under oneeqx.filter_vmap.input_dimis a generic linear width computed lab-side — no residual-dim/transformer concept leaks into core. layerwise/global are degenerate chunkings.ci_fn_mlp.py) live inci_fn.py;CIFnArch+build_ci_fnunify construction.init_decomp_vu_placed/init_ci_fn_placed(shardable = mesh>1 and all C%n==0); thematch ci_fn_archinrun_stateis gone. A non-tileable CI matrix replicating is logged, not silent.typeCI-config union (drops the nestedmode × fn_typeunions + the legacy null-drop shim).blocks_per_chunkconsecutive blocks → one residual tap in, block sites out.read_activations → CIcontract).Related Issue
N/A
Motivation and Context
The old CI fn was a single transformer over every site's input concatenated — expensive, with sharding welded to arch identity (MLP→replicated / transformer→sharded, an incidental correlation) and a nested
mode × fn_typeconfig. The chunkwise design is the "no-regrets" generalization: it interpolates between layerwise (one chunk per site) and global (one chunk), reads the residual stream entering each chunk rather than every per-site activation, and makes the input/output keyspaces independent (enabling future read-from-elsewhere variants). Along the way the genuinely-orthogonal axes are separated: architecture (inci_fn.py), construction (build_ci_fn), and placement (scale-driven), so adding a CI fn no longer touches a dispatchmatch.How Has This Been Tested?
make type: 0 errors (aftermake install-devsyncs the worktree venv).read_activations+ the training step.splitStopIterationin per-chunk init, invisible to the type checker) and is fixed; per-siteread_activationsoutputs are proven bit-identical to the deleted re-threads via the committedout::site_inputstacked-parity goldens.Does this PR introduce a breaking change?
Yes.
typeform (type: chunkwise_transformer/layerwise_mlp/global_mlp); oldmode/fn_typeYAMLs must be migrated (in-repo configs are migrated here).xfailpending atorch-oracleregen.🤖 Generated with Claude Code