@@ -303,23 +303,33 @@ def visit_FunctionDef(self, node):
303303 return node
304304
305305 def add_param_assignment (self , node , param_name ):
306- """Add param_name = b"" at the beginning of the function body"""
306+ """Add param_name = b"..." at the beginning of the function body with an inline comment"""
307+ # 创建包含赋值和注释的复合值
308+ value_with_comment = ast .JoinedStr (
309+ values = [
310+ ast .FormattedValue (value = ast .Constant (value = b"" ), conversion = - 1 ),
311+ ast .Constant (value = " # This is a test template" )
312+ ]
313+ )
314+
307315 # 创建赋值节点
308316 assign_node = ast .Assign (
309317 targets = [ast .Name (id = param_name , ctx = ast .Store ())],
310- value = ast . Constant ( value = b"" ),
318+ value = value_with_comment
311319 )
312-
320+
313321 # 如果有文档字符串,插入在文档字符串之后
314322 if (
315323 node .body
316324 and isinstance (node .body [0 ], ast .Expr )
317- and isinstance (node .body [0 ].value , ast .Str )
325+ and isinstance (node .body [0 ].value , ast .Constant )
326+ and isinstance (node .body [0 ].value .value , str )
318327 ):
328+ # 插入在文档字符串后面
319329 node .body .insert (1 , assign_node )
320330 else :
331+ # 插入在函数开头
321332 node .body .insert (0 , assign_node )
322-
323333 def remove_print_param (self , node , param_name ):
324334 """Remove print statements for the specific parameter"""
325335 new_body = []
@@ -360,62 +370,30 @@ def visit_If(self, node):
360370
361371
362372class TestGenTransformer (ast .NodeTransformer ):
363- """AST transformer for generating test cases from fuzzing inputs"""
364-
365- def __init__ (self , idx : int , fuzz_input : bytes ):
373+ def __init__ (self , idx , fuzz_input ):
366374 self .idx = idx
367375 self .fuzz_input = fuzz_input
368376 self .found_test_function = False
369377
370378 def visit_FunctionDef (self , node ):
371- # 只处理名为 test_ 的函数
372379 if node .name == "test_" :
373380 self .found_test_function = True
374-
375- # 1. 将函数名改为 test_{idx}
376- node .name = f"test_{ self .idx } "
377-
378- # 2. 找到并替换 data = b"" 赋值语句
379- self .replace_data_assignment (node )
380-
381+ # 遍历函数体,寻找包含注释的赋值语句
382+ for i , stmt in enumerate (node .body ):
383+ # 检查是否是赋值语句
384+ if isinstance (stmt , ast .Assign ):
385+ # 检查赋值语句的值是否是带有注释的复合值
386+ if (
387+ isinstance (stmt .value , ast .JoinedStr )
388+ and len (stmt .value .values ) >= 2
389+ and isinstance (stmt .value .values [1 ], ast .Constant )
390+ and stmt .value .values [1 ].value == " # This is a test template"
391+ ):
392+ # 替换为新的输入值
393+ stmt .value = ast .Constant (value = self .fuzz_input )
394+ break
381395 return node
382396
383- def replace_data_assignment (self , node ):
384- """Replace data assignment with fuzz input"""
385- for i , stmt in enumerate (node .body ):
386- # 查找赋值语句
387- if isinstance (stmt , ast .Assign ):
388- # 检查是否是 data = b"" 格式的赋值
389- if (
390- len (stmt .targets ) == 1
391- and isinstance (stmt .targets [0 ], ast .Name )
392- and isinstance (stmt .value , ast .Constant )
393- and stmt .value .value == b""
394- ):
395-
396- # 替换为新的输入数据
397- node .body [i ] = ast .Assign (
398- targets = [stmt .targets [0 ]],
399- value = ast .Constant (value = self .fuzz_input ),
400- )
401- return
402-
403- # 检查是否是 data = b'' 格式的赋值
404- if (
405- len (stmt .targets ) == 1
406- and isinstance (stmt .targets [0 ], ast .Name )
407- and isinstance (stmt .value , ast .Constant )
408- and stmt .value .value == b""
409- ):
410-
411- # 替换为新的输入数据
412- node .body [i ] = ast .Assign (
413- targets = [stmt .targets [0 ]],
414- value = ast .Constant (value = self .fuzz_input ),
415- )
416- return
417-
418-
419397def substitute_one_repo (
420398 repo : str ,
421399 targets : list [str ],
0 commit comments