Skip to content

Commit 47ef60d

Browse files
authored
Merge pull request #3346 from stan-dev/bugfix/3023-term-buffer-0
Bugfix/3023 term buffer 0 #3023
2 parents 1c10b04 + f651a80 commit 47ef60d

2 files changed

Lines changed: 277 additions & 9 deletions

File tree

src/stan/mcmc/stepsize_adaptation.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ class stepsize_adaptation : public base_adaptation {
7070
epsilon = std::exp(x);
7171
}
7272

73-
void complete_adaptation(double& epsilon) { epsilon = std::exp(x_bar_); }
73+
void complete_adaptation(double& epsilon) {
74+
if (counter_ > 0)
75+
epsilon = std::exp(x_bar_);
76+
}
7477

7578
protected:
7679
double counter_; // Adaptation iteration

src/test/unit/services/sample/hmc_nuts_diag_e_adapt_test.cpp

Lines changed: 273 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <stan/io/empty_var_context.hpp>
44
#include <test/test-models/good/optimization/rosenbrock.hpp>
55
#include <test/unit/services/instrumented_callbacks.hpp>
6+
#include <boost/algorithm/string.hpp>
67
#include <iostream>
78

89
class ServicesSampleHmcNutsDiagEAdapt : public testing::Test {
@@ -82,13 +83,13 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, parameter_checks) {
8283
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
8384
logger, init, parameter, diagnostic);
8485

85-
std::vector<std::vector<std::string> > parameter_names;
86+
std::vector<std::vector<std::string>> parameter_names;
8687
parameter_names = parameter.vector_string_values();
87-
std::vector<std::vector<double> > parameter_values;
88+
std::vector<std::vector<double>> parameter_values;
8889
parameter_values = parameter.vector_double_values();
89-
std::vector<std::vector<std::string> > diagnostic_names;
90+
std::vector<std::vector<std::string>> diagnostic_names;
9091
diagnostic_names = diagnostic.vector_string_values();
91-
std::vector<std::vector<double> > diagnostic_values;
92+
std::vector<std::vector<double>> diagnostic_values;
9293
diagnostic_values = diagnostic.vector_double_values();
9394

9495
// Expectations of parameter parameter names.
@@ -143,13 +144,13 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, output_sizes) {
143144
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
144145
logger, init, parameter, diagnostic);
145146

146-
std::vector<std::vector<std::string> > parameter_names;
147+
std::vector<std::vector<std::string>> parameter_names;
147148
parameter_names = parameter.vector_string_values();
148-
std::vector<std::vector<double> > parameter_values;
149+
std::vector<std::vector<double>> parameter_values;
149150
parameter_values = parameter.vector_double_values();
150-
std::vector<std::vector<std::string> > diagnostic_names;
151+
std::vector<std::vector<std::string>> diagnostic_names;
151152
diagnostic_names = diagnostic.vector_string_values();
152-
std::vector<std::vector<double> > diagnostic_values;
153+
std::vector<std::vector<double>> diagnostic_values;
153154
diagnostic_values = diagnostic.vector_double_values();
154155

155156
EXPECT_EQ(return_code, 0);
@@ -194,3 +195,267 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, output_regression) {
194195
EXPECT_EQ(1, logger.find_info("seconds (Total)"));
195196
EXPECT_EQ(0, logger.call_count_error());
196197
}
198+
199+
TEST_F(ServicesSampleHmcNutsDiagEAdapt, term_buffer_0) {
200+
unsigned int random_seed = 0;
201+
unsigned int chain = 1;
202+
double init_radius = 0;
203+
int num_warmup = 150;
204+
int num_samples = 10;
205+
int num_thin = 1;
206+
bool save_warmup = true;
207+
int refresh = 0;
208+
double stepsize = 1.0;
209+
double stepsize_jitter = 0.0;
210+
int max_depth = 10;
211+
double delta = .8;
212+
double gamma = .05;
213+
double kappa = .75;
214+
double t0 = 10;
215+
unsigned int init_buffer = 50;
216+
unsigned int term_buffer = 0;
217+
unsigned int window = 100;
218+
stan::test::unit::instrumented_interrupt interrupt;
219+
EXPECT_EQ(interrupt.call_count(), 0);
220+
221+
stan::services::sample::hmc_nuts_diag_e_adapt(
222+
model, context, random_seed, chain, init_radius, num_warmup, num_samples,
223+
num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
224+
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
225+
logger, init, parameter, diagnostic);
226+
227+
EXPECT_EQ(0, logger.call_count_error());
228+
int num_output_lines = (num_warmup + num_samples) / num_thin;
229+
EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
230+
231+
std::vector<std::string> messages = parameter.string_values();
232+
for (auto msg : messages) {
233+
if (msg.find("Step size") != std::string::npos) {
234+
EXPECT_NE("Step size = 1", msg);
235+
}
236+
}
237+
}
238+
239+
TEST_F(ServicesSampleHmcNutsDiagEAdapt, term_buffer_1) {
240+
unsigned int random_seed = 0;
241+
unsigned int chain = 1;
242+
double init_radius = 0;
243+
int num_warmup = 150;
244+
int num_samples = 10;
245+
int num_thin = 1;
246+
bool save_warmup = true;
247+
int refresh = 0;
248+
double stepsize = 1.0;
249+
double stepsize_jitter = 0.0;
250+
int max_depth = 10;
251+
double delta = .8;
252+
double gamma = .05;
253+
double kappa = .75;
254+
double t0 = 10;
255+
unsigned int init_buffer = 49;
256+
unsigned int term_buffer = 1;
257+
unsigned int window = 100;
258+
stan::test::unit::instrumented_interrupt interrupt;
259+
EXPECT_EQ(interrupt.call_count(), 0);
260+
261+
stan::services::sample::hmc_nuts_diag_e_adapt(
262+
model, context, random_seed, chain, init_radius, num_warmup, num_samples,
263+
num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
264+
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
265+
logger, init, parameter, diagnostic);
266+
267+
EXPECT_EQ(0, logger.call_count_error());
268+
int num_output_lines = (num_warmup + num_samples) / num_thin;
269+
EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
270+
271+
std::vector<std::vector<double>> draws = parameter.vector_double_values();
272+
auto draw = draws[draws.size() - 1];
273+
274+
std::vector<std::string> messages = parameter.string_values();
275+
for (auto msg : messages) {
276+
if (msg.find("Step size") != std::string::npos) {
277+
std::vector<std::string> toks;
278+
boost::split(toks, msg, boost::is_any_of(" "));
279+
auto adapted = std::stod(toks[toks.size() - 1]);
280+
EXPECT_NEAR(draw[2], adapted, 1e-5);
281+
}
282+
}
283+
}
284+
285+
TEST_F(ServicesSampleHmcNutsDiagEAdapt, no_stepsize_adapt) {
286+
unsigned int random_seed = 0;
287+
unsigned int chain = 1;
288+
double init_radius = 0;
289+
int num_warmup = 150;
290+
int num_samples = 10;
291+
int num_thin = 1;
292+
bool save_warmup = true;
293+
int refresh = 0;
294+
double stepsize = 1.0;
295+
double stepsize_jitter = 0.0;
296+
int max_depth = 10;
297+
double delta = .8;
298+
double gamma = .05;
299+
double kappa = .75;
300+
double t0 = 10;
301+
unsigned int init_buffer = 0;
302+
unsigned int term_buffer = 0;
303+
unsigned int window = 50;
304+
stan::test::unit::instrumented_interrupt interrupt;
305+
EXPECT_EQ(interrupt.call_count(), 0);
306+
307+
stan::services::sample::hmc_nuts_diag_e_adapt(
308+
model, context, random_seed, chain, init_radius, num_warmup, num_samples,
309+
num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
310+
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
311+
logger, init, parameter, diagnostic);
312+
313+
EXPECT_EQ(0, logger.call_count_error());
314+
int num_output_lines = (num_warmup + num_samples) / num_thin;
315+
EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
316+
317+
std::vector<std::string> messages = parameter.string_values();
318+
for (auto msg : messages) {
319+
if (msg.find("Step size") != std::string::npos) {
320+
EXPECT_NE("Step size = 1", msg);
321+
}
322+
}
323+
}
324+
325+
TEST_F(ServicesSampleHmcNutsDiagEAdapt, schedule_a) {
326+
unsigned int random_seed = 0;
327+
unsigned int chain = 1;
328+
double init_radius = 0;
329+
int num_warmup = 35;
330+
int num_samples = 2;
331+
int num_thin = 1;
332+
bool save_warmup = true;
333+
int refresh = 0;
334+
double stepsize = 1.0;
335+
double stepsize_jitter = 0.0;
336+
int max_depth = 10;
337+
double delta = .8;
338+
double gamma = .05;
339+
double kappa = .75;
340+
double t0 = 10;
341+
unsigned int init_buffer = 5;
342+
unsigned int term_buffer = 0;
343+
unsigned int window = 20;
344+
stan::test::unit::instrumented_interrupt interrupt;
345+
EXPECT_EQ(interrupt.call_count(), 0);
346+
347+
stan::services::sample::hmc_nuts_diag_e_adapt(
348+
model, context, random_seed, chain, init_radius, num_warmup, num_samples,
349+
num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
350+
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
351+
logger, init, parameter, diagnostic);
352+
353+
EXPECT_EQ(0, logger.call_count_error());
354+
int num_output_lines = (num_warmup + num_samples) / num_thin;
355+
EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
356+
357+
std::vector<std::vector<double>> draws = parameter.vector_double_values();
358+
auto draw = draws[draws.size() - 1];
359+
360+
std::vector<std::string> messages = parameter.string_values();
361+
for (auto msg : messages) {
362+
if (msg.find("Step size") != std::string::npos) {
363+
std::vector<std::string> toks;
364+
boost::split(toks, msg, boost::is_any_of(" "));
365+
auto adapted = std::stod(toks[toks.size() - 1]);
366+
EXPECT_NEAR(draw[2], adapted, 1e-5);
367+
}
368+
}
369+
}
370+
371+
TEST_F(ServicesSampleHmcNutsDiagEAdapt, schedule_b) {
372+
unsigned int random_seed = 0;
373+
unsigned int chain = 1;
374+
double init_radius = 0;
375+
int num_warmup = 36;
376+
int num_samples = 2;
377+
int num_thin = 1;
378+
bool save_warmup = true;
379+
int refresh = 0;
380+
double stepsize = 1.0;
381+
double stepsize_jitter = 0.0;
382+
int max_depth = 10;
383+
double delta = .8;
384+
double gamma = .05;
385+
double kappa = .75;
386+
double t0 = 10;
387+
unsigned int init_buffer = 5;
388+
unsigned int term_buffer = 1;
389+
unsigned int window = 30;
390+
stan::test::unit::instrumented_interrupt interrupt;
391+
EXPECT_EQ(interrupt.call_count(), 0);
392+
393+
stan::services::sample::hmc_nuts_diag_e_adapt(
394+
model, context, random_seed, chain, init_radius, num_warmup, num_samples,
395+
num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
396+
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
397+
logger, init, parameter, diagnostic);
398+
399+
EXPECT_EQ(0, logger.call_count_error());
400+
int num_output_lines = (num_warmup + num_samples) / num_thin;
401+
EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
402+
403+
std::vector<std::vector<double>> draws = parameter.vector_double_values();
404+
auto draw = draws[draws.size() - 1];
405+
406+
std::vector<std::string> messages = parameter.string_values();
407+
for (auto msg : messages) {
408+
if (msg.find("Step size") != std::string::npos) {
409+
std::vector<std::string> toks;
410+
boost::split(toks, msg, boost::is_any_of(" "));
411+
auto adapted = std::stod(toks[toks.size() - 1]);
412+
EXPECT_NEAR(draw[2], adapted, 1e-5);
413+
}
414+
}
415+
}
416+
417+
TEST_F(ServicesSampleHmcNutsDiagEAdapt, schedule_c) {
418+
unsigned int random_seed = 0;
419+
unsigned int chain = 1;
420+
double init_radius = 0;
421+
int num_warmup = 35;
422+
int num_samples = 2;
423+
int num_thin = 1;
424+
bool save_warmup = true;
425+
int refresh = 0;
426+
double stepsize = 1.0;
427+
double stepsize_jitter = 0.0;
428+
int max_depth = 10;
429+
double delta = .8;
430+
double gamma = .05;
431+
double kappa = .75;
432+
double t0 = 10;
433+
unsigned int init_buffer = 0;
434+
unsigned int term_buffer = 0;
435+
unsigned int window = 25;
436+
stan::test::unit::instrumented_interrupt interrupt;
437+
EXPECT_EQ(interrupt.call_count(), 0);
438+
439+
stan::services::sample::hmc_nuts_diag_e_adapt(
440+
model, context, random_seed, chain, init_radius, num_warmup, num_samples,
441+
num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth,
442+
delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt,
443+
logger, init, parameter, diagnostic);
444+
445+
EXPECT_EQ(0, logger.call_count_error());
446+
int num_output_lines = (num_warmup + num_samples) / num_thin;
447+
EXPECT_EQ(num_output_lines, parameter.call_count("vector_double"));
448+
449+
std::vector<std::vector<double>> draws = parameter.vector_double_values();
450+
auto draw = draws[draws.size() - 1];
451+
452+
std::vector<std::string> messages = parameter.string_values();
453+
for (auto msg : messages) {
454+
if (msg.find("Step size") != std::string::npos) {
455+
std::vector<std::string> toks;
456+
boost::split(toks, msg, boost::is_any_of(" "));
457+
auto adapted = std::stod(toks[toks.size() - 1]);
458+
EXPECT_NEAR(draw[2], adapted, 1e-5);
459+
}
460+
}
461+
}

0 commit comments

Comments
 (0)