@@ -198,7 +198,15 @@ def transform_repos(repos: list[str], jobs: int):
198198 def _transform_repo (repo : str ):
199199 project_name = os .path .basename (repo )
200200 oss_fuzz_dir = Path (repo ).parent .parent
201- targets = discover_targets (project_name , oss_fuzz_dir )
201+ raw_targets = discover_targets (project_name , oss_fuzz_dir )
202+
203+ # 只需移除目标名称中的 "_print1",不要添加任何新后缀
204+ transformed_targets = [t .replace ("_print1" , "" ) for t in raw_targets ]
205+
206+ # 去重
207+ targets = list (set (transformed_targets ))
208+
209+ # 传递给 generate_test_template 的是简单目标名称
202210 return [generate_test_template (t , repo ) for t in targets ]
203211
204212 with ProcessingPool (jobs ) as p :
@@ -209,7 +217,10 @@ def generate_test_template(target_name: str, repo_path: str):
209217 """
210218 Generate Python test template using AST for more precise code transformations
211219 """
212- src_file = pjoin (repo_path , target_name + ".py" )
220+ src_file = pjoin (repo_path , target_name )
221+ logging .info (f"Generating test template for { src_file } " )
222+ if not src_file .endswith (".py" ):
223+ src_file += ".py"
213224 if not os .path .exists (src_file ):
214225 logging .error (f"Source target file not found: { src_file } " )
215226 return None
@@ -253,7 +264,7 @@ def generate_test_template(target_name: str, repo_path: str):
253264 with open (init_path , "w" , encoding = "utf-8" ) as f :
254265 f .write ("" )
255266
256- template_path = pjoin (template_dir , f"{ target_name } .py" )
267+ template_path = pjoin (template_dir , f"{ os . path . splitext ( target_name )[ 0 ] } .py" )
257268 with open (template_path , "w" , encoding = "utf-8" ) as f :
258269 f .write (shebang + cleaned_code .strip () + "\n " )
259270
@@ -294,9 +305,6 @@ def visit_FunctionDef(self, node):
294305 if param_name :
295306 self .add_param_assignment (node , param_name )
296307
297- # f. 删除所有 print(原参数名) 的语句
298- if param_name :
299- self .remove_print_param (node , param_name )
300308
301309 # 确保继续遍历子节点
302310 self .generic_visit (node )
@@ -400,7 +408,7 @@ def visit_FunctionDef(self, node):
400408
401409def substitute_one_repo (
402410 repo : str ,
403- targets : list [str ],
411+ targets : list [tuple ], # 每个元素是 (transformed_target, raw_target)
404412 n_fuzz : int ,
405413 strategy : str ,
406414 max_len : int ,
@@ -414,23 +422,25 @@ def substitute_one_repo(
414422 template_dir = pjoin (repo , "tests-gen" )
415423 os .makedirs (template_dir , exist_ok = True )
416424
417- for target_name in targets :
418- source_file = pjoin (template_dir , f"{ target_name } .py" )
425+ for transformed_target , raw_target in targets :
426+ # 使用转换后的目标名称构建模板文件路径
427+ source_file = pjoin (template_dir , transformed_target + ".py" )
428+
429+ # 使用原始目标名称构建输入文件路径
430+ input_path = pjoin (input_dir , raw_target )
431+
432+ # 确保源文件存在
419433 if not os .path .exists (source_file ):
420434 logging .warning (f"Source file not found: { source_file } " )
421435 continue
422-
423- input_path = pjoin (input_dir , f"{ target_name } " )
424436 if not os .path .exists (input_path ):
425437 logging .warning (f"Input file not found: { input_path } " )
426438 continue
427-
439+
428440 # 读取所有有效的输入数据
429441 valid_inputs = []
430- # 首先读取文件内容,然后关闭文件
431442 with open (input_path , "rb" ) as f_input :
432443 lines = f_input .readlines ()
433-
434444 # 文件已关闭,现在处理数据
435445 for line in lines :
436446 # 使用 errors='replace' 确保解码不会失败
@@ -452,10 +462,12 @@ def substitute_one_repo(
452462 valid_inputs .append (line )
453463
454464 if not valid_inputs :
455- logging .warning (f"No valid inputs found for { target_name } " )
465+ # 使用 transformed_target 而不是 target_name
466+ logging .warning (f"No valid inputs found for { transformed_target } " )
456467 continue
457468
458- logging .info (f"Loaded { len (valid_inputs )} inputs for { target_name } " )
469+ # 使用 transformed_target 而不是 target_name
470+ logging .info (f"Loaded { len (valid_inputs )} inputs for { transformed_target } " )
459471 # 策略选择输入
460472 if strategy == "shuffle" :
461473 random .shuffle (valid_inputs )
@@ -487,28 +499,22 @@ def substitute_one_repo(
487499 # 生成新代码
488500 new_code = astunparse .unparse (new_tree )
489501
490- out_path = pjoin (template_dir , f"{ target_name } .testgen_{ idx } .py" )
502+ # 使用 transformed_target 而不是 target_name
503+ out_path = pjoin (template_dir , f"{ transformed_target } .testgen_{ idx } .py" )
491504 with open (out_path , "w" ) as f_out :
492505 f_out .write (new_code )
493506
494- # 格式化代码
495- formatter_installed = True
496- try :
497- subprocess .run (["black" , out_path ],
498- check = False ,
499- stdout = subprocess .DEVNULL , # 隐藏输出
500- stderr = subprocess .DEVNULL ) # 隐藏错误
501- except FileNotFoundError :
502- if formatter_installed : # 避免多次记录
503- logging .warning ("Black code formatter not found. For better formatting, install with:" )
504- logging .warning ("pip install black" )
505- formatter_installed = False
507+ # 格式化代码
508+ try :
509+ subprocess .run (["black" , out_path ], check = False )
510+ except FileNotFoundError :
511+ logging .warning ("Black formatter not found, skipping formatting" )
506512
507513 except SyntaxError as e :
508514 logging .error (f"Syntax error when processing { source_file } : { e } " )
509515 except Exception as e :
510- logging . error ( f"Error generating test case for { target_name } : { e } " )
511-
516+ # 使用 transformed_target 而不是 target_name
517+ logging . error ( f"Error generating test case for { transformed_target } : { e } " )
512518
513519def testgen_repos (
514520 repos : list [str ],
@@ -529,28 +535,34 @@ def testgen_repos(
529535 max_len (int): Maximum length
530536 sim_thresh (float): Similarity threshold
531537 """
532- # First get all targets
533- targets_list = []
538+ # First get all targets and apply transformation
539+ target_map = {}
534540 for repo in repos :
535541 project_name = os .path .basename (repo )
536542 oss_fuzz_dir = Path (repo ).parent .parent
537- targets = discover_targets (project_name , oss_fuzz_dir )
538- targets_list .append (targets )
539-
540- target_map = {repo : targets for repo , targets in zip (repos , targets_list )}
543+ raw_targets = discover_targets (project_name , oss_fuzz_dir )
544+
545+ # 保存原始目标名称和转换后的目标名称
546+ transformed_targets = [t .replace ("_print1" , "" ) for t in raw_targets ]
547+ targets = list (zip (transformed_targets , raw_targets )) # (转换后, 原始)
548+ target_map [repo ] = targets
541549
542550 # Process each repository in parallel
543551 with ProcessingPool (jobs ) as p :
544552 list (
545553 p .map (
546554 lambda item : substitute_one_repo (
547- item [0 ], item [1 ], n_fuzz , strategy , max_len , sim_thresh
555+ item [0 ], # repo path
556+ item [1 ], # list of (transformed, raw) targets
557+ n_fuzz ,
558+ strategy ,
559+ max_len ,
560+ sim_thresh
548561 ),
549562 target_map .items (),
550563 )
551564 )
552565
553-
554566def main (
555567 repo_id : str = "data/valid_projects.txt" ,
556568 repo_root : str = "fuzz/oss-fuzz/projects/" ,
0 commit comments