Skip to content

Commit a16d664

Browse files
committed
use ASTfor transform and testgen
1 parent 5a47088 commit a16d664

1 file changed

Lines changed: 209 additions & 101 deletions

File tree

fuzz/collect_fuzz_python.py

Lines changed: 209 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
usage: PYTHONPATH=. python3 fuzz/collect_fuzz_python.py --pipeline all
44
"""
55
from pathlib import Path
6+
import ast
7+
import astunparse
68
import logging
7-
from typing import Optional, List, Tuple
9+
from typing import Optional
810
import fire
911
import os
1012
from UniTSyn.frontend.util import wrap_repo, parallel_subprocess
@@ -159,10 +161,28 @@ def fuzz_repos(repos: list[str], jobs: int, timeout: int):
159161
logging.info(f"Starting parallel fuzzing with {jobs} jobs, timeout={timeout}s per target")
160162
parallel_subprocess(all_targets, jobs, lambda p: fuzz_one_target(p, timeout), on_exit=None)
161163

164+
def transform_repos(repos: list[str], jobs: int):
165+
"""
166+
Generate test templates for all targets
167+
168+
Args:
169+
repos (list[str]): List of repository paths
170+
jobs (int): Number of parallel tasks
171+
"""
172+
logging.info("Generating test templates")
173+
174+
def _transform_repo(repo: str):
175+
project_name = os.path.basename(repo)
176+
oss_fuzz_dir = Path(repo).parent.parent
177+
targets = discover_targets(project_name, oss_fuzz_dir)
178+
return [generate_test_template(t, repo) for t in targets]
179+
180+
with ProcessingPool(jobs) as p:
181+
return list(p.map(_transform_repo, repos))
182+
162183
def generate_test_template(target_name: str, repo_path: str):
163184
"""
164-
Generate Python test template for a single target by stripping license header,
165-
main() block, and converting TestInput/TestOneInput to test_ with data=b"".
185+
Generate Python test template using AST for more precise code transformations
166186
"""
167187
src_file = pjoin(repo_path, target_name + ".py")
168188
if not os.path.exists(src_file):
@@ -184,68 +204,22 @@ def generate_test_template(target_name: str, repo_path: str):
184204
)
185205
code_no_license = re.sub(license_pattern, "", original_code, count=1)
186206

187-
# --- 2. Remove main function and if __name__ == '__main__' ---
188-
code_no_main = re.sub(
189-
r"\n?def\s+main\([\s\S]*?(?=^if\s+__name__\s*==\s*['\"]__main__['\"]:)",
190-
"",
191-
code_no_license,
192-
flags=re.MULTILINE
193-
)
194-
code_no_main = re.sub(
195-
r"\n?if\s+__name__\s*==\s*['\"]__main__['\"]:\s*main\(\s*\)\s*",
196-
"",
197-
code_no_main,
198-
flags=re.MULTILINE
199-
)
207+
# --- 2. Parse code to AST ---
208+
try:
209+
tree = ast.parse(code_no_license)
210+
except SyntaxError as e:
211+
logging.error(f"Syntax error in {src_file}: {e}")
212+
return None
200213

201-
# --- 3. Convert test functions ---
202-
def process_test_function(match):
203-
# Extract the complete function definition and body
204-
func_str = match.group(0)
205-
206-
# 1. Remove print(data) statements
207-
func_str = re.sub(r'print\s*\(\s*data\s*\)\s*', '', func_str)
208-
209-
# 2. Change TestInput/TestOneInput to test_()
210-
func_str = re.sub(r'def\s+(TestInput|TestOneInput)\s*\(data\)', 'def test_()', func_str)
211-
212-
# 3. Insert data = b"" before the first executable line in the function body
213-
# Find the first non-empty line (ignoring empty lines and comments)
214-
lines = func_str.splitlines()
215-
if len(lines) < 2:
216-
return func_str
217-
218-
# Find the first non-empty, non-comment line after the function definition
219-
insert_idx = None
220-
for i in range(1, len(lines)):
221-
line = lines[i].strip()
222-
if line and not line.startswith('#'):
223-
insert_idx = i
224-
break
225-
226-
if insert_idx is None:
227-
return func_str
228-
229-
# Get the indentation level of that line
230-
indent_match = re.match(r'^(\s*)', lines[insert_idx])
231-
if not indent_match:
232-
return func_str
233-
234-
indent = indent_match.group(1)
235-
236-
# Insert data = b""
237-
lines.insert(insert_idx, f"{indent}data = b\"\"")
238-
239-
return "\n".join(lines)
240-
241-
cleaned_code = re.sub(
242-
r"def\s+(TestInput|TestOneInput)\s*\(data\):[\s\S]*?(?=\n\w|\Z)",
243-
process_test_function,
244-
code_no_main,
245-
flags=re.MULTILINE
246-
)
214+
# --- 3. AST transformation ---
215+
transformer = TestFunctionTransformer()
216+
new_tree = transformer.visit(tree)
217+
ast.fix_missing_locations(new_tree)
218+
219+
# --- 4. Generate cleaned code ---
220+
cleaned_code = astunparse.unparse(new_tree)
247221

248-
# --- 4. Output to tests-gen directory ---
222+
# --- 5. Output to tests-gen directory ---
249223
template_dir = pjoin(repo_path, "tests-gen")
250224
os.makedirs(template_dir, exist_ok=True)
251225

@@ -260,26 +234,144 @@ def process_test_function(match):
260234

261235
logging.info(f"Generated cleaned template: {template_path}")
262236
return template_path
263-
264-
def transform_repos(repos: list[str], jobs: int):
265-
"""
266-
Generate test templates for all targets
267237

268-
Args:
269-
repos (list[str]): List of repository paths
270-
jobs (int): Number of parallel tasks
271-
"""
272-
logging.info("Generating test templates")
238+
class TestFunctionTransformer(ast.NodeTransformer):
239+
"""AST transformer for test function conversion"""
273240

274-
def _transform_repo(repo: str):
275-
project_name = os.path.basename(repo)
276-
oss_fuzz_dir = Path(repo).parent.parent
277-
targets = discover_targets(project_name, oss_fuzz_dir)
278-
return [generate_test_template(t, repo) for t in targets]
241+
def visit_FunctionDef(self, node):
242+
# 首先处理 main 函数(移除)
243+
if node.name == "main":
244+
return None
245+
246+
# 处理 TestInput/TestOneInput 函数
247+
if node.name in ["TestInput", "TestOneInput"]:
248+
# a. 记录参数名称(假设只有一个参数)
249+
param_name = None
250+
if node.args.args:
251+
param_name = node.args.args[0].arg
252+
253+
# b. 将函数名改为 test_
254+
node.name = "test_"
255+
256+
# c. 移除参数(将参数列表设为空)
257+
node.args = ast.arguments(
258+
posonlyargs=[],
259+
args=[],
260+
vararg=None,
261+
kwonlyargs=[],
262+
kw_defaults=[],
263+
kwarg=None,
264+
defaults=[]
265+
)
266+
267+
# d. 在函数体开头插入 原参数名 = b""
268+
if param_name:
269+
self.add_param_assignment(node, param_name)
270+
271+
# f. 删除所有 print(原参数名) 的语句
272+
if param_name:
273+
self.remove_print_param(node, param_name)
274+
275+
# 确保继续遍历子节点
276+
self.generic_visit(node)
277+
return node
279278

280-
with ProcessingPool(jobs) as p:
281-
return list(p.map(_transform_repo, repos))
282-
279+
def add_param_assignment(self, node, param_name):
280+
"""Add param_name = b"" at the beginning of the function body"""
281+
# 创建赋值节点
282+
assign_node = ast.Assign(
283+
targets=[ast.Name(id=param_name, ctx=ast.Store())],
284+
value=ast.Constant(value=b"")
285+
)
286+
287+
# 如果有文档字符串,插入在文档字符串之后
288+
if node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str):
289+
node.body.insert(1, assign_node)
290+
else:
291+
node.body.insert(0, assign_node)
292+
293+
def remove_print_param(self, node, param_name):
294+
"""Remove print statements for the specific parameter"""
295+
new_body = []
296+
for stmt in node.body:
297+
# 跳过 print(param_name) 调用
298+
if (isinstance(stmt, ast.Expr) and
299+
isinstance(stmt.value, ast.Call) and
300+
isinstance(stmt.value.func, ast.Name) and
301+
stmt.value.func.id == "print" and
302+
any(isinstance(arg, ast.Name) and arg.id == param_name
303+
for arg in stmt.value.args)):
304+
continue
305+
new_body.append(stmt)
306+
node.body = new_body
307+
308+
def visit_If(self, node):
309+
"""Remove if __name__ == '__main__' blocks"""
310+
# 检查是否是主函数保护
311+
if (isinstance(node.test, ast.Compare) and
312+
isinstance(node.test.left, ast.Name) and
313+
node.test.left.id == "__name__" and
314+
isinstance(node.test.ops[0], ast.Eq) and
315+
isinstance(node.test.comparators[0], ast.Constant) and
316+
node.test.comparators[0].value == "__main__"):
317+
318+
# 移除整个 if 块
319+
return None
320+
321+
# 确保继续遍历子节点
322+
self.generic_visit(node)
323+
return node
324+
class TestGenTransformer(ast.NodeTransformer):
325+
"""AST transformer for generating test cases from fuzzing inputs"""
326+
327+
def __init__(self, idx: int, fuzz_input: bytes):
328+
self.idx = idx
329+
self.fuzz_input = fuzz_input
330+
self.found_test_function = False
331+
332+
def visit_FunctionDef(self, node):
333+
# 只处理名为 test_ 的函数
334+
if node.name == "test_":
335+
self.found_test_function = True
336+
337+
# 1. 将函数名改为 test_{idx}
338+
node.name = f"test_{self.idx}"
339+
340+
# 2. 找到并替换 data = b"" 赋值语句
341+
self.replace_data_assignment(node)
342+
343+
return node
344+
345+
def replace_data_assignment(self, node):
346+
"""Replace data assignment with fuzz input"""
347+
for i, stmt in enumerate(node.body):
348+
# 查找赋值语句
349+
if isinstance(stmt, ast.Assign):
350+
# 检查是否是 data = b"" 格式的赋值
351+
if (len(stmt.targets) == 1 and
352+
isinstance(stmt.targets[0], ast.Name) and
353+
isinstance(stmt.value, ast.Constant) and
354+
stmt.value.value == b""):
355+
356+
# 替换为新的输入数据
357+
node.body[i] = ast.Assign(
358+
targets=[stmt.targets[0]],
359+
value=ast.Constant(value=self.fuzz_input)
360+
)
361+
return
362+
363+
# 检查是否是 data = b'' 格式的赋值
364+
if (len(stmt.targets) == 1 and
365+
isinstance(stmt.targets[0], ast.Name) and
366+
isinstance(stmt.value, ast.Constant) and
367+
stmt.value.value == b''):
368+
369+
# 替换为新的输入数据
370+
node.body[i] = ast.Assign(
371+
targets=[stmt.targets[0]],
372+
value=ast.Constant(value=self.fuzz_input)
373+
)
374+
return
283375

284376
def substitute_one_repo(
285377
repo: str,
@@ -291,6 +383,7 @@ def substitute_one_repo(
291383
):
292384
"""
293385
Copy files from fuzz target template and generate multiple testgen files based on fuzz inputs
386+
using AST transformations
294387
"""
295388
input_dir = pjoin(repo, "fuzz_inputs")
296389
template_dir = pjoin(repo, "tests-gen")
@@ -307,15 +400,15 @@ def substitute_one_repo(
307400
logging.warning(f"Input file not found: {input_path}")
308401
continue
309402

310-
# Read all valid input data
403+
# 读取所有有效的输入数据
311404
valid_inputs = []
312405
with open(input_path, "rb") as f_input:
313406
for line in f_input:
314407
try:
315-
# Attempt to decode the line to check content
408+
# 尝试解码行以检查内容
316409
decoded = line.decode('utf-8', errors='replace')
317410

318-
# Only process lines starting with b' or b"
411+
# 只处理以 b' b" 开头的行
319412
if decoded.startswith(("b'", 'b"')):
320413
if decoded.startswith("b'") and decoded.endswith("'\n"):
321414
byte_data = line[2:-2]
@@ -336,7 +429,7 @@ def substitute_one_repo(
336429

337430
logging.info(f"Loaded {len(valid_inputs)} inputs for {target_name}")
338431

339-
# Strategy for selecting inputs
432+
# 策略选择输入
340433
if strategy == "shuffle":
341434
random.shuffle(valid_inputs)
342435
inputs = valid_inputs[:n_fuzz]
@@ -345,27 +438,42 @@ def substitute_one_repo(
345438
else:
346439
inputs = valid_inputs[:n_fuzz]
347440

348-
# Generate a separate file for each fuzz input
441+
# 每个 fuzz input 生成一个单独的文件(使用 AST)
349442
for idx, fuzz_input in enumerate(inputs, start=1):
350443
with open(source_file, "r") as f_src:
351444
code = f_src.read()
352445

353-
# 1. Change function name from test_ to test_{idx}
354-
code = re.sub(r'def\s+test_', f'def test_{idx}', code)
355-
356-
# 2. Replace data = b"" with input data
357-
input_repr = repr(fuzz_input)
358-
code = code.replace('data = b""', f'data = {input_repr}')
359-
360-
out_path = pjoin(template_dir, f"{target_name}.testgen_{idx}.py")
361-
with open(out_path, "w") as f_out:
362-
f_out.write(code)
363-
364-
# Format code
365446
try:
366-
subprocess.run(["black", out_path], check=False)
367-
except FileNotFoundError:
368-
logging.warning("Black formatter not found, skipping formatting")
447+
# 解析为 AST
448+
tree = ast.parse(code)
449+
450+
# 应用转换器
451+
transformer = TestGenTransformer(idx, fuzz_input)
452+
new_tree = transformer.visit(tree)
453+
ast.fix_missing_locations(new_tree)
454+
455+
# 确保找到并处理了测试函数
456+
if not transformer.found_test_function:
457+
logging.warning(f"No test_ function found in {source_file}")
458+
continue
459+
460+
# 生成新代码
461+
new_code = astunparse.unparse(new_tree)
462+
463+
out_path = pjoin(template_dir, f"{target_name}.testgen_{idx}.py")
464+
with open(out_path, "w") as f_out:
465+
f_out.write(new_code)
466+
467+
# 格式化代码
468+
try:
469+
subprocess.run(["black", out_path], check=False)
470+
except FileNotFoundError:
471+
logging.warning("Black formatter not found, skipping formatting")
472+
473+
except SyntaxError as e:
474+
logging.error(f"Syntax error when processing {source_file}: {e}")
475+
except Exception as e:
476+
logging.error(f"Error generating test case for {target_name}: {e}")
369477
def testgen_repos(
370478
repos: list[str],
371479
jobs: int,
@@ -405,7 +513,7 @@ def testgen_repos(
405513
))
406514

407515
def main(
408-
repo_id: str = "data/valid_projects3.txt",
516+
repo_id: str = "data/valid_projects.txt",
409517
repo_root: str = "fuzz/oss-fuzz/projects/",
410518
timeout: int = 30,
411519
jobs: int = 8,

0 commit comments

Comments
 (0)