|
5 | 5 | import torch.nn |
6 | 6 | from torch import nn |
7 | 7 |
|
8 | | -from labml import experiment, logger, lab, monit |
9 | | -from labml.logger import Text, Style |
10 | | -from labml.utils.pytorch import get_modules |
11 | 8 | from labml_helpers.module import Module |
12 | 9 | from python_autocomplete.dataset import Tokenizer, ID_CHARS |
13 | | -from python_autocomplete.train import Configs, StateUpdater |
| 10 | +from python_autocomplete.train import StateUpdater |
14 | 11 |
|
15 | 12 | EPS_PROB = 1e-6 |
16 | 13 | MIN_BEAM_PROB = 1e-4 |
@@ -221,176 +218,3 @@ def get_next_word(self, prompt: torch.Tensor, state: Any, rest: str, probs: List |
221 | 218 |
|
222 | 219 | def rstrip(self, prompt: str) -> Tuple[str, List[int]]: |
223 | 220 | return self.tokenizer.rstrip(prompt) |
224 | | - |
225 | | - |
226 | | -def evaluate(predictor: Predictor, text: str): |
227 | | - line_no = 1 |
228 | | - logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)] |
229 | | - |
230 | | - correct = 0 |
231 | | - i = 0 |
232 | | - key_strokes = 0 |
233 | | - |
234 | | - while i + 1 < len(text): |
235 | | - prefix = text[:i + 1] |
236 | | - stripped, prompt = predictor.rstrip(prefix) |
237 | | - rest = prefix[len(stripped):] |
238 | | - prediction_complete = NextWordPredictionComplete(stripped, rest, 5) |
239 | | - prompt = torch.tensor(prompt, dtype=torch.long).unsqueeze(-1) |
240 | | - |
241 | | - predictions = predictor.get_next_word(prompt, None, rest, [1.], prediction_complete, 5) |
242 | | - predictions.sort(key=lambda x: -x[0]) |
243 | | - if predictions: |
244 | | - next_token = predictions[0].text[len(rest):] |
245 | | - else: |
246 | | - next_token = '' |
247 | | - |
248 | | - if next_token and next_token == text[i + 1: i + 1 + len(next_token)]: |
249 | | - correct += len(next_token) |
250 | | - right = True |
251 | | - else: |
252 | | - next_token = text[i + 1] |
253 | | - right = False |
254 | | - |
255 | | - for j, c in enumerate(next_token): |
256 | | - if c == '\n': |
257 | | - logger.log(logs) |
258 | | - line_no += 1 |
259 | | - logs = [(f"{line_no: 4d}: ", Text.meta)] |
260 | | - elif c == '\r': |
261 | | - continue |
262 | | - else: |
263 | | - if right: |
264 | | - if j == 0: |
265 | | - logs.append((c, [Text.meta, Style.underline])) |
266 | | - else: |
267 | | - logs.append((c, [Text.success, Style.underline])) |
268 | | - else: |
269 | | - logs.append((c, [Text.warning])) |
270 | | - |
271 | | - i += len(next_token) |
272 | | - key_strokes += 1 |
273 | | - |
274 | | - logger.log(logs) |
275 | | - |
276 | | - logger.inspect(accuracy=correct / (len(text) - 1), |
277 | | - key_strokes=key_strokes, |
278 | | - length=len(text)) |
279 | | - |
280 | | - |
281 | | -def anomalies(predictor: Predictor, text: str): |
282 | | - line_no = 1 |
283 | | - logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)] |
284 | | - |
285 | | - i = 0 |
286 | | - |
287 | | - while i + 1 < len(text): |
288 | | - # print(i, self.predictor.prompt) |
289 | | - preds, _ = predictor.get_predictions(text[:i + 1], None, calc_probs=True) |
290 | | - preds = preds[0, :] |
291 | | - c = text[i + 1] |
292 | | - |
293 | | - if c == '\n': |
294 | | - logger.log(logs) |
295 | | - line_no += 1 |
296 | | - logs = [(f"{line_no: 4d}: ", Text.meta)] |
297 | | - elif c == '\r': |
298 | | - continue |
299 | | - elif c not in predictor.tokenizer.stoi: |
300 | | - logs.append(c) |
301 | | - else: |
302 | | - next_id = predictor.tokenizer.stoi[c] |
303 | | - prob = preds[next_id] |
304 | | - if prob > 0.9: |
305 | | - logs.append((c, [Style.bold, Text.success, Style.underline])) |
306 | | - elif prob > 0.75: |
307 | | - logs.append((c, [Text.success, Style.underline])) |
308 | | - elif prob > 0.2: |
309 | | - logs.append(c) |
310 | | - elif prob > 0.1: |
311 | | - logs.append((c, [Text.warning, Style.underline])) |
312 | | - elif prob > 0.01: |
313 | | - logs.append((c, [Style.bold, Text.warning, Style.underline])) |
314 | | - elif prob > 0.001: |
315 | | - logs.append((c, [Text.danger, Style.underline])) |
316 | | - else: |
317 | | - logs.append((c, [Style.bold, Text.danger, Style.underline])) |
318 | | - |
319 | | - i += 1 |
320 | | - |
321 | | - logger.log(logs) |
322 | | - |
323 | | - |
324 | | -def complete(predictor: Predictor, text: str, completion: int): |
325 | | - line_no = 1 |
326 | | - logs = [(f"{line_no: 4d}: ", Text.meta), (text[0], Text.subtle)] |
327 | | - |
328 | | - i = 0 |
329 | | - given = len(text) |
330 | | - |
331 | | - while i + 1 < given + completion: |
332 | | - if len(text) > i + 1: |
333 | | - c = text[i + 1] |
334 | | - else: |
335 | | - c, _ = predictor.get_next_token(text[:i + 1], None) |
336 | | - |
337 | | - if c == '\n': |
338 | | - logger.log(logs) |
339 | | - line_no += 1 |
340 | | - logs = [(f"{line_no: 4d}: ", Text.meta)] |
341 | | - elif c != '\r': |
342 | | - if len(text) > i + 1: |
343 | | - logs.append(c) |
344 | | - else: |
345 | | - logs.append((c, [Style.bold])) |
346 | | - |
347 | | - if len(text) <= i + 1: |
348 | | - text += c |
349 | | - |
350 | | - i += 1 |
351 | | - |
352 | | - logger.log(logs) |
353 | | - |
354 | | - |
355 | | -def get_predictor() -> Predictor: |
356 | | - conf = Configs() |
357 | | - experiment.evaluate() |
358 | | - |
359 | | - # This will download a pretrained model checkpoint and some cached files. |
360 | | - # It will download the archive as `saved_checkpoint.tar.gz` and extract it. |
361 | | - # |
362 | | - # If you have a locally trained model load it directly with |
363 | | - # run_uuid = 'RUN_UUID' |
364 | | - # And for latest checkpoint |
365 | | - # checkpoint = None |
366 | | - |
367 | | - run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6' # bpe |
368 | | - checkpoint = None |
369 | | - # run_uuid, checkpoint = experiment.load_bundle( |
370 | | - # lab.get_path() / 'saved_checkpoint.tar.gz', |
371 | | - # url='https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz') |
372 | | - |
373 | | - conf_dict = experiment.load_configs(run_uuid) |
374 | | - conf_dict['text.is_load_data'] = False |
375 | | - experiment.configs(conf, conf_dict) |
376 | | - experiment.add_pytorch_models(get_modules(conf)) |
377 | | - experiment.load(run_uuid, checkpoint) |
378 | | - |
379 | | - experiment.start() |
380 | | - conf.model.eval() |
381 | | - return Predictor(conf.model, conf.text.tokenizer, |
382 | | - state_updater=conf.state_updater, |
383 | | - is_token_by_token=conf.is_token_by_token) |
384 | | - |
385 | | - |
386 | | -def main(): |
387 | | - predictor = get_predictor() |
388 | | - |
389 | | - with open(str(lab.get_data_path() / 'sample.py'), 'r') as f: |
390 | | - sample = f.read() |
391 | | - with monit.section('Evaluate'): |
392 | | - evaluate(predictor, sample) |
393 | | - |
394 | | - |
395 | | -if __name__ == '__main__': |
396 | | - main() |
0 commit comments