Skip to content

Commit b2363ed

Browse files
committed
fmt
1 parent 658952d commit b2363ed

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

dictionary_learning/cache.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,9 @@ def collect(
653653
last_submodule.output.stop()
654654

655655
for i in range(len(submodules)):
656-
activation_cache[i][-1] = (
657-
activation_cache[i][-1][store_mask.reshape(-1).bool()]
658-
.cpu()
659-
) # remove padding tokens
656+
activation_cache[i][-1] = activation_cache[i][-1][
657+
store_mask.reshape(-1).bool()
658+
].cpu() # remove padding tokens
660659
running_stats[i].update(activation_cache[i][-1].view(-1, d_model))
661660
if dtype is not None:
662661
activation_cache[i][-1] = activation_cache[i][-1].to(dtype)

0 commit comments

Comments
 (0)