Skip to content

Commit 4541d10

Browse files
committed
Refactor activation normalization handling and remove ActivationNormalizer class, such that autoloading of activationnormalizer with load_pretrained works automatically.
This commit includes the following changes: - Removed the `ActivationNormalizer` class from `utils.py`, centralizing activation normalization logic within the `NormalizableMixin`. - Updated the `NormalizableMixin` to accept mean and standard deviation tensors directly, replacing the previous reliance on `ActivationNormalizer`. - Modified the constructors of `BatchTopKSAE`, `CrossCoder`, and their trainers to accept `activation_mean` and `activation_std` parameters instead of an `activation_normalizer`. - Adjusted normalization and denormalization methods to utilize the new mean and std tensors, ensuring that activations are processed correctly. - Cleaned up related code in `cache.py`, `dictionary.py`, and various trainer files to reflect these changes. These modifications enhance the clarity and maintainability of the code while ensuring proper handling of activation normalization across various components.
1 parent 68b3025 commit 4541d10

5 files changed

Lines changed: 116 additions & 139 deletions

File tree

dictionary_learning/cache.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
dtype_to_str,
1616
str_to_dtype,
1717
torch_to_numpy_dtype,
18-
ActivationNormalizer,
1918
)
2019

2120
if DEBUG:
@@ -303,9 +302,6 @@ def std(self):
303302
)
304303
return self._std
305304

306-
@property
307-
def normalizer(self):
308-
return ActivationNormalizer(self.mean, self.std)
309305

310306
@property
311307
def running_stats(self):
@@ -701,9 +697,6 @@ def std(self):
701697
(self.activation_cache_1.std, self.activation_cache_2.std), dim=0
702698
)
703699

704-
@property
705-
def normalizer(self):
706-
return ActivationNormalizer(self.mean, self.std)
707700

708701

709702
class ActivationCacheTuple:
@@ -732,7 +725,3 @@ def mean(self):
732725
@property
733726
def std(self):
734727
return th.stack([cache.std for cache in self.activation_caches], dim=0)
735-
736-
@property
737-
def normalizer(self):
738-
return ActivationNormalizer(self.mean, self.std)

0 commit comments

Comments
 (0)