We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 658952d commit b2363edCopy full SHA for b2363ed
1 file changed
dictionary_learning/cache.py
@@ -653,10 +653,9 @@ def collect(
653
last_submodule.output.stop()
654
655
for i in range(len(submodules)):
656
- activation_cache[i][-1] = (
657
- activation_cache[i][-1][store_mask.reshape(-1).bool()]
658
- .cpu()
659
- ) # remove padding tokens
+ activation_cache[i][-1] = activation_cache[i][-1][
+ store_mask.reshape(-1).bool()
+ ].cpu() # remove padding tokens
660
running_stats[i].update(activation_cache[i][-1].view(-1, d_model))
661
if dtype is not None:
662
activation_cache[i][-1] = activation_cache[i][-1].to(dtype)
0 commit comments