Skip to content

Commit 97cd2f6

Browse files
committed
cleanup
1 parent 3524f4f commit 97cd2f6

File tree

2 files changed

+3
-6
lines changed

2 files changed

+3
-6
lines changed

python_autocomplete/evaluate/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@ def __call__(self, text, token_str: str):
1919

2020

2121
class NextWordPredictionComplete(PredictionComplete):
22-
def __init__(self, prompt: str, rest: str, min_length: int):
22+
def __init__(self, rest: str, min_length: int):
2323
self.min_length = min_length
2424
self.rest = rest
25-
self.is_id = False
26-
if prompt and prompt[-1] in ID_CHARS:
27-
self.is_id = True
2825

2926
def __call__(self, text, token_str: str):
3027
if len(text) - len(self.rest) < self.min_length:
@@ -159,7 +156,7 @@ def update(self, next_token, itos: List[str], state, old_state):
159156
break
160157
# if not self.add_prediction(self.probs[b] * tokens[token].item(), b, token_str, state):
161158
# break
162-
elif not self.add_beam(self.probs[b] * tokens[token].item(), b, token):
159+
if not self.add_beam(self.probs[b] * tokens[token].item(), b, token):
163160
break
164161

165162

python_autocomplete/evaluate/eval_sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def evaluate(predictor: Predictor, text: str):
1818
prefix = text[:i + 1]
1919
stripped, prompt = predictor.rstrip(prefix)
2020
rest = prefix[len(stripped):]
21-
prediction_complete = NextWordPredictionComplete(stripped, rest, 5)
21+
prediction_complete = NextWordPredictionComplete(rest, 5)
2222
prompt = torch.tensor(prompt, dtype=torch.long).unsqueeze(-1)
2323

2424
predictions = predictor.get_next_word(prompt, None, rest, [1.], prediction_complete, 5)

0 commit comments

Comments
 (0)