1717from ..training .supervised import train_supervised
1818from ..utils .typing import assert_type
1919from .ccs_reporter import CcsConfig , CcsReporter
20- from .common import FitterConfig
20+ from .common import FitterConfig , Reporter
2121from .eigen_reporter import EigenFitter , EigenFitterConfig
2222
2323
24+ @dataclass
25+ class ReporterTrainResult :
26+ reporter : CcsReporter | Reporter
27+ train_loss : float | None
28+
29+
2430@dataclass
2531class Elicit (Run ):
2632 """Full specification of a reporter training run."""
@@ -69,22 +75,11 @@ def make_eval(self, model, eval_dataset):
6975 disable_cache = self .disable_cache ,
7076 )
7177
72- def apply_to_layer (
73- self ,
74- layer : int ,
75- devices : list [str ],
76- world_size : int ,
77- probe_per_prompt : bool ,
78- ) -> dict [str , pd .DataFrame ]:
79- """Train a single reporter on a single layer."""
80-
81- self .make_reproducible (seed = self .net .seed + layer )
82- device = self .get_device (devices , world_size )
83-
78+ # Create a separate function to handle the reporter training.
79+ def train_reporter (self , device , layer , out_dir ) -> ReporterTrainResult :
8480 train_dict = self .prepare_data (device , layer , "train" )
85- val_dict = self .prepare_data (device , layer , "val" )
8681
87- (first_train_h , train_gt , _ ), * rest = train_dict .values ()
82+ (first_train_h , train_gt , _ ), * rest = train_dict .values () # TODO can remove?
8883 (_ , v , k , d ) = first_train_h .shape
8984 if not all (other_h .shape [- 1 ] == d for other_h , _ , _ in rest ):
9085 raise ValueError ("All datasets must have the same hidden state size" )
@@ -96,16 +91,12 @@ def apply_to_layer(
9691 if not all (other_h .shape [- 2 ] == k for other_h , _ , _ in rest ):
9792 raise ValueError ("All datasets must have the same number of classes" )
9893
99- reporter_dir , lr_dir = self .create_models_dir (assert_type (Path , self .out_dir ))
10094 train_loss = None
101-
10295 if isinstance (self .net , CcsConfig ):
10396 assert len (train_dict ) == 1 , "CCS only supports single-task training"
104-
10597 reporter = CcsReporter (self .net , d , device = device , num_variants = v )
10698 train_loss = reporter .fit (first_train_h )
10799
108- (_ , v , k , _ ) = first_train_h .shape
109100 reporter .platt_scale (
110101 to_one_hot (repeat (train_gt , "n -> (n v)" , v = v ), k ).flatten (),
111102 rearrange (first_train_h , "n v k d -> (n v k) d" ),
@@ -137,20 +128,50 @@ def apply_to_layer(
137128 raise ValueError (f"Unknown reporter config type: { type (self .net )} " )
138129
139130 # Save reporter checkpoint to disk
140- torch .save (reporter , reporter_dir / f"layer_{ layer } .pt" )
131+ torch .save (reporter , out_dir / f"layer_{ layer } .pt" )
141132
142- # Fit supervised logistic regression model
133+ return ReporterTrainResult (reporter , train_loss )
134+
135+ def train_lr_model (self , train_dict , device , layer , out_dir ):
143136 if self .supervised != "none" :
144137 lr_models = train_supervised (
145138 train_dict ,
146139 device = device ,
147140 mode = self .supervised ,
148141 )
149- with open (lr_dir / f"layer_{ layer } .pt" , "wb" ) as file :
142+ with open (out_dir / f"layer_{ layer } .pt" , "wb" ) as file :
150143 torch .save (lr_models , file )
151144 else :
152145 lr_models = []
153146
147+ return lr_models
148+
149+ def apply_to_layer (
150+ self ,
151+ layer : int ,
152+ devices : list [str ],
153+ world_size : int ,
154+ probe_per_prompt : bool ,
155+ ) -> dict [str , pd .DataFrame ]:
156+ """Train a single reporter on a single layer."""
157+
158+ self .make_reproducible (seed = self .net .seed + layer )
159+ device = self .get_device (devices , world_size )
160+
161+ train_dict = self .prepare_data (device , layer , "train" )
162+ val_dict = self .prepare_data (device , layer , "val" )
163+
164+ (first_train_h , train_gt , _ ), * rest = train_dict .values ()
165+ (_ , v , k , d ) = first_train_h .shape
166+
167+ reporter_dir , lr_dir = self .create_models_dir (assert_type (Path , self .out_dir ))
168+
169+ reporter_train_result = self .train_reporter (device , layer , reporter_dir )
170+ reporter = reporter_train_result .reporter
171+ train_loss = reporter_train_result .train_loss
172+
173+ lr_models = self .train_lr_model (train_dict , device , layer , lr_dir )
174+
154175 row_bufs = defaultdict (list )
155176 for ds_name in val_dict :
156177 val_h , val_gt , val_lm_preds = val_dict [ds_name ]
0 commit comments