File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -503,15 +503,13 @@ def collect(
503503 ), "Shuffling shards and storing tokens is not supported yet"
504504
505505 # Check if we need to store sequence ranges
506- has_bos_token = model .tokenizer .bos_token_id is not None
506+ has_bos_token = model .tokenizer .bos_token is not None
507507 store_sequence_ranges = (
508508 store_tokens and
509509 not shuffle_shards and
510510 not has_bos_token
511511 )
512- if store_sequence_ranges :
513- print ("No BOS token found. Will store sequence ranges." )
514-
512+
515513 dataloader = DataLoader (data , batch_size = batch_size , num_workers = num_workers )
516514
517515 activation_cache = [[] for _ in submodules ]
Original file line number Diff line number Diff line change @@ -173,7 +173,7 @@ def loss(
173173 if step > self .threshold_start_step :
174174 self .update_threshold (f )
175175
176- x_hat = self .ae .decode (f , denormalize_activations = normalize_activations )
176+ x_hat = self .ae .decode (f , denormalize_activations = False )
177177
178178 e = x - x_hat
179179
Original file line number Diff line number Diff line change 1111import wandb
1212from typing import List , Optional
1313
14+ from .trainers .batch_top_k import BatchTopKTrainer
1415from .trainers .crosscoder import CrossCoderTrainer , BatchTopKCrossCoderTrainer
1516
1617
@@ -300,7 +301,7 @@ def trainSAE(
300301 use_threshold = False ,
301302 epoch_idx_per_step = epoch_idx_per_step ,
302303 )
303- if isinstance (trainer , BatchTopKCrossCoderTrainer ):
304+ if isinstance (trainer , BatchTopKCrossCoderTrainer ) or isinstance ( trainer , BatchTopKTrainer ) :
304305 log_stats (
305306 trainer ,
306307 step ,
You can’t perform that action at this time.
0 commit comments