|
3 | 3 | #include <stan/io/empty_var_context.hpp> |
4 | 4 | #include <test/test-models/good/optimization/rosenbrock.hpp> |
5 | 5 | #include <test/unit/services/instrumented_callbacks.hpp> |
| 6 | +#include <boost/algorithm/string.hpp> |
6 | 7 | #include <iostream> |
7 | 8 |
|
8 | 9 | class ServicesSampleHmcNutsDiagEAdapt : public testing::Test { |
@@ -82,13 +83,13 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, parameter_checks) { |
82 | 83 | delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, |
83 | 84 | logger, init, parameter, diagnostic); |
84 | 85 |
|
85 | | - std::vector<std::vector<std::string> > parameter_names; |
| 86 | + std::vector<std::vector<std::string>> parameter_names; |
86 | 87 | parameter_names = parameter.vector_string_values(); |
87 | | - std::vector<std::vector<double> > parameter_values; |
| 88 | + std::vector<std::vector<double>> parameter_values; |
88 | 89 | parameter_values = parameter.vector_double_values(); |
89 | | - std::vector<std::vector<std::string> > diagnostic_names; |
| 90 | + std::vector<std::vector<std::string>> diagnostic_names; |
90 | 91 | diagnostic_names = diagnostic.vector_string_values(); |
91 | | - std::vector<std::vector<double> > diagnostic_values; |
| 92 | + std::vector<std::vector<double>> diagnostic_values; |
92 | 93 | diagnostic_values = diagnostic.vector_double_values(); |
93 | 94 |
|
94 | 95 | // Expectations of parameter parameter names. |
@@ -143,13 +144,13 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, output_sizes) { |
143 | 144 | delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, |
144 | 145 | logger, init, parameter, diagnostic); |
145 | 146 |
|
146 | | - std::vector<std::vector<std::string> > parameter_names; |
| 147 | + std::vector<std::vector<std::string>> parameter_names; |
147 | 148 | parameter_names = parameter.vector_string_values(); |
148 | | - std::vector<std::vector<double> > parameter_values; |
| 149 | + std::vector<std::vector<double>> parameter_values; |
149 | 150 | parameter_values = parameter.vector_double_values(); |
150 | | - std::vector<std::vector<std::string> > diagnostic_names; |
| 151 | + std::vector<std::vector<std::string>> diagnostic_names; |
151 | 152 | diagnostic_names = diagnostic.vector_string_values(); |
152 | | - std::vector<std::vector<double> > diagnostic_values; |
| 153 | + std::vector<std::vector<double>> diagnostic_values; |
153 | 154 | diagnostic_values = diagnostic.vector_double_values(); |
154 | 155 |
|
155 | 156 | EXPECT_EQ(return_code, 0); |
@@ -194,3 +195,267 @@ TEST_F(ServicesSampleHmcNutsDiagEAdapt, output_regression) { |
194 | 195 | EXPECT_EQ(1, logger.find_info("seconds (Total)")); |
195 | 196 | EXPECT_EQ(0, logger.call_count_error()); |
196 | 197 | } |
| 198 | + |
| 199 | +TEST_F(ServicesSampleHmcNutsDiagEAdapt, term_buffer_0) { |
| 200 | + unsigned int random_seed = 0; |
| 201 | + unsigned int chain = 1; |
| 202 | + double init_radius = 0; |
| 203 | + int num_warmup = 150; |
| 204 | + int num_samples = 10; |
| 205 | + int num_thin = 1; |
| 206 | + bool save_warmup = true; |
| 207 | + int refresh = 0; |
| 208 | + double stepsize = 1.0; |
| 209 | + double stepsize_jitter = 0.0; |
| 210 | + int max_depth = 10; |
| 211 | + double delta = .8; |
| 212 | + double gamma = .05; |
| 213 | + double kappa = .75; |
| 214 | + double t0 = 10; |
| 215 | + unsigned int init_buffer = 50; |
| 216 | + unsigned int term_buffer = 0; |
| 217 | + unsigned int window = 100; |
| 218 | + stan::test::unit::instrumented_interrupt interrupt; |
| 219 | + EXPECT_EQ(interrupt.call_count(), 0); |
| 220 | + |
| 221 | + stan::services::sample::hmc_nuts_diag_e_adapt( |
| 222 | + model, context, random_seed, chain, init_radius, num_warmup, num_samples, |
| 223 | + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, |
| 224 | + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, |
| 225 | + logger, init, parameter, diagnostic); |
| 226 | + |
| 227 | + EXPECT_EQ(0, logger.call_count_error()); |
| 228 | + int num_output_lines = (num_warmup + num_samples) / num_thin; |
| 229 | + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); |
| 230 | + |
| 231 | + std::vector<std::string> messages = parameter.string_values(); |
| 232 | + for (auto msg : messages) { |
| 233 | + if (msg.find("Step size") != std::string::npos) { |
| 234 | + EXPECT_NE("Step size = 1", msg); |
| 235 | + } |
| 236 | + } |
| 237 | +} |
| 238 | + |
| 239 | +TEST_F(ServicesSampleHmcNutsDiagEAdapt, term_buffer_1) { |
| 240 | + unsigned int random_seed = 0; |
| 241 | + unsigned int chain = 1; |
| 242 | + double init_radius = 0; |
| 243 | + int num_warmup = 150; |
| 244 | + int num_samples = 10; |
| 245 | + int num_thin = 1; |
| 246 | + bool save_warmup = true; |
| 247 | + int refresh = 0; |
| 248 | + double stepsize = 1.0; |
| 249 | + double stepsize_jitter = 0.0; |
| 250 | + int max_depth = 10; |
| 251 | + double delta = .8; |
| 252 | + double gamma = .05; |
| 253 | + double kappa = .75; |
| 254 | + double t0 = 10; |
| 255 | + unsigned int init_buffer = 49; |
| 256 | + unsigned int term_buffer = 1; |
| 257 | + unsigned int window = 100; |
| 258 | + stan::test::unit::instrumented_interrupt interrupt; |
| 259 | + EXPECT_EQ(interrupt.call_count(), 0); |
| 260 | + |
| 261 | + stan::services::sample::hmc_nuts_diag_e_adapt( |
| 262 | + model, context, random_seed, chain, init_radius, num_warmup, num_samples, |
| 263 | + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, |
| 264 | + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, |
| 265 | + logger, init, parameter, diagnostic); |
| 266 | + |
| 267 | + EXPECT_EQ(0, logger.call_count_error()); |
| 268 | + int num_output_lines = (num_warmup + num_samples) / num_thin; |
| 269 | + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); |
| 270 | + |
| 271 | + std::vector<std::vector<double>> draws = parameter.vector_double_values(); |
| 272 | + auto draw = draws[draws.size() - 1]; |
| 273 | + |
| 274 | + std::vector<std::string> messages = parameter.string_values(); |
| 275 | + for (auto msg : messages) { |
| 276 | + if (msg.find("Step size") != std::string::npos) { |
| 277 | + std::vector<std::string> toks; |
| 278 | + boost::split(toks, msg, boost::is_any_of(" ")); |
| 279 | + auto adapted = std::stod(toks[toks.size() - 1]); |
| 280 | + EXPECT_NEAR(draw[2], adapted, 1e-5); |
| 281 | + } |
| 282 | + } |
| 283 | +} |
| 284 | + |
| 285 | +TEST_F(ServicesSampleHmcNutsDiagEAdapt, no_stepsize_adapt) { |
| 286 | + unsigned int random_seed = 0; |
| 287 | + unsigned int chain = 1; |
| 288 | + double init_radius = 0; |
| 289 | + int num_warmup = 150; |
| 290 | + int num_samples = 10; |
| 291 | + int num_thin = 1; |
| 292 | + bool save_warmup = true; |
| 293 | + int refresh = 0; |
| 294 | + double stepsize = 1.0; |
| 295 | + double stepsize_jitter = 0.0; |
| 296 | + int max_depth = 10; |
| 297 | + double delta = .8; |
| 298 | + double gamma = .05; |
| 299 | + double kappa = .75; |
| 300 | + double t0 = 10; |
| 301 | + unsigned int init_buffer = 0; |
| 302 | + unsigned int term_buffer = 0; |
| 303 | + unsigned int window = 50; |
| 304 | + stan::test::unit::instrumented_interrupt interrupt; |
| 305 | + EXPECT_EQ(interrupt.call_count(), 0); |
| 306 | + |
| 307 | + stan::services::sample::hmc_nuts_diag_e_adapt( |
| 308 | + model, context, random_seed, chain, init_radius, num_warmup, num_samples, |
| 309 | + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, |
| 310 | + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, |
| 311 | + logger, init, parameter, diagnostic); |
| 312 | + |
| 313 | + EXPECT_EQ(0, logger.call_count_error()); |
| 314 | + int num_output_lines = (num_warmup + num_samples) / num_thin; |
| 315 | + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); |
| 316 | + |
| 317 | + std::vector<std::string> messages = parameter.string_values(); |
| 318 | + for (auto msg : messages) { |
| 319 | + if (msg.find("Step size") != std::string::npos) { |
| 320 | + EXPECT_NE("Step size = 1", msg); |
| 321 | + } |
| 322 | + } |
| 323 | +} |
| 324 | + |
| 325 | +TEST_F(ServicesSampleHmcNutsDiagEAdapt, schedule_a) { |
| 326 | + unsigned int random_seed = 0; |
| 327 | + unsigned int chain = 1; |
| 328 | + double init_radius = 0; |
| 329 | + int num_warmup = 35; |
| 330 | + int num_samples = 2; |
| 331 | + int num_thin = 1; |
| 332 | + bool save_warmup = true; |
| 333 | + int refresh = 0; |
| 334 | + double stepsize = 1.0; |
| 335 | + double stepsize_jitter = 0.0; |
| 336 | + int max_depth = 10; |
| 337 | + double delta = .8; |
| 338 | + double gamma = .05; |
| 339 | + double kappa = .75; |
| 340 | + double t0 = 10; |
| 341 | + unsigned int init_buffer = 5; |
| 342 | + unsigned int term_buffer = 0; |
| 343 | + unsigned int window = 20; |
| 344 | + stan::test::unit::instrumented_interrupt interrupt; |
| 345 | + EXPECT_EQ(interrupt.call_count(), 0); |
| 346 | + |
| 347 | + stan::services::sample::hmc_nuts_diag_e_adapt( |
| 348 | + model, context, random_seed, chain, init_radius, num_warmup, num_samples, |
| 349 | + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, |
| 350 | + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, |
| 351 | + logger, init, parameter, diagnostic); |
| 352 | + |
| 353 | + EXPECT_EQ(0, logger.call_count_error()); |
| 354 | + int num_output_lines = (num_warmup + num_samples) / num_thin; |
| 355 | + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); |
| 356 | + |
| 357 | + std::vector<std::vector<double>> draws = parameter.vector_double_values(); |
| 358 | + auto draw = draws[draws.size() - 1]; |
| 359 | + |
| 360 | + std::vector<std::string> messages = parameter.string_values(); |
| 361 | + for (auto msg : messages) { |
| 362 | + if (msg.find("Step size") != std::string::npos) { |
| 363 | + std::vector<std::string> toks; |
| 364 | + boost::split(toks, msg, boost::is_any_of(" ")); |
| 365 | + auto adapted = std::stod(toks[toks.size() - 1]); |
| 366 | + EXPECT_NEAR(draw[2], adapted, 1e-5); |
| 367 | + } |
| 368 | + } |
| 369 | +} |
| 370 | + |
| 371 | +TEST_F(ServicesSampleHmcNutsDiagEAdapt, schedule_b) { |
| 372 | + unsigned int random_seed = 0; |
| 373 | + unsigned int chain = 1; |
| 374 | + double init_radius = 0; |
| 375 | + int num_warmup = 36; |
| 376 | + int num_samples = 2; |
| 377 | + int num_thin = 1; |
| 378 | + bool save_warmup = true; |
| 379 | + int refresh = 0; |
| 380 | + double stepsize = 1.0; |
| 381 | + double stepsize_jitter = 0.0; |
| 382 | + int max_depth = 10; |
| 383 | + double delta = .8; |
| 384 | + double gamma = .05; |
| 385 | + double kappa = .75; |
| 386 | + double t0 = 10; |
| 387 | + unsigned int init_buffer = 5; |
| 388 | + unsigned int term_buffer = 1; |
| 389 | + unsigned int window = 30; |
| 390 | + stan::test::unit::instrumented_interrupt interrupt; |
| 391 | + EXPECT_EQ(interrupt.call_count(), 0); |
| 392 | + |
| 393 | + stan::services::sample::hmc_nuts_diag_e_adapt( |
| 394 | + model, context, random_seed, chain, init_radius, num_warmup, num_samples, |
| 395 | + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, |
| 396 | + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, |
| 397 | + logger, init, parameter, diagnostic); |
| 398 | + |
| 399 | + EXPECT_EQ(0, logger.call_count_error()); |
| 400 | + int num_output_lines = (num_warmup + num_samples) / num_thin; |
| 401 | + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); |
| 402 | + |
| 403 | + std::vector<std::vector<double>> draws = parameter.vector_double_values(); |
| 404 | + auto draw = draws[draws.size() - 1]; |
| 405 | + |
| 406 | + std::vector<std::string> messages = parameter.string_values(); |
| 407 | + for (auto msg : messages) { |
| 408 | + if (msg.find("Step size") != std::string::npos) { |
| 409 | + std::vector<std::string> toks; |
| 410 | + boost::split(toks, msg, boost::is_any_of(" ")); |
| 411 | + auto adapted = std::stod(toks[toks.size() - 1]); |
| 412 | + EXPECT_NEAR(draw[2], adapted, 1e-5); |
| 413 | + } |
| 414 | + } |
| 415 | +} |
| 416 | + |
| 417 | +TEST_F(ServicesSampleHmcNutsDiagEAdapt, schedule_c) { |
| 418 | + unsigned int random_seed = 0; |
| 419 | + unsigned int chain = 1; |
| 420 | + double init_radius = 0; |
| 421 | + int num_warmup = 35; |
| 422 | + int num_samples = 2; |
| 423 | + int num_thin = 1; |
| 424 | + bool save_warmup = true; |
| 425 | + int refresh = 0; |
| 426 | + double stepsize = 1.0; |
| 427 | + double stepsize_jitter = 0.0; |
| 428 | + int max_depth = 10; |
| 429 | + double delta = .8; |
| 430 | + double gamma = .05; |
| 431 | + double kappa = .75; |
| 432 | + double t0 = 10; |
| 433 | + unsigned int init_buffer = 0; |
| 434 | + unsigned int term_buffer = 0; |
| 435 | + unsigned int window = 25; |
| 436 | + stan::test::unit::instrumented_interrupt interrupt; |
| 437 | + EXPECT_EQ(interrupt.call_count(), 0); |
| 438 | + |
| 439 | + stan::services::sample::hmc_nuts_diag_e_adapt( |
| 440 | + model, context, random_seed, chain, init_radius, num_warmup, num_samples, |
| 441 | + num_thin, save_warmup, refresh, stepsize, stepsize_jitter, max_depth, |
| 442 | + delta, gamma, kappa, t0, init_buffer, term_buffer, window, interrupt, |
| 443 | + logger, init, parameter, diagnostic); |
| 444 | + |
| 445 | + EXPECT_EQ(0, logger.call_count_error()); |
| 446 | + int num_output_lines = (num_warmup + num_samples) / num_thin; |
| 447 | + EXPECT_EQ(num_output_lines, parameter.call_count("vector_double")); |
| 448 | + |
| 449 | + std::vector<std::vector<double>> draws = parameter.vector_double_values(); |
| 450 | + auto draw = draws[draws.size() - 1]; |
| 451 | + |
| 452 | + std::vector<std::string> messages = parameter.string_values(); |
| 453 | + for (auto msg : messages) { |
| 454 | + if (msg.find("Step size") != std::string::npos) { |
| 455 | + std::vector<std::string> toks; |
| 456 | + boost::split(toks, msg, boost::is_any_of(" ")); |
| 457 | + auto adapted = std::stod(toks[toks.size() - 1]); |
| 458 | + EXPECT_NEAR(draw[2], adapted, 1e-5); |
| 459 | + } |
| 460 | + } |
| 461 | +} |
0 commit comments