33import ast
44import fire
55
6+
7+ class InsertPrintTransformer (ast .NodeTransformer ):
8+ def visit_FunctionDef (self , node ):
9+ if node .name in ("TestOneInput" , "TestInput" ) and node .args .args :
10+ first_arg_name = node .args .args [0 ].arg
11+ print_stmt = ast .Expr (
12+ value = ast .Call (
13+ func = ast .Name (id = 'print' , ctx = ast .Load ()),
14+ args = [ast .Name (id = first_arg_name , ctx = ast .Load ())],
15+ keywords = []
16+ )
17+ )
18+ # 添加空body检查
19+ if not node .body :
20+ node .body .append (print_stmt )
21+ else :
22+ # 增强重复检查逻辑
23+ first_stmt = node .body [0 ]
24+ if not (isinstance (first_stmt , ast .Expr )
25+ and isinstance (first_stmt .value , ast .Call )
26+ and hasattr (first_stmt .value .func , 'id' )
27+ and first_stmt .value .func .id == 'print' ):
28+ node .body .insert (0 , print_stmt )
29+ return node
30+
631def add_print_to_testoneinput (file_path ):
732 with open (file_path , 'r' ) as f :
833 content = f .read ()
934
10- # 解析 AST
1135 tree = ast .parse (content )
12-
13- class InsertPrintTransformer (ast .NodeTransformer ):
14- def visit_FunctionDef (self , node ):
15- if node .name in ("TestOneInput" , "TestInput" ) and node .args .args :
16- first_arg_name = node .args .args [0 ].arg
17- # 创建 print(参数名) 语句
18- print_stmt = ast .Expr (
19- value = ast .Call (
20- func = ast .Name (id = 'print' , ctx = ast .Load ()),
21- args = [ast .Name (id = first_arg_name , ctx = ast .Load ())],
22- keywords = []
23- )
24- )
25- # 确保没有重复插入
26- if not (
27- isinstance (node .body [0 ], ast .Expr )
28- and isinstance (node .body [0 ].value , ast .Call )
29- and getattr (node .body [0 ].value .func , "id" , None ) == "print"
30- ):
31- node .body .insert (0 , print_stmt )
32- return node
33-
3436 transformer = InsertPrintTransformer ()
3537 new_tree = transformer .visit (tree )
3638 ast .fix_missing_locations (new_tree )
3739
38- # 转回代码
3940 import astor
4041 new_content = astor .to_source (new_tree )
4142 return new_content
@@ -44,31 +45,21 @@ def main(
4445 projects_path = "fuzz/oss-fuzz/projects" ,
4546 valid_projects_file = "data/valid_projects.txt"
4647):
47- """
48- 给 fuzz target 的 TestOneInput / TestInput 函数开头插入 print(参数名)
49-
50- Args:
51- projects_path (str): OSS-Fuzz 项目的根目录
52- valid_projects_file (str): 包含有效项目名的文件路径
53- """
48+ """为fuzz target添加打印语句"""
5449 with open (valid_projects_file , 'r' ) as f :
5550 projects = [line .strip () for line in f if line .strip ()]
5651
5752 for project in projects :
5853 project_dir = os .path .join (projects_path , project )
59-
6054 if not os .path .isdir (project_dir ):
6155 continue
6256
6357 for root , _ , files in os .walk (project_dir ):
6458 for file in files :
6559 if file .startswith ('fuzz_' ) and file .endswith ('.py' ):
6660 file_path = os .path .join (root , file )
67-
6861 try :
6962 new_content = add_print_to_testoneinput (file_path )
70-
71- # 保存修改后的文件
7263 new_file_path = file_path .rsplit ('.' , 1 )[0 ] + '_print1.py'
7364 with open (new_file_path , 'w' ) as f :
7465 f .write (new_content )
@@ -78,4 +69,4 @@ def main(
7869 print (f"Error processing { file_path } : { str (e )} " )
7970
8071if __name__ == "__main__" :
81- fire .Fire (main )
72+ fire .Fire (main )
0 commit comments