Skip to content

Commit 43da7ba

Browse files
Global solver - Uniform Grid Search - Log samples.
Only logs every 10 seconds, to avoid the log being filled with log messages.
1 parent 1a2de77 commit 43da7ba

1 file changed

Lines changed: 92 additions & 3 deletions

File tree

lib/rust/mmoptimise/src/global/uniform_grid_search.rs

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,16 @@
115115
use anyhow::Result;
116116
use mmlogger::mm_log_progress;
117117
use rayon::prelude::*;
118+
use std::sync::atomic::{AtomicUsize, Ordering};
119+
use std::sync::Mutex;
118120
use thiserror::Error;
119121

120122
use crate::global::Evaluator;
121123

124+
/// How many seconds must pass between periodic progress log lines during
125+
/// a grid search. Change this value in source to adjust the interval.
126+
const PROGRESS_LOG_INTERVAL_SECS: f64 = 10.0;
127+
122128
// ----------------------------------------------------------
123129
// Config / Errors
124130
// ----------------------------------------------------------
@@ -339,7 +345,7 @@ impl UniformGridSearch {
339345
/// # Panics
340346
///
341347
/// Panics if `best_out.len()` != `num_dimensions`.
342-
pub fn run<E: Evaluator + Sync, L: mmlogger::Logger>(
348+
pub fn run<E: Evaluator + Sync, L: mmlogger::Logger + Sync>(
343349
&self,
344350
evaluator: &E,
345351
best_out: &mut [f64],
@@ -384,11 +390,94 @@ impl UniformGridSearch {
384390
self.cfg.num_dimensions
385391
);
386392

387-
// Step 3: Evaluate all grid points in parallel
393+
// Progress tracking shared across rayon threads.
394+
let completed = AtomicUsize::new(0);
395+
let last_log_time =
396+
Mutex::new(std::time::Instant::now());
397+
// Tracks (best_cost, best_params) seen so far across all threads.
398+
let best_so_far: Mutex<(f64, Vec<f64>)> = Mutex::new((
399+
f64::MAX,
400+
vec![f64::NAN; self.cfg.num_dimensions],
401+
));
402+
403+
// Step 3: Evaluate all grid points in parallel.
404+
//
405+
// After each evaluation, update the shared best and emit a
406+
// progress line at most once per PROGRESS_LOG_INTERVAL_SECS.
388407
let results: Vec<(usize, f64)> = grid_points
389408
.par_iter()
390409
.enumerate()
391-
.map(|(idx, params)| (idx, evaluator.evaluate(params)))
410+
.map(|(idx, params)| {
411+
let cost = evaluator.evaluate(params);
412+
413+
// Update shared best if this result is an improvement.
414+
{
415+
let mut best = best_so_far.lock().unwrap();
416+
if cost < best.0 {
417+
best.0 = cost;
418+
best.1.copy_from_slice(params);
419+
}
420+
}
421+
422+
let done =
423+
completed.fetch_add(1, Ordering::Relaxed) + 1;
424+
425+
// Decide whether to emit a progress line.
426+
let now = std::time::Instant::now();
427+
let should_log = {
428+
let mut last_time = last_log_time.lock().unwrap();
429+
if now
430+
.duration_since(*last_time)
431+
.as_secs_f64()
432+
>= PROGRESS_LOG_INTERVAL_SECS
433+
{
434+
*last_time = now;
435+
true
436+
} else {
437+
false
438+
}
439+
};
440+
441+
if should_log {
442+
let elapsed_secs =
443+
start_time.elapsed().as_secs_f64();
444+
let remaining = total_points.saturating_sub(done);
445+
let eta_secs = if done > 0 && remaining > 0 {
446+
elapsed_secs * remaining as f64 / done as f64
447+
} else {
448+
0.0
449+
};
450+
let (current_best_cost, current_best_params) = {
451+
let g = best_so_far.lock().unwrap();
452+
(g.0, g.1.clone())
453+
};
454+
455+
let num_dimensions = current_best_params.len();
456+
let params_str = if num_dimensions <= 10 {
457+
format!("{:?}", current_best_params)
458+
} else {
459+
format!(
460+
"[{:.6}, {:.6}, ... {} more]",
461+
current_best_params[0],
462+
current_best_params[1],
463+
num_dimensions - 2
464+
)
465+
};
466+
467+
mm_log_progress!(
468+
logger,
469+
"[UGS] {}/{} | best cost: {:.9}, params: {} | elapsed: {:.1}s (ETA: ~{:.0}s)",
470+
done,
471+
total_points,
472+
current_best_cost,
473+
params_str,
474+
elapsed_secs,
475+
eta_secs,
476+
);
477+
}
478+
479+
(idx, cost)
480+
})
392481
.collect();
393482

394483
// Step 4: Find minimum cost

0 commit comments

Comments
 (0)