1717from .utils import set_decoder_norm_to_unit_norm , ActivationNormalizer
1818
1919
20- class NormalizableMixin (ABC ):
20+ class NormalizableMixin (nn . Module ):
2121 """
2222 Mixin class providing activation normalization functionality.
2323
@@ -34,6 +34,7 @@ def __init__(self, activation_normalizer: ActivationNormalizer | None = None):
3434 activation_normalizer: Optional normalizer for activations. If None,
3535 normalization is a no-op.
3636 """
37+ super ().__init__ ()
3738 self .activation_normalizer = activation_normalizer
3839 if self .activation_normalizer is not None :
3940 self .activation_normalizer .to (self .device )
@@ -400,7 +401,7 @@ def from_pretrained(
400401 return autoencoder .to (dtype = dtype , device = device )
401402
402403
403- class BatchTopKSAE (Dictionary , nn . Module , NormalizableMixin ):
404+ class BatchTopKSAE (NormalizableMixin , Dictionary ):
404405 """
405406 Batch Top-K Sparse Autoencoder implementation.
406407
@@ -943,7 +944,7 @@ def __repr__(self) -> str:
943944 return self .name
944945
945946
946- class CrossCoder (Dictionary , nn . Module , NormalizableMixin ):
947+ class CrossCoder (Dictionary , NormalizableMixin ):
947948 """
948949 A crosscoder sparse autoencoder for multi-layer activation processing.
949950
0 commit comments