Skip to content

Commit ad78a8f

Browse files
committed
feat: add max_energy_error diagnostic
1 parent 470b73e commit ad78a8f

3 files changed

Lines changed: 15 additions & 1 deletion

File tree

src/adapt_strategy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ mod test {
476476
step_size: (),
477477
mass_matrix: (),
478478
},
479-
hamiltonian: (),
479+
hamiltonian: -1i64,
480480
point: TransformedPointStatsOptions {
481481
store_gradient: true,
482482
store_unconstrained: true,

src/stepsize/adapt.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ pub struct Strategy {
6060
pub last_sym_mean_tree_accept: f64,
6161
/// Last number of steps
6262
pub last_n_steps: u64,
63+
/// Maximum absolute energy error observed in the last trajectory
64+
pub last_max_energy_error: f64,
6365
}
6466

6567
impl Strategy {
@@ -82,6 +84,7 @@ impl Strategy {
8284
last_n_steps: 0,
8385
last_sym_mean_tree_accept: 0.0,
8486
last_mean_tree_accept: 0.0,
87+
last_max_energy_error: 0.0,
8588
}
8689
}
8790

@@ -187,6 +190,7 @@ impl Strategy {
187190
self.last_mean_tree_accept = mean;
188191
self.last_sym_mean_tree_accept = mean_sym;
189192
self.last_n_steps = n_steps;
193+
self.last_max_energy_error = collector.max_energy_error;
190194
}
191195

192196
pub fn update_estimator_early(&mut self) {
@@ -258,6 +262,7 @@ pub struct Stats {
258262
pub mean_tree_accept: f64,
259263
pub mean_tree_accept_sym: f64,
260264
pub n_steps: u64,
265+
pub max_energy_error: f64,
261266
}
262267

263268
impl<M: Math> SamplerStats<M> for Strategy {
@@ -280,6 +285,7 @@ impl<M: Math> SamplerStats<M> for Strategy {
280285
mean_tree_accept: self.last_mean_tree_accept,
281286
mean_tree_accept_sym: self.last_sym_mean_tree_accept,
282287
n_steps: self.last_n_steps,
288+
max_energy_error: self.last_max_energy_error,
283289
}
284290
}
285291
}

src/stepsize/dual_avg.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ pub struct AcceptanceRateCollector {
113113
initial_energy: f64,
114114
pub(crate) mean: RunningMean,
115115
pub(crate) mean_sym: RunningMean,
116+
pub(crate) max_energy_error: f64,
116117
}
117118

118119
impl AcceptanceRateCollector {
@@ -121,6 +122,7 @@ impl AcceptanceRateCollector {
121122
initial_energy: 0.,
122123
mean: RunningMean::new(),
123124
mean_sym: RunningMean::new(),
125+
max_energy_error: 0.,
124126
}
125127
}
126128
}
@@ -137,6 +139,7 @@ impl<M: Math, P: Point<M>> Collector<M, P> for AcceptanceRateCollector {
137139
Some(_) => {
138140
self.mean.add(0.);
139141
self.mean_sym.add(0.);
142+
self.max_energy_error = f64::NEG_INFINITY;
140143
}
141144
None => {
142145
let base_energy = self.initial_energy;
@@ -146,6 +149,10 @@ impl<M: Math, P: Point<M>> Collector<M, P> for AcceptanceRateCollector {
146149
self.mean.add(diff.min(0.).exp());
147150
self.mean_sym
148151
.add(2. * diff.min(0.).exp() / (1. + diff.exp()));
152+
let energy_error = diff;
153+
if energy_error.abs() > self.max_energy_error.abs() {
154+
self.max_energy_error = energy_error;
155+
}
149156
}
150157
};
151158
}
@@ -154,5 +161,6 @@ impl<M: Math, P: Point<M>> Collector<M, P> for AcceptanceRateCollector {
154161
self.initial_energy = state.energy();
155162
self.mean.reset();
156163
self.mean_sym.reset();
164+
self.max_energy_error = 0.;
157165
}
158166
}

0 commit comments

Comments
 (0)