Skip to content

Commit 330dae3

Browse files
authored
feat: evaluation prove type equiv using TypeOperators (#64)
* fix: allow generation to fail * remove unnecessary imports * fix: OpenAI response add reasoning summary * fix: load_gen_results_json type * fix: analysis_saved script * fix: evaluation benchmark name * fix: OpenAI response API add summary * use pydantic-v2 * extract incorrect task-answer pairs * fix: groundtruth error (#63) * fix: missing type class and typevar in benchmark * fix: order of tasks in tfb * fix: allow load_gen_results to load error * remove error_cls unused imports * extract type variables from source code * add GHC type check by proving type equiv * fix: cp -> process * fix: API change for AST * feat: type prover support new type definition * test: ghc and type_util * feat: use prover_evaluate for base split * test: add real tfbench test cases, which the deprecated evaluation failed * alt error to syntax parsing error * feat: typeclass constrains reorder * fix: AST.get_all_nodes_of_type ignores the root itself * reorder_constraints using compiler frontend static analysis * feat: add type definitions for pure tasks * test: check type equivalence prover after rewriting mono types * fix: handle type classes alone when ading new definitions * feat: define new types automatically for pure tasks * ghc prover remove standalone type class * doc: detaile docstring for prover_evaluate * script: analysis_saved run both split
1 parent 029e7d8 commit 330dae3

35 files changed

Lines changed: 1431 additions & 139 deletions

.github/workflows/unitttest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ jobs:
1818
run: uv sync
1919

2020
- name: Run Unit Tests
21-
run: uv run pytest
21+
run: uv run pytest -n auto

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ cd TF-Bench
1414
uv sync # create a virtual environment, and install dependencies
1515
```
1616

17+
### Haskell
18+
19+
To run evaluation, you need GHC (the Glasgow Haskell Compiler) installed.
20+
We recommend using [ghcup](https://www.haskell.org/ghcup/) to install.
21+
You can use any version suggested by ghcup.
22+
23+
```sh
24+
curl --proto '=https' --tlsv1.2 -sSf https://get-ghcup.haskell.org | sh
25+
```
26+
27+
Due to the GHC dependency, the evaluation module currently only supports Linux and macOS.
28+
1729
## Building TF-Bench From Scratch (Optional)
1830

1931
### TF-Bench

benchmark/task_19.hs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Ad-hoc
77

88
# signature
99
```haskell
10-
quot :: Integral => a -> a -> a
10+
quot :: Integral a => a -> a -> a
1111
```
1212

1313
# code

benchmark/task_59.hs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Parametric
77

88
# signature
99
```haskell
10-
foldl :: (b -> a -> b) -> b -> t a -> b
10+
foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b
1111
```
1212

1313
# code

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ dependencies = [
2323
"orjson>=3.11.3",
2424
"orjsonl>=1.0.0",
2525
"pyarrow>=21.0.0",
26+
"pydantic>=2.11.7",
2627
"pytest>=8.0.0",
2728
"python-dotenv==1.0.1",
2829
"requests==2.32.3",
@@ -34,6 +35,7 @@ dependencies = [
3435
"tqdm>=4.66.2",
3536
"tree-sitter==0.22.3",
3637
"tree-sitter-haskell==0.21.0",
38+
"types-deprecated>=1.2.15.20250304",
3739
"types-requests>=2.31.0",
3840
"vllm>=0.10.1.1",
3941
]
@@ -100,3 +102,9 @@ plugins = ["returns.contrib.mypy.returns_plugin"]
100102

101103
[tool.ruff]
102104
exclude = ["tests/", "plots/"]
105+
106+
[dependency-groups]
107+
dev = [
108+
"pytest-cov>=6.2.1",
109+
"pytest-xdist>=3.8.0",
110+
]

scripts/analysis_saved.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,30 @@
66
analysis_multi_runs,
77
load_tfb_from_hf,
88
load_gen_results_jsonl,
9-
evaluate,
9+
prover_evaluate,
1010
)
1111

1212

13-
def main(result_dir: str, log_file: str | None = None):
13+
def eval_one_split(result_dir: str, split: str, log_file: str | None = None):
1414
"""
15-
1615
Arguments:
1716
result_dir (str): assumed in format `/some/path/.../<model>/<split>/`.
18-
For example: results/gpt-5-nano-2025-08-07/base, where <model> is `gpt-5-nano-2025-08-07` and <split> is `base`.
17+
For example: results/gpt-5-nano-2025-08-07/base,
18+
where <model> is `gpt-5-nano-2025-08-07` and <split> is `base`.
1919
WARNING: we parse the <model> and <split> in this way.
2020
log_file (str | None): path to the log file. If None, this script only prints to stdout.
2121
"""
2222

23-
result_dir = abspath(result_dir)
23+
result_dir = abspath(pjoin(result_dir, split))
2424
model = basename(dirname(result_dir))
25-
split = basename(result_dir)
2625

2726
tasks = load_tfb_from_hf(split)
2827
# load all jsonl files from `result_dir`
2928
jsonl_files = [
3029
pjoin(result_dir, f) for f in os.listdir(result_dir) if f.endswith(".jsonl")
3130
]
3231
runs = [load_gen_results_jsonl(f) for f in jsonl_files]
33-
accs = [evaluate(tasks, run) for run in runs]
32+
accs = [prover_evaluate(tasks, run, split == "pure") for run in runs]
3433
mean, std = analysis_multi_runs(accs)
3534

3635
print(f"Model: {model}")
@@ -49,5 +48,11 @@ def main(result_dir: str, log_file: str | None = None):
4948
orjsonl.append(log_file, log_obj)
5049

5150

51+
def main(result_dir: str, log_file: str | None = None):
52+
"""run evaluation on all jsonl files in the result directory"""
53+
eval_one_split(result_dir, "base", log_file)
54+
eval_one_split(result_dir, "pure", log_file)
55+
56+
5257
if __name__ == "__main__":
5358
fire.Fire(main)

scripts/error_cls.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from os.path import abspath, basename, join as pjoin
2+
import os
3+
4+
import orjson
5+
from pydantic import BaseModel
6+
from openai import OpenAI
7+
import fire
8+
from tqdm import tqdm
9+
10+
from tfbench import (
11+
load_tfb_from_hf,
12+
load_gen_results_jsonl,
13+
LMAnswer,
14+
)
15+
from tfbench.evaluation import get_incorrect
16+
from tfbench.common import get_prompt as get_task_prompt, BenchmarkTask
17+
18+
19+
PROMPT_TEMPLATE = """
20+
The Haskell type inference task is as follows:
21+
{task}
22+
23+
The ground-truth correct answer is:
24+
{correct_answer}
25+
26+
My incorrect answer is:
27+
{wrong_answer}
28+
29+
My reasoning behind my answer is:
30+
{reasoning}
31+
32+
What mistake did I make?
33+
"""
34+
35+
INSTRUCTION = """
36+
You are a programming language and logic expert.
37+
You will be shown a Haskell type inference task,
38+
an incorrect answer, and the reasoning behind it.
39+
Your job is to identify the mistake in the answer,
40+
and classify the mistake in one word.
41+
The current error categories are:
42+
{categories}.
43+
Choose one category, or construct a new one if you are sure that
44+
none of the current categories fit.
45+
Only output the one-word classification and a short definition of the class.
46+
NOTE that the short definition should be generalized to other tasks that fall in the same category.
47+
"""
48+
49+
50+
class ClsResponse(BaseModel):
51+
category: str
52+
definition: str
53+
54+
def __hash__(self):
55+
return hash(self.category)
56+
57+
58+
def get_prompt(task: BenchmarkTask, answer: LMAnswer) -> str:
59+
prompt = PROMPT_TEMPLATE.format(
60+
task=get_task_prompt(task),
61+
correct_answer=task.signature,
62+
wrong_answer=answer.answer,
63+
reasoning=answer.reasoning_steps,
64+
)
65+
return prompt
66+
67+
68+
def categories_str(categories: set[ClsResponse]) -> str:
69+
"""dump all categories in jsonl format string"""
70+
return "\n".join(orjson.dumps(c.__dict__).decode() for c in categories)
71+
72+
73+
def classify_run(
74+
client: OpenAI,
75+
categories: set[ClsResponse],
76+
tasks: list[BenchmarkTask],
77+
run_result: list[LMAnswer | None],
78+
) -> set[ClsResponse]:
79+
incorrect = get_incorrect(tasks, run_result)
80+
81+
categories_: set[ClsResponse] = categories.copy()
82+
for task, answer in tqdm(incorrect):
83+
assert answer is not None
84+
response = client.responses.parse(
85+
model="gpt-5",
86+
instructions=INSTRUCTION.format(categories=categories_str(categories_)),
87+
input=get_prompt(task, answer),
88+
reasoning={"effort": "medium"},
89+
text_format=ClsResponse,
90+
)
91+
assert response.output_parsed is not None
92+
categories_.add(response.output_parsed)
93+
return categories_
94+
95+
96+
def main(result_file_dir: str):
97+
98+
client = OpenAI()
99+
categories: set[ClsResponse] = set()
100+
101+
split = basename(abspath(result_file_dir))
102+
print(split)
103+
base = load_tfb_from_hf(split)
104+
105+
for file in os.listdir(result_file_dir):
106+
if not file.endswith(".jsonl"):
107+
continue
108+
result_file_path = pjoin(result_file_dir, file)
109+
run_result = load_gen_results_jsonl(result_file_path)
110+
print(f"Processing {result_file_path}")
111+
run_categories = classify_run(
112+
client,
113+
categories,
114+
base,
115+
run_result,
116+
)
117+
categories.update(run_categories)
118+
119+
with open("error_categories.json", "wb") as f:
120+
f.write(
121+
orjson.dumps(
122+
[c.model_dump(mode="json") for c in categories],
123+
option=orjson.OPT_INDENT_2,
124+
)
125+
)
126+
127+
128+
if __name__ == "__main__":
129+
fire.Fire(main)

scripts/preprocess_benchmark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def main(input_raw_benchmark_path: str = "benchmark", output_path: str = "tfb.js
1414

1515
# read in all files ending with .md in the input_raw_benchmark_path
1616
tasks: list[BenchmarkTask] = []
17-
for file in os.listdir(input_raw_benchmark_path):
17+
files = os.listdir(input_raw_benchmark_path)
18+
files_w_order = sorted(files)
19+
for file in files_w_order:
1820
if not file.endswith(".hs.md"):
1921
continue
2022
with open(os.path.join(input_raw_benchmark_path, file), "r") as f:

src/main.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33

44
import fire
55
from orjsonl import orjsonl
6-
from returns.result import Success, Failure
76

8-
from tfbench import run_one_model, analysis_multi_runs
7+
from tfbench import run_one_model, analysis_multi_runs, EvalResult
98

109

1110
def main(
@@ -17,42 +16,38 @@ def main(
1716
"""Main script to run experiments reported in the paper"""
1817

1918
def _run(pure: bool):
20-
results = []
19+
results: list[EvalResult] = []
2120
split = "pure" if pure else "base"
2221
for i in range(n_repeats):
2322
result_dir = abspath(pjoin("results", model, split))
2423
os.makedirs(result_dir, exist_ok=True)
2524
result_file = pjoin(result_dir, f"run-{i}.jsonl")
26-
match run_one_model(
27-
model, pure=pure, output_file=result_file, effort=effort
28-
):
29-
case Success(r):
30-
results.append(r)
31-
case Failure(e):
32-
return Failure(e)
33-
return Success(analysis_multi_runs(results))
25+
r = run_one_model(
26+
model,
27+
pure=pure,
28+
output_file=result_file,
29+
effort=effort,
30+
)
31+
results.append(r)
32+
return analysis_multi_runs(results)
3433

3534
def _eval(pure: bool):
3635
split = "pure" if pure else "base"
3736
print(f"Running {model} on TF-Bench ({split}):")
38-
match _run(pure=pure):
39-
case Success((mean, std)):
40-
print(f"Accuracy: {mean:.4f} ± {std:.4f}")
41-
print("====================================")
42-
orjsonl.append(
43-
log_file,
44-
{
45-
"model": model,
46-
"split": split,
47-
"effort": effort,
48-
"n_repeats": n_repeats,
49-
"mean": mean,
50-
"std": std,
51-
},
52-
)
53-
case Failure(e):
54-
print(f"Error in base run: {e}")
55-
return
37+
mean, std = _run(pure=pure)
38+
print(f"Accuracy: {mean:.4f} ± {std:.4f}")
39+
print("====================================")
40+
orjsonl.append(
41+
log_file,
42+
{
43+
"model": model,
44+
"split": split,
45+
"effort": effort,
46+
"n_repeats": n_repeats,
47+
"mean": mean,
48+
"std": std,
49+
},
50+
)
5651

5752
_eval(pure=False)
5853
_eval(pure=True)

src/tfbench/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dotenv import load_dotenv
22

33
from .experiment import run_one_model
4-
from .evaluation import EvalResult, analysis_multi_runs, evaluate
4+
from .evaluation import EvalResult, analysis_multi_runs, evaluate, prover_evaluate
55
from .load import load_tfb_from_hf, load_gen_results_jsonl
66
from .lm import LMAnswer
77

@@ -12,6 +12,7 @@
1212
"EvalResult",
1313
"analysis_multi_runs",
1414
"evaluate",
15+
"prover_evaluate",
1516
"load_tfb_from_hf",
1617
"load_gen_results_jsonl",
1718
"LMAnswer",

0 commit comments

Comments
 (0)