1212from python_autocomplete .dataset import Tokenizer , ID_CHARS
1313from python_autocomplete .train import Configs , StateUpdater
1414
15+ EPS_PROB = 1e-6
16+ MIN_BEAM_PROB = 1e-4
1517
1618class PredictionComplete :
1719 def __call__ (self , text , token_str : str ):
@@ -27,7 +29,7 @@ def __init__(self, prompt: str):
2729 def __call__ (self , text , token_str : str ):
2830 prediction = set (token_str )
2931 intersection = prediction .intersection (ID_CHARS )
30- is_id = intersection and intersection == prediction
32+ is_id = len ( intersection ) > 0 and intersection == prediction
3133 is_not_id = intersection != prediction
3234 if is_id and is_not_id :
3335 return True
@@ -65,25 +67,32 @@ def is_substr(original, token_str):
6567
6668 def add_prediction (self , prob : float , beam_idx : int , token_str : str , state ):
6769 if len (self .result_heap ) == self .max_beam_size :
68- if self .result_heap [0 ][0 ] > prob :
69- return
70+ if self .result_heap [0 ][0 ] > prob - EPS_PROB :
71+ return False
7072 heappop (self .result_heap )
7173
7274 state = self .state_updater .get_from_batch (state , beam_idx )
7375 text = self .text [beam_idx ] + token_str
7476 heappush (self .result_heap , (prob , (text , state )))
7577
78+ return True
79+
7680 def add_beam (self , prob : float , beam_idx : int , token : int ):
77- if self .result_heap and self .result_heap [0 ][0 ] > prob :
78- return
81+ if self .result_heap and self .result_heap [0 ][0 ] > prob - EPS_PROB :
82+ return False
83+
84+ if prob < MIN_BEAM_PROB :
85+ return False
7986
8087 if len (self .beam_heap ) == self .max_beam_size :
81- if self .beam_heap [0 ][0 ] > prob :
82- return
88+ if self .beam_heap [0 ][0 ] > prob - EPS_PROB :
89+ return False
8390 heappop (self .beam_heap )
8491
8592 heappush (self .beam_heap , (prob , (beam_idx , token )))
8693
94+ return True
95+
8796 def next_batch (self , prompt : torch .Tensor , state : Any , itos : List [str ]):
8897 if not self .beam_heap :
8998 return None , None
@@ -122,13 +131,20 @@ def update(self, next_token, itos: List[str], state):
122131 else :
123132 check_rest = self .rest [len (text ):]
124133
125- for token , token_str in enumerate (itos ):
134+ tokens = next_token [b ]
135+ sort_idx = torch .argsort (tokens )
136+
137+ for i in reversed (range (len (tokens ))):
138+ token = sort_idx [i ]
139+ token_str = itos [token ]
126140 if not self .is_substr (check_rest , token_str ):
127141 continue
128142
129143 if self .prediction_complete (text , token_str ):
130- self .add_prediction (self .probs [b ] * next_token [b ][token ].item (), b , token_str , state )
131- self .add_beam (self .probs [b ] * next_token [b ][token ].item (), b , token )
144+ if not self .add_prediction (self .probs [b ] * tokens [token ].item (), b , token_str , state ):
145+ break
146+ elif not self .add_beam (self .probs [b ] * tokens [token ].item (), b , token ):
147+ break
132148
133149
134150class Prediction (NamedTuple ):
@@ -329,7 +345,7 @@ def get_predictor() -> Predictor:
329345 # And for latest checkpoint
330346 # checkpoint = None
331347
332- run_uuid = '109d1b8c6e8611eb80e13584488b68a4 ' # bpe
348+ run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6 ' # bpe
333349 checkpoint = None
334350 # run_uuid, checkpoint = experiment.load_bundle(
335351 # lab.get_path() / 'saved_checkpoint.tar.gz',
0 commit comments