11import os
2- import re
2+ import ast
33
44def add_print_to_testoneinput (file_path ):
55 with open (file_path , 'r' ) as f :
66 content = f .read ()
77
8- # 正则表达式匹配TestOneInput或TestInput函数定义及其函数体
9- pattern = r'(\bdef\s+(TestOneInput|TestInput)\(data\):\s*\n)((?:[ \t]+.*\n|\s*\n)*)'
10- matches = re .finditer (pattern , content , re .MULTILINE )
8+ # 解析 AST
9+ tree = ast .parse (content )
1110
12- new_content = content
13- for match in reversed (list (matches )):
14- function_def = match .group (1 )
15- function_body = match .group (3 )
16-
17- # 在函数体开头添加print(data)语句
18- new_function_body = re .sub (
19- r'^([ \t]*)(.*\n)' ,
20- r'\g<1>\2\g<1>print(data)\n' ,
21- function_body ,
22- count = 1
23- )
24-
25- # 只有在函数体非空且未添加过print时才替换
26- if new_function_body != function_body :
27- new_content = (
28- new_content [:match .start (3 )] +
29- new_function_body +
30- new_content [match .end (3 ):]
31- )
11+ class InsertPrintTransformer (ast .NodeTransformer ):
12+ def visit_FunctionDef (self , node ):
13+ if node .name in ("TestOneInput" , "TestInput" ) and node .args .args :
14+ first_arg_name = node .args .args [0 ].arg
15+ # 创建 print(参数名) 语句
16+ print_stmt = ast .Expr (
17+ value = ast .Call (
18+ func = ast .Name (id = 'print' , ctx = ast .Load ()),
19+ args = [ast .Name (id = first_arg_name , ctx = ast .Load ())],
20+ keywords = []
21+ )
22+ )
23+ # 确保没有重复插入
24+ if not (
25+ isinstance (node .body [0 ], ast .Expr )
26+ and isinstance (node .body [0 ].value , ast .Call )
27+ and getattr (node .body [0 ].value .func , "id" , None ) == "print"
28+ ):
29+ node .body .insert (0 , print_stmt )
30+ return node
31+
32+ transformer = InsertPrintTransformer ()
33+ new_tree = transformer .visit (tree )
34+ ast .fix_missing_locations (new_tree )
35+
36+ # 转回代码
37+ import astor
38+ new_content = astor .to_source (new_tree )
3239
3340 return new_content
3441
@@ -41,26 +48,26 @@ def main():
4148
4249 for project in projects :
4350 project_dir = os .path .join (projects_path , project )
44-
51+
4552 if not os .path .isdir (project_dir ):
4653 continue
4754
4855 for root , _ , files in os .walk (project_dir ):
4956 for file in files :
5057 if file .startswith ('fuzz_' ) and file .endswith ('.py' ):
5158 file_path = os .path .join (root , file )
52-
59+
5360 try :
5461 new_content = add_print_to_testoneinput (file_path )
55-
56- # 保存修改后的文件(添加_print后缀)
62+
63+ # 保存修改后的文件
5764 new_file_path = file_path .rsplit ('.' , 1 )[0 ] + '_print1.py'
5865 with open (new_file_path , 'w' ) as f :
5966 f .write (new_content )
6067 print (f"Processed: { file_path } -> { new_file_path } " )
61-
68+
6269 except Exception as e :
6370 print (f"Error processing { file_path } : { str (e )} " )
6471
6572if __name__ == "__main__" :
66- main ()
73+ main ()
0 commit comments