Skip to content

Commit 1da9e13

Browse files
Oliver Clive-Griffinclaude
andcommitted
Widen RunBatch/ReconstructionLoss for structured outputs
Lets experiments package per-batch context (padding masks, labels, MSA aux features) into output dataclasses instead of smuggling them through tensor shapes. Surfaced while stress-testing the abstractions against ESM2, Carbon, and GPN-MSA bio models. - RunBatch: (model, batch) -> Tensor → -> Any - ReconstructionLoss args: (pred, target) → (output, target_output); types Any - OutputWithCache.output, MetricContext.target_out: Tensor → Any - (sum, n) return shape kept — earns its keep for variable-mask eval - Notes the tied-embedding gap in make_components (deferred) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent d07cc3a commit 1da9e13

18 files changed

Lines changed: 100 additions & 62 deletions

param_decomp/batch_and_loss_fns.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,38 @@
1111

1212

1313
class RunBatch(Protocol):
14-
"""Callable that runs one batch through `model` and returns the output tensor."""
14+
"""Callable that runs one batch through `model` and returns its output.
1515
16-
def __call__(self, model: nn.Module, batch: Any) -> Tensor: ...
16+
The output type is experiment-defined (`Any`) — typically a tensor of logits, but
17+
may be a dataclass / dict carrying additional fields (attention masks, hidden
18+
states, labels) that the experiment's `ReconstructionLoss` consumes. The same
19+
`RunBatch` is invoked on both the frozen target and the decomposed model, so the
20+
two `output` values it produces share a structure.
21+
"""
1722

23+
def __call__(self, model: nn.Module, batch: Any) -> Any: ...
1824

19-
class ReconstructionLoss(Protocol):
20-
"""Callable that compares `pred` against `target` and returns `(sum, n_elements)`.
2125

22-
The first entry is the unreduced sum of per-element losses; the second is the count
23-
it summed over. Callers reduce `sum / n_elements` to a mean as needed.
26+
class ReconstructionLoss(Protocol):
27+
"""Compare a decomposed-model `output` against the frozen-target `target_output`.
28+
29+
Both are whatever the experiment's `RunBatch` returns. The return pair
30+
`(sum, n_elements)` is the unreduced sum of per-element losses and the count it
31+
summed over (or sum-of-weights for weighted/masked losses); callers reduce
32+
`sum / n_elements` to a mean as needed.
33+
34+
Per-batch context the loss needs (padding masks, MLM-masked positions,
35+
per-channel weights, labels) rides on the `output` / `target_output` structure
36+
— experiments are responsible for packaging it inside `RunBatch`. Static aux
37+
state (e.g. a k-mer→nucleotide lookup table) lives in a closure / partial /
38+
`__call__`-bearing class — the Protocol stays minimal.
2439
"""
2540

26-
def __call__(self, pred: Tensor, target: Tensor) -> tuple[Float[Tensor, ""], int]: ...
41+
def __call__(
42+
self,
43+
output: Any,
44+
target_output: Any,
45+
) -> tuple[Float[Tensor, ""], int]: ...
2746

2847

2948
def move_batch_to_device(batch: Any, device: str | torch.device) -> Any:

param_decomp/component_model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ class OutputWithCache(NamedTuple):
2828
"""Forward output paired with per-module cached activations.
2929
3030
Cache keys are target-module paths (or `f"{path}_{kind}"` for component-acts entries);
31-
contents depend on the `cache_type` requested.
31+
contents depend on the `cache_type` requested. `output` is whatever `RunBatch`
32+
returns — typically a tensor but may be a dataclass / dict for experiments that
33+
package per-batch context (masks, labels) for `ReconstructionLoss`.
3234
"""
3335

34-
output: Tensor
36+
output: Any
3537
cache: dict[str, Tensor]
3638

3739

@@ -168,10 +170,10 @@ def __call__(
168170
batch: Any,
169171
mask_infos: dict[str, ComponentsMaskInfo] | None = None,
170172
cache_type: Literal["none"] = "none",
171-
) -> Tensor: ...
173+
) -> Any: ...
172174

173175
@override
174-
def __call__(self, *args: Any, **kwargs: Any) -> Tensor | OutputWithCache:
176+
def __call__(self, *args: Any, **kwargs: Any) -> Any | OutputWithCache:
175177
return super().__call__(*args, **kwargs)
176178

177179
@override
@@ -180,7 +182,7 @@ def forward(
180182
batch: Any,
181183
mask_infos: dict[str, ComponentsMaskInfo] | None = None,
182184
cache_type: Literal["component_acts", "input", "output", "none"] = "none",
183-
) -> Tensor | OutputWithCache:
185+
) -> Any | OutputWithCache:
184186
"""Run the target model with optional component replacement and/or caching.
185187
186188
With no extra args, this is just a forward pass through the frozen target model.
@@ -220,7 +222,7 @@ def forward(
220222
)
221223

222224
with self._attach_forward_hooks(hooks):
223-
out: Tensor = self._run_batch(self.target_model, batch)
225+
out: Any = self._run_batch(self.target_model, batch)
224226

225227
match cache_type:
226228
case "input" | "output" | "component_acts":

param_decomp/components.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,12 @@ def make_components(
274274
Dict keyed by the same submodule paths, mapping to a `Components` instance whose
275275
weights have been initialised but not yet trained.
276276
"""
277+
# NOTE: storage-tied weights (e.g. `tie_word_embeddings=True` on Llama/ESM/GPT-2,
278+
# where `embed_tokens.weight is lm_head.weight`) are not detected here — decomposing
279+
# both sides of a tie produces two independent `Components` instances that silently
280+
# learn the same target. Deferred: we don't currently decompose embeddings, so this
281+
# is dormant. Fix would be to detect shared `weight.data_ptr()` and either share one
282+
# `Components` instance or auto-add to `tied_weights`.
277283
out: dict[str, Components] = {}
278284
for path, C in module_to_c.items():
279285
target_module = target_model.get_submodule(path)

param_decomp/metrics/ci_masked_recon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class CIMaskedReconLossConfig(LossMetricConfig):
2020
def _ci_masked_recon_loss_update(
2121
model: ComponentModel,
2222
batch: Any,
23-
target_out: Tensor,
23+
target_out: Any,
2424
ci: dict[str, Float[Tensor, "... C"]],
2525
reconstruction_loss: ReconstructionLoss,
2626
) -> tuple[Float[Tensor, ""], int]:
@@ -32,7 +32,7 @@ def _ci_masked_recon_loss_update(
3232
def ci_masked_recon_loss(
3333
model: ComponentModel,
3434
batch: Any,
35-
target_out: Tensor,
35+
target_out: Any,
3636
ci: dict[str, Float[Tensor, "... C"]],
3737
reconstruction_loss: ReconstructionLoss,
3838
) -> Float[Tensor, ""]:

param_decomp/metrics/ci_masked_recon_layerwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class CIMaskedReconLayerwiseLossConfig(LossMetricConfig):
2121
def _ci_masked_recon_layerwise_loss_update(
2222
model: ComponentModel,
2323
batch: Any,
24-
target_out: Tensor,
24+
target_out: Any,
2525
ci: dict[str, Float[Tensor, "... C"]],
2626
reconstruction_loss: ReconstructionLoss,
2727
) -> tuple[Float[Tensor, ""], int]:
@@ -39,7 +39,7 @@ def _ci_masked_recon_layerwise_loss_update(
3939
def ci_masked_recon_layerwise_loss(
4040
model: ComponentModel,
4141
batch: Any,
42-
target_out: Tensor,
42+
target_out: Any,
4343
ci: dict[str, Float[Tensor, "... C"]],
4444
reconstruction_loss: ReconstructionLoss,
4545
) -> Float[Tensor, ""]:

param_decomp/metrics/ci_masked_recon_subset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class CIMaskedReconSubsetLossConfig(LossMetricConfig):
3030
def _ci_masked_recon_subset_loss_update(
3131
model: ComponentModel,
3232
batch: Any,
33-
target_out: Tensor,
33+
target_out: Any,
3434
ci: dict[str, Float[Tensor, "... C"]],
3535
router: Router,
3636
reconstruction_loss: ReconstructionLoss,
@@ -51,7 +51,7 @@ def _ci_masked_recon_subset_loss_update(
5151
def ci_masked_recon_subset_loss(
5252
model: ComponentModel,
5353
batch: Any,
54-
target_out: Tensor,
54+
target_out: Any,
5555
ci: dict[str, Float[Tensor, "... C"]],
5656
routing: SubsetRoutingType,
5757
reconstruction_loss: ReconstructionLoss,

param_decomp/metrics/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ class MetricContext:
2424

2525
model: ComponentModel
2626
batch: Any
27-
target_out: Tensor
27+
target_out: (
28+
Any # Whatever `RunBatch` returns — Tensor in simple cases, dataclass / dict otherwise.
29+
)
2830
pre_weight_acts: dict[str, Float[Tensor, "..."]]
2931
ci: CIOutputs
3032
weight_deltas: dict[str, Float[Tensor, "d_out d_in"]]

param_decomp/metrics/persistent_pgd_state.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Annotated, Any, Literal, override
1010

1111
import torch
12-
from jaxtyping import Float, Int
12+
from jaxtyping import Float
1313
from pydantic import Field, NonNegativeFloat, PositiveInt
1414
from torch import Tensor
1515
from torch.distributed import ReduceOp
@@ -311,8 +311,8 @@ def load_state_dict(self, state: dict[str, Any]) -> None:
311311
def warmup(
312312
self,
313313
model: ComponentModel,
314-
batch: Int[Tensor, "..."] | Float[Tensor, "..."],
315-
target_out: Float[Tensor, "... vocab"],
314+
batch: Any,
315+
target_out: Any,
316316
ci: dict[str, Float[Tensor, "... C"]],
317317
weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None,
318318
) -> None:
@@ -331,8 +331,8 @@ def warmup(
331331
def compute_recon_sum_and_n(
332332
self,
333333
model: ComponentModel,
334-
batch: Int[Tensor, "..."] | Float[Tensor, "..."],
335-
target_out: Float[Tensor, "... vocab"],
334+
batch: Any,
335+
target_out: Any,
336336
ci: dict[str, Float[Tensor, "... C"]],
337337
weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None,
338338
router: Router | None = None,
@@ -420,8 +420,8 @@ def _compute_ppgd_recon_loss(
420420
model: ComponentModel,
421421
ppgd_sources: PPGDSources,
422422
reconstruction_loss: ReconstructionLoss,
423-
batch: Int[Tensor, "..."] | Float[Tensor, "..."],
424-
target_out: Float[Tensor, "... vocab"],
423+
batch: Any,
424+
target_out: Any,
425425
ci: dict[str, Float[Tensor, "... C"]],
426426
weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None,
427427
routing_masks: RoutingMasks,
@@ -431,5 +431,5 @@ def _compute_ppgd_recon_loss(
431431

432432
mask_infos = get_ppgd_mask_infos(ci, weight_deltas, ppgd_sources, routing_masks, batch_dims)
433433
out = model(batch, mask_infos=mask_infos)
434-
loss, n_examples = reconstruction_loss(pred=out, target=target_out)
434+
loss, n_examples = reconstruction_loss(output=out, target_output=target_out)
435435
return loss, n_examples

param_decomp/metrics/pgd_masked_recon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def pgd_recon_loss(
2222
*,
2323
model: ComponentModel,
2424
batch: Any,
25-
target_out: Tensor,
25+
target_out: Any,
2626
ci: dict[str, Float[Tensor, "... C"]],
2727
weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None,
2828
pgd_config: PGDConfig,

param_decomp/metrics/pgd_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _forward_with_adv_sources(
129129
ci: dict[str, Float[Tensor, "... C"]],
130130
weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None,
131131
routing_masks: RoutingMasks,
132-
target_out: Tensor,
132+
target_out: Any,
133133
batch_dims: tuple[int, ...],
134134
reconstruction_loss: ReconstructionLoss,
135135
) -> tuple[Float[Tensor, ""], int]:
@@ -149,7 +149,7 @@ def pgd_masked_recon_loss_update(
149149
batch: Any,
150150
ci: dict[str, Float[Tensor, "... C"]],
151151
weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None,
152-
target_out: Tensor,
152+
target_out: Any,
153153
router: Router,
154154
pgd_config: PGDConfig,
155155
reconstruction_loss: ReconstructionLoss,

0 commit comments

Comments
 (0)