Skip to content

Commit c7b2527

Browse files
committed
assert right padding for remove_bos logic
1 parent d639166 commit c7b2527

2 files changed

Lines changed: 10 additions & 0 deletions

File tree

dictionary_learning/buffer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def __init__(self,
5959
self.remove_bos = remove_bos
6060
self.add_special_tokens = add_special_tokens
6161

62+
print(self.model.tokenizer.padding_side)
63+
64+
if self.remove_bos:
65+
assert self.model.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)"
66+
6267
def __iter__(self):
6368
return self
6469

@@ -138,6 +143,7 @@ def refresh(self):
138143
if isinstance(hidden_states, tuple):
139144
hidden_states = hidden_states[0]
140145
if self.remove_bos:
146+
assert self.model.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)"
141147
hidden_states = hidden_states[:, 1:, :]
142148
attn_mask = attn_mask[:, 1:]
143149
hidden_states = hidden_states[attn_mask != 0]

dictionary_learning/pytorch_buffer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ def __init__(
123123

124124
if not self.tokenizer.pad_token:
125125
self.tokenizer.pad_token = self.tokenizer.eos_token
126+
127+
if self.remove_bos:
128+
assert self.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)"
126129

127130
def __iter__(self):
128131
return self
@@ -194,6 +197,7 @@ def refresh(self):
194197
hidden_states = collect_activations(self.model, self.submodule, input)
195198
attn_mask = input["attention_mask"]
196199
if self.remove_bos:
200+
assert self.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)"
197201
hidden_states = hidden_states[:, 1:, :]
198202
attn_mask = attn_mask[:, 1:]
199203
hidden_states = hidden_states[attn_mask != 0]

0 commit comments

Comments
 (0)