diff --git a/.github/workflows/unitttest.yml b/.github/workflows/unitttest.yml index f23191e..aabc9ca 100644 --- a/.github/workflows/unitttest.yml +++ b/.github/workflows/unitttest.yml @@ -18,4 +18,4 @@ jobs: run: uv sync - name: Run Unit Tests - run: uv run pytest + run: uv run pytest -n auto diff --git a/README.md b/README.md index 69cd003..e2f41f0 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,18 @@ cd TF-Bench uv sync # create a virtual environment, and install dependencies ``` +### Haskell + +To run evaluation, you need GHC (the Glasgow Haskell Compiler) installed. +We recommend using [ghcup](https://www.haskell.org/ghcup/) to install. +You can use any version suggested by ghcup. + +```sh +curl --proto '=https' --tlsv1.2 -sSf https://get-ghcup.haskell.org | sh +``` + +Due to the GHC dependency, the evaluation module currently only supports Linux and macOS. + ## Building TF-Bench From Scratch (Optional) ### TF-Bench diff --git a/benchmark/task_19.hs.md b/benchmark/task_19.hs.md index 29c01f0..46d4d6a 100644 --- a/benchmark/task_19.hs.md +++ b/benchmark/task_19.hs.md @@ -7,7 +7,7 @@ Ad-hoc # signature ```haskell -quot :: Integral => a -> a -> a +quot :: Integral a => a -> a -> a ``` # code diff --git a/benchmark/task_59.hs.md b/benchmark/task_59.hs.md index d2c3ba5..6c0e3a9 100644 --- a/benchmark/task_59.hs.md +++ b/benchmark/task_59.hs.md @@ -7,7 +7,7 @@ Parametric # signature ```haskell -foldl :: (b -> a -> b) -> b -> t a -> b +foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b ``` # code diff --git a/pyproject.toml b/pyproject.toml index 8756bbc..fafa193 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "orjson>=3.11.3", "orjsonl>=1.0.0", "pyarrow>=21.0.0", + "pydantic>=2.11.7", "pytest>=8.0.0", "python-dotenv==1.0.1", "requests==2.32.3", @@ -34,6 +35,7 @@ dependencies = [ "tqdm>=4.66.2", "tree-sitter==0.22.3", "tree-sitter-haskell==0.21.0", + "types-deprecated>=1.2.15.20250304", "types-requests>=2.31.0", "vllm>=0.10.1.1", ] @@ -100,3 +102,9 @@ plugins = ["returns.contrib.mypy.returns_plugin"] [tool.ruff] exclude = ["tests/", "plots/"] + +[dependency-groups] +dev = [ + "pytest-cov>=6.2.1", + "pytest-xdist>=3.8.0", +] diff --git a/scripts/analysis_saved.py b/scripts/analysis_saved.py index afe0c2f..3945ea6 100644 --- a/scripts/analysis_saved.py +++ b/scripts/analysis_saved.py @@ -6,23 +6,22 @@ analysis_multi_runs, load_tfb_from_hf, load_gen_results_jsonl, - evaluate, + prover_evaluate, ) -def main(result_dir: str, log_file: str | None = None): +def eval_one_split(result_dir: str, split: str, log_file: str | None = None): """ - Arguments: result_dir (str): assumed in format `/some/path/...///`. - For example: results/gpt-5-nano-2025-08-07/base, where is `gpt-5-nano-2025-08-07` and is `base`. + For example: results/gpt-5-nano-2025-08-07/base, + where is `gpt-5-nano-2025-08-07` and is `base`. WARNING: we parse the and in this way. log_file (str | None): path to the log file. If None, this script only prints to stdout. """ - result_dir = abspath(result_dir) + result_dir = abspath(pjoin(result_dir, split)) model = basename(dirname(result_dir)) - split = basename(result_dir) tasks = load_tfb_from_hf(split) # load all jsonl files from `result_dir` @@ -30,7 +29,7 @@ def main(result_dir: str, log_file: str | None = None): pjoin(result_dir, f) for f in os.listdir(result_dir) if f.endswith(".jsonl") ] runs = [load_gen_results_jsonl(f) for f in jsonl_files] - accs = [evaluate(tasks, run) for run in runs] + accs = [prover_evaluate(tasks, run, split == "pure") for run in runs] mean, std = analysis_multi_runs(accs) print(f"Model: {model}") @@ -49,5 +48,11 @@ def main(result_dir: str, log_file: str | None = None): orjsonl.append(log_file, log_obj) +def main(result_dir: str, log_file: str | None = None): + """run evaluation on all jsonl files in the result directory""" + eval_one_split(result_dir, "base", log_file) + eval_one_split(result_dir, "pure", log_file) + + if __name__ == "__main__": fire.Fire(main) diff --git a/scripts/error_cls.py b/scripts/error_cls.py new file mode 100644 index 0000000..bae55e9 --- /dev/null +++ b/scripts/error_cls.py @@ -0,0 +1,129 @@ +from os.path import abspath, basename, join as pjoin +import os + +import orjson +from pydantic import BaseModel +from openai import OpenAI +import fire +from tqdm import tqdm + +from tfbench import ( + load_tfb_from_hf, + load_gen_results_jsonl, + LMAnswer, +) +from tfbench.evaluation import get_incorrect +from tfbench.common import get_prompt as get_task_prompt, BenchmarkTask + + +PROMPT_TEMPLATE = """ +The Haskell type inference task is as follows: +{task} + +The ground-truth correct answer is: +{correct_answer} + +My incorrect answer is: +{wrong_answer} + +My reasoning behind my answer is: +{reasoning} + +What mistake did I make? +""" + +INSTRUCTION = """ +You are a programming language and logic expert. +You will be shown a Haskell type inference task, +an incorrect answer, and the reasoning behind it. +Your job is to identify the mistake in the answer, +and classify the mistake in one word. +The current error categories are: +{categories}. +Choose one category, or construct a new one if you are sure that +none of the current categories fit. +Only output the one-word classification and a short definition of the class. +NOTE that the short definition should be generalized to other tasks that fall in the same category. +""" + + +class ClsResponse(BaseModel): + category: str + definition: str + + def __hash__(self): + return hash(self.category) + + +def get_prompt(task: BenchmarkTask, answer: LMAnswer) -> str: + prompt = PROMPT_TEMPLATE.format( + task=get_task_prompt(task), + correct_answer=task.signature, + wrong_answer=answer.answer, + reasoning=answer.reasoning_steps, + ) + return prompt + + +def categories_str(categories: set[ClsResponse]) -> str: + """dump all categories in jsonl format string""" + return "\n".join(orjson.dumps(c.__dict__).decode() for c in categories) + + +def classify_run( + client: OpenAI, + categories: set[ClsResponse], + tasks: list[BenchmarkTask], + run_result: list[LMAnswer | None], +) -> set[ClsResponse]: + incorrect = get_incorrect(tasks, run_result) + + categories_: set[ClsResponse] = categories.copy() + for task, answer in tqdm(incorrect): + assert answer is not None + response = client.responses.parse( + model="gpt-5", + instructions=INSTRUCTION.format(categories=categories_str(categories_)), + input=get_prompt(task, answer), + reasoning={"effort": "medium"}, + text_format=ClsResponse, + ) + assert response.output_parsed is not None + categories_.add(response.output_parsed) + return categories_ + + +def main(result_file_dir: str): + + client = OpenAI() + categories: set[ClsResponse] = set() + + split = basename(abspath(result_file_dir)) + print(split) + base = load_tfb_from_hf(split) + + for file in os.listdir(result_file_dir): + if not file.endswith(".jsonl"): + continue + result_file_path = pjoin(result_file_dir, file) + run_result = load_gen_results_jsonl(result_file_path) + print(f"Processing {result_file_path}") + run_categories = classify_run( + client, + categories, + base, + run_result, + ) + categories.update(run_categories) + + with open("error_categories.json", "wb") as f: + f.write( + orjson.dumps( + [c.model_dump(mode="json") for c in categories], + option=orjson.OPT_INDENT_2, + ) + ) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/scripts/preprocess_benchmark.py b/scripts/preprocess_benchmark.py index 818b2c7..cd4ada8 100644 --- a/scripts/preprocess_benchmark.py +++ b/scripts/preprocess_benchmark.py @@ -14,7 +14,9 @@ def main(input_raw_benchmark_path: str = "benchmark", output_path: str = "tfb.js # read in all files ending with .md in the input_raw_benchmark_path tasks: list[BenchmarkTask] = [] - for file in os.listdir(input_raw_benchmark_path): + files = os.listdir(input_raw_benchmark_path) + files_w_order = sorted(files) + for file in files_w_order: if not file.endswith(".hs.md"): continue with open(os.path.join(input_raw_benchmark_path, file), "r") as f: diff --git a/src/main.py b/src/main.py index 8996428..03cf67a 100755 --- a/src/main.py +++ b/src/main.py @@ -3,9 +3,8 @@ import fire from orjsonl import orjsonl -from returns.result import Success, Failure -from tfbench import run_one_model, analysis_multi_runs +from tfbench import run_one_model, analysis_multi_runs, EvalResult def main( @@ -17,42 +16,38 @@ def main( """Main script to run experiments reported in the paper""" def _run(pure: bool): - results = [] + results: list[EvalResult] = [] split = "pure" if pure else "base" for i in range(n_repeats): result_dir = abspath(pjoin("results", model, split)) os.makedirs(result_dir, exist_ok=True) result_file = pjoin(result_dir, f"run-{i}.jsonl") - match run_one_model( - model, pure=pure, output_file=result_file, effort=effort - ): - case Success(r): - results.append(r) - case Failure(e): - return Failure(e) - return Success(analysis_multi_runs(results)) + r = run_one_model( + model, + pure=pure, + output_file=result_file, + effort=effort, + ) + results.append(r) + return analysis_multi_runs(results) def _eval(pure: bool): split = "pure" if pure else "base" print(f"Running {model} on TF-Bench ({split}):") - match _run(pure=pure): - case Success((mean, std)): - print(f"Accuracy: {mean:.4f} ± {std:.4f}") - print("====================================") - orjsonl.append( - log_file, - { - "model": model, - "split": split, - "effort": effort, - "n_repeats": n_repeats, - "mean": mean, - "std": std, - }, - ) - case Failure(e): - print(f"Error in base run: {e}") - return + mean, std = _run(pure=pure) + print(f"Accuracy: {mean:.4f} ± {std:.4f}") + print("====================================") + orjsonl.append( + log_file, + { + "model": model, + "split": split, + "effort": effort, + "n_repeats": n_repeats, + "mean": mean, + "std": std, + }, + ) _eval(pure=False) _eval(pure=True) diff --git a/src/tfbench/__init__.py b/src/tfbench/__init__.py index 1597aa8..06a1ef6 100644 --- a/src/tfbench/__init__.py +++ b/src/tfbench/__init__.py @@ -1,7 +1,7 @@ from dotenv import load_dotenv from .experiment import run_one_model -from .evaluation import EvalResult, analysis_multi_runs, evaluate +from .evaluation import EvalResult, analysis_multi_runs, evaluate, prover_evaluate from .load import load_tfb_from_hf, load_gen_results_jsonl from .lm import LMAnswer @@ -12,6 +12,7 @@ "EvalResult", "analysis_multi_runs", "evaluate", + "prover_evaluate", "load_tfb_from_hf", "load_gen_results_jsonl", "LMAnswer", diff --git a/src/tfbench/add_dependency.py b/src/tfbench/add_dependency.py index adb328d..86e8c61 100644 --- a/src/tfbench/add_dependency.py +++ b/src/tfbench/add_dependency.py @@ -6,8 +6,7 @@ from dacite import from_dict from tfbench.common import extract_function_name -from tfbench.hs_parser import HASKELL_LANGUAGE -from tfbench.hs_parser.ast_util import AST +from tfbench.hs_parser import AST from tfbench.common import BenchmarkTask @@ -25,7 +24,7 @@ def get_func_calls(task: BenchmarkTask) -> set[str]: fn_name = extract_function_name(task) assert fn_name is not None - ast = AST(task.code, HASKELL_LANGUAGE) + ast = AST(task.code) root = ast.root calls: list[str] = ( diff --git a/src/tfbench/dataset.py b/src/tfbench/dataset.py index 69b637a..66789f1 100644 --- a/src/tfbench/dataset.py +++ b/src/tfbench/dataset.py @@ -12,9 +12,8 @@ from tqdm import tqdm from funcy import lmap -from tfbench.hs_parser import HASKELL_LANGUAGE -from tfbench.hs_parser.ast_util import AST, HaskellFunction -from tfbench.hs_parser.polymorphism import get_polymorphic_type +from tfbench.hs_parser import AST, HaskellFunction +from tfbench.hs_parser.type_util import get_polymorphic_type from tfbench.common import remove_comments @@ -48,7 +47,7 @@ def collect_from_file(file_path: str) -> list[dict[str, str]]: with open(file_path, "r", errors="replace") as fp: code = fp.read() - ast = AST(code, HASKELL_LANGUAGE) + ast = AST(code) def _to_json(func: HaskellFunction) -> dict[str, str]: func_id = f"{file_path}--{ast.get_fn_name(func.type_signature).value_or(None)}" diff --git a/src/tfbench/evaluation.py b/src/tfbench/evaluation.py index 7e4c1c5..e209083 100644 --- a/src/tfbench/evaluation.py +++ b/src/tfbench/evaluation.py @@ -1,12 +1,17 @@ from itertools import starmap import re from typing import TypedDict +from multiprocessing import Pool, cpu_count import numpy as np +from deprecated import deprecated +from returns.result import Success from .common import BenchmarkTask from .postprocessing import postprocess, TASK_STRATEGIES, RESPONSE_STRATEGIES from .lm import LMAnswer +from .ghc import get_prover, ghc_prove_equiv +from .type_def import get_type_defs def tokenize_type_signature(sig: str) -> list[str]: @@ -49,6 +54,7 @@ def normalize_type_vars(tokens: list[str]) -> list[str]: return result +@deprecated(reason="Use `ghc_prove_equiv` instead", version="0.1.0") def alpha_equiv(s1: str, s2: str) -> bool: """ Check if two type signatures are 'alpha-equivalent' under @@ -63,8 +69,12 @@ def alpha_equiv(s1: str, s2: str) -> bool: return n1 == n2 -def evaluate_one_task(task: BenchmarkTask, result: LMAnswer) -> bool: +@deprecated(reason="Use `prove_one_task` instead", version="0.1.0") +def evaluate_one_task(task: BenchmarkTask, result: LMAnswer | None) -> bool: """evaluate a single task against its result by alpha equivalence""" + if result is None: + return False + ground_truth = postprocess(task.signature, TASK_STRATEGIES).strip() predicted = postprocess(result.answer, RESPONSE_STRATEGIES).strip() return alpha_equiv(ground_truth, predicted) @@ -76,16 +86,73 @@ class EvalResult(TypedDict): accuracy: float -def evaluate(benchmark_f: list[BenchmarkTask], results: list[LMAnswer]) -> EvalResult: +@deprecated(reason="Use `prover_evaluate` instead", version="0.1.0") +def evaluate(tasks: list[BenchmarkTask], results: list[LMAnswer | None]) -> EvalResult: """evaluate all generation results""" - assert len(benchmark_f) == len(results) - eval_results = starmap(evaluate_one_task, zip(benchmark_f, results)) + assert len(tasks) == len(results) + eval_results = starmap(evaluate_one_task, zip(tasks, results)) + n_correct = sum(eval_results) + acc = n_correct / len(tasks) + + return { + "total": len(tasks), + "n_correct": n_correct, + "accuracy": acc, + } + + +def prove_one_task( + task: BenchmarkTask, result: LMAnswer | None, pure: bool = False +) -> bool: + """prove two type signatures are equivalent using GHC""" + if result is None: + return False + + predicted_body = postprocess(result.answer, RESPONSE_STRATEGIES).strip() + predicted = f"f :: {predicted_body}" + defs = get_type_defs(task) if pure else [] + # only failing case from get_prover is syntax error of generated type signature + equiv = ( + get_prover(task.signature, predicted, defs) + .alt(lambda _: "Syntax Error: Tree-Sitter Parsing Failed") + .bind(ghc_prove_equiv) + ) + return isinstance(equiv, Success) + + +def prover_evaluate( + tasks: list[BenchmarkTask], + results: list[LMAnswer | None], + pure: bool = False, + nproc: int = cpu_count(), +) -> EvalResult: + """evaluate all generation results using GHC to prove equivalence + + NOTE: currently only support the `base` split + + Args: + tasks (list[BenchmarkTask]): list of benchmark tasks + results (list[LMAnswer | None]): list of generation results + pure (bool, optional): whether to evaluate on the `pure` split or not. + Since we use TypeOperators to *prove type equivalence, + we need to define all custom types in the `pure` split. + Defaults to False. + nproc (int, optional): number of processes to use. + Defaults to cpu_count() to use all available CPUs. + """ + assert len(tasks) == len(results) + + with Pool(processes=nproc) as pool: + eval_results = pool.starmap( + prove_one_task, zip(tasks, results, [pure] * len(tasks)) + ) + n_correct = sum(eval_results) - acc = n_correct / len(benchmark_f) + acc = n_correct / len(tasks) return { - "total": len(benchmark_f), + "total": len(tasks), "n_correct": n_correct, "accuracy": acc, } @@ -95,3 +162,14 @@ def analysis_multi_runs(results: list[EvalResult]) -> tuple[float, float]: """calculate mean and std of accuracy of multiple runs""" accs = list(map(lambda r: r["accuracy"], results)) return np.mean(accs).item(), np.std(accs).item() + + +def get_incorrect( + tasks: list[BenchmarkTask], results: list[LMAnswer | None] +) -> list[tuple[BenchmarkTask, LMAnswer | None]]: + """Get a list of tasks that were incorrectly answered.""" + incorrect = [] + for task, result in zip(tasks, results): + if not evaluate_one_task(task, result): + incorrect.append((task, result)) + return incorrect diff --git a/src/tfbench/experiment.py b/src/tfbench/experiment.py index b4b8d86..c7aa2db 100644 --- a/src/tfbench/experiment.py +++ b/src/tfbench/experiment.py @@ -2,11 +2,9 @@ Experiment script """ -import logging from tqdm import tqdm -from returns.result import Success, Failure, ResultE -import orjson +from orjsonl import orjsonl from .common import get_prompt from .evaluation import evaluate, EvalResult @@ -19,7 +17,7 @@ def run_one_model( pure: bool = False, output_file: str | None = None, effort: str | None = None, -) -> ResultE[EvalResult]: +) -> EvalResult: """Running the generation & evaluation pipeline for one pre-supported model Args: @@ -34,22 +32,18 @@ def run_one_model( EvalResult: evaluation result including accuracy """ client = router(model, pure, effort) - if not client: - return Failure(Exception(f"Failed to create client for {model}.")) + assert client is not None, f"Failed to create client for {model}." tasks = load_tfb_from_hf("pure" if pure else "base") - gen_results: list[LMAnswer] = [] + gen_results: list[LMAnswer | None] = [] for task in tqdm(tasks, desc=model): prompt = get_prompt(task) - match client.generate(prompt): - case Success(r): - gen_results.append(r) - if output_file: - with open(output_file, "ab") as file: - file.write(orjson.dumps(r, option=orjson.OPT_APPEND_NEWLINE)) - case Failure(e): - logging.error(f"Error generating response: {e}") - return Failure(e) + + response = client.generate(prompt) + r: LMAnswer | None = response.value_or(None) + gen_results.append(r) + if output_file: + orjsonl.append(output_file, r if r else {"error": str(response.failure())}) eval_acc = evaluate(tasks, gen_results) - return Success(eval_acc) + return eval_acc diff --git a/src/tfbench/ghc.py b/src/tfbench/ghc.py new file mode 100644 index 0000000..7e2210a --- /dev/null +++ b/src/tfbench/ghc.py @@ -0,0 +1,107 @@ +"""Util functions for using GHC""" + +import os +import tempfile +import subprocess +from string import Template # NOTE: use t-string after py3.14 +from returns.result import Result, Success, Failure, safe + +from .hs_parser import get_type_vars, get_type_constraints + + +def ghc_prove_equiv(code: str) -> Result[None, str]: + """let GHC typecheck the given code snippet by compiling it + + Args: + code: Haskell code snippet to typecheck + Returns: + Success(None) if typechecks, Failure(error_message) otherwise + """ + file_name = "Check.hs" + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, file_name) + with open(path, "w") as f: + f.write(code) + # -fno-code = typecheck only; -v0 = quiet + process = subprocess.run( + ["ghc", "-fno-code", "-v0", file_name], + capture_output=True, + text=True, + cwd=tmpdir, + ) + + if process.returncode == 0: + return Success(None) + + return Failure(process.stderr) + + +PROVER = Template( + """ +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ImpredicativeTypes #-} +module Check where + +import Data.Type.Equality + +-- Some predefined types synonyms to avoid name clashes +type Int_ = Int +type Bool_ = Bool +type Char_ = Char +type Float_ = Float +type Double_ = Double +data Natural = Natural + +$new_types + +type TRUTH $truth_vars = $truth_signature +type ANSWER $answer_vars = $answer_signature + +proof :: TRUTH $truth_vars :~: ANSWER $truth_vars +proof = Refl +""" +) + + +def _get_var_str(source_code: str) -> str: + ty_vars = get_type_vars(source_code) + return " ".join(ty_vars) + + +def _get_body_str(source_code: str) -> str: + assert "::" in source_code, "invalid type signature" + return source_code.split("::", 1)[1].strip() + + +def reorder_constraints(type_signature: str) -> str: + """Reorder type classes constrains in a type signature to a canonical order""" + # No type classes to reorder + if "=>" not in type_signature: + return type_signature + + assert "::" in type_signature, "invalid type signature" + prefix, body = type_signature.split("::", 1) + + _, rest = body.split("=>", 1) + constrains = get_type_constraints(type_signature) + constrains.sort() # Sort alphabetically + return f"{prefix}:: ({', '.join(constrains)}) => {rest}" + + +@safe +def get_prover( + ground_truth: str, answer: str, types_defs: list[str] | None = None +) -> str: + """Construct Prover program based on Haskell TypeOperators and Type.Equality, + If this prover compiles, then the answer is equivalent to the ground truth. + """ + types_defs = types_defs or [] + ground_truth = reorder_constraints(ground_truth) + answer = reorder_constraints(answer) + return PROVER.substitute( + new_types="\n".join(types_defs), + truth_vars=_get_var_str(ground_truth), + truth_signature=_get_body_str(ground_truth), + answer_vars=_get_var_str(answer), + answer_signature=_get_body_str(answer), + ) diff --git a/src/tfbench/hoogle.py b/src/tfbench/hoogle.py index 879b6b2..0cc578c 100644 --- a/src/tfbench/hoogle.py +++ b/src/tfbench/hoogle.py @@ -12,8 +12,7 @@ from funcy_chain import Chain from tfbench.common import BenchmarkTask, extract_function_name -from tfbench.hs_parser import HASKELL_LANGUAGE -from tfbench.hs_parser.ast_util import AST +from tfbench.hs_parser import AST from tfbench.manual import MANUAL_TASKS @@ -29,7 +28,7 @@ def generate_variable_banlist(code: str): """ Generates list of variables that are already defined in the code """ - ast = AST(code, HASKELL_LANGUAGE) + ast = AST(code) root = ast.root # Remove variables that are already defined in the code @@ -61,7 +60,7 @@ def get_where_blacklist(task: BenchmarkTask) -> set[str]: where_index = task.code.index("where") where_code = task.code[(where_index + 5) :].strip() - ast = AST(where_code, HASKELL_LANGUAGE) + ast = AST(where_code) root = ast.root ban_list: list[str] = generate_variable_banlist(where_code) @@ -87,7 +86,7 @@ def get_func_calls(task: BenchmarkTask) -> set[str]: assert fn_name is not None print(f"Function: {fn_name}") - ast = AST(task.code, HASKELL_LANGUAGE) + ast = AST(task.code) root = ast.root variables: list[str] = ( diff --git a/src/tfbench/hs_parser/__init__.py b/src/tfbench/hs_parser/__init__.py index 79b1982..77ea621 100644 --- a/src/tfbench/hs_parser/__init__.py +++ b/src/tfbench/hs_parser/__init__.py @@ -1,5 +1,21 @@ -from tree_sitter import Language -import tree_sitter_haskell as ts_haskell +from .ast_util import AST, HASKELL_LANGUAGE, HaskellFunction +from .type_util import ( + PolymorphicType, + get_polymorphic_type, + to_type_node, + get_type_vars, + get_type_constraints, +) +from .extractor import TypeExtractor - -HASKELL_LANGUAGE = Language(ts_haskell.language()) +__all__ = [ + "AST", + "HASKELL_LANGUAGE", + "HaskellFunction", + "PolymorphicType", + "get_polymorphic_type", + "to_type_node", + "get_type_vars", + "get_type_constraints", + "TypeExtractor", +] diff --git a/src/tfbench/hs_parser/ast_util.py b/src/tfbench/hs_parser/ast_util.py index 8d51c34..f4ce11c 100644 --- a/src/tfbench/hs_parser/ast_util.py +++ b/src/tfbench/hs_parser/ast_util.py @@ -2,10 +2,13 @@ from dataclasses import dataclass from tree_sitter import Language, Parser, Tree, Node +import tree_sitter_haskell as ts_haskell from returns.maybe import Maybe, Nothing, Some from funcy_chain import Chain from funcy import lmap +HASKELL_LANGUAGE = Language(ts_haskell.language()) + @dataclass class ASTLoc: @@ -35,7 +38,7 @@ def from_pair(p: tuple[Node, list[Node]]): class AST: """Helper class to build, read, and manipulate ASTs using tree-sitter.""" - def __init__(self, source_code: str, lang: Language) -> None: + def __init__(self, source_code: str) -> None: """ Initializes an AST object with the given source code and language. @@ -45,7 +48,7 @@ def __init__(self, source_code: str, lang: Language) -> None: """ self.src = source_code self.parser = Parser() - self.parser.language = lang # Directly assign the language + self.parser.language = HASKELL_LANGUAGE self.tree: Tree = self.parser.parse(bytes(self.src, "utf8")) @property @@ -96,19 +99,6 @@ def get_fn_name(self, node: Node) -> Maybe[str]: return Nothing return Some(fn_name.strip()) - def get_fn_docstring(self, node: Node) -> Maybe[str]: - """ - Retrieves the docstring associated with a function node. - - Args: - node (Node): The AST node representing a function. - - Returns: - Maybe[str]: A Maybe containing the docstring if found, or Nothing otherwise. - """ - # todo: implement docstring finder - raise NotImplementedError - def func2src(self, func: HaskellFunction) -> tuple[str, str]: """ Converts a `HaskellFunction` object into its corresponding type signature and code source. @@ -203,10 +193,10 @@ def get_all_nodes_of_type( nodes: list[Node] = [] if max_level == 0: return nodes + if node_type is None or root.type == node_type: + nodes.append(root) for child in root.children: - if node_type is None or child.type == node_type: - nodes.append(child) nodes += AST.get_all_nodes_of_type( child, node_type, max_level=max_level - 1 ) diff --git a/src/tfbench/hs_parser/extractor.py b/src/tfbench/hs_parser/extractor.py new file mode 100644 index 0000000..e92122b --- /dev/null +++ b/src/tfbench/hs_parser/extractor.py @@ -0,0 +1,124 @@ +from collections import defaultdict, Counter +from dataclasses import dataclass +from tree_sitter import Node +from .ast_util import AST + + +class TypeExtractor(AST): + """Static analyzer for Haskell type signatures. + NOTE: this analyzer works on the body of a type signature only, + i.e. the part after the `=>` symbol if it has constraints, + or otherwise after the `::` symbol. + The constraints (if any) are handled in other modules. + """ + + def __init__(self, code: str): + super().__init__(code) + self.constructors: dict[str, Counter] = defaultdict(Counter) + self.names: set[str] = set() + + self._analysis_types() + + @property + def type_constructors(self) -> dict[str, int]: + """Get a mapping of type constructor names to their maximum observed arity (i.e. number of parameters).""" + return {k: max(v.keys()) for k, v in self.constructors.items()} + + def _analysis_types(self): + """analysis types in the function signature to fill out self.constructors and self.names""" + sigs = self.get_all_nodes_of_type(self.root, "signature") + functions = self.get_all_nodes_of_type(sigs[0], "function") + if len(functions) > 0: + self._visit(functions[0]) + + def _collect_from_tuple(self, node: Node): + # record tuple arity if you care: arity = count of element children + # then continue walking children + for ch in node.named_children: + self._visit(ch) + + def _visit(self, n: Node): + t = n.type + + if t == "apply": + # Count this application chain once, at the top-most 'apply' only. + parent = n.parent + if not ( + parent + and parent.type == "apply" + and parent.child_by_field_name("constructor") is n + ): + apply_chain = _peel_apply_chain(n) + ctor_name = self.get_src_from_node(apply_chain.constructor) + self.constructors[ctor_name][apply_chain.arity] += 1 + # Recurse into children so we also catch nested names/applications. + for ch in n.named_children: + self._visit(ch) + return + + if t == "constructor": + # Zero-arity constructor occurrence (e.g., `Int`) not part of an apply + parent = n.parent + if not ( + parent + and parent.type == "apply" + and parent.child_by_field_name("constructor") is n + ): + name_node = n.child_by_field_name("name") or ( + n.named_children[0] if n.named_children else None + ) + if name_node: + constructor_name = self.get_src_from_node(name_node) + self.constructors[constructor_name][0] += 1 + # still walk inside + for ch in n.named_children: + self._visit(ch) + return + + if t == "tuple": + self._collect_from_tuple(n) + return + + if t == "name": + # Treat as a plain type variable/name when not under a constructor role. + p = n.parent + # If its parent is 'constructor', it's part of a constructor; skip here. + if p is None or p.type != "constructor": + self.names.add(self.get_src_from_node(n)) + return + + # default: recurse + for ch in n.named_children: + self._visit(ch) + + +@dataclass +class TypeApplyChain: + constructor: Node + arity: int + arguments: list[Node] + + +def _peel_apply_chain(node: Node) -> TypeApplyChain: + """ + Given an (apply ...) subtree, walk left through nested apply nodes to + find the root constructor name and count how many arguments were applied. + # Returns (arity, arg_nodes_list, constructor_node). + """ + args = [] + arity = 0 + cur = node + while cur.type == "apply": + arity += 1 + arg = cur.child_by_field_name("argument") + if arg is not None: + args.append(arg) + # could be 'constructor' or another 'apply' + next_level = cur.child_by_field_name("constructor") + if not next_level: + break + cur = next_level + + # now cur is either a 'constructor' node or a 'name' (rare) + ctor_node = cur + return TypeApplyChain(constructor=ctor_node, arity=arity, arguments=args) diff --git a/src/tfbench/hs_parser/polymorphism.py b/src/tfbench/hs_parser/type_util.py similarity index 53% rename from src/tfbench/hs_parser/polymorphism.py rename to src/tfbench/hs_parser/type_util.py index 3b0a616..1830e2e 100644 --- a/src/tfbench/hs_parser/polymorphism.py +++ b/src/tfbench/hs_parser/type_util.py @@ -1,5 +1,6 @@ from enum import Enum from tree_sitter import Node +from funcy_chain import Chain from .ast_util import AST @@ -48,3 +49,47 @@ def get_polymorphic_type(type_signature: Node) -> PolymorphicType: return PolymorphicType.PARAMETRIC return PolymorphicType.MONO + + +def get_type_vars(source_code: str) -> list[str]: + """extract type variables from a type signature source code. + + NOTE: since GHC proves the `forall` quantification of type variables, + the order of type variables does not really matter + as long as they are **consistent**. + + Args: + source_code (str): the source code of the type signature + + Returns: + list[str]: type variables + """ + ast = AST(source_code=source_code) + sig = ast.get_all_nodes_of_type(ast.root, "signature")[0] + type_node = to_type_node(sig) + + ty_vars = [ + ast.get_src_from_node(n) + for n in ast.get_all_nodes_of_type(type_node, "variable") + ] + return list(dict.fromkeys(ty_vars)) # remove duplicates while preserving order + + +def get_type_constraints(source_code: str) -> list[str]: + """extract type class constraints from a type signature source code""" + assert "=>" in source_code, "no type class constraints found" + + ast = AST(source_code) + signature = ast.get_all_nodes_of_type(ast.root, "signature")[0] + + # context node is the body of type signature + context = ast.get_all_nodes_of_type(signature, "context")[0] + + type_constrains: list[str] = ( + Chain(ast.get_all_nodes_of_type(context.children[0], "apply")) + .map(ast.get_src_from_node) + .map(str.strip) + .filter(lambda c: c[0].isupper()) + .value + ) + return type_constrains diff --git a/src/tfbench/lm/_openai.py b/src/tfbench/lm/_openai.py index ba561af..e36cb9f 100644 --- a/src/tfbench/lm/_openai.py +++ b/src/tfbench/lm/_openai.py @@ -2,6 +2,7 @@ from openai import OpenAI, NOT_GIVEN from openai.types.shared.reasoning_effort import ReasoningEffort +from openai.types.responses.response import Response from ._types import LM, LMAnswer, NoneResponseError @@ -97,11 +98,28 @@ def _gen(self, prompt: str) -> LMAnswer: instructions=self.instruction, input=prompt, reasoning=( - { - "effort": self.effort, - } + {"effort": self.effort, "summary": "detailed"} if self.effort else NOT_GIVEN ), ) - return LMAnswer(answer=response.output_text) + return LMAnswer( + answer=response.output_text, + reasoning_steps=_reasoning_summary(response), + ) + + +def _reasoning_summary(response: Response) -> str: + """helper function to extract response from OpenAI Response API, + implementation follows `openai/types/responses/response.py#L275` + for format, please see + https://platform.openai.com/docs/guides/reasoning?api-mode=responses#reasoning-summaries + """ + texts: list[str] = [] + for output in response.output: + if output.type == "reasoning" and output.summary is not None: + for summary in output.summary: + if summary.type == "summary_text": + texts.append(summary.text) + + return "".join(texts) diff --git a/src/tfbench/load.py b/src/tfbench/load.py index aafb12a..73714c1 100644 --- a/src/tfbench/load.py +++ b/src/tfbench/load.py @@ -33,7 +33,7 @@ def load_tfb_from_hf(split: str = "base") -> list[BenchmarkTask]: return [_cast(d) for d in dataset] -def load_gen_results_jsonl(result_file: str) -> list[LMAnswer]: +def load_gen_results_jsonl(result_file: str) -> list[LMAnswer | None]: """load generation results from a jsonl file""" - objs: list[dict[str, str | None]] = orjsonl.load(result_file) # type: ignore - return [from_dict(LMAnswer, obj) for obj in objs] + objs: list[dict[str, str]] = orjsonl.load(result_file) # type: ignore + return [from_dict(LMAnswer, obj) if "answer" in obj else None for obj in objs] diff --git a/src/tfbench/prelude.py b/src/tfbench/prelude.py index 6873234..ec52c83 100644 --- a/src/tfbench/prelude.py +++ b/src/tfbench/prelude.py @@ -6,8 +6,7 @@ from funcy_chain import Chain from dacite import from_dict -from tfbench.hs_parser import HASKELL_LANGUAGE -from tfbench.hs_parser.ast_util import AST +from tfbench.hs_parser import AST from tfbench.add_dependency import add_dependencies from tfbench.common import clean_tab_spaces, BenchmarkTask, task2md @@ -25,7 +24,7 @@ def main( with open(prelude, "r") as fp: prelude_code = fp.read() - ast = AST(prelude_code, HASKELL_LANGUAGE) + ast = AST(prelude_code) root = ast.root prelude_vars = lmap( ast.get_src_from_node, ast.get_all_nodes_of_type(root, "variable") diff --git a/src/tfbench/type_def.py b/src/tfbench/type_def.py new file mode 100644 index 0000000..6ea920a --- /dev/null +++ b/src/tfbench/type_def.py @@ -0,0 +1,74 @@ +from funcy import lfilter + +from .common import BenchmarkTask +from .hs_parser import AST, get_type_constraints +from .hs_parser.extractor import TypeExtractor + + +def _is_type(code: str, type_name: str) -> bool: + ast = AST(code) + decl = ast.get_all_nodes_of_type(ast.root, "declarations")[0] + decl_fst_child = decl.child(0) + return decl_fst_child is not None and decl_fst_child.type == type_name + + +def is_data_type(code: str) -> bool: + """check if the given line of code is a data type definition""" + return _is_type(code, "data_type") + + +def is_class(code: str) -> bool: + """check if the given line of code is a type class definition""" + return _is_type(code, "class") + + +def def_new_type(type_name: str) -> str: + """construct a new, empty yet unique type definition for a given Monomorphic type name""" + return f"data {type_name} = {type_name}" + + +def def_new_type_class(class_name: str, type_vars: list[str]) -> str: + """construct a new, empty yet unique type class definition for a given Ad-hoc type class name""" + return f"class {class_name} {' '.join(type_vars)}" + + +def def_new_type_constructor(constructor_name: str, type_vars: list[str]) -> str: + """construct a new, empty yet unique type constructor definition for a given type constructor name""" + return f"data {constructor_name} {' '.join(type_vars)}" + + +def is_type_def(code: str) -> bool: + """check if the given line of code is a type definition (data type or type class)""" + return is_data_type(code) or is_class(code) + + +def is_type_defined(type_name: str, type_defs: list[str]) -> bool: + """check if a type name is defined in the given list of type definitions""" + return any(type_name in td for td in type_defs) + + +def get_type_defs(task: BenchmarkTask) -> list[str]: + """Get Haskell type definitions from a BenchmarkTask""" + existing_defs = lfilter(is_type_def, task.dependencies) + + if "=>" in task.signature: + constrains = get_type_constraints(task.signature) + for c in constrains: + [ty_class, *ty_vars] = c.split(" ") + if is_type_defined(ty_class, existing_defs): + continue + existing_defs.append(def_new_type_class(ty_class, ty_vars)) + + extractor = TypeExtractor(task.signature) + for ctor_name, arity in extractor.type_constructors.items(): + if is_type_defined(ctor_name, existing_defs): + continue + type_vars = [f"t{i}" for i in range(arity)] + existing_defs.append(def_new_type_constructor(ctor_name, type_vars)) + + for type_name in extractor.names: + if is_type_defined(type_name, existing_defs): + continue + existing_defs.append(def_new_type(type_name)) + + return list(existing_defs) diff --git a/tests/ast_node_test.py b/tests/ast_node_test.py index 0f73eb6..26c3f55 100644 --- a/tests/ast_node_test.py +++ b/tests/ast_node_test.py @@ -1,17 +1,7 @@ # importing the requests library -from tfbench.hs_parser.ast_util import AST -import json -from dacite import from_dict -import fire from funcy_chain import Chain -import requests -from urllib.parse import quote -from tfbench.common import BenchmarkTask -from tfbench.common import extract_function_name -from tfbench.hs_parser import HASKELL_LANGUAGE -from functools import lru_cache from pprint import pprint -from tree_sitter import Node +from tfbench.hs_parser import AST """ This is a test file for seeing all the Nodes in the AST of certain pieces of code @@ -21,7 +11,7 @@ def main(): assert True code = "lines \"\" = []\nlines s = cons (case break (== '\\n') s of\n (l, s') -> (l, case s' of\n [] -> []\n _:s'' -> lines s''))\n where\n cons ~(h, t) = h : t" - ast = AST(code, HASKELL_LANGUAGE) + ast = AST(code) root = ast.root # Get both types and type classes diff --git a/tests/test_ast.py b/tests/test_ast.py index a6a9ab8..c1873f3 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -1,10 +1,9 @@ -import pytest -from tfbench.hs_parser import HASKELL_LANGUAGE -from tfbench.hs_parser.ast_util import AST, ASTLoc from hypothesis import given import hypothesis.strategies as st from funcy_chain import Chain -from operator import add + +from tfbench.hs_parser import HASKELL_LANGUAGE +from tfbench.hs_parser import AST def test_function_extract(): @@ -36,13 +35,8 @@ def test_function_extract(): @given(st.permutations(hs_lines)) def haskell_lines_in_any_order(lines): code = "\n".join(lines) - ast = AST(code, HASKELL_LANGUAGE) - fs = ( - Chain(ast.get_functions()) - .map(ast.func2src) - .map(lambda xs: "\n".join(xs)) - .value - ) + ast = AST(code) + fs = Chain(ast.get_functions()).map(ast.func2src).map("\n".join).value fs.sort() assert fs[0] == fn_add.strip("\n") diff --git a/tests/test_dependency.py b/tests/test_dependency.py index a9d77b7..ca2a33c 100644 --- a/tests/test_dependency.py +++ b/tests/test_dependency.py @@ -1,7 +1,7 @@ -from tfbench.common import BenchmarkTask -from tfbench.add_dependency import get_func_calls from dacite import from_dict import traceback +from tfbench.common import BenchmarkTask +from tfbench.add_dependency import get_func_calls def test_len(): diff --git a/tests/test_eval_diff.py b/tests/test_eval_diff.py new file mode 100644 index 0000000..38bf147 --- /dev/null +++ b/tests/test_eval_diff.py @@ -0,0 +1,71 @@ +from os.path import abspath, dirname, basename, join as pjoin +import os +from itertools import starmap +from multiprocessing import Pool + +import pytest +import fire +from orjsonl import orjsonl +from tqdm import tqdm +from tfbench import ( + analysis_multi_runs, + load_tfb_from_hf, + load_gen_results_jsonl, + prover_evaluate, +) +from tfbench.ghc import get_prover +from tfbench.evaluation import evaluate_one_task, prove_one_task +from tfbench.common import task2md +from tfbench.type_def import get_type_defs +from tfbench.postprocessing import postprocess, TASK_STRATEGIES, RESPONSE_STRATEGIES + + +def diff_one_file(file_path: str, split: str): + tasks = load_tfb_from_hf(split) + answers = load_gen_results_jsonl(abspath(file_path)) + + old_eval = starmap(evaluate_one_task, zip(tasks, answers)) + with Pool() as pool: + new_eval = pool.starmap( + prove_one_task, zip(tasks, answers, [split == "pure"] * len(tasks)) + ) + + for t, a, o, n in zip(tasks, answers, old_eval, new_eval): + if a is None: + continue + # if o: + # assert n, "both evaluations should return a result" + if o and not n: + print(task2md(t)) + defs = get_type_defs(t) + + predicted_body = postprocess(a.answer, RESPONSE_STRATEGIES).strip() + predicted = f"f :: {predicted_body}" + print(get_prover(t.signature, predicted, defs).unwrap()) + assert False + + +def test_diff_recorded(): + """different test evaluation function with recorded results + Since the new prover evaluation fixes the false negative issue, + we assume if an answer is determined as correct by the old evaluation, + it should also be correct by the new evaluation. + """ + + result_path = abspath("results") + # skip the test if there are not recorded results + if not os.path.exists(result_path): + pytest.skip("No recorded results found, skip the test.") + + # walk the result directory to find all jsonl files + for b, _, f in os.walk(result_path): + for file in f: + if file.endswith(".jsonl"): + file_path = pjoin(b, file) + split = basename(b) + print(f"Diffing {file_path} ...") + diff_one_file(file_path, split) + + +if __name__ == "__main__": + fire.Fire(test_diff_recorded) diff --git a/tests/test_extractor.py b/tests/test_extractor.py new file mode 100644 index 0000000..6279972 --- /dev/null +++ b/tests/test_extractor.py @@ -0,0 +1,27 @@ +from tfbench.hs_parser import TypeExtractor + + +def test_real_cases(): + code = "f:: T1 t1 => t1" + et = TypeExtractor(code) + assert not et.type_constructors + + code = "f:: T1 t1 => T2 -> t1" + et = TypeExtractor(code) + assert not et.type_constructors + assert et.names == {"T2"} + + code = "f:: T1 t1 => T2 T3 -> t1" + et = TypeExtractor(code) + assert et.type_constructors == {"T2": 1} + assert et.names == {"T2", "T3"} + + code = "f:: T1 -> T2 T3 -> Either T1 T3 -> (T1, T3, T2 T3)" + et = TypeExtractor(code) + assert et.type_constructors == {"T2": 1, "Either": 2} + assert et.names == {"T1", "T2", "T3", "Either"} + + code = "g:: Ord a => Int -> Either String a -> T3 T1 T2 T4" + et = TypeExtractor(code) + assert et.type_constructors == {"Either": 2, "T3": 3} + assert et.names == {"Int", "String", "T1", "T2", "T3", "T4", "Either"} diff --git a/tests/test_ghc.py b/tests/test_ghc.py new file mode 100644 index 0000000..9e3855d --- /dev/null +++ b/tests/test_ghc.py @@ -0,0 +1,284 @@ +from returns.result import Result, Success, Failure +from tfbench.ghc import ghc_prove_equiv, get_prover, reorder_constraints + + +def _equiv( + truth: str, + answer: str, + new_types: list[str] | None = None, + should_pass: bool = True, +): + equiv = get_prover(truth, answer, new_types).alt(str).bind(ghc_prove_equiv) + match equiv: + case Success(_): + assert should_pass + case Failure(err): + assert not should_pass, err + + +def _not_equiv( + truth: str, + answer: str, + new_types: list[str] | None = None, +): + _equiv(truth, answer, new_types, should_pass=False) + + +def test_monomorphic(): + """test GHC type equivalence prover for monomorphic types""" + + _equiv("f::Int -> Int", "g ::Int -> Int") + _equiv("f::(Int, Bool) -> Int", "g ::(Int, Bool) -> Int") + _equiv("f::Int -> Bool -> Int", "g ::Int -> Bool -> Int") + _equiv( + "f::(Int -> Bool) -> [Int] -> [Bool]", + "g ::(Int -> Bool) -> [Int] -> [Bool]", + ) + _equiv( + "f::(Int -> Bool) -> [Int] -> Maybe Bool", + "g ::(Int -> Bool) -> [Int] -> Maybe Bool", + ) + + # negative cases + _not_equiv("f::Int -> Int", "g ::Bool -> Bool") + _not_equiv("f::(Int, Bool) -> Int", "g ::(Bool, Int) -> Int") + _not_equiv( + "f::(Int -> Bool) -> [Int] -> [Bool]", + "g ::(Bool -> Int) -> [Int] -> [Bool]", + ) + + +def test_parametric(): + """test GHC type equivalence prover for parametric types""" + + _equiv("f :: a -> a", "g :: b -> b") + _equiv("f :: a -> b -> a", "g :: x -> y -> x") + _equiv("f :: (a, b) -> (b, a)", "g :: (x, y) -> (y, x)") + _equiv("f :: (a -> b) -> [a] -> [b]", "g :: (x -> y) -> [x] -> [y]") + _equiv( + "f :: (a -> b) -> [a] -> Maybe b", + "g :: (x -> y) -> [x] -> Maybe y", + ) + _equiv("g :: t1 -> t2 -> (t1, t2)", "h :: u1 -> u2 -> (u1, u2)") + _equiv( + "h :: (m -> n) -> (n -> o) -> m -> o", + "k :: (x -> y) -> (y -> z) -> x -> z", + ) + _equiv( + "k :: (a -> b) -> (b -> c) -> (c -> d) -> a -> d", + "m :: (x -> y) -> (y -> z) -> (z -> w) -> x -> w", + ) + _equiv( + "m :: (x -> y) -> [x] -> Maybe y", + "n :: (a -> b) -> [a] -> Maybe b", + ) + _equiv( + "n :: (f a -> f b) -> [a] -> [b]", + "p :: (g x -> g y) -> [x] -> [y]", + ) + + # although it is very rare to have these kinds of function types, + # the following should still hold, since a, b, c,d are arbitrary type variables + # it is not really possible to write a function that return any type variable + _equiv("f :: a -> b", "g :: b -> a") + _equiv("f :: a -> b", "g :: c -> d") + + # however, the following should not hold + # although a, b are arbitrary type variables + # `g` requires both input and output to be the same type variable, + # whereas `f` does not. + _not_equiv("f :: a -> b", "g :: a -> a") + + # negative cases + _not_equiv("f :: a -> a", "g :: a -> b") + _not_equiv("f :: a -> b", "g :: a -> a") + + _not_equiv("f :: a -> b", "g :: Int-> Int") + _not_equiv( + "f :: (a, b) -> c", + "g :: (Int, Int) -> Int", + ) + + +def test_mono_para_mixed(): + """test GHC type equivalence prover for mixed monomorphic and parametric types""" + + _equiv("f :: a -> Int", "g :: b -> Int") + _equiv("f :: Int -> a", "g :: Int -> b") + _equiv("f :: (a, Int) -> (Int, a)", "g :: (b, Int) -> (Int, b)") + _equiv( + "f :: (a -> Int) -> [Int] -> [a]", + "g :: (b -> Int) -> [Int] -> [b]", + ) + _equiv( + "f :: (a -> Int) -> [Int] -> Maybe a", + "g :: (b -> Int) -> [Int] -> Maybe b", + ) + _equiv( + "g :: t1 -> Int -> (t1, Int)", + "h :: u1 -> Int -> (u1, Int)", + ) + _equiv( + "h :: (Int -> n) -> (n -> Int) -> Int -> Int", + "k :: (Int -> y) -> (y -> Int) -> Int -> Int", + ) + _equiv( + "k :: (a -> Int) -> (Int -> c) -> (c -> d) -> a -> d", + "m :: (x -> Int) -> (Int -> y) -> (y -> z) -> x -> z", + ) + _equiv( + "m :: (x -> y) -> [Int] -> Maybe y", + "n :: (a -> b) -> [Int] -> Maybe b", + ) + _equiv( + "n :: (f a -> f b) -> [Int] -> [b]", + "p :: (g x -> g y) -> [Int] -> [y]", + ) + + # negative cases + _not_equiv("f :: a -> Int", "g :: a-> a") + _not_equiv("f :: a-> a", "g :: a-> Int") + + _not_equiv("f :: a-> Int", "g :: Int-> Int") + _not_equiv( + "f :: (a, Int) -> Int", + "g :: (Int, Int) -> Int", + ) + + +def test_adhoc(): + """test GHC type equivalence prover for ad-hoc polymorphic types, + i.e., with type class constraints + """ + + _equiv("f :: Eq a => a -> a", "g :: Eq b => b -> b") + _equiv("f :: (Eq a) => a -> a", "g :: (Eq b) => b -> b") + # test with parenthesis around constraints + _equiv("f :: (Eq a) => a -> a", "g :: Eq b => b -> b") + _equiv("f :: (Eq a, Show b) => a -> b -> a", "g :: (Eq x, Show y) => x -> y -> x") + _equiv( + "f :: (Ord a, Show b) => (a, b) -> (b, a)", + "g :: (Ord x, Show y) => (x, y) -> (y, x)", + ) + _equiv( + "f :: (Eq a, Show b) => (a -> b) -> [a] -> [b]", + "g :: (Eq x, Show y) => (x -> y) -> [x] -> [y]", + ) + _equiv( + "f :: (Eq a, Show b) => (a -> b) -> [a] -> Maybe b", + "g :: (Eq x, Show y) => (x -> y) -> [x] -> Maybe y", + ) + _equiv( + "m :: (Eq x, Show y) => (x -> y) -> [x] -> Maybe y", + "n :: (Eq a, Show b) => (a -> b) -> [a] -> Maybe b", + ) + _equiv( + "n :: (Eq (f a), Show (f b)) => (f a -> f b) -> [a] -> [b]", + "p :: (Eq (g x), Show (g y)) => (g x -> g y) -> [x] -> [y]", + ) + + # negative cases + _not_equiv("f :: Eq a => a -> a", "g :: a-> a") + _not_equiv("f :: a-> a", "g :: Eq a => a-> a") + + _not_equiv("f :: Eq a => a-> Int", "g :: Int-> Int") + _not_equiv( + "f :: (Eq a, Show b) => (a, Int) -> Int", + "g :: (Int, Int) -> Int", + ) + + +def test_para_adhoc_mix(): + """test GHC type equivalence prover for mixed parametric and ad-hoc polymorphic types""" + + _equiv("f :: Eq a => a -> Int", "g :: Eq b => b -> Int") + _equiv("f :: Int -> a", "g :: Int -> b") + _equiv("f :: (Eq a) => (a, Int) -> (Int, a)", "g :: (Eq b) => (b, Int) -> (Int, b)") + _equiv( + "f :: (Ord a, Show b) => (a -> Int) -> [Int] -> [a]", + "g :: (Ord x, Show y) => (x -> Int) -> [Int] -> [x]", + ) + _equiv( + "f :: (Eq a, Show b) => (a -> Int) -> [Int] -> Maybe a", + "g :: (Eq x, Show y) => (x -> Int) -> [Int] -> Maybe x", + ) + _equiv( + "g :: Eq t1 => t1 -> Int -> (t1, Int)", + "h :: Eq u1 => u1 -> Int -> (u1, Int)", + ) + _equiv( + "h :: (Int -> n) -> (n -> Int) -> Int -> Int", + "k :: (Int -> y) -> (y -> Int) -> Int -> Int", + ) + _equiv( + "k :: (Eq a, Show c) => (a -> Int) -> (Int -> c) -> (c -> d) -> a -> d", + "m :: (Eq x, Show y) => (x -> Int) -> (Int -> y) -> (y -> z) -> x -> z", + ) + _equiv( + "m :: (Eq x, Show y) => (x -> y) -> [Int] -> Maybe y", + "n :: (Eq a, Show b) => (a -> b) -> [Int] -> Maybe b", + ) + _equiv( + "n :: (Eq (f a), Show (f b)) => (f a -> f b) -> [a] -> [b]", + "p :: (Eq (g x), Show (g y)) => (g x -> g y) -> [x] -> [y]", + ) + _not_equiv( + "n :: (Eq (f a), Show (f b)) => (f a -> f b) -> [Int] -> [b]", + "p :: (Eq (g x), Show (g y)) => (g x -> g y) -> [x] -> [y]", + ) + + +def test_tfb_real(): + """test cases from TF-Bench real tasks, + where the deprecated evaluate failed + """ + + # type -> is right-associative + _equiv( + "uncurry :: (a -> b -> c) -> ((a, b) -> c)", + "g::(a -> b -> c) -> (a, b) -> c", + ) + _equiv( + "(.) :: (b -> c) -> (a -> b) -> a -> c", + "(.) :: (b -> c) -> (a -> b) -> (a -> c)", + ) + _equiv("($) :: (a -> b) -> a -> b", "($) :: (a -> b) -> (a -> b)") + + # type class constraints are commutative + _equiv( + "elem :: (Foldable t, Eq a) => a -> t a -> Bool", + "g :: (Eq a, Foldable t) => a -> t a -> Bool", + ) + + # type alias not expanded + _equiv( + "showList :: Show a => [a] -> ShowS", + "g :: Show a => [a] -> String -> String", + ) + + +def test_reorder(): + """test reorder_type_classes function""" + s1 = "f :: (Eq a, Show a) => a -> String" + s2 = "f :: (Show a, Eq a) => a -> String" + + rs1 = reorder_constraints(s1) + rs2 = reorder_constraints(s2) + + _equiv(rs1, rs2) + _equiv(rs2, rs2) + _equiv(rs1, rs1) + + _equiv(s1, rs2) + _equiv(s2, rs1) + _equiv(s1, s2) + + # single constraint + s3 = "f :: Eq a => a -> String" + rs3 = reorder_constraints(s3) + _equiv(s3, rs3) + + # constrain in body + s4 = "f :: a -> Maybe a" + rs4 = reorder_constraints(s4) + _equiv(s4, rs4) diff --git a/tests/test_ghc_pure.py b/tests/test_ghc_pure.py new file mode 100644 index 0000000..8f1dafc --- /dev/null +++ b/tests/test_ghc_pure.py @@ -0,0 +1,67 @@ +from returns.result import Result, Success, Failure +from funcy import lmap +from tfbench.ghc import ghc_prove_equiv, get_prover, reorder_constraints +from tfbench.type_def import def_new_type, def_new_type_constructor + + +def _equiv( + truth: str, + answer: str, + new_types: list[str] | None = None, + should_pass: bool = True, +): + equiv = get_prover(truth, answer, new_types or []).alt(str).bind(ghc_prove_equiv) + match equiv: + case Success(_): + assert should_pass + case Failure(err): + assert not should_pass, err + + +def _not_equiv( + truth: str, + answer: str, + new_types: list[str] | None = None, +): + _equiv(truth, answer, new_types, should_pass=False) + + +def test_mono(): + """test GHC type equivalence prover for monomorphic types after rewriting""" + # check with type after rewriting, + # i.e. T1, ... T_n + _equiv("f:: T1-> T1", "g ::T1 -> T1", new_types=[def_new_type("T1")]) + _not_equiv( + "f:: T1-> T1", + "g ::T2 -> T2", + new_types=lmap(def_new_type, ["T1", "T2"]), + ) + + _equiv( + "f:: (T1, T2) -> T1", + "g ::(T1, T2) -> T1", + new_types=lmap(def_new_type, ["T1", "T2"]), + ) + _equiv( + "f:: (Int, T2) -> Int", + "g ::(Int, T2) -> Int", + new_types=lmap(def_new_type, ["T2"]), + ) + _not_equiv( + "f:: (Int, T2) -> Int", + "g ::(Int, T2) -> T2", + new_types=lmap(def_new_type, ["T2"]), + ) + + +def test_typeclass_in_body(): + f = "f :: T1 -> T2 T3" + _equiv( + f, + f, + new_types=[ + def_new_type("T1"), + def_new_type("T3"), + def_new_type_constructor("T2", ["a"]), + ], + ) diff --git a/tests/test_polymorphism.py b/tests/test_polymorphism.py index 4902d4c..e1de5ad 100644 --- a/tests/test_polymorphism.py +++ b/tests/test_polymorphism.py @@ -1,9 +1,4 @@ -from tfbench.hs_parser import HASKELL_LANGUAGE -from tfbench.hs_parser.ast_util import AST, ASTLoc, HaskellFunction -from hypothesis import given -import hypothesis.strategies as st -from funcy_chain import Chain -from tfbench.hs_parser.polymorphism import get_polymorphic_type, PolymorphicType +from tfbench.hs_parser import get_polymorphic_type, PolymorphicType, AST def test_polymorphism(): @@ -15,7 +10,7 @@ def test_polymorphism(): f1 :: forall a b. a -> b -> a """ - ast = AST(types, HASKELL_LANGUAGE) + ast = AST(types) fn_addInt, fn_map, fn_id, fn_elem, fn_f1 = ast.get_functions() assert get_polymorphic_type(fn_addInt.type_signature) == PolymorphicType.MONO diff --git a/tests/test_type_def.py b/tests/test_type_def.py new file mode 100644 index 0000000..fb1029e --- /dev/null +++ b/tests/test_type_def.py @@ -0,0 +1,20 @@ +from tfbench.hs_parser import AST +from tfbench.type_def import is_class, is_data_type + + +def test_def_type_checker(): + type1 = "data T1 = T1" + class1 = "class T1 a" + f = "f :: Int -> Int" + assert is_data_type(type1) + assert not is_data_type(class1) + assert is_class(class1) + assert not is_class(type1) + assert not is_class(f) + assert not is_data_type(f) + + +def test_checker_after_rewrite(): + t1 = "data T1 a = Nothing | Just a" + assert is_data_type(t1) + assert not is_class(t1) diff --git a/tests/test_type_vars.py b/tests/test_type_vars.py new file mode 100644 index 0000000..5d3bd7e --- /dev/null +++ b/tests/test_type_vars.py @@ -0,0 +1,125 @@ +from tfbench.hs_parser import get_type_vars + + +def test_monomorphic(): + """ + test type var extraction for monomorphic types, + which should return empty set no matter how + """ + candidates = [ + "f :: Int -> Int", + "f::Int->Int", + "f :: Int -> Int ", + "f:: Int -> Int ", + "f::Int-> Int", + "f ::Int->Int", + "g :: Int -> Int -- with comment", + "g::Int->Int--with comment", + "g:: Int", + "f1 :: Char -> Char\n", + "f2 :: T1 -> T2", + "f3 :: (Int, Char) -> (Char, Int)", + ] + + ty_vars = map(get_type_vars, candidates) + assert all(len(vs) == 0 for vs in ty_vars) + + +def _assert_equal(sig: str, expected: list[str]): + ty_vars = get_type_vars(sig) + assert ty_vars == expected, f"for {sig}, expected {expected}, got {ty_vars}" + + +def test_parametric(): + """test type var extraction for parametric types""" + + _assert_equal("f :: a -> a", ["a"]) + _assert_equal("f :: a -> b -> a", ["a", "b"]) + _assert_equal("f :: (a, b) -> (b, a)", ["a", "b"]) + _assert_equal("f :: (a -> b) -> [a] -> [b]", ["a", "b"]) + _assert_equal("f :: (a -> b) -> [a] -> Maybe b", ["a", "b"]) + _assert_equal("g :: t1 -> t2 -> (t1, t2)", ["t1", "t2"]) + _assert_equal("h :: (m -> n) -> (n -> o) -> m -> o", ["m", "n", "o"]) + _assert_equal( + "k :: (a -> b) -> (b -> c) -> (c -> d) -> a -> d", + ["a", "b", "c", "d"], + ) + _assert_equal("m :: (x -> y) -> [x] -> Maybe y", ["x", "y"]) + _assert_equal("n :: (f a -> f b) -> [a] -> [b]", ["f", "a", "b"]) + + # arbitrary order + _assert_equal("p :: b -> a -> b", ["b", "a"]) + _assert_equal("q :: (b, a) -> (a, b)", ["b", "a"]) + _assert_equal("r :: (f b -> f a) -> [b] -> [a]", ["f", "b", "a"]) + + +def test_mono_para_mixed(): + """test type var extraction for mixed monomorphic and parametric types""" + + _assert_equal("f :: a -> Int", ["a"]) + _assert_equal("f :: Int -> a", ["a"]) + _assert_equal("f :: (a, Int) -> (Int, a)", ["a"]) + _assert_equal("f :: (Int -> a) -> [Int] -> [a]", ["a"]) + _assert_equal("f :: (Int -> a) -> [Int] -> Maybe a", ["a"]) + _assert_equal("g :: t1 -> Int -> (t1, Int)", ["t1"]) + _assert_equal("h :: (Int -> n) -> (n -> Int) -> Int -> Int", ["n"]) + _assert_equal( + "k :: (a -> Int) -> (Int -> c) -> (c -> d) -> a -> d", + ["a", "c", "d"], + ) + _assert_equal("m :: (x -> y) -> [Int] -> Maybe y", ["x", "y"]) + _assert_equal("n :: (f a -> f b) -> [Int] -> [b]", ["f", "a", "b"]) + + # arbitrary order + _assert_equal("p :: b -> Int -> b", ["b"]) + _assert_equal("q :: (b, Int) -> (Int, b)", ["b"]) + _assert_equal("r :: (f b -> f a) -> [Int] -> [a]", ["f", "b", "a"]) + + +def test_adhoc(): + """test type var extraction for ad-hoc polymorphic types, + i.e., with type class constraints + """ + + _assert_equal("f :: Eq a => a -> a", ["a"]) + _assert_equal("f :: (Eq a) => a -> a", ["a"]) + _assert_equal("f :: (Eq a, Show b) => a -> b -> a", ["a", "b"]) + _assert_equal("f :: (Ord a, Show b) => (a, b) -> (b, a)", ["a", "b"]) + _assert_equal( + "f :: (Eq a, Show b) => (a -> b) -> [a] -> [b]", + ["a", "b"], + ) + _assert_equal( + "f :: (Eq a, Show b) => (a -> b) -> [a] -> Maybe b", + ["a", "b"], + ) + _assert_equal( + "m :: (Eq x, Show y) => (x -> y) -> [x] -> Maybe y", + ["x", "y"], + ) + _assert_equal( + "n :: (Eq (f a), Show (f b)) => (f a -> f b) -> [a] -> [b]", + ["f", "a", "b"], + ) + + +def test_para_adhoc_mix(): + """test type var extraction for mixed parametric and ad-hoc polymorphic types""" + _assert_equal("g :: Eq t1 => t1 -> t2 -> (t1, t2)", ["t1", "t2"]) + _assert_equal( + "h :: (Ord m, Eq n) => (m -> n) -> (n -> o) -> m -> o", + ["m", "n", "o"], + ) + # type class constraints order affects the order of type vars + _assert_equal( + "k :: (Eq a, Show c) => (a -> b) -> (b -> c) -> (c -> d) -> a -> d", + ["a", "c", "b", "d"], + ) + + # arbitrary order + _assert_equal("p :: Show b => b -> a -> b", ["b", "a"]) + _assert_equal("q :: Ord b => (b, a) -> (a, b)", ["b", "a"]) + _assert_equal( + "r :: Eq (f b) => (f b -> f a) -> [b] -> [a]", + ["f", "b", "a"], + ) diff --git a/uv.lock b/uv.lock index 1c18d54..7cdcf1c 100644 --- a/uv.lock +++ b/uv.lock @@ -414,6 +414,70 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/8c/469afb6465b853afff216f9528ffda78a915ff880ed58813ba4faf4ba0b6/contourpy-1.3.3-cp314-cp314t-win_arm64.whl", hash = "sha256:b7448cb5a725bb1e35ce88771b86fba35ef418952474492cf7c764059933ff8b", size = 203831 }, ] +[[package]] +name = "coverage" +version = "7.10.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/14/70/025b179c993f019105b79575ac6edb5e084fb0f0e63f15cdebef4e454fb5/coverage-7.10.6.tar.gz", hash = "sha256:f644a3ae5933a552a29dbb9aa2f90c677a875f80ebea028e5a52a4f429044b90", size = 823736 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/26/06/263f3305c97ad78aab066d116b52250dd316e74fcc20c197b61e07eb391a/coverage-7.10.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5b2dd6059938063a2c9fee1af729d4f2af28fd1a545e9b7652861f0d752ebcea", size = 217324 }, + { url = "https://files.pythonhosted.org/packages/e9/60/1e1ded9a4fe80d843d7d53b3e395c1db3ff32d6c301e501f393b2e6c1c1f/coverage-7.10.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:388d80e56191bf846c485c14ae2bc8898aa3124d9d35903fef7d907780477634", size = 217560 }, + { url = "https://files.pythonhosted.org/packages/b8/25/52136173c14e26dfed8b106ed725811bb53c30b896d04d28d74cb64318b3/coverage-7.10.6-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:90cb5b1a4670662719591aa92d0095bb41714970c0b065b02a2610172dbf0af6", size = 249053 }, + { url = "https://files.pythonhosted.org/packages/cb/1d/ae25a7dc58fcce8b172d42ffe5313fc267afe61c97fa872b80ee72d9515a/coverage-7.10.6-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:961834e2f2b863a0e14260a9a273aff07ff7818ab6e66d2addf5628590c628f9", size = 251802 }, + { url = "https://files.pythonhosted.org/packages/f5/7a/1f561d47743710fe996957ed7c124b421320f150f1d38523d8d9102d3e2a/coverage-7.10.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bf9a19f5012dab774628491659646335b1928cfc931bf8d97b0d5918dd58033c", size = 252935 }, + { url = "https://files.pythonhosted.org/packages/6c/ad/8b97cd5d28aecdfde792dcbf646bac141167a5cacae2cd775998b45fabb5/coverage-7.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:99c4283e2a0e147b9c9cc6bc9c96124de9419d6044837e9799763a0e29a7321a", size = 250855 }, + { url = "https://files.pythonhosted.org/packages/33/6a/95c32b558d9a61858ff9d79580d3877df3eb5bc9eed0941b1f187c89e143/coverage-7.10.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:282b1b20f45df57cc508c1e033403f02283adfb67d4c9c35a90281d81e5c52c5", size = 248974 }, + { url = "https://files.pythonhosted.org/packages/0d/9c/8ce95dee640a38e760d5b747c10913e7a06554704d60b41e73fdea6a1ffd/coverage-7.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8cdbe264f11afd69841bd8c0d83ca10b5b32853263ee62e6ac6a0ab63895f972", size = 250409 }, + { url = "https://files.pythonhosted.org/packages/04/12/7a55b0bdde78a98e2eb2356771fd2dcddb96579e8342bb52aa5bc52e96f0/coverage-7.10.6-cp312-cp312-win32.whl", hash = "sha256:a517feaf3a0a3eca1ee985d8373135cfdedfbba3882a5eab4362bda7c7cf518d", size = 219724 }, + { url = "https://files.pythonhosted.org/packages/36/4a/32b185b8b8e327802c9efce3d3108d2fe2d9d31f153a0f7ecfd59c773705/coverage-7.10.6-cp312-cp312-win_amd64.whl", hash = "sha256:856986eadf41f52b214176d894a7de05331117f6035a28ac0016c0f63d887629", size = 220536 }, + { url = "https://files.pythonhosted.org/packages/08/3a/d5d8dc703e4998038c3099eaf77adddb00536a3cec08c8dcd556a36a3eb4/coverage-7.10.6-cp312-cp312-win_arm64.whl", hash = "sha256:acf36b8268785aad739443fa2780c16260ee3fa09d12b3a70f772ef100939d80", size = 219171 }, + { url = "https://files.pythonhosted.org/packages/bd/e7/917e5953ea29a28c1057729c1d5af9084ab6d9c66217523fd0e10f14d8f6/coverage-7.10.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ffea0575345e9ee0144dfe5701aa17f3ba546f8c3bb48db62ae101afb740e7d6", size = 217351 }, + { url = "https://files.pythonhosted.org/packages/eb/86/2e161b93a4f11d0ea93f9bebb6a53f113d5d6e416d7561ca41bb0a29996b/coverage-7.10.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:95d91d7317cde40a1c249d6b7382750b7e6d86fad9d8eaf4fa3f8f44cf171e80", size = 217600 }, + { url = "https://files.pythonhosted.org/packages/0e/66/d03348fdd8df262b3a7fb4ee5727e6e4936e39e2f3a842e803196946f200/coverage-7.10.6-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3e23dd5408fe71a356b41baa82892772a4cefcf758f2ca3383d2aa39e1b7a003", size = 248600 }, + { url = "https://files.pythonhosted.org/packages/73/dd/508420fb47d09d904d962f123221bc249f64b5e56aa93d5f5f7603be475f/coverage-7.10.6-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0f3f56e4cb573755e96a16501a98bf211f100463d70275759e73f3cbc00d4f27", size = 251206 }, + { url = "https://files.pythonhosted.org/packages/e9/1f/9020135734184f439da85c70ea78194c2730e56c2d18aee6e8ff1719d50d/coverage-7.10.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:db4a1d897bbbe7339946ffa2fe60c10cc81c43fab8b062d3fcb84188688174a4", size = 252478 }, + { url = "https://files.pythonhosted.org/packages/a4/a4/3d228f3942bb5a2051fde28c136eea23a761177dc4ff4ef54533164ce255/coverage-7.10.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d8fd7879082953c156d5b13c74aa6cca37f6a6f4747b39538504c3f9c63d043d", size = 250637 }, + { url = "https://files.pythonhosted.org/packages/36/e3/293dce8cdb9a83de971637afc59b7190faad60603b40e32635cbd15fbf61/coverage-7.10.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:28395ca3f71cd103b8c116333fa9db867f3a3e1ad6a084aa3725ae002b6583bc", size = 248529 }, + { url = "https://files.pythonhosted.org/packages/90/26/64eecfa214e80dd1d101e420cab2901827de0e49631d666543d0e53cf597/coverage-7.10.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:61c950fc33d29c91b9e18540e1aed7d9f6787cc870a3e4032493bbbe641d12fc", size = 250143 }, + { url = "https://files.pythonhosted.org/packages/3e/70/bd80588338f65ea5b0d97e424b820fb4068b9cfb9597fbd91963086e004b/coverage-7.10.6-cp313-cp313-win32.whl", hash = "sha256:160c00a5e6b6bdf4e5984b0ef21fc860bc94416c41b7df4d63f536d17c38902e", size = 219770 }, + { url = "https://files.pythonhosted.org/packages/a7/14/0b831122305abcc1060c008f6c97bbdc0a913ab47d65070a01dc50293c2b/coverage-7.10.6-cp313-cp313-win_amd64.whl", hash = "sha256:628055297f3e2aa181464c3808402887643405573eb3d9de060d81531fa79d32", size = 220566 }, + { url = "https://files.pythonhosted.org/packages/83/c6/81a83778c1f83f1a4a168ed6673eeedc205afb562d8500175292ca64b94e/coverage-7.10.6-cp313-cp313-win_arm64.whl", hash = "sha256:df4ec1f8540b0bcbe26ca7dd0f541847cc8a108b35596f9f91f59f0c060bfdd2", size = 219195 }, + { url = "https://files.pythonhosted.org/packages/d7/1c/ccccf4bf116f9517275fa85047495515add43e41dfe8e0bef6e333c6b344/coverage-7.10.6-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:c9a8b7a34a4de3ed987f636f71881cd3b8339f61118b1aa311fbda12741bff0b", size = 218059 }, + { url = "https://files.pythonhosted.org/packages/92/97/8a3ceff833d27c7492af4f39d5da6761e9ff624831db9e9f25b3886ddbca/coverage-7.10.6-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8dd5af36092430c2b075cee966719898f2ae87b636cefb85a653f1d0ba5d5393", size = 218287 }, + { url = "https://files.pythonhosted.org/packages/92/d8/50b4a32580cf41ff0423777a2791aaf3269ab60c840b62009aec12d3970d/coverage-7.10.6-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b0353b0f0850d49ada66fdd7d0c7cdb0f86b900bb9e367024fd14a60cecc1e27", size = 259625 }, + { url = "https://files.pythonhosted.org/packages/7e/7e/6a7df5a6fb440a0179d94a348eb6616ed4745e7df26bf2a02bc4db72c421/coverage-7.10.6-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d6b9ae13d5d3e8aeca9ca94198aa7b3ebbc5acfada557d724f2a1f03d2c0b0df", size = 261801 }, + { url = "https://files.pythonhosted.org/packages/3a/4c/a270a414f4ed5d196b9d3d67922968e768cd971d1b251e1b4f75e9362f75/coverage-7.10.6-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:675824a363cc05781b1527b39dc2587b8984965834a748177ee3c37b64ffeafb", size = 264027 }, + { url = "https://files.pythonhosted.org/packages/9c/8b/3210d663d594926c12f373c5370bf1e7c5c3a427519a8afa65b561b9a55c/coverage-7.10.6-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:692d70ea725f471a547c305f0d0fc6a73480c62fb0da726370c088ab21aed282", size = 261576 }, + { url = "https://files.pythonhosted.org/packages/72/d0/e1961eff67e9e1dba3fc5eb7a4caf726b35a5b03776892da8d79ec895775/coverage-7.10.6-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:851430a9a361c7a8484a36126d1d0ff8d529d97385eacc8dfdc9bfc8c2d2cbe4", size = 259341 }, + { url = "https://files.pythonhosted.org/packages/3a/06/d6478d152cd189b33eac691cba27a40704990ba95de49771285f34a5861e/coverage-7.10.6-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d9369a23186d189b2fc95cc08b8160ba242057e887d766864f7adf3c46b2df21", size = 260468 }, + { url = "https://files.pythonhosted.org/packages/ed/73/737440247c914a332f0b47f7598535b29965bf305e19bbc22d4c39615d2b/coverage-7.10.6-cp313-cp313t-win32.whl", hash = "sha256:92be86fcb125e9bda0da7806afd29a3fd33fdf58fba5d60318399adf40bf37d0", size = 220429 }, + { url = "https://files.pythonhosted.org/packages/bd/76/b92d3214740f2357ef4a27c75a526eb6c28f79c402e9f20a922c295c05e2/coverage-7.10.6-cp313-cp313t-win_amd64.whl", hash = "sha256:6b3039e2ca459a70c79523d39347d83b73f2f06af5624905eba7ec34d64d80b5", size = 221493 }, + { url = "https://files.pythonhosted.org/packages/fc/8e/6dcb29c599c8a1f654ec6cb68d76644fe635513af16e932d2d4ad1e5ac6e/coverage-7.10.6-cp313-cp313t-win_arm64.whl", hash = "sha256:3fb99d0786fe17b228eab663d16bee2288e8724d26a199c29325aac4b0319b9b", size = 219757 }, + { url = "https://files.pythonhosted.org/packages/d3/aa/76cf0b5ec00619ef208da4689281d48b57f2c7fde883d14bf9441b74d59f/coverage-7.10.6-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:6008a021907be8c4c02f37cdc3ffb258493bdebfeaf9a839f9e71dfdc47b018e", size = 217331 }, + { url = "https://files.pythonhosted.org/packages/65/91/8e41b8c7c505d398d7730206f3cbb4a875a35ca1041efc518051bfce0f6b/coverage-7.10.6-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5e75e37f23eb144e78940b40395b42f2321951206a4f50e23cfd6e8a198d3ceb", size = 217607 }, + { url = "https://files.pythonhosted.org/packages/87/7f/f718e732a423d442e6616580a951b8d1ec3575ea48bcd0e2228386805e79/coverage-7.10.6-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0f7cb359a448e043c576f0da00aa8bfd796a01b06aa610ca453d4dde09cc1034", size = 248663 }, + { url = "https://files.pythonhosted.org/packages/e6/52/c1106120e6d801ac03e12b5285e971e758e925b6f82ee9b86db3aa10045d/coverage-7.10.6-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c68018e4fc4e14b5668f1353b41ccf4bc83ba355f0e1b3836861c6f042d89ac1", size = 251197 }, + { url = "https://files.pythonhosted.org/packages/3d/ec/3a8645b1bb40e36acde9c0609f08942852a4af91a937fe2c129a38f2d3f5/coverage-7.10.6-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cd4b2b0707fc55afa160cd5fc33b27ccbf75ca11d81f4ec9863d5793fc6df56a", size = 252551 }, + { url = "https://files.pythonhosted.org/packages/a1/70/09ecb68eeb1155b28a1d16525fd3a9b65fbe75337311a99830df935d62b6/coverage-7.10.6-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:4cec13817a651f8804a86e4f79d815b3b28472c910e099e4d5a0e8a3b6a1d4cb", size = 250553 }, + { url = "https://files.pythonhosted.org/packages/c6/80/47df374b893fa812e953b5bc93dcb1427a7b3d7a1a7d2db33043d17f74b9/coverage-7.10.6-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:f2a6a8e06bbda06f78739f40bfb56c45d14eb8249d0f0ea6d4b3d48e1f7c695d", size = 248486 }, + { url = "https://files.pythonhosted.org/packages/4a/65/9f98640979ecee1b0d1a7164b589de720ddf8100d1747d9bbdb84be0c0fb/coverage-7.10.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:081b98395ced0d9bcf60ada7661a0b75f36b78b9d7e39ea0790bb4ed8da14747", size = 249981 }, + { url = "https://files.pythonhosted.org/packages/1f/55/eeb6603371e6629037f47bd25bef300387257ed53a3c5fdb159b7ac8c651/coverage-7.10.6-cp314-cp314-win32.whl", hash = "sha256:6937347c5d7d069ee776b2bf4e1212f912a9f1f141a429c475e6089462fcecc5", size = 220054 }, + { url = "https://files.pythonhosted.org/packages/15/d1/a0912b7611bc35412e919a2cd59ae98e7ea3b475e562668040a43fb27897/coverage-7.10.6-cp314-cp314-win_amd64.whl", hash = "sha256:adec1d980fa07e60b6ef865f9e5410ba760e4e1d26f60f7e5772c73b9a5b0713", size = 220851 }, + { url = "https://files.pythonhosted.org/packages/ef/2d/11880bb8ef80a45338e0b3e0725e4c2d73ffbb4822c29d987078224fd6a5/coverage-7.10.6-cp314-cp314-win_arm64.whl", hash = "sha256:a80f7aef9535442bdcf562e5a0d5a5538ce8abe6bb209cfbf170c462ac2c2a32", size = 219429 }, + { url = "https://files.pythonhosted.org/packages/83/c0/1f00caad775c03a700146f55536ecd097a881ff08d310a58b353a1421be0/coverage-7.10.6-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:0de434f4fbbe5af4fa7989521c655c8c779afb61c53ab561b64dcee6149e4c65", size = 218080 }, + { url = "https://files.pythonhosted.org/packages/a9/c4/b1c5d2bd7cc412cbeb035e257fd06ed4e3e139ac871d16a07434e145d18d/coverage-7.10.6-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6e31b8155150c57e5ac43ccd289d079eb3f825187d7c66e755a055d2c85794c6", size = 218293 }, + { url = "https://files.pythonhosted.org/packages/3f/07/4468d37c94724bf6ec354e4ec2f205fda194343e3e85fd2e59cec57e6a54/coverage-7.10.6-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:98cede73eb83c31e2118ae8d379c12e3e42736903a8afcca92a7218e1f2903b0", size = 259800 }, + { url = "https://files.pythonhosted.org/packages/82/d8/f8fb351be5fee31690cd8da768fd62f1cfab33c31d9f7baba6cd8960f6b8/coverage-7.10.6-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f863c08f4ff6b64fa8045b1e3da480f5374779ef187f07b82e0538c68cb4ff8e", size = 261965 }, + { url = "https://files.pythonhosted.org/packages/e8/70/65d4d7cfc75c5c6eb2fed3ee5cdf420fd8ae09c4808723a89a81d5b1b9c3/coverage-7.10.6-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2b38261034fda87be356f2c3f42221fdb4171c3ce7658066ae449241485390d5", size = 264220 }, + { url = "https://files.pythonhosted.org/packages/98/3c/069df106d19024324cde10e4ec379fe2fb978017d25e97ebee23002fbadf/coverage-7.10.6-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:0e93b1476b79eae849dc3872faeb0bf7948fd9ea34869590bc16a2a00b9c82a7", size = 261660 }, + { url = "https://files.pythonhosted.org/packages/fc/8a/2974d53904080c5dc91af798b3a54a4ccb99a45595cc0dcec6eb9616a57d/coverage-7.10.6-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:ff8a991f70f4c0cf53088abf1e3886edcc87d53004c7bb94e78650b4d3dac3b5", size = 259417 }, + { url = "https://files.pythonhosted.org/packages/30/38/9616a6b49c686394b318974d7f6e08f38b8af2270ce7488e879888d1e5db/coverage-7.10.6-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ac765b026c9f33044419cbba1da913cfb82cca1b60598ac1c7a5ed6aac4621a0", size = 260567 }, + { url = "https://files.pythonhosted.org/packages/76/16/3ed2d6312b371a8cf804abf4e14895b70e4c3491c6e53536d63fd0958a8d/coverage-7.10.6-cp314-cp314t-win32.whl", hash = "sha256:441c357d55f4936875636ef2cfb3bee36e466dcf50df9afbd398ce79dba1ebb7", size = 220831 }, + { url = "https://files.pythonhosted.org/packages/d5/e5/d38d0cb830abede2adb8b147770d2a3d0e7fecc7228245b9b1ae6c24930a/coverage-7.10.6-cp314-cp314t-win_amd64.whl", hash = "sha256:073711de3181b2e204e4870ac83a7c4853115b42e9cd4d145f2231e12d670930", size = 221950 }, + { url = "https://files.pythonhosted.org/packages/f4/51/e48e550f6279349895b0ffcd6d2a690e3131ba3a7f4eafccc141966d4dea/coverage-7.10.6-cp314-cp314t-win_arm64.whl", hash = "sha256:137921f2bac5559334ba66122b753db6dc5d1cf01eb7b64eb412bb0d064ef35b", size = 219969 }, + { url = "https://files.pythonhosted.org/packages/44/0c/50db5379b615854b5cf89146f8f5bd1d5a9693d7f3a987e269693521c404/coverage-7.10.6-py3-none-any.whl", hash = "sha256:92c4ecf6bf11b2e85fd4d8204814dc26e6a19f0c9d938c207c5cb0eadfcabbe3", size = 208986 }, +] + [[package]] name = "cupy-cuda12x" version = "13.6.0" @@ -567,6 +631,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/15/545e2b6cf2e3be84bc1ed85613edd75b8aea69807a71c26f4ca6a9258e82/email_validator-2.3.0-py3-none-any.whl", hash = "sha256:80f13f623413e6b197ae73bb10bf4eb0908faf509ad8362c5edeb0be7fd450b4", size = 35604 }, ] +[[package]] +name = "execnet" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/ff/b4c0dc78fbe20c3e59c0c7334de0c27eb4001a2b2017999af398bf730817/execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3", size = 166524 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/09/2aea36ff60d16dd8879bdb2f5b3ee0ba8d08cbbdcdfe870e695ce3784385/execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", size = 40612 }, +] + [[package]] name = "fastapi" version = "0.116.1" @@ -2523,6 +2596,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474 }, ] +[[package]] +name = "pytest-cov" +version = "6.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage" }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/18/99/668cade231f434aaa59bbfbf49469068d2ddd945000621d3d165d2e7dd7b/pytest_cov-6.2.1.tar.gz", hash = "sha256:25cc6cc0a5358204b8108ecedc51a9b57b34cc6b8c967cc2c01a4e00d8a67da2", size = 69432 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644 }, +] + +[[package]] +name = "pytest-xdist" +version = "3.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3285,6 +3385,7 @@ dependencies = [ { name = "orjson" }, { name = "orjsonl" }, { name = "pyarrow" }, + { name = "pydantic" }, { name = "pytest" }, { name = "python-dotenv" }, { name = "requests" }, @@ -3296,10 +3397,17 @@ dependencies = [ { name = "tqdm" }, { name = "tree-sitter" }, { name = "tree-sitter-haskell" }, + { name = "types-deprecated" }, { name = "types-requests" }, { name = "vllm" }, ] +[package.dev-dependencies] +dev = [ + { name = "pytest-cov" }, + { name = "pytest-xdist" }, +] + [package.metadata] requires-dist = [ { name = "anthropic", specifier = "==0.49.0" }, @@ -3320,6 +3428,7 @@ requires-dist = [ { name = "orjson", specifier = ">=3.11.3" }, { name = "orjsonl", specifier = ">=1.0.0" }, { name = "pyarrow", specifier = ">=21.0.0" }, + { name = "pydantic", specifier = ">=2.11.7" }, { name = "pytest", specifier = ">=8.0.0" }, { name = "python-dotenv", specifier = "==1.0.1" }, { name = "requests", specifier = "==2.32.3" }, @@ -3331,10 +3440,17 @@ requires-dist = [ { name = "tqdm", specifier = ">=4.66.2" }, { name = "tree-sitter", specifier = "==0.22.3" }, { name = "tree-sitter-haskell", specifier = "==0.21.0" }, + { name = "types-deprecated", specifier = ">=1.2.15.20250304" }, { name = "types-requests", specifier = ">=2.31.0" }, { name = "vllm", specifier = ">=0.10.1.1" }, ] +[package.metadata.requires-dev] +dev = [ + { name = "pytest-cov", specifier = ">=6.2.1" }, + { name = "pytest-xdist", specifier = ">=3.8.0" }, +] + [[package]] name = "tiktoken" version = "0.7.0" @@ -3559,6 +3675,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/76/06dbe78f39b2203d2a47d5facc5df5102d0561e2807396471b5f7c5a30a1/typer-0.16.1-py3-none-any.whl", hash = "sha256:90ee01cb02d9b8395ae21ee3368421faf21fa138cb2a541ed369c08cec5237c9", size = 46397 }, ] +[[package]] +name = "types-deprecated" +version = "1.2.15.20250304" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/67/eeefaaabb03b288aad85483d410452c8bbcbf8b2bd876b0e467ebd97415b/types_deprecated-1.2.15.20250304.tar.gz", hash = "sha256:c329030553029de5cc6cb30f269c11f4e00e598c4241290179f63cda7d33f719", size = 8015 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/e3/c18aa72ab84e0bc127a3a94e93be1a6ac2cb281371d3a45376ab7cfdd31c/types_deprecated-1.2.15.20250304-py3-none-any.whl", hash = "sha256:86a65aa550ea8acf49f27e226b8953288cd851de887970fbbdf2239c116c3107", size = 8553 }, +] + [[package]] name = "types-requests" version = "2.32.4.20250809"