Skip to content

Commit febe8e4

Browse files
committed
run bug run tests and prompts
1 parent 2c97749 commit febe8e4

10 files changed

Lines changed: 425 additions & 35 deletions

File tree

elleelleaime/core/benchmarks/runbugrun/runbugrun.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import subprocess
77
import logging
8-
8+
from tqdm import tqdm
99
import pandas as pd
1010

1111
class RunBugRun(Benchmark):
@@ -25,32 +25,68 @@ def initialize(self) -> None:
2525
logging.info("Initializing RunBugRun benchmark...")
2626

2727
python_path = Path(self.get_path(), 'python_valid0.jsonl')
28-
# test_path = Path(self.get_path(), 'tests_all.jsonl')
28+
test_path = Path(self.get_path(), 'tests_all.jsonl')
2929

3030
python_df = pd.read_json(python_path, lines=True).set_index('problem_id')
31-
31+
test_df = pd.read_json(test_path, lines=True).set_index('id')
32+
33+
subprocess.run(
34+
f"mkdir -p {self.path}/buggy",
35+
shell=True,
36+
capture_output=True,
37+
check=True,
38+
)
39+
40+
subprocess.run(
41+
f"mkdir -p {self.path}/fixed",
42+
shell=True,
43+
capture_output=True,
44+
check=True,
45+
)
46+
47+
buggy_submissions = python_df.drop_duplicates(subset=['buggy_submission_id']).head(10)
48+
3249
for prob_id, (buggy_submission_id, buggy_code, fixed_submission_id, fixed_code) \
33-
in python_df.drop_duplicates(subset=['buggy_submission_id'])[
34-
['buggy_submission_id','buggy_code', 'fixed_submission_id', 'fixed_code']
35-
].iterrows():
36-
37-
buggy_file = Path(self.path, f'{prob_id}_{buggy_submission_id}.py')
38-
fixed_file = Path(self.path, f'{prob_id}_{fixed_submission_id}.py')
50+
in tqdm(
51+
buggy_submissions[['buggy_submission_id','buggy_code', 'fixed_submission_id', 'fixed_code']].iterrows(),
52+
total=len(buggy_submissions)
53+
):
54+
55+
buggy_file = Path(self.path, 'buggy', f'{prob_id}_{buggy_submission_id}.py')
56+
fixed_file = Path(self.path, 'fixed', f'{prob_id}_{buggy_submission_id}.py') # using buggy id for both to maintain file correspondence
57+
58+
with open(buggy_file, 'w') as f:
59+
f.write(buggy_code)
60+
f.write('\n')
61+
62+
with open(fixed_file, 'w') as f:
63+
f.write(fixed_code)
64+
f.write('\n')
3965

4066
run = subprocess.run(
4167
f"""cd {self.get_path()} &&
42-
echo '''{buggy_code}''' > {buggy_file} &&
43-
echo '''{fixed_code}''' > {fixed_file} &&
4468
diff --unified {fixed_file.relative_to(self.path)} {buggy_file.relative_to(self.path)}""",
4569
shell=True,
4670
capture_output=True
4771
)
48-
if run.returncode:
49-
print (run)
5072

5173
diff = PatchSet(run.stdout.decode("utf-8"))
5274
# Change the source file path to point to the buggy version
5375
diff[0].source_file = f"{buggy_file.relative_to(self.path)}"
76+
77+
failing_tests = {}
78+
79+
for test_id, (test_input, test_output) in test_df[test_df.problem_id == prob_id][['input', 'output']].iterrows():
80+
error_code, result = RunBugRunBug.execute_test_case(buggy_file, test_input)
81+
82+
if error_code:
83+
cause = f"""Function with input {test_input.replace('"', "'")} failed with error: {result}"""
84+
elif result != test_output.strip():
85+
cause = f"""Expected function with input {test_input.replace('"', "'")} to output {test_output.replace('"', "'").replace("'", r"\'")} but got {result}"""
86+
else:
87+
continue # skip passing
88+
89+
failing_tests[f"""{test_input} -> {test_output}"""] = cause
5490

55-
self.add_bug(RunBugRunBug(self, f"{prob_id}_{buggy_submission_id}", str(diff)))
91+
self.add_bug(RunBugRunBug(self, f"{prob_id}_{buggy_submission_id}", str(diff), failing_tests))
5692

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,85 @@
11
import subprocess
22
import shutil
33
import os
4-
from elleelleaime.core.benchmarks.benchmark import Benchmark
4+
from pathlib import Path
55

6-
from elleelleaime.core.benchmarks.bug import Bug
6+
from elleelleaime.core.benchmarks.benchmark import Benchmark
7+
from elleelleaime.core.benchmarks.bug import RichBug
78
from elleelleaime.core.benchmarks.test_result import TestResult
89
from elleelleaime.core.benchmarks.compile_result import CompileResult
910

10-
class RunBugRunBug(Bug):
11+
class RunBugRunBug(RichBug):
1112
"""
1213
The class for representing RunBugRun bugs
1314
"""
14-
15-
def __init__(self, benchmark: Benchmark, bid: str, ground_truth: str) -> None:
16-
super().__init__(benchmark, bid, ground_truth, True)
1715

1816
def checkout(self, path: str, fixed: bool = False) -> bool:
19-
pass
17+
# Remove the directory if it exists
18+
shutil.rmtree(path, ignore_errors=True)
19+
# Make the directory
20+
subprocess.run(
21+
f"mkdir -p {path}",
22+
shell=True,
23+
capture_output=True,
24+
check=True,
25+
)
26+
27+
# Checkout the bug is the same as copying the entire benchmark
28+
# Copy source files
29+
cmd = f"cd {self.benchmark.get_path()}; mkdir {path}; cp {'fixed' if fixed else 'buggy'}/{self.identifier}.py {path}"
30+
run = subprocess.run(cmd, shell=True, capture_output=True, check=True)
31+
32+
# Copy test files
33+
# cmd = f"cd {self.benchmark.get_path()}; mkdir -p {path}/java_testcases/junit; cp java_testcases/junit/{self.identifier}_TEST.java {path}/java_testcases/junit; cp java_testcases/junit/QuixFixOracleHelper.java {path}/java_testcases/junit"
34+
# run = subprocess.run(cmd, shell=True, capture_output=True, check=True)
35+
return run.returncode == 0
2036

2137
def compile(self, path: str) -> CompileResult:
22-
pass
38+
file_path = Path(path, f"{self.get_identifier()}.py")
39+
assert file_path.exists()
40+
41+
with open(file_path) as f:
42+
bug_code = f.read()
43+
assert bug_code
44+
45+
try:
46+
compile(bug_code, file_path, 'exec')
47+
return CompileResult(True)
48+
except:
49+
return CompileResult(False)
2350

2451
def test(self, path: str) -> TestResult:
25-
pass
52+
file_path = Path(path, f"{self.get_identifier()}.py")
53+
assert file_path.exists()
54+
55+
for test_case in self.failing_tests:
56+
57+
test_input, test_output = test_case.split(' -> ')
58+
59+
error_code, result = RunBugRunBug.execute_test_case(file_path, test_input)
60+
if error_code:
61+
return TestResult(False)
62+
elif result != test_output.strip():
63+
return TestResult(False)
64+
65+
return TestResult(True)
66+
67+
@staticmethod
68+
def execute_test_case(code_path, test_input):
69+
if test_input.strip():
70+
cmd = f"""echo "{test_input}" | python {code_path}"""
71+
else:
72+
cmd = f"""python {code_path}"""
73+
74+
run = subprocess.run(
75+
cmd,
76+
shell=True,
77+
capture_output=True,
78+
check=False,
79+
)
80+
81+
return run.returncode, run.stderr.decode("utf-8").strip() if run.returncode else run.stdout.decode("utf-8").strip()
82+
83+
def get_src_test_dir(self, path: str) -> str:
84+
pass
85+

elleelleaime/core/utils/benchmarks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from elleelleaime.core.benchmarks.humanevaljava.humanevaljava import HumanEvalJava
44
from elleelleaime.core.benchmarks.quixbugs.quixbugs import QuixBugs
55
from elleelleaime.core.benchmarks.gitbugjava.gitbugjava import GitBugJava
6+
from elleelleaime.core.benchmarks.runbugrun.runbugrun import RunBugRun
67

78
from typing import Optional
89

@@ -11,6 +12,7 @@
1112
"HumanEvalJava": HumanEvalJava,
1213
"QuixBugs": QuixBugs,
1314
"GitBugJava": GitBugJava,
15+
"RunBugRun": RunBugRun
1416
}
1517

1618

elleelleaime/core/utils/python/__init__.py

Whitespace-only changes.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Optional, Tuple, List
2+
from unidiff import PatchSet
3+
from uuid import uuid4
4+
from pathlib import Path
5+
import logging
6+
import getpass, tempfile, difflib, shutil
7+
import subprocess
8+
import re
9+
import ast
10+
11+
from elleelleaime.core.benchmarks.bug import Bug, RichBug
12+
13+
def extract_functions(source_code):
14+
# Parse the source code into an AST
15+
tree = ast.parse(source_code)
16+
17+
# Extract all function definitions
18+
functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)]
19+
20+
# Convert the function nodes back to source code
21+
function_sources= [ast.get_source_segment(source_code, func) for func in functions]
22+
23+
return function_sources
24+
25+
def extract_single_function(bug: Bug) -> Optional[Tuple[str, str]]:
26+
"""
27+
Extracts the buggy and fixed code of single-function bugs.
28+
Returns None is bug is not single-function
29+
30+
Args:
31+
bug (Bug): THe bug to extract the code from
32+
33+
Returns:
34+
Optional[Tuple[str, str]]: None if the bug is not single-function, otherwise a tuple of the form (buggy_code, fixed_code)
35+
"""
36+
buggy_path = Path(
37+
tempfile.gettempdir(),
38+
f"elleelleaime-{getpass.getuser()}",
39+
bug.get_identifier(),
40+
str(uuid4()),
41+
)
42+
fixed_path = Path(
43+
tempfile.gettempdir(),
44+
f"elleelleaime-{getpass.getuser()}",
45+
bug.get_identifier(),
46+
str(uuid4()),
47+
)
48+
49+
try:
50+
# Checkout the buggy and fixed versions of the bug
51+
bug.checkout(str(buggy_path), fixed=False)
52+
bug.checkout(str(fixed_path), fixed=True)
53+
54+
with open(Path(buggy_path, f"{bug.get_identifier()}.py")) as f:
55+
buggy_code = f.read()
56+
57+
with open(Path(fixed_path, f"{bug.get_identifier()}.py")) as f:
58+
fixed_code = f.read()
59+
60+
buggy_functions = extract_functions(buggy_code)
61+
fixed_functions = extract_functions(fixed_code)
62+
63+
assert len(buggy_functions) == len(fixed_functions)
64+
65+
# if len(buggy_functions) == len(fixed_functions) == 1:
66+
# return buggy_functions[0], fixed_functions[0]
67+
68+
# most of run bug run are straight through scripts, not functions
69+
return buggy_code, fixed_code
70+
71+
72+
finally:
73+
# Remove the checked-out bugs
74+
shutil.rmtree(buggy_path, ignore_errors=True)
75+
shutil.rmtree(fixed_path, ignore_errors=True)

elleelleaime/sample/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .strategy import PromptingStrategy
22
from .strategies.infilling import InfillingPrompting
33
from .strategies.instruct import InstructPrompting
4+
from .strategies.instruct_python import InstructPromptingPython
45

56

67
class PromptStrategyRegistry:
@@ -11,6 +12,7 @@ class PromptStrategyRegistry:
1112
__STRATEGIES: dict[str, type] = {
1213
"infilling": InfillingPrompting,
1314
"instruct": InstructPrompting,
15+
"instruct_python": InstructPromptingPython,
1416
}
1517

1618
@classmethod
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Optional, Tuple
2+
from unidiff import PatchSet
3+
4+
from elleelleaime.sample.strategy import PromptingStrategy
5+
from elleelleaime.core.benchmarks.bug import RichBug
6+
from elleelleaime.core.utils.python.python import (
7+
extract_single_function,
8+
# extract_failing_test_cases,
9+
)
10+
11+
12+
class InstructPromptingPython(PromptingStrategy):
13+
"""
14+
Implements instruction prompting strategies.
15+
"""
16+
17+
def __init__(self, **kwargs):
18+
super().__init__("instruct_python")
19+
20+
def instruct(
21+
self, bug: RichBug
22+
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
23+
"""
24+
Builds an instruction prompt for the given bug.
25+
26+
Args:
27+
bug: The bug to generate the prompt for.
28+
Returns:
29+
Tuple: A tuple of the form (buggy_code, fixed_code, prompt).
30+
"""
31+
result = extract_single_function(bug)
32+
if result is None:
33+
return None, None, None
34+
35+
buggy_code, fixed_code = result
36+
37+
failing_test_causes = bug.get_failing_tests()
38+
39+
failing_tests_string = ""
40+
for test_case, cause in failing_test_causes.items():
41+
failing_tests_string += f"""Test `{test_case}`:
42+
```python
43+
assert result == {test_case.split(' -> ')[-1]}
44+
```
45+
Test `{test_case}` error:
46+
```
47+
{cause}
48+
```
49+
50+
"""
51+
52+
prompt = f"""You are an automatic program repair tool. Your task is to fix the provided buggy code.
53+
54+
The following code contains a buggy function:
55+
```python
56+
{buggy_code}
57+
```
58+
59+
The code fails the following tests.
60+
61+
{failing_tests_string}
62+
Please provide a fixed version of the buggy function, and only that function, inside a code block.
63+
"""
64+
65+
return buggy_code, fixed_code, prompt
66+
67+
def prompt(self, bug: RichBug) -> dict[str, Optional[str]]:
68+
"""
69+
Returns the prompt for the given bug.
70+
71+
:param bug: The bug to generate the prompt for.
72+
"""
73+
result = {
74+
"identifier": bug.get_identifier(),
75+
"buggy_code": None,
76+
"fixed_code": None,
77+
"prompt_strategy": self.strategy_name,
78+
"prompt": None,
79+
"ground_truth": bug.get_ground_truth(),
80+
}
81+
82+
diff = PatchSet(bug.get_ground_truth())
83+
84+
# This strategy only supports single-file prompts
85+
if len(diff) != 1:
86+
return result
87+
88+
(
89+
result["buggy_code"],
90+
result["fixed_code"],
91+
result["prompt"],
92+
) = self.instruct(bug)
93+
return result

0 commit comments

Comments
 (0)