|
115 | 115 | use anyhow::Result; |
116 | 116 | use mmlogger::mm_log_progress; |
117 | 117 | use rayon::prelude::*; |
| 118 | +use std::sync::atomic::{AtomicUsize, Ordering}; |
| 119 | +use std::sync::Mutex; |
118 | 120 | use thiserror::Error; |
119 | 121 |
|
120 | 122 | use crate::global::Evaluator; |
121 | 123 |
|
| 124 | +/// How many seconds must pass between periodic progress log lines during |
| 125 | +/// a grid search. Change this value in source to adjust the interval. |
| 126 | +const PROGRESS_LOG_INTERVAL_SECS: f64 = 10.0; |
| 127 | + |
122 | 128 | // ---------------------------------------------------------- |
123 | 129 | // Config / Errors |
124 | 130 | // ---------------------------------------------------------- |
@@ -339,7 +345,7 @@ impl UniformGridSearch { |
339 | 345 | /// # Panics |
340 | 346 | /// |
341 | 347 | /// Panics if `best_out.len()` != `num_dimensions`. |
342 | | - pub fn run<E: Evaluator + Sync, L: mmlogger::Logger>( |
| 348 | + pub fn run<E: Evaluator + Sync, L: mmlogger::Logger + Sync>( |
343 | 349 | &self, |
344 | 350 | evaluator: &E, |
345 | 351 | best_out: &mut [f64], |
@@ -384,11 +390,94 @@ impl UniformGridSearch { |
384 | 390 | self.cfg.num_dimensions |
385 | 391 | ); |
386 | 392 |
|
387 | | - // Step 3: Evaluate all grid points in parallel |
| 393 | + // Progress tracking shared across rayon threads. |
| 394 | + let completed = AtomicUsize::new(0); |
| 395 | + let last_log_time = |
| 396 | + Mutex::new(std::time::Instant::now()); |
| 397 | + // Tracks (best_cost, best_params) seen so far across all threads. |
| 398 | + let best_so_far: Mutex<(f64, Vec<f64>)> = Mutex::new(( |
| 399 | + f64::MAX, |
| 400 | + vec![f64::NAN; self.cfg.num_dimensions], |
| 401 | + )); |
| 402 | + |
| 403 | + // Step 3: Evaluate all grid points in parallel. |
| 404 | + // |
| 405 | + // After each evaluation, update the shared best and emit a |
| 406 | + // progress line at most once per PROGRESS_LOG_INTERVAL_SECS. |
388 | 407 | let results: Vec<(usize, f64)> = grid_points |
389 | 408 | .par_iter() |
390 | 409 | .enumerate() |
391 | | - .map(|(idx, params)| (idx, evaluator.evaluate(params))) |
| 410 | + .map(|(idx, params)| { |
| 411 | + let cost = evaluator.evaluate(params); |
| 412 | + |
| 413 | + // Update shared best if this result is an improvement. |
| 414 | + { |
| 415 | + let mut best = best_so_far.lock().unwrap(); |
| 416 | + if cost < best.0 { |
| 417 | + best.0 = cost; |
| 418 | + best.1.copy_from_slice(params); |
| 419 | + } |
| 420 | + } |
| 421 | + |
| 422 | + let done = |
| 423 | + completed.fetch_add(1, Ordering::Relaxed) + 1; |
| 424 | + |
| 425 | + // Decide whether to emit a progress line. |
| 426 | + let now = std::time::Instant::now(); |
| 427 | + let should_log = { |
| 428 | + let mut last_time = last_log_time.lock().unwrap(); |
| 429 | + if now |
| 430 | + .duration_since(*last_time) |
| 431 | + .as_secs_f64() |
| 432 | + >= PROGRESS_LOG_INTERVAL_SECS |
| 433 | + { |
| 434 | + *last_time = now; |
| 435 | + true |
| 436 | + } else { |
| 437 | + false |
| 438 | + } |
| 439 | + }; |
| 440 | + |
| 441 | + if should_log { |
| 442 | + let elapsed_secs = |
| 443 | + start_time.elapsed().as_secs_f64(); |
| 444 | + let remaining = total_points.saturating_sub(done); |
| 445 | + let eta_secs = if done > 0 && remaining > 0 { |
| 446 | + elapsed_secs * remaining as f64 / done as f64 |
| 447 | + } else { |
| 448 | + 0.0 |
| 449 | + }; |
| 450 | + let (current_best_cost, current_best_params) = { |
| 451 | + let g = best_so_far.lock().unwrap(); |
| 452 | + (g.0, g.1.clone()) |
| 453 | + }; |
| 454 | + |
| 455 | + let num_dimensions = current_best_params.len(); |
| 456 | + let params_str = if num_dimensions <= 10 { |
| 457 | + format!("{:?}", current_best_params) |
| 458 | + } else { |
| 459 | + format!( |
| 460 | + "[{:.6}, {:.6}, ... {} more]", |
| 461 | + current_best_params[0], |
| 462 | + current_best_params[1], |
| 463 | + num_dimensions - 2 |
| 464 | + ) |
| 465 | + }; |
| 466 | + |
| 467 | + mm_log_progress!( |
| 468 | + logger, |
| 469 | + "[UGS] {}/{} | best cost: {:.9}, params: {} | elapsed: {:.1}s (ETA: ~{:.0}s)", |
| 470 | + done, |
| 471 | + total_points, |
| 472 | + current_best_cost, |
| 473 | + params_str, |
| 474 | + elapsed_secs, |
| 475 | + eta_secs, |
| 476 | + ); |
| 477 | + } |
| 478 | + |
| 479 | + (idx, cost) |
| 480 | + }) |
392 | 481 | .collect(); |
393 | 482 |
|
394 | 483 | // Step 4: Find minimum cost |
|
0 commit comments