diff --git a/pyproject.toml b/pyproject.toml index fafa193..d6497c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "openai==1.99.9", "orjson>=3.11.3", "orjsonl>=1.0.0", + "pokepalette>=0.0.6", "pyarrow>=21.0.0", "pydantic>=2.11.7", "pytest>=8.0.0", diff --git a/scripts/error_analysis.py b/scripts/error_analysis.py new file mode 100644 index 0000000..2babfb3 --- /dev/null +++ b/scripts/error_analysis.py @@ -0,0 +1,65 @@ +from os.path import abspath, basename, join as pjoin +import os +from typing import Literal + +from openai import OpenAI +import fire +from tqdm import tqdm +from orjsonl import orjsonl + +from tfbench import ( + load_tfb_from_hf, + load_gen_results_jsonl, +) +from tfbench.evaluation import get_incorrect +from tfbench.error_analysis import error_analysis, ErrorAnalysisResult + + +def analysis(result_file_dir: str, split: Literal["base", "pure"], output_file: str): + """script to run error analysis fo incorrect TF-Bench tasks""" + client = OpenAI() + tasks = load_tfb_from_hf(split) + model = basename(abspath(result_file_dir)) + + split_result_dir = pjoin(result_file_dir, split) + + incorrect = [] + print(f"Collecting incorrect results from {split_result_dir} on split {split}") + for file in os.listdir(split_result_dir): + if not file.endswith(".jsonl"): + continue + result_file_path = pjoin(split_result_dir, file) + run_result = load_gen_results_jsonl(result_file_path) + + incorrect_of_run = get_incorrect(tasks, run_result, split == "pure") + incorrect.extend(incorrect_of_run) + + print(f"Running error classification on {len(incorrect)} incorrect results") + for task, answer, msg in tqdm(incorrect): + error = error_analysis(client, task, answer, error_msg=msg) + log_obj: ErrorAnalysisResult = { + "model": model, + "split": split, + "task_id": task.task_id, + "ground_truth": task.signature, + "predicted": answer.answer if answer else None, + "error_category": error.category, + "error_explanation": error.explanation, + } + orjsonl.append(output_file, log_obj) + + +def main(result_file_dir: str, output_file: str | None = None): + """run result analysis on both base and pure splits""" + + model = basename(abspath(result_file_dir)) + print(f"Running error analysis for model {model}") + if output_file is None: + output_file = f"{model}.error_analysis.jsonl" + + analysis(result_file_dir, "base", output_file) + analysis(result_file_dir, "pure", output_file) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/scripts/error_cls.py b/scripts/error_cls.py deleted file mode 100644 index bae55e9..0000000 --- a/scripts/error_cls.py +++ /dev/null @@ -1,129 +0,0 @@ -from os.path import abspath, basename, join as pjoin -import os - -import orjson -from pydantic import BaseModel -from openai import OpenAI -import fire -from tqdm import tqdm - -from tfbench import ( - load_tfb_from_hf, - load_gen_results_jsonl, - LMAnswer, -) -from tfbench.evaluation import get_incorrect -from tfbench.common import get_prompt as get_task_prompt, BenchmarkTask - - -PROMPT_TEMPLATE = """ -The Haskell type inference task is as follows: -{task} - -The ground-truth correct answer is: -{correct_answer} - -My incorrect answer is: -{wrong_answer} - -My reasoning behind my answer is: -{reasoning} - -What mistake did I make? -""" - -INSTRUCTION = """ -You are a programming language and logic expert. -You will be shown a Haskell type inference task, -an incorrect answer, and the reasoning behind it. -Your job is to identify the mistake in the answer, -and classify the mistake in one word. -The current error categories are: -{categories}. -Choose one category, or construct a new one if you are sure that -none of the current categories fit. -Only output the one-word classification and a short definition of the class. -NOTE that the short definition should be generalized to other tasks that fall in the same category. -""" - - -class ClsResponse(BaseModel): - category: str - definition: str - - def __hash__(self): - return hash(self.category) - - -def get_prompt(task: BenchmarkTask, answer: LMAnswer) -> str: - prompt = PROMPT_TEMPLATE.format( - task=get_task_prompt(task), - correct_answer=task.signature, - wrong_answer=answer.answer, - reasoning=answer.reasoning_steps, - ) - return prompt - - -def categories_str(categories: set[ClsResponse]) -> str: - """dump all categories in jsonl format string""" - return "\n".join(orjson.dumps(c.__dict__).decode() for c in categories) - - -def classify_run( - client: OpenAI, - categories: set[ClsResponse], - tasks: list[BenchmarkTask], - run_result: list[LMAnswer | None], -) -> set[ClsResponse]: - incorrect = get_incorrect(tasks, run_result) - - categories_: set[ClsResponse] = categories.copy() - for task, answer in tqdm(incorrect): - assert answer is not None - response = client.responses.parse( - model="gpt-5", - instructions=INSTRUCTION.format(categories=categories_str(categories_)), - input=get_prompt(task, answer), - reasoning={"effort": "medium"}, - text_format=ClsResponse, - ) - assert response.output_parsed is not None - categories_.add(response.output_parsed) - return categories_ - - -def main(result_file_dir: str): - - client = OpenAI() - categories: set[ClsResponse] = set() - - split = basename(abspath(result_file_dir)) - print(split) - base = load_tfb_from_hf(split) - - for file in os.listdir(result_file_dir): - if not file.endswith(".jsonl"): - continue - result_file_path = pjoin(result_file_dir, file) - run_result = load_gen_results_jsonl(result_file_path) - print(f"Processing {result_file_path}") - run_categories = classify_run( - client, - categories, - base, - run_result, - ) - categories.update(run_categories) - - with open("error_categories.json", "wb") as f: - f.write( - orjson.dumps( - [c.model_dump(mode="json") for c in categories], - option=orjson.OPT_INDENT_2, - ) - ) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/scripts/error_plot.py b/scripts/error_plot.py new file mode 100644 index 0000000..6850753 --- /dev/null +++ b/scripts/error_plot.py @@ -0,0 +1,181 @@ +from typing import get_args +import os + +import fire +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +import pokepalette + +from tfbench.error_analysis import ErrorCategories + +plt.rcParams["pdf.fonttype"] = 42 +plt.rcParams["ps.fonttype"] = 42 +FONT_SIZE = 20 + +# CMAP = pokepalette.get_colormap("lapras") +CMAP = pokepalette.get_colormap("gengar") + + +def plot_error_categories_pie_charts(df: pd.DataFrame): + """ + Plot pie charts showing the proportion of error categories by split and model. + + Args: + df: pandas DataFrame containing ErrorAnalysisResult data + """ + # Get unique models and splits + models = sorted(df["model"].unique()) + splits = ["base", "pure"] # Ensure consistent ordering + + # Set up the subplot grid + n_models = len(models) + fig, axes = plt.subplots(2, n_models, figsize=(5 * n_models, 10)) + + # Handle case where there's only one model (axes won't be 2D) + if n_models == 1: + axes = axes.reshape(-1, 1) + + # Define colors for consistency across plots + error_categories = list(get_args(ErrorCategories)) + colors = CMAP(np.linspace(0, 1, len(error_categories))) + color_map = dict(zip(error_categories, colors)) + + # Create pie charts for each split-model combination + for split_idx, split in enumerate(splits): + for model_idx, model in enumerate(models): + # Filter data for current split and model + subset = df[(df["split"] == split) & (df["model"] == model)] + + if len(subset) == 0: + # Handle empty subset + axes[split_idx, model_idx].text( + 0.5, + 0.5, + "No Data", + ha="center", + va="center", + transform=axes[split_idx, model_idx].transAxes, + ) + axes[split_idx, model_idx].set_title(f"{model} ({split})") + continue + + # Count error categories and ensure consistent ordering + error_counts = subset["error_category"].value_counts() + + # Prepare data for pie chart with consistent ordering + labels = [] + sizes = [] + plot_colors = [] + + for category in error_categories: + if category in error_counts: + labels.append(category) + sizes.append(error_counts[category]) + plot_colors.append(color_map[category]) + + # Create pie chart (without labels and autopct since we'll add custom percentages) + wedges = axes[split_idx, model_idx].pie( + sizes, + colors=plot_colors, + startangle=90, + )[0] + + # Calculate percentages + total = sum(sizes) + percentages = [(size / total) * 100 for size in sizes] + + # Add percentage labels only for slices >= 5% + for i, (wedge, pct) in enumerate(zip(wedges, percentages)): + if pct >= 5: # Only show percentage if >= 5% + # Get wedge center angle + angle = (wedge.theta2 + wedge.theta1) / 2 + + # Place inside the pie slice + radius = 0.7 + + # Calculate text position + x = radius * np.cos(np.radians(angle)) + y = radius * np.sin(np.radians(angle)) + + # Add percentage text + axes[split_idx, model_idx].text( + x, + y, + f"{pct:.1f}%", + ha="center", + va="center", + fontweight="bold", + fontsize=10, + color="black", + ) + + # Add model names as column headers (only on top row) + for model_idx, model in enumerate(models): + axes[0, model_idx].set_title(model, fontsize=FONT_SIZE) + + # Add split names as row labels (only on leftmost column) + for split_idx, split in enumerate(splits): + axes[split_idx, 0].text( + -0.03, + 0.5, + split.upper(), + transform=axes[split_idx, 0].transAxes, + fontsize=FONT_SIZE, + ha="center", + va="center", + rotation=90, + ) + + # Adjust layout + plt.tight_layout() + plt.subplots_adjust( + top=0.93, + wspace=-0.15, # controls the space between columns + hspace=-0.1, # controls the space between rows + ) + + # Add legend with all error categories that appear anywhere in the data + all_categories_in_data = set(df["error_category"].unique()) + + # Create legend for all categories that appear in the data, in consistent order + legend_elements = [ + plt.Rectangle((0, 0), 1, 1, fc=color_map[cat]) + for cat in error_categories + if cat in all_categories_in_data + ] + legend_labels_filtered = [ + cat for cat in error_categories if cat in all_categories_in_data + ] + + fig.legend( + legend_elements, + legend_labels_filtered, + loc="center", + bbox_to_anchor=(0.5, 0.02), + ncol=min(4, len(legend_labels_filtered)), + fontsize=16, + ) + + return fig + + +def main( + error_analysis_file_dir: str, + output_file: str = "error_analysis_pie_charts.png", +): + # Example usage: + # Assuming your DataFrame is called 'df' + files = [ + os.path.join(error_analysis_file_dir, f) + for f in os.listdir(error_analysis_file_dir) + if f.endswith(".jsonl") + ] + df = pd.concat([pd.read_json(f, lines=True) for f in files], ignore_index=True) + fig = plot_error_categories_pie_charts(df) + + fig.savefig(output_file, dpi=500, bbox_inches="tight") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/src/tfbench/__init__.py b/src/tfbench/__init__.py index 06a1ef6..87f2ee6 100644 --- a/src/tfbench/__init__.py +++ b/src/tfbench/__init__.py @@ -1,7 +1,11 @@ from dotenv import load_dotenv from .experiment import run_one_model -from .evaluation import EvalResult, analysis_multi_runs, evaluate, prover_evaluate +from .evaluation import ( + EvalResult, + analysis_multi_runs, + prover_evaluate, +) from .load import load_tfb_from_hf, load_gen_results_jsonl from .lm import LMAnswer @@ -11,7 +15,6 @@ "run_one_model", "EvalResult", "analysis_multi_runs", - "evaluate", "prover_evaluate", "load_tfb_from_hf", "load_gen_results_jsonl", diff --git a/src/tfbench/error_analysis.py b/src/tfbench/error_analysis.py new file mode 100644 index 0000000..c71e128 --- /dev/null +++ b/src/tfbench/error_analysis.py @@ -0,0 +1,124 @@ +from typing import TypedDict, Literal + +from pydantic import BaseModel +from openai import OpenAI + +from .common import get_prompt as get_task_prompt, BenchmarkTask +from .lm import LMAnswer + +PROMPT_TEMPLATE = """ +The Haskell type inference task is as follows: +{task} + +The ground-truth correct answer is: +{correct_answer} + +My incorrect answer is: +{wrong_answer} + +My reasoning behind my answer is: +{reasoning} + +The error message from GHC's type checker is: +{ghc_error} + +What mistake did I make? +""" + +INSTRUCTION = """ +You are a programming language and logic expert. +You will be shown a Haskell type inference task, +an incorrect answer, and the reasoning behind it. +The type signatures should be alpha-equivalent to the ground-truth answer. +Your job is to identify the mistake in the answer, +and classify the mistake in the following category. +The error categories and their definitions are: + +- OverGeneralization: Choose a type that is too general—used broader polymorphism +(e.g., different input/output type variables) where the most general valid type actually requires them to be the same. + +- UnderGeneralization: Added an unnecessary/stronger type-class constraint that is not provided by the implementation, +making the signature more specific than the most general valid type. + +- ArgOrderMismatch: Right type variables but in the wrong parameter order; +the type's argument sequence doesn't match the implementation (a permutation error, not a generality/constraint issue). + +- ArityMismatch: Provided a type with the wrong number of arguments (too many or too few) compared to the implementation. + +- ConstraintError: Used incorrect type-class constraints that don't align with the implementation's requirements. +The wrong type-class constraints were applied to the type variables. + +- SyntaxError: Provided an answer that is not a valid Haskell type signature. + +- InstructionFollowing: Failed to follow the instructions given in the prompt. + +- ResponseError: No answer was provided, or the answer is completely unrelated to the task. + +The prompt asked to only output the type signature, +but the answer contains additional text or explanation. +Choose one category from the above. +Only output the one-word classification and a short explanation of the why this category fits. +""" + +ErrorCategories = Literal[ + "OverGeneralization", + "UnderGeneralization", + "ArgOrderMismatch", + "ArityMismatch", + "ConstraintError", + "SyntaxError", + "InstructionFollowing", + "ResponseError", +] + + +class ErrorAnalysisResponse(BaseModel): + category: ErrorCategories + explanation: str + + +def get_error_analysis_prompt( + task: BenchmarkTask, answer: LMAnswer, error_msg: str +) -> str: + """construct classification prompt for one task and answer pair""" + prompt = PROMPT_TEMPLATE.format( + task=get_task_prompt(task), + correct_answer=task.signature, + wrong_answer=answer.answer, + reasoning=answer.reasoning_steps, + ghc_error=error_msg, + ) + return prompt + + +def error_analysis( + client: OpenAI, + task: BenchmarkTask, + answer: LMAnswer | None, + error_msg: str, +) -> ErrorAnalysisResponse: + """classify errors for all incorrect answers in the run_result""" + if answer is None: + return ErrorAnalysisResponse( + category="ResponseError", explanation="No answer provided." + ) + + response = client.responses.parse( + model="gpt-5", + instructions=INSTRUCTION, + input=get_error_analysis_prompt(task, answer, error_msg=error_msg), + reasoning={"effort": "medium"}, + text_format=ErrorAnalysisResponse, + ) + assert response.output_parsed is not None + return response.output_parsed + + +class ErrorAnalysisResult(TypedDict): + model: str + split: Literal["base", "pure"] + task_id: str + ground_truth: str + predicted: str | None + error_category: ErrorCategories + error_explanation: str diff --git a/src/tfbench/evaluation.py b/src/tfbench/evaluation.py index e209083..82d727c 100644 --- a/src/tfbench/evaluation.py +++ b/src/tfbench/evaluation.py @@ -5,7 +5,7 @@ import numpy as np from deprecated import deprecated -from returns.result import Success +from returns.result import Success, Failure, Result from .common import BenchmarkTask from .postprocessing import postprocess, TASK_STRATEGIES, RESPONSE_STRATEGIES @@ -104,10 +104,10 @@ def evaluate(tasks: list[BenchmarkTask], results: list[LMAnswer | None]) -> Eval def prove_one_task( task: BenchmarkTask, result: LMAnswer | None, pure: bool = False -) -> bool: +) -> Result[None, str]: """prove two type signatures are equivalent using GHC""" if result is None: - return False + return Failure("Generation Failed") predicted_body = postprocess(result.answer, RESPONSE_STRATEGIES).strip() predicted = f"f :: {predicted_body}" @@ -118,7 +118,7 @@ def prove_one_task( .alt(lambda _: "Syntax Error: Tree-Sitter Parsing Failed") .bind(ghc_prove_equiv) ) - return isinstance(equiv, Success) + return equiv def prover_evaluate( @@ -148,7 +148,7 @@ def prover_evaluate( prove_one_task, zip(tasks, results, [pure] * len(tasks)) ) - n_correct = sum(eval_results) + n_correct = sum(1 for r in eval_results if isinstance(r, Success)) acc = n_correct / len(tasks) return { @@ -165,11 +165,25 @@ def analysis_multi_runs(results: list[EvalResult]) -> tuple[float, float]: def get_incorrect( - tasks: list[BenchmarkTask], results: list[LMAnswer | None] -) -> list[tuple[BenchmarkTask, LMAnswer | None]]: + tasks: list[BenchmarkTask], + results: list[LMAnswer | None], + pure: bool = False, + nproc: int = cpu_count(), +) -> list[tuple[BenchmarkTask, LMAnswer | None, str]]: """Get a list of tasks that were incorrectly answered.""" + + assert len(tasks) == len(results) + + with Pool(processes=nproc) as pool: + eval_results = pool.starmap( + prove_one_task, zip(tasks, results, [pure] * len(tasks)) + ) + incorrect = [] - for task, result in zip(tasks, results): - if not evaluate_one_task(task, result): - incorrect.append((task, result)) + for task, result, eval_result in zip(tasks, results, eval_results): + match eval_result: + case Success(_): + continue + case Failure(message): + incorrect.append((task, result, message)) return incorrect diff --git a/uv.lock b/uv.lock index 7cdcf1c..507b261 100644 --- a/uv.lock +++ b/uv.lock @@ -2177,6 +2177,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/c7/5572fa4a3f45740eaab6ae86fcdf7195b55beac1371ac8c619d880cfe948/pillow-11.3.0-cp314-cp314t-win_arm64.whl", hash = "sha256:79ea0d14d3ebad43ec77ad5272e6ff9bba5b679ef73375ea760261207fa8e0aa", size = 2512835 }, ] +[[package]] +name = "pip" +version = "25.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/16/650289cd3f43d5a2fadfd98c68bd1e1e7f2550a1a5326768cddfbcedb2c5/pip-25.2.tar.gz", hash = "sha256:578283f006390f85bb6282dffb876454593d637f5d1be494b5202ce4877e71f2", size = 1840021 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/3f/945ef7ab14dc4f9d7f40288d2df998d1837ee0888ec3659c813487572faa/pip-25.2-py3-none-any.whl", hash = "sha256:6d67a2b4e7f14d8b31b8b52648866fa717f45a1eb70e83002f4331d07e953717", size = 1752557 }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -2186,6 +2195,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 }, ] +[[package]] +name = "pokepalette" +version = "0.0.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "packaging" }, + { name = "pip" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b8/53/142f55ee0a2a2143f006b0616d082c17252e102bca93a14febf4c5912a75/pokepalette-0.0.6.tar.gz", hash = "sha256:68eed84640a7fe2f9ab3f68ab4370c462ab04030661d74595b938c04e9633a6d", size = 39916 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/34/7c1bcae762550dadbcc65b8259caba42890766d9e08448726f6defa001d2/pokepalette-0.0.6-py3-none-any.whl", hash = "sha256:03b7e6ed1179397e49a71e31e307a271da65f9ebcda2272078590c5ae89ae5c2", size = 66341 }, +] + [[package]] name = "prometheus-client" version = "0.22.1" @@ -3384,6 +3407,7 @@ dependencies = [ { name = "openai" }, { name = "orjson" }, { name = "orjsonl" }, + { name = "pokepalette" }, { name = "pyarrow" }, { name = "pydantic" }, { name = "pytest" }, @@ -3427,6 +3451,7 @@ requires-dist = [ { name = "openai", specifier = "==1.99.9" }, { name = "orjson", specifier = ">=3.11.3" }, { name = "orjsonl", specifier = ">=1.0.0" }, + { name = "pokepalette", specifier = ">=0.0.6" }, { name = "pyarrow", specifier = ">=21.0.0" }, { name = "pydantic", specifier = ">=2.11.7" }, { name = "pytest", specifier = ">=8.0.0" },