@@ -144,7 +144,9 @@ def parse_args() -> argparse.Namespace:
144144 parser .add_argument ("--input-line-cap-factor" , type = float , default = 1.5 , help = "Maximum allowed factor increase in total lines for input-stage trials" )
145145 parser .add_argument ("--input-fake-line-cap-factor" , type = float , default = 1.5 , help = "Maximum allowed factor increase in fake lines for input-stage trials" )
146146 parser .add_argument ("--input-fake-line-cap-add" , type = int , default = 200 , help = "Minimum additive fake-line allowance for input-stage trials" )
147+ parser .add_argument ("--failed-trial-score" , type = float , default = - 1.0 , help = "Objective value assigned to trials whose workflow fails" )
147148 parser .add_argument ("--artifact-dir" , type = Path , default = Path (".vertex_optuna" ), help = "Directory for trial workflows and summaries" )
149+ parser .add_argument ("--run-tracking" , action = "store_true" , help = "Run ITS tracking after vertexing instead of using vertexer-only trials" )
148150 parser .add_argument ("--dry-run" , action = "store_true" , help = "Patch and print one trial setup without running" )
149151 return parser .parse_args ()
150152
@@ -462,13 +464,16 @@ def patch_workflow_for_trial(
462464 workflow : dict [str , Any ],
463465 task_tag : str ,
464466 overrides : dict [str , Any ],
467+ run_tracking : bool = False ,
465468) -> dict [str , Any ]:
466469 patched = json .loads (json .dumps (workflow ))
467470 _ , patched_tasks = load_workflow_from_object (patched )
468471 prefix = f"{ task_tag } _"
469472 full_overrides = dict (overrides )
470473 full_overrides ["ITSVertexerParam.nIterations" ] = 2
471474 full_overrides ["ITSCATrackerParam.doUPCIteration" ] = 1
475+ if not run_tracking :
476+ full_overrides ["ITSCATrackerParam.nIterations" ] = 0
472477 for task in patched_tasks :
473478 name = str (task .get ("name" , "" ))
474479 if not name .startswith (prefix ):
@@ -530,8 +535,9 @@ def evaluate_trial(
530535 trial_name : str ,
531536 overrides : dict [str , Any ],
532537 selected_tasks : list [TaskInfo ],
538+ run_tracking : bool ,
533539) -> dict [str , Any ]:
534- patched = patch_workflow_for_trial (workflow , task_tag , overrides )
540+ patched = patch_workflow_for_trial (workflow , task_tag , overrides , run_tracking )
535541 workflow_copy = artifact_dir / f"{ trial_name } _workflow.json"
536542 with workflow_copy .open ("w" ) as stream :
537543 json .dump (patched , stream , indent = 2 )
@@ -573,7 +579,7 @@ def main() -> int:
573579 if args .dry_run :
574580 specs = INPUT_STAGE_SPECS if args .stage == "input" else VERTEX_STAGE_SPECS
575581 dry_overrides = {spec .key : base_params [spec .key ] for spec in specs }
576- patched = patch_workflow_for_trial (workflow , args .task_tag , dry_overrides )
582+ patched = patch_workflow_for_trial (workflow , args .task_tag , dry_overrides , args . run_tracking )
577583 workflow_copy = artifact_dir / "dry_run_workflow.json"
578584 with workflow_copy .open ("w" ) as stream :
579585 json .dump (patched , stream , indent = 2 )
@@ -594,6 +600,7 @@ def main() -> int:
594600 trial_name = "baseline" ,
595601 overrides = {},
596602 selected_tasks = selected_tasks ,
603+ run_tracking = args .run_tracking ,
597604 )
598605 print (f"Baseline: { summarize_result (baseline_metrics )} " )
599606
@@ -608,18 +615,25 @@ def main() -> int:
608615
609616 def objective (trial : optuna .Trial ) -> float :
610617 overrides = suggest_params (trial , args .stage , base_params )
611- metrics = evaluate_trial (
612- workflow = workflow ,
613- task_tag = args .task_tag ,
614- run_dir = run_dir ,
615- runner = runner ,
616- artifact_dir = artifact_dir ,
617- trial_name = f"trial_{ trial .number :04d} " ,
618- overrides = overrides ,
619- selected_tasks = selected_tasks ,
620- )
621618 for key , value in overrides .items ():
622619 trial .set_user_attr (key , value )
620+ try :
621+ metrics = evaluate_trial (
622+ workflow = workflow ,
623+ task_tag = args .task_tag ,
624+ run_dir = run_dir ,
625+ runner = runner ,
626+ artifact_dir = artifact_dir ,
627+ trial_name = f"trial_{ trial .number :04d} " ,
628+ overrides = overrides ,
629+ selected_tasks = selected_tasks ,
630+ run_tracking = args .run_tracking ,
631+ )
632+ except (subprocess .CalledProcessError , RuntimeError ) as exc :
633+ trial .set_user_attr ("workflow_failed" , True )
634+ trial .set_user_attr ("failure" , str (exc ))
635+ print (f"Trial { trial .number } : workflow failed -> { exc } " )
636+ return args .failed_trial_score
623637 for key in ("line_findable" , "true_lines" , "fake_lines" , "total_lines" , "unique_truth_ge2_lines" ,
624638 "findable" , "true_found" , "unique_true_findable" , "unique_true_all" , "total_found" ,
625639 "line_eff" , "eff" , "purity" , "fake_rate" , "duplicate_rate" , "f1" , "elapsed_ms" ):
@@ -643,7 +657,7 @@ def objective(trial: optuna.Trial) -> float:
643657 trial .set_user_attr ("guardrail_failed" , True )
644658 trial .set_user_attr ("guardrail_violations" , "," .join (input_guardrail_violations ))
645659 print (f"Trial { trial .number } : guardrail failed ({ ', ' .join (input_guardrail_violations )} ) -> { summarize_result (metrics )} " )
646- return - 1.0
660+ return args . failed_trial_score
647661
648662 violations = []
649663 baseline_eff_bins = baseline_metrics ["line_eff_bins" ] if args .stage == "input" else baseline_metrics ["eff_bins" ]
@@ -665,7 +679,7 @@ def objective(trial: optuna.Trial) -> float:
665679 trial .set_user_attr ("guardrail_failed" , True )
666680 trial .set_user_attr ("guardrail_violations" , "," .join (violations ))
667681 print (f"Trial { trial .number } : guardrail failed ({ ', ' .join (violations )} ) -> { summarize_result (metrics )} " )
668- return - 1.0
682+ return args . failed_trial_score
669683
670684 print (f"Trial { trial .number } : { summarize_result (metrics )} " )
671685 return metrics ["line_eff" ] if args .stage == "input" else metrics ["f1" ]
0 commit comments