Skip to content

Commit 726103d

Browse files
committed
FIX2
Signed-off-by: Felix Schlepper <felix.schlepper@cern.ch>
1 parent a1d1852 commit 726103d

3 files changed

Lines changed: 39 additions & 23 deletions

File tree

Detectors/ITSMFT/ITS/macros/test/optimize_vertexer.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def parse_args() -> argparse.Namespace:
144144
parser.add_argument("--input-line-cap-factor", type=float, default=1.5, help="Maximum allowed factor increase in total lines for input-stage trials")
145145
parser.add_argument("--input-fake-line-cap-factor", type=float, default=1.5, help="Maximum allowed factor increase in fake lines for input-stage trials")
146146
parser.add_argument("--input-fake-line-cap-add", type=int, default=200, help="Minimum additive fake-line allowance for input-stage trials")
147+
parser.add_argument("--failed-trial-score", type=float, default=-1.0, help="Objective value assigned to trials whose workflow fails")
147148
parser.add_argument("--artifact-dir", type=Path, default=Path(".vertex_optuna"), help="Directory for trial workflows and summaries")
149+
parser.add_argument("--run-tracking", action="store_true", help="Run ITS tracking after vertexing instead of using vertexer-only trials")
148150
parser.add_argument("--dry-run", action="store_true", help="Patch and print one trial setup without running")
149151
return parser.parse_args()
150152

@@ -462,13 +464,16 @@ def patch_workflow_for_trial(
462464
workflow: dict[str, Any],
463465
task_tag: str,
464466
overrides: dict[str, Any],
467+
run_tracking: bool = False,
465468
) -> dict[str, Any]:
466469
patched = json.loads(json.dumps(workflow))
467470
_, patched_tasks = load_workflow_from_object(patched)
468471
prefix = f"{task_tag}_"
469472
full_overrides = dict(overrides)
470473
full_overrides["ITSVertexerParam.nIterations"] = 2
471474
full_overrides["ITSCATrackerParam.doUPCIteration"] = 1
475+
if not run_tracking:
476+
full_overrides["ITSCATrackerParam.nIterations"] = 0
472477
for task in patched_tasks:
473478
name = str(task.get("name", ""))
474479
if not name.startswith(prefix):
@@ -530,8 +535,9 @@ def evaluate_trial(
530535
trial_name: str,
531536
overrides: dict[str, Any],
532537
selected_tasks: list[TaskInfo],
538+
run_tracking: bool,
533539
) -> dict[str, Any]:
534-
patched = patch_workflow_for_trial(workflow, task_tag, overrides)
540+
patched = patch_workflow_for_trial(workflow, task_tag, overrides, run_tracking)
535541
workflow_copy = artifact_dir / f"{trial_name}_workflow.json"
536542
with workflow_copy.open("w") as stream:
537543
json.dump(patched, stream, indent=2)
@@ -573,7 +579,7 @@ def main() -> int:
573579
if args.dry_run:
574580
specs = INPUT_STAGE_SPECS if args.stage == "input" else VERTEX_STAGE_SPECS
575581
dry_overrides = {spec.key: base_params[spec.key] for spec in specs}
576-
patched = patch_workflow_for_trial(workflow, args.task_tag, dry_overrides)
582+
patched = patch_workflow_for_trial(workflow, args.task_tag, dry_overrides, args.run_tracking)
577583
workflow_copy = artifact_dir / "dry_run_workflow.json"
578584
with workflow_copy.open("w") as stream:
579585
json.dump(patched, stream, indent=2)
@@ -594,6 +600,7 @@ def main() -> int:
594600
trial_name="baseline",
595601
overrides={},
596602
selected_tasks=selected_tasks,
603+
run_tracking=args.run_tracking,
597604
)
598605
print(f"Baseline: {summarize_result(baseline_metrics)}")
599606

@@ -608,18 +615,25 @@ def main() -> int:
608615

609616
def objective(trial: optuna.Trial) -> float:
610617
overrides = suggest_params(trial, args.stage, base_params)
611-
metrics = evaluate_trial(
612-
workflow=workflow,
613-
task_tag=args.task_tag,
614-
run_dir=run_dir,
615-
runner=runner,
616-
artifact_dir=artifact_dir,
617-
trial_name=f"trial_{trial.number:04d}",
618-
overrides=overrides,
619-
selected_tasks=selected_tasks,
620-
)
621618
for key, value in overrides.items():
622619
trial.set_user_attr(key, value)
620+
try:
621+
metrics = evaluate_trial(
622+
workflow=workflow,
623+
task_tag=args.task_tag,
624+
run_dir=run_dir,
625+
runner=runner,
626+
artifact_dir=artifact_dir,
627+
trial_name=f"trial_{trial.number:04d}",
628+
overrides=overrides,
629+
selected_tasks=selected_tasks,
630+
run_tracking=args.run_tracking,
631+
)
632+
except (subprocess.CalledProcessError, RuntimeError) as exc:
633+
trial.set_user_attr("workflow_failed", True)
634+
trial.set_user_attr("failure", str(exc))
635+
print(f"Trial {trial.number}: workflow failed -> {exc}")
636+
return args.failed_trial_score
623637
for key in ("line_findable", "true_lines", "fake_lines", "total_lines", "unique_truth_ge2_lines",
624638
"findable", "true_found", "unique_true_findable", "unique_true_all", "total_found",
625639
"line_eff", "eff", "purity", "fake_rate", "duplicate_rate", "f1", "elapsed_ms"):
@@ -643,7 +657,7 @@ def objective(trial: optuna.Trial) -> float:
643657
trial.set_user_attr("guardrail_failed", True)
644658
trial.set_user_attr("guardrail_violations", ",".join(input_guardrail_violations))
645659
print(f"Trial {trial.number}: guardrail failed ({', '.join(input_guardrail_violations)}) -> {summarize_result(metrics)}")
646-
return -1.0
660+
return args.failed_trial_score
647661

648662
violations = []
649663
baseline_eff_bins = baseline_metrics["line_eff_bins"] if args.stage == "input" else baseline_metrics["eff_bins"]
@@ -665,7 +679,7 @@ def objective(trial: optuna.Trial) -> float:
665679
trial.set_user_attr("guardrail_failed", True)
666680
trial.set_user_attr("guardrail_violations", ",".join(violations))
667681
print(f"Trial {trial.number}: guardrail failed ({', '.join(violations)}) -> {summarize_result(metrics)}")
668-
return -1.0
682+
return args.failed_trial_score
669683

670684
print(f"Trial {trial.number}: {summarize_result(metrics)}")
671685
return metrics["line_eff"] if args.stage == "input" else metrics["f1"]

Detectors/ITSMFT/ITS/tracking/src/Configuration.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ std::vector<TrackingParameters> TrackingMode::getTrackingParameters(TrackingMode
264264
if (trackParams.size() > tc.nIterations) {
265265
trackParams.resize(tc.nIterations);
266266
}
267-
267+
trackParams.resize(0);
268268
return trackParams;
269269
}
270270

Detectors/ITSMFT/ITS/tracking/src/TrackingInterface.cxx

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,13 @@ void ITSTrackingInterface::run(framework::ProcessingContext& pc)
198198
return;
199199
}
200200

201-
if (mOverrideBeamEstimation) {
202-
mTimeFrame->setBeamPosition(mMeanVertex->getX(),
203-
mMeanVertex->getY(),
204-
mMeanVertex->getSigmaY2(),
205-
mTracker->getParameters()[0].LayerResolution[0],
206-
mTracker->getParameters()[0].SystErrorY2[0]);
207-
}
201+
const TrackingParameters defaultTrackingParams;
202+
const auto& trackParams = mTracker->getParameters();
203+
const auto& beamParams = trackParams.empty() ? defaultTrackingParams : trackParams[0];
204+
mTimeFrame->setBeamPosition(-0.050641048699617386, -0.02497512847185135,
205+
3.405e-06,
206+
beamParams.LayerResolution[0],
207+
beamParams.SystErrorY2[0]);
208208

209209
mTracker->setBz(o2::base::Propagator::Instance()->getNominalBz());
210210
mTracker->setTimeSlice(tfInfo.timeslice);
@@ -443,7 +443,7 @@ void ITSTrackingInterface::run(framework::ProcessingContext& pc)
443443
LOG(info) << fmt::format(" + Beam position computed for the TF: {}, {}", mTimeFrame->getBeamX(), mTimeFrame->getBeamY());
444444
}
445445

446-
if (hasClusters) {
446+
if (hasClusters && !mTracker->getParameters().empty()) {
447447
mTimeFrame->setMultiplicityCutMask(processMultiplictyMask);
448448
mTimeFrame->setUPCCutMask(processUPCMask);
449449
if (mMode == o2::its::TrackingMode::Async && o2::its::TrackerParamConfig::Instance().fataliseUponFailure) {
@@ -452,6 +452,8 @@ void ITSTrackingInterface::run(framework::ProcessingContext& pc)
452452
trackerElapsedTime = mTracker->clustersToTracks(logger, errorLogger);
453453
}
454454
LOGP(info, " + Tracking total elapse time: {} ms for {} tracks found", trackerElapsedTime, mTimeFrame->getNumberOfTracks());
455+
} else if (hasClusters) {
456+
LOGP(info, " + Tracking skipped: no tracking iterations configured");
455457
}
456458
if constexpr (constants::DoTimeBenchmarks) {
457459
const auto& trackConf = o2::its::TrackerParamConfig::Instance();

0 commit comments

Comments
 (0)