@@ -56,14 +56,9 @@ def __init__(self,
5656 self .refresh_batch_size = refresh_batch_size
5757 self .out_batch_size = out_batch_size
5858 self .device = device
59- self .remove_bos = remove_bos
59+ self .remove_bos = remove_bos and ( self . model . tokenizer . bos_token_id is not None )
6060 self .add_special_tokens = add_special_tokens
6161
62- print (self .model .tokenizer .padding_side )
63-
64- if self .remove_bos :
65- assert self .model .tokenizer .padding_side == "right" , "Padding side must be right (bos-trimming logic assumes right padding)"
66-
6762 def __iter__ (self ):
6863 return self
6964
@@ -138,15 +133,15 @@ def refresh(self):
138133 input = self .model .inputs .save ()
139134
140135 self .submodule .output .stop ()
141- attn_mask = input .value [1 ]["attention_mask" ]
136+
137+ mask = (input .value [1 ]["attention_mask" ] != 0 )
142138 hidden_states = hidden_states .value
143139 if isinstance (hidden_states , tuple ):
144140 hidden_states = hidden_states [0 ]
145141 if self .remove_bos :
146- assert self .model .tokenizer .padding_side == "right" , "Padding side must be right (bos-trimming logic assumes right padding)"
147- hidden_states = hidden_states [:, 1 :, :]
148- attn_mask = attn_mask [:, 1 :]
149- hidden_states = hidden_states [attn_mask != 0 ]
142+ bos_mask = (input .value [1 ]["input_ids" ] == self .model .tokenizer .bos_token_id )
143+ mask = mask & ~ bos_mask
144+ hidden_states = hidden_states [mask ]
150145
151146 remaining_space = self .activation_buffer_size - current_idx
152147 assert remaining_space > 0
0 commit comments