Skip to content

Commit daec121

Browse files
committed
refactor reporter training
1 parent 78549f5 commit daec121

1 file changed

Lines changed: 43 additions & 22 deletions

File tree

elk/training/train.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,16 @@
1717
from ..training.supervised import train_supervised
1818
from ..utils.typing import assert_type
1919
from .ccs_reporter import CcsConfig, CcsReporter
20-
from .common import FitterConfig
20+
from .common import FitterConfig, Reporter
2121
from .eigen_reporter import EigenFitter, EigenFitterConfig
2222

2323

24+
@dataclass
25+
class ReporterTrainResult:
26+
reporter: CcsReporter | Reporter
27+
train_loss: float | None
28+
29+
2430
@dataclass
2531
class Elicit(Run):
2632
"""Full specification of a reporter training run."""
@@ -69,22 +75,11 @@ def make_eval(self, model, eval_dataset):
6975
disable_cache=self.disable_cache,
7076
)
7177

72-
def apply_to_layer(
73-
self,
74-
layer: int,
75-
devices: list[str],
76-
world_size: int,
77-
probe_per_prompt: bool,
78-
) -> dict[str, pd.DataFrame]:
79-
"""Train a single reporter on a single layer."""
80-
81-
self.make_reproducible(seed=self.net.seed + layer)
82-
device = self.get_device(devices, world_size)
83-
78+
# Create a separate function to handle the reporter training.
79+
def train_reporter(self, device, layer, out_dir) -> ReporterTrainResult:
8480
train_dict = self.prepare_data(device, layer, "train")
85-
val_dict = self.prepare_data(device, layer, "val")
8681

87-
(first_train_h, train_gt, _), *rest = train_dict.values()
82+
(first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove?
8883
(_, v, k, d) = first_train_h.shape
8984
if not all(other_h.shape[-1] == d for other_h, _, _ in rest):
9085
raise ValueError("All datasets must have the same hidden state size")
@@ -96,16 +91,12 @@ def apply_to_layer(
9691
if not all(other_h.shape[-2] == k for other_h, _, _ in rest):
9792
raise ValueError("All datasets must have the same number of classes")
9893

99-
reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
10094
train_loss = None
101-
10295
if isinstance(self.net, CcsConfig):
10396
assert len(train_dict) == 1, "CCS only supports single-task training"
104-
10597
reporter = CcsReporter(self.net, d, device=device, num_variants=v)
10698
train_loss = reporter.fit(first_train_h)
10799

108-
(_, v, k, _) = first_train_h.shape
109100
reporter.platt_scale(
110101
to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(),
111102
rearrange(first_train_h, "n v k d -> (n v k) d"),
@@ -137,20 +128,50 @@ def apply_to_layer(
137128
raise ValueError(f"Unknown reporter config type: {type(self.net)}")
138129

139130
# Save reporter checkpoint to disk
140-
torch.save(reporter, reporter_dir / f"layer_{layer}.pt")
131+
torch.save(reporter, out_dir / f"layer_{layer}.pt")
141132

142-
# Fit supervised logistic regression model
133+
return ReporterTrainResult(reporter, train_loss)
134+
135+
def train_lr_model(self, train_dict, device, layer, out_dir):
143136
if self.supervised != "none":
144137
lr_models = train_supervised(
145138
train_dict,
146139
device=device,
147140
mode=self.supervised,
148141
)
149-
with open(lr_dir / f"layer_{layer}.pt", "wb") as file:
142+
with open(out_dir / f"layer_{layer}.pt", "wb") as file:
150143
torch.save(lr_models, file)
151144
else:
152145
lr_models = []
153146

147+
return lr_models
148+
149+
def apply_to_layer(
150+
self,
151+
layer: int,
152+
devices: list[str],
153+
world_size: int,
154+
probe_per_prompt: bool,
155+
) -> dict[str, pd.DataFrame]:
156+
"""Train a single reporter on a single layer."""
157+
158+
self.make_reproducible(seed=self.net.seed + layer)
159+
device = self.get_device(devices, world_size)
160+
161+
train_dict = self.prepare_data(device, layer, "train")
162+
val_dict = self.prepare_data(device, layer, "val")
163+
164+
(first_train_h, train_gt, _), *rest = train_dict.values()
165+
(_, v, k, d) = first_train_h.shape
166+
167+
reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
168+
169+
reporter_train_result = self.train_reporter(device, layer, reporter_dir)
170+
reporter = reporter_train_result.reporter
171+
train_loss = reporter_train_result.train_loss
172+
173+
lr_models = self.train_lr_model(train_dict, device, layer, lr_dir)
174+
154175
row_bufs = defaultdict(list)
155176
for ds_name in val_dict:
156177
val_h, val_gt, val_lm_preds = val_dict[ds_name]

0 commit comments

Comments
 (0)