Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2617c30
fix: allow generation to fail
EYH0602 Aug 30, 2025
a917fd0
remove unnecessary imports
EYH0602 Aug 30, 2025
507c521
fix: OpenAI response add reasoning summary
EYH0602 Aug 31, 2025
b8a0bf6
fix: load_gen_results_json type
EYH0602 Sep 2, 2025
e35dc5f
fix: analysis_saved script
EYH0602 Sep 2, 2025
bcc82ee
fix: evaluation benchmark name
EYH0602 Sep 2, 2025
cec0c6e
fix: OpenAI response API add summary
EYH0602 Sep 2, 2025
ebb259e
use pydantic-v2
EYH0602 Sep 3, 2025
b50c21d
extract incorrect task-answer pairs
EYH0602 Sep 3, 2025
a0c4f9e
fix: groundtruth error (#63)
EYH0602 Sep 3, 2025
a071c8a
extract type variables from source code
EYH0602 Sep 4, 2025
6eff8f4
add GHC type check by proving type equiv
EYH0602 Sep 4, 2025
c808523
fix: cp -> process
EYH0602 Sep 4, 2025
aafcc13
fix: API change for AST
EYH0602 Sep 4, 2025
e579af1
feat: type prover support new type definition
EYH0602 Sep 4, 2025
0c73488
test: ghc and type_util
EYH0602 Sep 4, 2025
e817200
feat: use prover_evaluate for base split
EYH0602 Sep 4, 2025
abe9b40
test: add real tfbench test cases, which the deprecated evaluation fa…
EYH0602 Sep 5, 2025
38cb54c
alt error to syntax parsing error
EYH0602 Sep 5, 2025
ca3c8b7
feat: typeclass constrains reorder
EYH0602 Sep 5, 2025
f67954c
fix: AST.get_all_nodes_of_type ignores the root itself
EYH0602 Sep 5, 2025
9f77a85
reorder_constraints using compiler frontend static analysis
EYH0602 Sep 5, 2025
0d3e7a7
feat: add type definitions for pure tasks
EYH0602 Sep 5, 2025
a5814a5
test: check type equivalence prover after rewriting mono types
EYH0602 Sep 5, 2025
7a52e9d
fix: handle type classes alone when ading new definitions
EYH0602 Sep 6, 2025
7032a3e
feat: define new types automatically for pure tasks
EYH0602 Sep 9, 2025
fb99ad6
ghc prover remove standalone type class
EYH0602 Sep 9, 2025
5f8b0ff
doc: detaile docstring for prover_evaluate
EYH0602 Sep 9, 2025
3f44e2a
script: analysis_saved run both split
EYH0602 Sep 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unitttest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
run: uv sync

- name: Run Unit Tests
run: uv run pytest
run: uv run pytest -n auto
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ cd TF-Bench
uv sync # create a virtual environment, and install dependencies
```

### Haskell

To run evaluation, you need GHC (the Glasgow Haskell Compiler) installed.
We recommend using [ghcup](https://www.haskell.org/ghcup/) to install.
You can use any version suggested by ghcup.

```sh
curl --proto '=https' --tlsv1.2 -sSf https://get-ghcup.haskell.org | sh
```

Due to the GHC dependency, the evaluation module currently only supports Linux and macOS.

## Building TF-Bench From Scratch (Optional)

### TF-Bench
Expand Down
2 changes: 1 addition & 1 deletion benchmark/task_19.hs.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Ad-hoc

# signature
```haskell
quot :: Integral => a -> a -> a
quot :: Integral a => a -> a -> a
```

# code
Expand Down
2 changes: 1 addition & 1 deletion benchmark/task_59.hs.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Parametric

# signature
```haskell
foldl :: (b -> a -> b) -> b -> t a -> b
foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b
```

# code
Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"orjson>=3.11.3",
"orjsonl>=1.0.0",
"pyarrow>=21.0.0",
"pydantic>=2.11.7",
"pytest>=8.0.0",
"python-dotenv==1.0.1",
"requests==2.32.3",
Expand All @@ -34,6 +35,7 @@ dependencies = [
"tqdm>=4.66.2",
"tree-sitter==0.22.3",
"tree-sitter-haskell==0.21.0",
"types-deprecated>=1.2.15.20250304",
"types-requests>=2.31.0",
"vllm>=0.10.1.1",
]
Expand Down Expand Up @@ -100,3 +102,9 @@ plugins = ["returns.contrib.mypy.returns_plugin"]

[tool.ruff]
exclude = ["tests/", "plots/"]

[dependency-groups]
dev = [
"pytest-cov>=6.2.1",
"pytest-xdist>=3.8.0",
]
19 changes: 12 additions & 7 deletions scripts/analysis_saved.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,30 @@
analysis_multi_runs,
load_tfb_from_hf,
load_gen_results_jsonl,
evaluate,
prover_evaluate,
)


def main(result_dir: str, log_file: str | None = None):
def eval_one_split(result_dir: str, split: str, log_file: str | None = None):
"""

Arguments:
result_dir (str): assumed in format `/some/path/.../<model>/<split>/`.
For example: results/gpt-5-nano-2025-08-07/base, where <model> is `gpt-5-nano-2025-08-07` and <split> is `base`.
For example: results/gpt-5-nano-2025-08-07/base,
where <model> is `gpt-5-nano-2025-08-07` and <split> is `base`.
WARNING: we parse the <model> and <split> in this way.
log_file (str | None): path to the log file. If None, this script only prints to stdout.
"""

result_dir = abspath(result_dir)
result_dir = abspath(pjoin(result_dir, split))
model = basename(dirname(result_dir))
split = basename(result_dir)

tasks = load_tfb_from_hf(split)
# load all jsonl files from `result_dir`
jsonl_files = [
pjoin(result_dir, f) for f in os.listdir(result_dir) if f.endswith(".jsonl")
]
runs = [load_gen_results_jsonl(f) for f in jsonl_files]
accs = [evaluate(tasks, run) for run in runs]
accs = [prover_evaluate(tasks, run, split == "pure") for run in runs]
mean, std = analysis_multi_runs(accs)

print(f"Model: {model}")
Expand All @@ -49,5 +48,11 @@ def main(result_dir: str, log_file: str | None = None):
orjsonl.append(log_file, log_obj)


def main(result_dir: str, log_file: str | None = None):
"""run evaluation on all jsonl files in the result directory"""
eval_one_split(result_dir, "base", log_file)
eval_one_split(result_dir, "pure", log_file)


if __name__ == "__main__":
fire.Fire(main)
129 changes: 129 additions & 0 deletions scripts/error_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from os.path import abspath, basename, join as pjoin
import os

import orjson
from pydantic import BaseModel
from openai import OpenAI
import fire
from tqdm import tqdm

from tfbench import (
load_tfb_from_hf,
load_gen_results_jsonl,
LMAnswer,
)
from tfbench.evaluation import get_incorrect
from tfbench.common import get_prompt as get_task_prompt, BenchmarkTask


PROMPT_TEMPLATE = """
The Haskell type inference task is as follows:
{task}

The ground-truth correct answer is:
{correct_answer}

My incorrect answer is:
{wrong_answer}

My reasoning behind my answer is:
{reasoning}

What mistake did I make?
"""

INSTRUCTION = """
You are a programming language and logic expert.
You will be shown a Haskell type inference task,
an incorrect answer, and the reasoning behind it.
Your job is to identify the mistake in the answer,
and classify the mistake in one word.
The current error categories are:
{categories}.
Choose one category, or construct a new one if you are sure that
none of the current categories fit.
Only output the one-word classification and a short definition of the class.
NOTE that the short definition should be generalized to other tasks that fall in the same category.
"""


class ClsResponse(BaseModel):
category: str
definition: str

def __hash__(self):
return hash(self.category)


def get_prompt(task: BenchmarkTask, answer: LMAnswer) -> str:
prompt = PROMPT_TEMPLATE.format(
task=get_task_prompt(task),
correct_answer=task.signature,
wrong_answer=answer.answer,
reasoning=answer.reasoning_steps,
)
return prompt


def categories_str(categories: set[ClsResponse]) -> str:
"""dump all categories in jsonl format string"""
return "\n".join(orjson.dumps(c.__dict__).decode() for c in categories)


def classify_run(
client: OpenAI,
categories: set[ClsResponse],
tasks: list[BenchmarkTask],
run_result: list[LMAnswer | None],
) -> set[ClsResponse]:
incorrect = get_incorrect(tasks, run_result)

categories_: set[ClsResponse] = categories.copy()
for task, answer in tqdm(incorrect):
assert answer is not None
response = client.responses.parse(
model="gpt-5",
instructions=INSTRUCTION.format(categories=categories_str(categories_)),
input=get_prompt(task, answer),
reasoning={"effort": "medium"},
text_format=ClsResponse,
)
assert response.output_parsed is not None
categories_.add(response.output_parsed)
return categories_


def main(result_file_dir: str):

client = OpenAI()
categories: set[ClsResponse] = set()

split = basename(abspath(result_file_dir))
print(split)
base = load_tfb_from_hf(split)

for file in os.listdir(result_file_dir):
if not file.endswith(".jsonl"):
continue
result_file_path = pjoin(result_file_dir, file)
run_result = load_gen_results_jsonl(result_file_path)
print(f"Processing {result_file_path}")
run_categories = classify_run(
client,
categories,
base,
run_result,
)
categories.update(run_categories)

with open("error_categories.json", "wb") as f:
f.write(
orjson.dumps(
[c.model_dump(mode="json") for c in categories],
option=orjson.OPT_INDENT_2,
)
)


if __name__ == "__main__":
fire.Fire(main)
4 changes: 3 additions & 1 deletion scripts/preprocess_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def main(input_raw_benchmark_path: str = "benchmark", output_path: str = "tfb.js

# read in all files ending with .md in the input_raw_benchmark_path
tasks: list[BenchmarkTask] = []
for file in os.listdir(input_raw_benchmark_path):
files = os.listdir(input_raw_benchmark_path)
files_w_order = sorted(files)
for file in files_w_order:
if not file.endswith(".hs.md"):
continue
with open(os.path.join(input_raw_benchmark_path, file), "r") as f:
Expand Down
53 changes: 24 additions & 29 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import fire
from orjsonl import orjsonl
from returns.result import Success, Failure

from tfbench import run_one_model, analysis_multi_runs
from tfbench import run_one_model, analysis_multi_runs, EvalResult


def main(
Expand All @@ -17,42 +16,38 @@ def main(
"""Main script to run experiments reported in the paper"""

def _run(pure: bool):
results = []
results: list[EvalResult] = []
split = "pure" if pure else "base"
for i in range(n_repeats):
result_dir = abspath(pjoin("results", model, split))
os.makedirs(result_dir, exist_ok=True)
result_file = pjoin(result_dir, f"run-{i}.jsonl")
match run_one_model(
model, pure=pure, output_file=result_file, effort=effort
):
case Success(r):
results.append(r)
case Failure(e):
return Failure(e)
return Success(analysis_multi_runs(results))
r = run_one_model(
model,
pure=pure,
output_file=result_file,
effort=effort,
)
results.append(r)
return analysis_multi_runs(results)

def _eval(pure: bool):
split = "pure" if pure else "base"
print(f"Running {model} on TF-Bench ({split}):")
match _run(pure=pure):
case Success((mean, std)):
print(f"Accuracy: {mean:.4f} ± {std:.4f}")
print("====================================")
orjsonl.append(
log_file,
{
"model": model,
"split": split,
"effort": effort,
"n_repeats": n_repeats,
"mean": mean,
"std": std,
},
)
case Failure(e):
print(f"Error in base run: {e}")
return
mean, std = _run(pure=pure)
print(f"Accuracy: {mean:.4f} ± {std:.4f}")
print("====================================")
orjsonl.append(
log_file,
{
"model": model,
"split": split,
"effort": effort,
"n_repeats": n_repeats,
"mean": mean,
"std": std,
},
)

_eval(pure=False)
_eval(pure=True)
Expand Down
3 changes: 2 additions & 1 deletion src/tfbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dotenv import load_dotenv

from .experiment import run_one_model
from .evaluation import EvalResult, analysis_multi_runs, evaluate
from .evaluation import EvalResult, analysis_multi_runs, evaluate, prover_evaluate
from .load import load_tfb_from_hf, load_gen_results_jsonl
from .lm import LMAnswer

Expand All @@ -12,6 +12,7 @@
"EvalResult",
"analysis_multi_runs",
"evaluate",
"prover_evaluate",
"load_tfb_from_hf",
"load_gen_results_jsonl",
"LMAnswer",
Expand Down
5 changes: 2 additions & 3 deletions src/tfbench/add_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from dacite import from_dict

from tfbench.common import extract_function_name
from tfbench.hs_parser import HASKELL_LANGUAGE
from tfbench.hs_parser.ast_util import AST
from tfbench.hs_parser import AST
from tfbench.common import BenchmarkTask


Expand All @@ -25,7 +24,7 @@ def get_func_calls(task: BenchmarkTask) -> set[str]:
fn_name = extract_function_name(task)
assert fn_name is not None

ast = AST(task.code, HASKELL_LANGUAGE)
ast = AST(task.code)
root = ast.root

calls: list[str] = (
Expand Down
7 changes: 3 additions & 4 deletions src/tfbench/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from tqdm import tqdm
from funcy import lmap

from tfbench.hs_parser import HASKELL_LANGUAGE
from tfbench.hs_parser.ast_util import AST, HaskellFunction
from tfbench.hs_parser.polymorphism import get_polymorphic_type
from tfbench.hs_parser import AST, HaskellFunction
from tfbench.hs_parser.type_util import get_polymorphic_type
from tfbench.common import remove_comments


Expand Down Expand Up @@ -48,7 +47,7 @@ def collect_from_file(file_path: str) -> list[dict[str, str]]:
with open(file_path, "r", errors="replace") as fp:
code = fp.read()

ast = AST(code, HASKELL_LANGUAGE)
ast = AST(code)

def _to_json(func: HaskellFunction) -> dict[str, str]:
func_id = f"{file_path}--{ast.get_fn_name(func.type_signature).value_or(None)}"
Expand Down
Loading