Skip to content

Commit 5e2166b

Browse files
committed
beam search efficiency
1 parent 879d494 commit 5e2166b

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

python_autocomplete/evaluate.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from python_autocomplete.dataset import Tokenizer, ID_CHARS
1313
from python_autocomplete.train import Configs, StateUpdater
1414

15+
EPS_PROB = 1e-6
16+
MIN_BEAM_PROB = 1e-4
1517

1618
class PredictionComplete:
1719
def __call__(self, text, token_str: str):
@@ -27,7 +29,7 @@ def __init__(self, prompt: str):
2729
def __call__(self, text, token_str: str):
2830
prediction = set(token_str)
2931
intersection = prediction.intersection(ID_CHARS)
30-
is_id = intersection and intersection == prediction
32+
is_id = len(intersection) > 0 and intersection == prediction
3133
is_not_id = intersection != prediction
3234
if is_id and is_not_id:
3335
return True
@@ -65,25 +67,32 @@ def is_substr(original, token_str):
6567

6668
def add_prediction(self, prob: float, beam_idx: int, token_str: str, state):
6769
if len(self.result_heap) == self.max_beam_size:
68-
if self.result_heap[0][0] > prob:
69-
return
70+
if self.result_heap[0][0] > prob - EPS_PROB:
71+
return False
7072
heappop(self.result_heap)
7173

7274
state = self.state_updater.get_from_batch(state, beam_idx)
7375
text = self.text[beam_idx] + token_str
7476
heappush(self.result_heap, (prob, (text, state)))
7577

78+
return True
79+
7680
def add_beam(self, prob: float, beam_idx: int, token: int):
77-
if self.result_heap and self.result_heap[0][0] > prob:
78-
return
81+
if self.result_heap and self.result_heap[0][0] > prob - EPS_PROB:
82+
return False
83+
84+
if prob < MIN_BEAM_PROB:
85+
return False
7986

8087
if len(self.beam_heap) == self.max_beam_size:
81-
if self.beam_heap[0][0] > prob:
82-
return
88+
if self.beam_heap[0][0] > prob - EPS_PROB:
89+
return False
8390
heappop(self.beam_heap)
8491

8592
heappush(self.beam_heap, (prob, (beam_idx, token)))
8693

94+
return True
95+
8796
def next_batch(self, prompt: torch.Tensor, state: Any, itos: List[str]):
8897
if not self.beam_heap:
8998
return None, None
@@ -122,13 +131,20 @@ def update(self, next_token, itos: List[str], state):
122131
else:
123132
check_rest = self.rest[len(text):]
124133

125-
for token, token_str in enumerate(itos):
134+
tokens = next_token[b]
135+
sort_idx = torch.argsort(tokens)
136+
137+
for i in reversed(range(len(tokens))):
138+
token = sort_idx[i]
139+
token_str = itos[token]
126140
if not self.is_substr(check_rest, token_str):
127141
continue
128142

129143
if self.prediction_complete(text, token_str):
130-
self.add_prediction(self.probs[b] * next_token[b][token].item(), b, token_str, state)
131-
self.add_beam(self.probs[b] * next_token[b][token].item(), b, token)
144+
if not self.add_prediction(self.probs[b] * tokens[token].item(), b, token_str, state):
145+
break
146+
elif not self.add_beam(self.probs[b] * tokens[token].item(), b, token):
147+
break
132148

133149

134150
class Prediction(NamedTuple):
@@ -329,7 +345,7 @@ def get_predictor() -> Predictor:
329345
# And for latest checkpoint
330346
# checkpoint = None
331347

332-
run_uuid = '109d1b8c6e8611eb80e13584488b68a4' # bpe
348+
run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6' # bpe
333349
checkpoint = None
334350
# run_uuid, checkpoint = experiment.load_bundle(
335351
# lab.get_path() / 'saved_checkpoint.tar.gz',

python_autocomplete/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,14 @@ def __call__(self, old_mem, new_mem):
230230
return mem
231231

232232
def get_from_batch(self, state, batch_idx):
233+
if state is None:
234+
return None
235+
233236
return [m[:, batch_idx] for m in state]
234237

235238
def make_batch(self, batch):
239+
if batch[0] is None:
240+
return None
236241
return [torch.stack([b[n] for b in batch], dim=1) for n in range(len(batch[0]))]
237242

238243

0 commit comments

Comments
 (0)