@@ -21,6 +21,7 @@ def __init__(
2121 layer : int ,
2222 lm_name : str ,
2323 dict_class : type = BatchTopKSAE ,
24+ pretrained_ae : Optional [BatchTopKSAE ] = None ,
2425 lr : Optional [float ] = None ,
2526 auxk_alpha : float = 1 / 32 ,
2627 warmup_steps : int = 1000 ,
@@ -33,7 +34,7 @@ def __init__(
3334 activation_mean : Optional [t .Tensor ] = None ,
3435 activation_std : Optional [t .Tensor ] = None ,
3536 target_rms : float = 1.0 ,
36- encoder_init_norm : str = 1.0 ,
37+ encoder_init_norm : float = 1.0 ,
3738 ):
3839 super ().__init__ (seed )
3940 assert layer is not None and lm_name is not None
@@ -51,15 +52,18 @@ def __init__(
5152 t .manual_seed (seed )
5253 t .cuda .manual_seed_all (seed )
5354
54- self .ae = dict_class (
55- activation_dim ,
56- dict_size ,
57- k ,
58- activation_mean = activation_mean ,
59- activation_std = activation_std ,
60- target_rms = target_rms ,
61- encoder_init_norm = encoder_init_norm ,
62- )
55+ if pretrained_ae is None :
56+ self .ae = dict_class (
57+ activation_dim ,
58+ dict_size ,
59+ k ,
60+ activation_mean = activation_mean ,
61+ activation_std = activation_std ,
62+ target_rms = target_rms ,
63+ encoder_init_norm = encoder_init_norm ,
64+ )
65+ else :
66+ self .ae = pretrained_ae
6367
6468 if device is None :
6569 self .device = "cuda" if t .cuda .is_available () else "cpu"
0 commit comments