1515EPS_PROB = 1e-6
1616MIN_BEAM_PROB = 1e-4
1717
18+
1819class PredictionComplete :
1920 def __call__ (self , text , token_str : str ):
2021 raise NotImplementedError
2122
2223
2324class NextWordPredictionComplete (PredictionComplete ):
24- def __init__ (self , prompt : str ):
25+ def __init__ (self , prompt : str , rest : str , min_length : int ):
26+ self .min_length = min_length
27+ self .rest = rest
2528 self .is_id = False
2629 if prompt and prompt [- 1 ] in ID_CHARS :
2730 self .is_id = True
2831
2932 def __call__ (self , text , token_str : str ):
30- prediction = set ( token_str )
31- intersection = prediction . intersection ( ID_CHARS )
32- is_id = len ( intersection ) > 0 and intersection == prediction
33- is_not_id = intersection != prediction
34- if is_id and is_not_id :
35- return True
36- return is_id == self . is_id
33+ if len ( text ) - len ( self . rest ) < self . min_length :
34+ return False
35+
36+ prev_is_id = text [ - 1 ] in ID_CHARS
37+ last_is_id = token_str [ - 1 ] in ID_CHARS
38+
39+ return prev_is_id != last_is_id
3740
3841
3942class BeamSearch :
@@ -77,6 +80,18 @@ def add_prediction(self, prob: float, beam_idx: int, token_str: str, state):
7780
7881 return True
7982
83+ def add_prediction_before_token (self , prob : float , beam_idx : int , state ):
84+ if len (self .result_heap ) == self .max_beam_size :
85+ if self .result_heap [0 ][0 ] > prob - EPS_PROB :
86+ return False
87+ heappop (self .result_heap )
88+
89+ state = self .state_updater .get_from_batch (state , beam_idx )
90+ text = self .text [beam_idx ]
91+ heappush (self .result_heap , (prob , (text , state )))
92+
93+ return True
94+
8095 def add_beam (self , prob : float , beam_idx : int , token : int ):
8196 if self .result_heap and self .result_heap [0 ][0 ] > prob - EPS_PROB :
8297 return False
@@ -121,7 +136,7 @@ def next_batch(self, prompt: torch.Tensor, state: Any, itos: List[str]):
121136
122137 return new_prompt , new_state
123138
124- def update (self , next_token , itos : List [str ], state ):
139+ def update (self , next_token , itos : List [str ], state , old_state ):
125140 self .beam_heap = []
126141
127142 for b , text in enumerate (self .text ):
@@ -141,8 +156,12 @@ def update(self, next_token, itos: List[str], state):
141156 continue
142157
143158 if self .prediction_complete (text , token_str ):
144- if not self .add_prediction (self .probs [b ] * tokens [token ].item (), b , token_str , state ):
159+ if not self .add_prediction_before_token (self .probs [b ], b , old_state ):
160+ break
161+ else :
145162 break
163+ # if not self.add_prediction(self.probs[b] * tokens[token].item(), b, token_str, state):
164+ # break
146165 elif not self .add_beam (self .probs [b ] * tokens [token ].item (), b , token ):
147166 break
148167
@@ -190,9 +209,9 @@ def get_next_word(self, prompt: torch.Tensor, state: Any, rest: str, probs: List
190209 probs , self .is_token_by_token )
191210
192211 for _ in range (10 ):
193- next_token , state = self ._get_predictions (prompt , state )
194- beam .update (next_token , self .tokenizer .itos , state )
195- prompt , state = beam .next_batch (prompt , state , self .tokenizer .itos )
212+ next_token , new_state = self ._get_predictions (prompt , state )
213+ beam .update (next_token , self .tokenizer .itos , new_state , state )
214+ prompt , state = beam .next_batch (prompt , new_state , self .tokenizer .itos )
196215
197216 if prompt is None :
198217 break
@@ -216,7 +235,7 @@ def evaluate(predictor: Predictor, text: str):
216235 prefix = text [:i + 1 ]
217236 stripped , prompt = predictor .rstrip (prefix )
218237 rest = prefix [len (stripped ):]
219- prediction_complete = NextWordPredictionComplete (stripped )
238+ prediction_complete = NextWordPredictionComplete (stripped , rest , 5 )
220239 prompt = torch .tensor (prompt , dtype = torch .long ).unsqueeze (- 1 )
221240
222241 predictions = predictor .get_next_word (prompt , None , rest , [1. ], prediction_complete , 5 )
0 commit comments