File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -59,6 +59,11 @@ def __init__(self,
5959 self .remove_bos = remove_bos
6060 self .add_special_tokens = add_special_tokens
6161
62+ print (self .model .tokenizer .padding_side )
63+
64+ if self .remove_bos :
65+ assert self .model .tokenizer .padding_side == "right" , "Padding side must be right (bos-trimming logic assumes right padding)"
66+
6267 def __iter__ (self ):
6368 return self
6469
@@ -138,6 +143,7 @@ def refresh(self):
138143 if isinstance (hidden_states , tuple ):
139144 hidden_states = hidden_states [0 ]
140145 if self .remove_bos :
146+ assert self .model .tokenizer .padding_side == "right" , "Padding side must be right (bos-trimming logic assumes right padding)"
141147 hidden_states = hidden_states [:, 1 :, :]
142148 attn_mask = attn_mask [:, 1 :]
143149 hidden_states = hidden_states [attn_mask != 0 ]
Original file line number Diff line number Diff line change @@ -123,6 +123,9 @@ def __init__(
123123
124124 if not self .tokenizer .pad_token :
125125 self .tokenizer .pad_token = self .tokenizer .eos_token
126+
127+ if self .remove_bos :
128+ assert self .tokenizer .padding_side == "right" , "Padding side must be right (bos-trimming logic assumes right padding)"
126129
127130 def __iter__ (self ):
128131 return self
@@ -194,6 +197,7 @@ def refresh(self):
194197 hidden_states = collect_activations (self .model , self .submodule , input )
195198 attn_mask = input ["attention_mask" ]
196199 if self .remove_bos :
200+ assert self .tokenizer .padding_side == "right" , "Padding side must be right (bos-trimming logic assumes right padding)"
197201 hidden_states = hidden_states [:, 1 :, :]
198202 attn_mask = attn_mask [:, 1 :]
199203 hidden_states = hidden_states [attn_mask != 0 ]
You can’t perform that action at this time.
0 commit comments