Skip to content

Commit 592a8c1

Browse files
committed
improve token completion
1 parent 5e2166b commit 592a8c1

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

python_autocomplete/evaluate.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,28 @@
1515
EPS_PROB = 1e-6
1616
MIN_BEAM_PROB = 1e-4
1717

18+
1819
class PredictionComplete:
1920
def __call__(self, text, token_str: str):
2021
raise NotImplementedError
2122

2223

2324
class NextWordPredictionComplete(PredictionComplete):
24-
def __init__(self, prompt: str):
25+
def __init__(self, prompt: str, rest: str, min_length: int):
26+
self.min_length = min_length
27+
self.rest = rest
2528
self.is_id = False
2629
if prompt and prompt[-1] in ID_CHARS:
2730
self.is_id = True
2831

2932
def __call__(self, text, token_str: str):
30-
prediction = set(token_str)
31-
intersection = prediction.intersection(ID_CHARS)
32-
is_id = len(intersection) > 0 and intersection == prediction
33-
is_not_id = intersection != prediction
34-
if is_id and is_not_id:
35-
return True
36-
return is_id == self.is_id
33+
if len(text) - len(self.rest) < self.min_length:
34+
return False
35+
36+
prev_is_id = text[-1] in ID_CHARS
37+
last_is_id = token_str[-1] in ID_CHARS
38+
39+
return prev_is_id != last_is_id
3740

3841

3942
class BeamSearch:
@@ -77,6 +80,18 @@ def add_prediction(self, prob: float, beam_idx: int, token_str: str, state):
7780

7881
return True
7982

83+
def add_prediction_before_token(self, prob: float, beam_idx: int, state):
84+
if len(self.result_heap) == self.max_beam_size:
85+
if self.result_heap[0][0] > prob - EPS_PROB:
86+
return False
87+
heappop(self.result_heap)
88+
89+
state = self.state_updater.get_from_batch(state, beam_idx)
90+
text = self.text[beam_idx]
91+
heappush(self.result_heap, (prob, (text, state)))
92+
93+
return True
94+
8095
def add_beam(self, prob: float, beam_idx: int, token: int):
8196
if self.result_heap and self.result_heap[0][0] > prob - EPS_PROB:
8297
return False
@@ -121,7 +136,7 @@ def next_batch(self, prompt: torch.Tensor, state: Any, itos: List[str]):
121136

122137
return new_prompt, new_state
123138

124-
def update(self, next_token, itos: List[str], state):
139+
def update(self, next_token, itos: List[str], state, old_state):
125140
self.beam_heap = []
126141

127142
for b, text in enumerate(self.text):
@@ -141,8 +156,12 @@ def update(self, next_token, itos: List[str], state):
141156
continue
142157

143158
if self.prediction_complete(text, token_str):
144-
if not self.add_prediction(self.probs[b] * tokens[token].item(), b, token_str, state):
159+
if not self.add_prediction_before_token(self.probs[b], b, old_state):
160+
break
161+
else:
145162
break
163+
# if not self.add_prediction(self.probs[b] * tokens[token].item(), b, token_str, state):
164+
# break
146165
elif not self.add_beam(self.probs[b] * tokens[token].item(), b, token):
147166
break
148167

@@ -190,9 +209,9 @@ def get_next_word(self, prompt: torch.Tensor, state: Any, rest: str, probs: List
190209
probs, self.is_token_by_token)
191210

192211
for _ in range(10):
193-
next_token, state = self._get_predictions(prompt, state)
194-
beam.update(next_token, self.tokenizer.itos, state)
195-
prompt, state = beam.next_batch(prompt, state, self.tokenizer.itos)
212+
next_token, new_state = self._get_predictions(prompt, state)
213+
beam.update(next_token, self.tokenizer.itos, new_state, state)
214+
prompt, state = beam.next_batch(prompt, new_state, self.tokenizer.itos)
196215

197216
if prompt is None:
198217
break
@@ -216,7 +235,7 @@ def evaluate(predictor: Predictor, text: str):
216235
prefix = text[:i + 1]
217236
stripped, prompt = predictor.rstrip(prefix)
218237
rest = prefix[len(stripped):]
219-
prediction_complete = NextWordPredictionComplete(stripped)
238+
prediction_complete = NextWordPredictionComplete(stripped, rest, 5)
220239
prompt = torch.tensor(prompt, dtype=torch.long).unsqueeze(-1)
221240

222241
predictions = predictor.get_next_word(prompt, None, rest, [1.], prediction_complete, 5)

0 commit comments

Comments
 (0)