|
1 | | -import os |
2 | | -import json |
3 | 1 | import logging |
| 2 | +from os.path import join as pjoin, abspath |
| 3 | +import os |
4 | 4 |
|
5 | | -from funcy_chain import Chain |
6 | | -from funcy import lmap |
7 | | -from tqdm import tqdm |
8 | | -from returns.result import ResultE |
9 | 5 | import fire |
| 6 | +import numpy as np |
| 7 | +import orjson |
| 8 | +from returns.result import Success, Failure |
| 9 | +from tfbench import run_one_model, EvalResult |
10 | 10 |
|
11 | | -from tfbench.common import get_prompt |
12 | | -from tfbench.postprocessing import postprocess, RESPONSE_STRATEGIES |
13 | | -from tfbench.evaluation import evaluate |
14 | | -from tfbench.lm import router, LMAnswer, extract_response |
15 | | -from tfbench.load import load_from_hf |
| 11 | + |
| 12 | +def analysis(results: list[EvalResult]): |
| 13 | + """calculate mean and std of accuracy of multiple runs""" |
| 14 | + accs = list(map(lambda r: r["accuracy"], results)) |
| 15 | + return np.mean(accs), np.std(accs) |
16 | 16 |
|
17 | 17 |
|
18 | 18 | def main( |
19 | 19 | model: str, |
20 | | - pure: bool = False, |
21 | 20 | effort: str | None = None, |
22 | | - output_file: str | None = None, |
| 21 | + n_repeats: int = 3, |
23 | 22 | log_file: str = "evaluation_log.jsonl", |
24 | 23 | ): |
25 | | - """ |
26 | | - Run an experiment using various AI models to generate and evaluate type signatures. |
27 | | -
|
28 | | - Parameters: |
29 | | - model (str): Name of the model to use for generating type signatures. Must be one of: |
30 | | - - GPT_MODELS: ["gpt-3.5-turbo-0125", "gpt-4-turbo-2024-04-09", ...] |
31 | | - - OLLAMA_MODELS, CLAUDE_MODELS, or O1_MODELS. |
32 | | - Default is "gpt-3.5-turbo". |
33 | | -
|
34 | | - pure (bool): If True, uses the original variable naming in type inference. |
35 | | - If False, uses rewritten variable naming (e.g., `v1`, `v2`, ...). Default is False. |
36 | | -
|
37 | | - """ |
38 | | - |
39 | | - if output_file is None: |
40 | | - os.makedirs("result", exist_ok=True) |
41 | | - if "/" in model: |
42 | | - dir_name = model.split("/")[0] |
43 | | - os.makedirs(f"result/{dir_name}", exist_ok=True) |
44 | | - output_file = os.path.abspath(f"result/{model}.txt") |
45 | | - logging.info(f"Writing generation results in {output_file}.") |
46 | | - |
47 | | - client = router(model, pure, effort) |
48 | | - assert client, f"Failed to create client for {model}." |
49 | | - |
50 | | - tasks = load_from_hf("pure" if pure else "base") |
51 | | - prompts = lmap(get_prompt, tasks) |
52 | | - responses: list[ResultE[LMAnswer]] = lmap( |
53 | | - client.generate, tqdm(prompts, desc=model) |
54 | | - ) |
55 | | - |
56 | | - gen_results = ( |
57 | | - Chain(responses) |
58 | | - .map(extract_response) |
59 | | - .map(lambda s: postprocess(s, RESPONSE_STRATEGIES)) |
60 | | - .map(str.strip) |
61 | | - .value |
62 | | - ) |
63 | | - |
64 | | - # writing results |
65 | | - with open(output_file, "w", errors="ignore") as file: |
66 | | - file.write("\n".join(gen_results)) |
67 | | - |
68 | | - eval_acc = evaluate(tasks, gen_results) |
69 | | - print(eval_acc) |
70 | | - |
71 | | - os.makedirs(os.path.dirname(output_file), exist_ok=True) |
72 | | - with open(log_file, "a") as fp: |
73 | | - logging_result = {"model_name": model, **eval_acc, "pure": pure} |
74 | | - fp.write(f"{json.dumps(logging_result)}\n") |
| 24 | + """Main script to run experiments reported in the paper""" |
| 25 | + |
| 26 | + def _run(pure: bool): |
| 27 | + results = [] |
| 28 | + for i in range(n_repeats): |
| 29 | + ext = "pure" if pure else "base" |
| 30 | + |
| 31 | + result_dir = abspath(pjoin("results", model)) |
| 32 | + os.makedirs(result_dir, exist_ok=True) |
| 33 | + result_file = pjoin(result_dir, f"run-{i}.{ext}.jsonl") |
| 34 | + match run_one_model( |
| 35 | + model, pure=pure, output_file=result_file, effort=effort |
| 36 | + ): |
| 37 | + case Success(r): |
| 38 | + results.append(r) |
| 39 | + case Failure(e): |
| 40 | + return Failure(e) |
| 41 | + return Success(analysis(results)) |
| 42 | + |
| 43 | + def _eval(pure: bool): |
| 44 | + split = "pure" if pure else "base" |
| 45 | + logging.info(f"Running {model} on TF-Bench ({split}):") |
| 46 | + match _run(pure=False): |
| 47 | + case Success((mean, std)): |
| 48 | + logging.info(f"Accuracy: {mean:.4f} ± {std:.4f}") |
| 49 | + with open(log_file, "ab") as f: |
| 50 | + f.write( |
| 51 | + orjson.dumps( |
| 52 | + { |
| 53 | + "model": model, |
| 54 | + "split": split, |
| 55 | + "effort": effort, |
| 56 | + "n_repeats": n_repeats, |
| 57 | + "mean": mean, |
| 58 | + "std": std, |
| 59 | + }, |
| 60 | + option=orjson.OPT_APPEND_NEWLINE, |
| 61 | + ) |
| 62 | + ) |
| 63 | + case Failure(e): |
| 64 | + print(f"Error in base run: {e}") |
| 65 | + return |
| 66 | + |
| 67 | + _eval(pure=False) |
| 68 | + _eval(pure=True) |
75 | 69 |
|
76 | 70 |
|
77 | 71 | if __name__ == "__main__": |
|
0 commit comments