Skip to content

Commit 3524f4f

Browse files
committed
refractor
1 parent 592a8c1 commit 3524f4f

File tree

6 files changed

+219
-178
lines changed

6 files changed

+219
-178
lines changed
Lines changed: 1 addition & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,9 @@
55
import torch.nn
66
from torch import nn
77

8-
from labml import experiment, logger, lab, monit
9-
from labml.logger import Text, Style
10-
from labml.utils.pytorch import get_modules
118
from labml_helpers.module import Module
129
from python_autocomplete.dataset import Tokenizer, ID_CHARS
13-
from python_autocomplete.train import Configs, StateUpdater
10+
from python_autocomplete.train import StateUpdater
1411

1512
EPS_PROB = 1e-6
1613
MIN_BEAM_PROB = 1e-4
@@ -221,176 +218,3 @@ def get_next_word(self, prompt: torch.Tensor, state: Any, rest: str, probs: List
221218

222219
def rstrip(self, prompt: str) -> Tuple[str, List[int]]:
223220
return self.tokenizer.rstrip(prompt)
224-
225-
226-
def evaluate(predictor: Predictor, text: str):
227-
line_no = 1
228-
logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)]
229-
230-
correct = 0
231-
i = 0
232-
key_strokes = 0
233-
234-
while i + 1 < len(text):
235-
prefix = text[:i + 1]
236-
stripped, prompt = predictor.rstrip(prefix)
237-
rest = prefix[len(stripped):]
238-
prediction_complete = NextWordPredictionComplete(stripped, rest, 5)
239-
prompt = torch.tensor(prompt, dtype=torch.long).unsqueeze(-1)
240-
241-
predictions = predictor.get_next_word(prompt, None, rest, [1.], prediction_complete, 5)
242-
predictions.sort(key=lambda x: -x[0])
243-
if predictions:
244-
next_token = predictions[0].text[len(rest):]
245-
else:
246-
next_token = ''
247-
248-
if next_token and next_token == text[i + 1: i + 1 + len(next_token)]:
249-
correct += len(next_token)
250-
right = True
251-
else:
252-
next_token = text[i + 1]
253-
right = False
254-
255-
for j, c in enumerate(next_token):
256-
if c == '\n':
257-
logger.log(logs)
258-
line_no += 1
259-
logs = [(f"{line_no: 4d}: ", Text.meta)]
260-
elif c == '\r':
261-
continue
262-
else:
263-
if right:
264-
if j == 0:
265-
logs.append((c, [Text.meta, Style.underline]))
266-
else:
267-
logs.append((c, [Text.success, Style.underline]))
268-
else:
269-
logs.append((c, [Text.warning]))
270-
271-
i += len(next_token)
272-
key_strokes += 1
273-
274-
logger.log(logs)
275-
276-
logger.inspect(accuracy=correct / (len(text) - 1),
277-
key_strokes=key_strokes,
278-
length=len(text))
279-
280-
281-
def anomalies(predictor: Predictor, text: str):
282-
line_no = 1
283-
logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)]
284-
285-
i = 0
286-
287-
while i + 1 < len(text):
288-
# print(i, self.predictor.prompt)
289-
preds, _ = predictor.get_predictions(text[:i + 1], None, calc_probs=True)
290-
preds = preds[0, :]
291-
c = text[i + 1]
292-
293-
if c == '\n':
294-
logger.log(logs)
295-
line_no += 1
296-
logs = [(f"{line_no: 4d}: ", Text.meta)]
297-
elif c == '\r':
298-
continue
299-
elif c not in predictor.tokenizer.stoi:
300-
logs.append(c)
301-
else:
302-
next_id = predictor.tokenizer.stoi[c]
303-
prob = preds[next_id]
304-
if prob > 0.9:
305-
logs.append((c, [Style.bold, Text.success, Style.underline]))
306-
elif prob > 0.75:
307-
logs.append((c, [Text.success, Style.underline]))
308-
elif prob > 0.2:
309-
logs.append(c)
310-
elif prob > 0.1:
311-
logs.append((c, [Text.warning, Style.underline]))
312-
elif prob > 0.01:
313-
logs.append((c, [Style.bold, Text.warning, Style.underline]))
314-
elif prob > 0.001:
315-
logs.append((c, [Text.danger, Style.underline]))
316-
else:
317-
logs.append((c, [Style.bold, Text.danger, Style.underline]))
318-
319-
i += 1
320-
321-
logger.log(logs)
322-
323-
324-
def complete(predictor: Predictor, text: str, completion: int):
325-
line_no = 1
326-
logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)]
327-
328-
i = 0
329-
given = len(text)
330-
331-
while i + 1 < given + completion:
332-
if len(text) > i + 1:
333-
c = text[i + 1]
334-
else:
335-
c, _ = predictor.get_next_token(text[:i + 1], None)
336-
337-
if c == '\n':
338-
logger.log(logs)
339-
line_no += 1
340-
logs = [(f"{line_no: 4d}: ", Text.meta)]
341-
elif c != '\r':
342-
if len(text) > i + 1:
343-
logs.append(c)
344-
else:
345-
logs.append((c, [Style.bold]))
346-
347-
if len(text) <= i + 1:
348-
text += c
349-
350-
i += 1
351-
352-
logger.log(logs)
353-
354-
355-
def get_predictor() -> Predictor:
356-
conf = Configs()
357-
experiment.evaluate()
358-
359-
# This will download a pretrained model checkpoint and some cached files.
360-
# It will download the archive as `saved_checkpoint.tar.gz` and extract it.
361-
#
362-
# If you have a locally trained model load it directly with
363-
# run_uuid = 'RUN_UUID'
364-
# And for latest checkpoint
365-
# checkpoint = None
366-
367-
run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6' # bpe
368-
checkpoint = None
369-
# run_uuid, checkpoint = experiment.load_bundle(
370-
# lab.get_path() / 'saved_checkpoint.tar.gz',
371-
# url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
372-
373-
conf_dict = experiment.load_configs(run_uuid)
374-
conf_dict['text.is_load_data'] = False
375-
experiment.configs(conf, conf_dict)
376-
experiment.add_pytorch_models(get_modules(conf))
377-
experiment.load(run_uuid, checkpoint)
378-
379-
experiment.start()
380-
conf.model.eval()
381-
return Predictor(conf.model, conf.text.tokenizer,
382-
state_updater=conf.state_updater,
383-
is_token_by_token=conf.is_token_by_token)
384-
385-
386-
def main():
387-
predictor = get_predictor()
388-
389-
with open(str(lab.get_data_path() / 'sample.py'), 'r') as f:
390-
sample = f.read()
391-
with monit.section('Evaluate'):
392-
evaluate(predictor, sample)
393-
394-
395-
if __name__ == '__main__':
396-
main()
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from labml import logger, lab, monit
2+
from labml.logger import Text, Style
3+
from python_autocomplete.evaluate import Predictor
4+
from python_autocomplete.evaluate.factory import get_predictor
5+
6+
7+
def anomalies(predictor: Predictor, text: str):
8+
line_no = 1
9+
logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)]
10+
11+
i = 0
12+
13+
while i + 1 < len(text):
14+
# print(i, self.predictor.prompt)
15+
preds, _ = predictor.get_predictions(text[:i + 1], None, calc_probs=True)
16+
preds = preds[0, :]
17+
c = text[i + 1]
18+
19+
if c == '\n':
20+
logger.log(logs)
21+
line_no += 1
22+
logs = [(f"{line_no: 4d}: ", Text.meta)]
23+
elif c == '\r':
24+
continue
25+
elif c not in predictor.tokenizer.stoi:
26+
logs.append(c)
27+
else:
28+
next_id = predictor.tokenizer.stoi[c]
29+
prob = preds[next_id]
30+
if prob > 0.9:
31+
logs.append((c, [Style.bold, Text.success, Style.underline]))
32+
elif prob > 0.75:
33+
logs.append((c, [Text.success, Style.underline]))
34+
elif prob > 0.2:
35+
logs.append(c)
36+
elif prob > 0.1:
37+
logs.append((c, [Text.warning, Style.underline]))
38+
elif prob > 0.01:
39+
logs.append((c, [Style.bold, Text.warning, Style.underline]))
40+
elif prob > 0.001:
41+
logs.append((c, [Text.danger, Style.underline]))
42+
else:
43+
logs.append((c, [Style.bold, Text.danger, Style.underline]))
44+
45+
i += 1
46+
47+
logger.log(logs)
48+
49+
50+
def main():
51+
predictor = get_predictor()
52+
53+
with open(str(lab.get_data_path() / 'sample.py'), 'r') as f:
54+
sample = f.read()
55+
with monit.section('Anomalies'):
56+
anomalies(predictor, sample)
57+
58+
59+
if __name__ == '__main__':
60+
main()
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
3+
from labml import logger, lab, monit
4+
from labml.logger import Text, Style
5+
from python_autocomplete.evaluate import NextWordPredictionComplete, Predictor
6+
from python_autocomplete.evaluate.factory import get_predictor
7+
8+
9+
def evaluate(predictor: Predictor, text: str):
10+
line_no = 1
11+
logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)]
12+
13+
correct = 0
14+
i = 0
15+
key_strokes = 0
16+
17+
while i + 1 < len(text):
18+
prefix = text[:i + 1]
19+
stripped, prompt = predictor.rstrip(prefix)
20+
rest = prefix[len(stripped):]
21+
prediction_complete = NextWordPredictionComplete(stripped, rest, 5)
22+
prompt = torch.tensor(prompt, dtype=torch.long).unsqueeze(-1)
23+
24+
predictions = predictor.get_next_word(prompt, None, rest, [1.], prediction_complete, 5)
25+
predictions.sort(key=lambda x: -x[0])
26+
if predictions:
27+
next_token = predictions[0].text[len(rest):]
28+
else:
29+
next_token = ''
30+
31+
if next_token and next_token == text[i + 1: i + 1 + len(next_token)]:
32+
correct += len(next_token)
33+
right = True
34+
else:
35+
next_token = text[i + 1]
36+
right = False
37+
38+
for j, c in enumerate(next_token):
39+
if c == '\n':
40+
logger.log(logs)
41+
line_no += 1
42+
logs = [(f"{line_no: 4d}: ", Text.meta)]
43+
elif c == '\r':
44+
continue
45+
else:
46+
if right:
47+
if j == 0:
48+
logs.append((c, [Text.meta, Style.underline]))
49+
else:
50+
logs.append((c, [Text.success, Style.underline]))
51+
else:
52+
logs.append((c, [Text.warning]))
53+
54+
i += len(next_token)
55+
key_strokes += 1
56+
57+
logger.log(logs)
58+
59+
logger.inspect(accuracy=correct / (len(text) - 1),
60+
key_strokes=key_strokes,
61+
length=len(text))
62+
63+
64+
def main():
65+
predictor = get_predictor()
66+
67+
with open(str(lab.get_data_path() / 'sample.py'), 'r') as f:
68+
sample = f.read()
69+
with monit.section('Evaluate'):
70+
evaluate(predictor, sample)
71+
72+
73+
if __name__ == '__main__':
74+
main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from labml import experiment
2+
from labml.utils.pytorch import get_modules
3+
from python_autocomplete.evaluate import Predictor
4+
from python_autocomplete.train import Configs
5+
6+
7+
def get_predictor() -> Predictor:
8+
conf = Configs()
9+
experiment.evaluate()
10+
11+
# This will download a pretrained model checkpoint and some cached files.
12+
# It will download the archive as `saved_checkpoint.tar.gz` and extract it.
13+
#
14+
# If you have a locally trained model load it directly with
15+
# run_uuid = 'RUN_UUID'
16+
# And for latest checkpoint
17+
# checkpoint = None
18+
19+
run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6' # bpe
20+
checkpoint = None
21+
# run_uuid, checkpoint = experiment.load_bundle(
22+
# lab.get_path() / 'saved_checkpoint.tar.gz',
23+
# url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz')
24+
25+
conf_dict = experiment.load_configs(run_uuid)
26+
conf_dict['text.is_load_data'] = False
27+
experiment.configs(conf, conf_dict)
28+
experiment.add_pytorch_models(get_modules(conf))
29+
experiment.load(run_uuid, checkpoint)
30+
31+
experiment.start()
32+
conf.model.eval()
33+
return Predictor(conf.model, conf.text.tokenizer,
34+
state_updater=conf.state_updater,
35+
is_token_by_token=conf.is_token_by_token)

0 commit comments

Comments
 (0)