-
Notifications
You must be signed in to change notification settings - Fork 101
Expand file tree
/
Copy pathAutoFix.py
More file actions
126 lines (103 loc) · 4.04 KB
/
AutoFix.py
File metadata and controls
126 lines (103 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
from enum import IntEnum
from pathlib import Path
import yaml
from patchwork.common.utils.progress_bar import PatchflowProgressBar
from patchwork.common.utils.step_typing import validate_steps_with_inputs
from patchwork.logger import logger
from patchwork.step import Step
from patchwork.steps import (
LLM,
PR,
CallLLM,
CommitChanges,
CreatePR,
ExtractCode,
ExtractModelResponse,
ModifyCode,
PreparePR,
PreparePrompt,
ScanSemgrep,
)
_DEFAULT_PROMPT_JSON = Path(__file__).parent / "default_prompt.json"
_DEFAULT_INPUT_FILE = Path(__file__).parent / "defaults.yml"
class Compatibility(IntEnum):
HIGH = 3
MEDIUM = 2
LOW = 1
UNKNOWN = 0
@staticmethod
def from_str(value: str) -> "Compatibility":
try:
return Compatibility[value.upper()]
except KeyError:
logger.error(f"Invalid compatibility value: {value}")
return Compatibility.UNKNOWN
class AutoFix(Step):
def __init__(self, inputs: dict):
PatchflowProgressBar(self).register_steps(
CallLLM,
CommitChanges,
CreatePR,
ExtractCode,
ExtractModelResponse,
ModifyCode,
PreparePR,
PreparePrompt,
ScanSemgrep,
)
final_inputs = yaml.safe_load(_DEFAULT_INPUT_FILE.read_text())
final_inputs.update(inputs)
self.n = int(final_inputs.get("n", 1))
self.compatibility_threshold = Compatibility.from_str(final_inputs["compatibility"])
if "prompt_id" not in final_inputs.keys():
final_inputs["prompt_id"] = "fixprompt"
if "prompt_template_file" not in final_inputs.keys():
final_inputs["prompt_template_file"] = _DEFAULT_PROMPT_JSON
final_inputs["response_partitions"] = {
"commit_message": ["A. Commit message:", "B. Change summary:"],
"patch_message": ["B. Change summary:", "C. Compatibility Risk:"],
"compatibility": ["C. Compatibility Risk:", "D. Fixed Code:"],
"patch": ["D. Fixed Code:", "```", "\n", "```"],
}
if "pr_title" not in final_inputs.keys():
final_inputs["pr_title"] = f"PatchWork {self.__class__.__name__}"
if "branch_prefix" not in final_inputs.keys():
final_inputs["branch_prefix"] = f"{self.__class__.__name__.lower()}-"
validate_steps_with_inputs(
set(final_inputs.keys()).union({"prompt_values"}), ScanSemgrep, ExtractCode, LLM, ModifyCode, PR
)
self.inputs = final_inputs
def run(self) -> dict:
outputs = ScanSemgrep(self.inputs).run()
self.inputs.update(outputs)
outputs = ExtractCode(self.inputs).run()
self.inputs.update(outputs)
for i in range(self.n):
self.inputs["prompt_values"] = outputs.get("files_to_patch", [])
outputs = LLM(self.inputs).run()
self.inputs.update(outputs)
for extracted_response in self.inputs["extracted_responses"]:
response_compatibility = Compatibility.from_str(
extracted_response.get("compatibility", "UNKNOWN").strip()
)
if response_compatibility < self.compatibility_threshold:
extracted_response.pop("patch", None)
outputs = ModifyCode(self.inputs).run()
self.inputs.update(outputs)
if i == self.n - 1:
break
# validation
self.inputs.pop("sarif_file_path", None)
outputs = ScanSemgrep(self.inputs).run()
self.inputs.update(outputs)
outputs = ExtractCode(self.inputs).run()
self.inputs.update(outputs)
if self.inputs.get("prompt_value_file") is not None:
with open(self.inputs["prompt_value_file"], "r") as fp:
vulns = json.load(fp)
if len(vulns) < 1:
break
outputs = PR(self.inputs).run()
self.inputs.update(outputs)
return self.inputs