Skip to content

Commit 0067af3

Browse files
committed
apply transformations on the original unmodified fuzz targets.
1 parent a0bbe56 commit 0067af3

1 file changed

Lines changed: 51 additions & 39 deletions

File tree

fuzz/collect_fuzz_python.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,15 @@ def transform_repos(repos: list[str], jobs: int):
198198
def _transform_repo(repo: str):
199199
project_name = os.path.basename(repo)
200200
oss_fuzz_dir = Path(repo).parent.parent
201-
targets = discover_targets(project_name, oss_fuzz_dir)
201+
raw_targets = discover_targets(project_name, oss_fuzz_dir)
202+
203+
# 只需移除目标名称中的 "_print1",不要添加任何新后缀
204+
transformed_targets = [t.replace("_print1", "") for t in raw_targets]
205+
206+
# 去重
207+
targets = list(set(transformed_targets))
208+
209+
# 传递给 generate_test_template 的是简单目标名称
202210
return [generate_test_template(t, repo) for t in targets]
203211

204212
with ProcessingPool(jobs) as p:
@@ -209,7 +217,10 @@ def generate_test_template(target_name: str, repo_path: str):
209217
"""
210218
Generate Python test template using AST for more precise code transformations
211219
"""
212-
src_file = pjoin(repo_path, target_name + ".py")
220+
src_file = pjoin(repo_path, target_name)
221+
logging.info(f"Generating test template for {src_file}")
222+
if not src_file.endswith(".py"):
223+
src_file += ".py"
213224
if not os.path.exists(src_file):
214225
logging.error(f"Source target file not found: {src_file}")
215226
return None
@@ -253,7 +264,7 @@ def generate_test_template(target_name: str, repo_path: str):
253264
with open(init_path, "w", encoding="utf-8") as f:
254265
f.write("")
255266

256-
template_path = pjoin(template_dir, f"{target_name}.py")
267+
template_path = pjoin(template_dir, f"{os.path.splitext(target_name)[0]}.py")
257268
with open(template_path, "w", encoding="utf-8") as f:
258269
f.write(shebang + cleaned_code.strip() + "\n")
259270

@@ -294,9 +305,6 @@ def visit_FunctionDef(self, node):
294305
if param_name:
295306
self.add_param_assignment(node, param_name)
296307

297-
# f. 删除所有 print(原参数名) 的语句
298-
if param_name:
299-
self.remove_print_param(node, param_name)
300308

301309
# 确保继续遍历子节点
302310
self.generic_visit(node)
@@ -400,7 +408,7 @@ def visit_FunctionDef(self, node):
400408

401409
def substitute_one_repo(
402410
repo: str,
403-
targets: list[str],
411+
targets: list[tuple], # 每个元素是 (transformed_target, raw_target)
404412
n_fuzz: int,
405413
strategy: str,
406414
max_len: int,
@@ -414,23 +422,25 @@ def substitute_one_repo(
414422
template_dir = pjoin(repo, "tests-gen")
415423
os.makedirs(template_dir, exist_ok=True)
416424

417-
for target_name in targets:
418-
source_file = pjoin(template_dir, f"{target_name}.py")
425+
for transformed_target, raw_target in targets:
426+
# 使用转换后的目标名称构建模板文件路径
427+
source_file = pjoin(template_dir, transformed_target + ".py")
428+
429+
# 使用原始目标名称构建输入文件路径
430+
input_path = pjoin(input_dir, raw_target)
431+
432+
# 确保源文件存在
419433
if not os.path.exists(source_file):
420434
logging.warning(f"Source file not found: {source_file}")
421435
continue
422-
423-
input_path = pjoin(input_dir, f"{target_name}")
424436
if not os.path.exists(input_path):
425437
logging.warning(f"Input file not found: {input_path}")
426438
continue
427-
439+
428440
# 读取所有有效的输入数据
429441
valid_inputs = []
430-
# 首先读取文件内容,然后关闭文件
431442
with open(input_path, "rb") as f_input:
432443
lines = f_input.readlines()
433-
434444
# 文件已关闭,现在处理数据
435445
for line in lines:
436446
# 使用 errors='replace' 确保解码不会失败
@@ -452,10 +462,12 @@ def substitute_one_repo(
452462
valid_inputs.append(line)
453463

454464
if not valid_inputs:
455-
logging.warning(f"No valid inputs found for {target_name}")
465+
# 使用 transformed_target 而不是 target_name
466+
logging.warning(f"No valid inputs found for {transformed_target}")
456467
continue
457468

458-
logging.info(f"Loaded {len(valid_inputs)} inputs for {target_name}")
469+
# 使用 transformed_target 而不是 target_name
470+
logging.info(f"Loaded {len(valid_inputs)} inputs for {transformed_target}")
459471
# 策略选择输入
460472
if strategy == "shuffle":
461473
random.shuffle(valid_inputs)
@@ -487,28 +499,22 @@ def substitute_one_repo(
487499
# 生成新代码
488500
new_code = astunparse.unparse(new_tree)
489501

490-
out_path = pjoin(template_dir, f"{target_name}.testgen_{idx}.py")
502+
# 使用 transformed_target 而不是 target_name
503+
out_path = pjoin(template_dir, f"{transformed_target}.testgen_{idx}.py")
491504
with open(out_path, "w") as f_out:
492505
f_out.write(new_code)
493506

494-
# 格式化代码
495-
formatter_installed = True
496-
try:
497-
subprocess.run(["black", out_path],
498-
check=False,
499-
stdout=subprocess.DEVNULL, # 隐藏输出
500-
stderr=subprocess.DEVNULL) # 隐藏错误
501-
except FileNotFoundError:
502-
if formatter_installed: # 避免多次记录
503-
logging.warning("Black code formatter not found. For better formatting, install with:")
504-
logging.warning("pip install black")
505-
formatter_installed = False
507+
# 格式化代码
508+
try:
509+
subprocess.run(["black", out_path], check=False)
510+
except FileNotFoundError:
511+
logging.warning("Black formatter not found, skipping formatting")
506512

507513
except SyntaxError as e:
508514
logging.error(f"Syntax error when processing {source_file}: {e}")
509515
except Exception as e:
510-
logging.error(f"Error generating test case for {target_name}: {e}")
511-
516+
# 使用 transformed_target 而不是 target_name
517+
logging.error(f"Error generating test case for {transformed_target}: {e}")
512518

513519
def testgen_repos(
514520
repos: list[str],
@@ -529,28 +535,34 @@ def testgen_repos(
529535
max_len (int): Maximum length
530536
sim_thresh (float): Similarity threshold
531537
"""
532-
# First get all targets
533-
targets_list = []
538+
# First get all targets and apply transformation
539+
target_map = {}
534540
for repo in repos:
535541
project_name = os.path.basename(repo)
536542
oss_fuzz_dir = Path(repo).parent.parent
537-
targets = discover_targets(project_name, oss_fuzz_dir)
538-
targets_list.append(targets)
539-
540-
target_map = {repo: targets for repo, targets in zip(repos, targets_list)}
543+
raw_targets = discover_targets(project_name, oss_fuzz_dir)
544+
545+
# 保存原始目标名称和转换后的目标名称
546+
transformed_targets = [t.replace("_print1", "") for t in raw_targets]
547+
targets = list(zip(transformed_targets, raw_targets)) # (转换后, 原始)
548+
target_map[repo] = targets
541549

542550
# Process each repository in parallel
543551
with ProcessingPool(jobs) as p:
544552
list(
545553
p.map(
546554
lambda item: substitute_one_repo(
547-
item[0], item[1], n_fuzz, strategy, max_len, sim_thresh
555+
item[0], # repo path
556+
item[1], # list of (transformed, raw) targets
557+
n_fuzz,
558+
strategy,
559+
max_len,
560+
sim_thresh
548561
),
549562
target_map.items(),
550563
)
551564
)
552565

553-
554566
def main(
555567
repo_id: str = "data/valid_projects.txt",
556568
repo_root: str = "fuzz/oss-fuzz/projects/",

0 commit comments

Comments
 (0)