Skip to content

Commit 59abc88

Browse files
committed
mask out bos activations
1 parent c7b2527 commit 59abc88

2 files changed

Lines changed: 11 additions & 20 deletions

File tree

dictionary_learning/buffer.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,9 @@ def __init__(self,
5656
self.refresh_batch_size = refresh_batch_size
5757
self.out_batch_size = out_batch_size
5858
self.device = device
59-
self.remove_bos = remove_bos
59+
self.remove_bos = remove_bos and (self.model.tokenizer.bos_token_id is not None)
6060
self.add_special_tokens = add_special_tokens
6161

62-
print(self.model.tokenizer.padding_side)
63-
64-
if self.remove_bos:
65-
assert self.model.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)"
66-
6762
def __iter__(self):
6863
return self
6964

@@ -138,15 +133,15 @@ def refresh(self):
138133
input = self.model.inputs.save()
139134

140135
self.submodule.output.stop()
141-
attn_mask = input.value[1]["attention_mask"]
136+
137+
mask = (input.value[1]["attention_mask"] != 0)
142138
hidden_states = hidden_states.value
143139
if isinstance(hidden_states, tuple):
144140
hidden_states = hidden_states[0]
145141
if self.remove_bos:
146-
assert self.model.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)"
147-
hidden_states = hidden_states[:, 1:, :]
148-
attn_mask = attn_mask[:, 1:]
149-
hidden_states = hidden_states[attn_mask != 0]
142+
bos_mask = (input.value[1]["input_ids"] == self.model.tokenizer.bos_token_id)
143+
mask = mask & ~bos_mask
144+
hidden_states = hidden_states[mask]
150145

151146
remaining_space = self.activation_buffer_size - current_idx
152147
assert remaining_space > 0

dictionary_learning/pytorch_buffer.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,12 @@ def __init__(
117117
self.refresh_batch_size = refresh_batch_size
118118
self.out_batch_size = out_batch_size
119119
self.device = device
120-
self.remove_bos = remove_bos
121120
self.add_special_tokens = add_special_tokens
122121
self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path)
122+
self.remove_bos = remove_bos and (self.tokenizer.bos_token_id is not None)
123123

124124
if not self.tokenizer.pad_token:
125125
self.tokenizer.pad_token = self.tokenizer.eos_token
126-
127-
if self.remove_bos:
128-
assert self.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)"
129126

130127
def __iter__(self):
131128
return self
@@ -195,12 +192,11 @@ def refresh(self):
195192
with t.no_grad():
196193
input = self.tokenized_batch()
197194
hidden_states = collect_activations(self.model, self.submodule, input)
198-
attn_mask = input["attention_mask"]
195+
mask = (input["attention_mask"] != 0)
199196
if self.remove_bos:
200-
assert self.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)"
201-
hidden_states = hidden_states[:, 1:, :]
202-
attn_mask = attn_mask[:, 1:]
203-
hidden_states = hidden_states[attn_mask != 0]
197+
bos_mask = (input["input_ids"] == self.tokenizer.bos_token_id)
198+
mask = mask & ~bos_mask
199+
hidden_states = hidden_states[mask]
204200

205201
remaining_space = self.activation_buffer_size - current_idx
206202
assert remaining_space > 0

0 commit comments

Comments
 (0)