33usage: PYTHONPATH=. python3 fuzz/collect_fuzz_python.py --pipeline all
44"""
55from pathlib import Path
6+ import ast
7+ import astunparse
68import logging
7- from typing import Optional , List , Tuple
9+ from typing import Optional
810import fire
911import os
1012from UniTSyn .frontend .util import wrap_repo , parallel_subprocess
@@ -159,10 +161,28 @@ def fuzz_repos(repos: list[str], jobs: int, timeout: int):
159161 logging .info (f"Starting parallel fuzzing with { jobs } jobs, timeout={ timeout } s per target" )
160162 parallel_subprocess (all_targets , jobs , lambda p : fuzz_one_target (p , timeout ), on_exit = None )
161163
164+ def transform_repos (repos : list [str ], jobs : int ):
165+ """
166+ Generate test templates for all targets
167+
168+ Args:
169+ repos (list[str]): List of repository paths
170+ jobs (int): Number of parallel tasks
171+ """
172+ logging .info ("Generating test templates" )
173+
174+ def _transform_repo (repo : str ):
175+ project_name = os .path .basename (repo )
176+ oss_fuzz_dir = Path (repo ).parent .parent
177+ targets = discover_targets (project_name , oss_fuzz_dir )
178+ return [generate_test_template (t , repo ) for t in targets ]
179+
180+ with ProcessingPool (jobs ) as p :
181+ return list (p .map (_transform_repo , repos ))
182+
162183def generate_test_template (target_name : str , repo_path : str ):
163184 """
164- Generate Python test template for a single target by stripping license header,
165- main() block, and converting TestInput/TestOneInput to test_ with data=b"".
185+ Generate Python test template using AST for more precise code transformations
166186 """
167187 src_file = pjoin (repo_path , target_name + ".py" )
168188 if not os .path .exists (src_file ):
@@ -184,68 +204,22 @@ def generate_test_template(target_name: str, repo_path: str):
184204 )
185205 code_no_license = re .sub (license_pattern , "" , original_code , count = 1 )
186206
187- # --- 2. Remove main function and if __name__ == '__main__' ---
188- code_no_main = re .sub (
189- r"\n?def\s+main\([\s\S]*?(?=^if\s+__name__\s*==\s*['\"]__main__['\"]:)" ,
190- "" ,
191- code_no_license ,
192- flags = re .MULTILINE
193- )
194- code_no_main = re .sub (
195- r"\n?if\s+__name__\s*==\s*['\"]__main__['\"]:\s*main\(\s*\)\s*" ,
196- "" ,
197- code_no_main ,
198- flags = re .MULTILINE
199- )
207+ # --- 2. Parse code to AST ---
208+ try :
209+ tree = ast .parse (code_no_license )
210+ except SyntaxError as e :
211+ logging .error (f"Syntax error in { src_file } : { e } " )
212+ return None
200213
201- # --- 3. Convert test functions ---
202- def process_test_function (match ):
203- # Extract the complete function definition and body
204- func_str = match .group (0 )
205-
206- # 1. Remove print(data) statements
207- func_str = re .sub (r'print\s*\(\s*data\s*\)\s*' , '' , func_str )
208-
209- # 2. Change TestInput/TestOneInput to test_()
210- func_str = re .sub (r'def\s+(TestInput|TestOneInput)\s*\(data\)' , 'def test_()' , func_str )
211-
212- # 3. Insert data = b"" before the first executable line in the function body
213- # Find the first non-empty line (ignoring empty lines and comments)
214- lines = func_str .splitlines ()
215- if len (lines ) < 2 :
216- return func_str
217-
218- # Find the first non-empty, non-comment line after the function definition
219- insert_idx = None
220- for i in range (1 , len (lines )):
221- line = lines [i ].strip ()
222- if line and not line .startswith ('#' ):
223- insert_idx = i
224- break
225-
226- if insert_idx is None :
227- return func_str
228-
229- # Get the indentation level of that line
230- indent_match = re .match (r'^(\s*)' , lines [insert_idx ])
231- if not indent_match :
232- return func_str
233-
234- indent = indent_match .group (1 )
235-
236- # Insert data = b""
237- lines .insert (insert_idx , f"{ indent } data = b\" \" " )
238-
239- return "\n " .join (lines )
240-
241- cleaned_code = re .sub (
242- r"def\s+(TestInput|TestOneInput)\s*\(data\):[\s\S]*?(?=\n\w|\Z)" ,
243- process_test_function ,
244- code_no_main ,
245- flags = re .MULTILINE
246- )
214+ # --- 3. AST transformation ---
215+ transformer = TestFunctionTransformer ()
216+ new_tree = transformer .visit (tree )
217+ ast .fix_missing_locations (new_tree )
218+
219+ # --- 4. Generate cleaned code ---
220+ cleaned_code = astunparse .unparse (new_tree )
247221
248- # --- 4 . Output to tests-gen directory ---
222+ # --- 5 . Output to tests-gen directory ---
249223 template_dir = pjoin (repo_path , "tests-gen" )
250224 os .makedirs (template_dir , exist_ok = True )
251225
@@ -260,26 +234,144 @@ def process_test_function(match):
260234
261235 logging .info (f"Generated cleaned template: { template_path } " )
262236 return template_path
263-
264- def transform_repos (repos : list [str ], jobs : int ):
265- """
266- Generate test templates for all targets
267237
268- Args:
269- repos (list[str]): List of repository paths
270- jobs (int): Number of parallel tasks
271- """
272- logging .info ("Generating test templates" )
238+ class TestFunctionTransformer (ast .NodeTransformer ):
239+ """AST transformer for test function conversion"""
273240
274- def _transform_repo (repo : str ):
275- project_name = os .path .basename (repo )
276- oss_fuzz_dir = Path (repo ).parent .parent
277- targets = discover_targets (project_name , oss_fuzz_dir )
278- return [generate_test_template (t , repo ) for t in targets ]
241+ def visit_FunctionDef (self , node ):
242+ # 首先处理 main 函数(移除)
243+ if node .name == "main" :
244+ return None
245+
246+ # 处理 TestInput/TestOneInput 函数
247+ if node .name in ["TestInput" , "TestOneInput" ]:
248+ # a. 记录参数名称(假设只有一个参数)
249+ param_name = None
250+ if node .args .args :
251+ param_name = node .args .args [0 ].arg
252+
253+ # b. 将函数名改为 test_
254+ node .name = "test_"
255+
256+ # c. 移除参数(将参数列表设为空)
257+ node .args = ast .arguments (
258+ posonlyargs = [],
259+ args = [],
260+ vararg = None ,
261+ kwonlyargs = [],
262+ kw_defaults = [],
263+ kwarg = None ,
264+ defaults = []
265+ )
266+
267+ # d. 在函数体开头插入 原参数名 = b""
268+ if param_name :
269+ self .add_param_assignment (node , param_name )
270+
271+ # f. 删除所有 print(原参数名) 的语句
272+ if param_name :
273+ self .remove_print_param (node , param_name )
274+
275+ # 确保继续遍历子节点
276+ self .generic_visit (node )
277+ return node
279278
280- with ProcessingPool (jobs ) as p :
281- return list (p .map (_transform_repo , repos ))
282-
279+ def add_param_assignment (self , node , param_name ):
280+ """Add param_name = b"" at the beginning of the function body"""
281+ # 创建赋值节点
282+ assign_node = ast .Assign (
283+ targets = [ast .Name (id = param_name , ctx = ast .Store ())],
284+ value = ast .Constant (value = b"" )
285+ )
286+
287+ # 如果有文档字符串,插入在文档字符串之后
288+ if node .body and isinstance (node .body [0 ], ast .Expr ) and isinstance (node .body [0 ].value , ast .Str ):
289+ node .body .insert (1 , assign_node )
290+ else :
291+ node .body .insert (0 , assign_node )
292+
293+ def remove_print_param (self , node , param_name ):
294+ """Remove print statements for the specific parameter"""
295+ new_body = []
296+ for stmt in node .body :
297+ # 跳过 print(param_name) 调用
298+ if (isinstance (stmt , ast .Expr ) and
299+ isinstance (stmt .value , ast .Call ) and
300+ isinstance (stmt .value .func , ast .Name ) and
301+ stmt .value .func .id == "print" and
302+ any (isinstance (arg , ast .Name ) and arg .id == param_name
303+ for arg in stmt .value .args )):
304+ continue
305+ new_body .append (stmt )
306+ node .body = new_body
307+
308+ def visit_If (self , node ):
309+ """Remove if __name__ == '__main__' blocks"""
310+ # 检查是否是主函数保护
311+ if (isinstance (node .test , ast .Compare ) and
312+ isinstance (node .test .left , ast .Name ) and
313+ node .test .left .id == "__name__" and
314+ isinstance (node .test .ops [0 ], ast .Eq ) and
315+ isinstance (node .test .comparators [0 ], ast .Constant ) and
316+ node .test .comparators [0 ].value == "__main__" ):
317+
318+ # 移除整个 if 块
319+ return None
320+
321+ # 确保继续遍历子节点
322+ self .generic_visit (node )
323+ return node
324+ class TestGenTransformer (ast .NodeTransformer ):
325+ """AST transformer for generating test cases from fuzzing inputs"""
326+
327+ def __init__ (self , idx : int , fuzz_input : bytes ):
328+ self .idx = idx
329+ self .fuzz_input = fuzz_input
330+ self .found_test_function = False
331+
332+ def visit_FunctionDef (self , node ):
333+ # 只处理名为 test_ 的函数
334+ if node .name == "test_" :
335+ self .found_test_function = True
336+
337+ # 1. 将函数名改为 test_{idx}
338+ node .name = f"test_{ self .idx } "
339+
340+ # 2. 找到并替换 data = b"" 赋值语句
341+ self .replace_data_assignment (node )
342+
343+ return node
344+
345+ def replace_data_assignment (self , node ):
346+ """Replace data assignment with fuzz input"""
347+ for i , stmt in enumerate (node .body ):
348+ # 查找赋值语句
349+ if isinstance (stmt , ast .Assign ):
350+ # 检查是否是 data = b"" 格式的赋值
351+ if (len (stmt .targets ) == 1 and
352+ isinstance (stmt .targets [0 ], ast .Name ) and
353+ isinstance (stmt .value , ast .Constant ) and
354+ stmt .value .value == b"" ):
355+
356+ # 替换为新的输入数据
357+ node .body [i ] = ast .Assign (
358+ targets = [stmt .targets [0 ]],
359+ value = ast .Constant (value = self .fuzz_input )
360+ )
361+ return
362+
363+ # 检查是否是 data = b'' 格式的赋值
364+ if (len (stmt .targets ) == 1 and
365+ isinstance (stmt .targets [0 ], ast .Name ) and
366+ isinstance (stmt .value , ast .Constant ) and
367+ stmt .value .value == b'' ):
368+
369+ # 替换为新的输入数据
370+ node .body [i ] = ast .Assign (
371+ targets = [stmt .targets [0 ]],
372+ value = ast .Constant (value = self .fuzz_input )
373+ )
374+ return
283375
284376def substitute_one_repo (
285377 repo : str ,
@@ -291,6 +383,7 @@ def substitute_one_repo(
291383):
292384 """
293385 Copy files from fuzz target template and generate multiple testgen files based on fuzz inputs
386+ using AST transformations
294387 """
295388 input_dir = pjoin (repo , "fuzz_inputs" )
296389 template_dir = pjoin (repo , "tests-gen" )
@@ -307,15 +400,15 @@ def substitute_one_repo(
307400 logging .warning (f"Input file not found: { input_path } " )
308401 continue
309402
310- # Read all valid input data
403+ # 读取所有有效的输入数据
311404 valid_inputs = []
312405 with open (input_path , "rb" ) as f_input :
313406 for line in f_input :
314407 try :
315- # Attempt to decode the line to check content
408+ # 尝试解码行以检查内容
316409 decoded = line .decode ('utf-8' , errors = 'replace' )
317410
318- # Only process lines starting with b' or b"
411+ # 只处理以 b' 或 b" 开头的行
319412 if decoded .startswith (("b'" , 'b"' )):
320413 if decoded .startswith ("b'" ) and decoded .endswith ("'\n " ):
321414 byte_data = line [2 :- 2 ]
@@ -336,7 +429,7 @@ def substitute_one_repo(
336429
337430 logging .info (f"Loaded { len (valid_inputs )} inputs for { target_name } " )
338431
339- # Strategy for selecting inputs
432+ # 策略选择输入
340433 if strategy == "shuffle" :
341434 random .shuffle (valid_inputs )
342435 inputs = valid_inputs [:n_fuzz ]
@@ -345,27 +438,42 @@ def substitute_one_repo(
345438 else :
346439 inputs = valid_inputs [:n_fuzz ]
347440
348- # Generate a separate file for each fuzz input
441+ # 每个 fuzz input 生成一个单独的文件(使用 AST)
349442 for idx , fuzz_input in enumerate (inputs , start = 1 ):
350443 with open (source_file , "r" ) as f_src :
351444 code = f_src .read ()
352445
353- # 1. Change function name from test_ to test_{idx}
354- code = re .sub (r'def\s+test_' , f'def test_{ idx } ' , code )
355-
356- # 2. Replace data = b"" with input data
357- input_repr = repr (fuzz_input )
358- code = code .replace ('data = b""' , f'data = { input_repr } ' )
359-
360- out_path = pjoin (template_dir , f"{ target_name } .testgen_{ idx } .py" )
361- with open (out_path , "w" ) as f_out :
362- f_out .write (code )
363-
364- # Format code
365446 try :
366- subprocess .run (["black" , out_path ], check = False )
367- except FileNotFoundError :
368- logging .warning ("Black formatter not found, skipping formatting" )
447+ # 解析为 AST
448+ tree = ast .parse (code )
449+
450+ # 应用转换器
451+ transformer = TestGenTransformer (idx , fuzz_input )
452+ new_tree = transformer .visit (tree )
453+ ast .fix_missing_locations (new_tree )
454+
455+ # 确保找到并处理了测试函数
456+ if not transformer .found_test_function :
457+ logging .warning (f"No test_ function found in { source_file } " )
458+ continue
459+
460+ # 生成新代码
461+ new_code = astunparse .unparse (new_tree )
462+
463+ out_path = pjoin (template_dir , f"{ target_name } .testgen_{ idx } .py" )
464+ with open (out_path , "w" ) as f_out :
465+ f_out .write (new_code )
466+
467+ # 格式化代码
468+ try :
469+ subprocess .run (["black" , out_path ], check = False )
470+ except FileNotFoundError :
471+ logging .warning ("Black formatter not found, skipping formatting" )
472+
473+ except SyntaxError as e :
474+ logging .error (f"Syntax error when processing { source_file } : { e } " )
475+ except Exception as e :
476+ logging .error (f"Error generating test case for { target_name } : { e } " )
369477def testgen_repos (
370478 repos : list [str ],
371479 jobs : int ,
@@ -405,7 +513,7 @@ def testgen_repos(
405513 ))
406514
407515def main (
408- repo_id : str = "data/valid_projects3 .txt" ,
516+ repo_id : str = "data/valid_projects .txt" ,
409517 repo_root : str = "fuzz/oss-fuzz/projects/" ,
410518 timeout : int = 30 ,
411519 jobs : int = 8 ,
0 commit comments