Skip to content

Commit 450f974

Browse files
authored
Merge pull request #3373 from stan-dev/fix/3372-gq-timings
Add timing information to standalone GQ
2 parents 11d1639 + c0b5a7c commit 450f974

7 files changed

Lines changed: 67 additions & 4 deletions

File tree

src/stan/services/sample/standalone_gqs.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ int standalone_generate(const Model &model, const Eigen::MatrixXd &draws,
6767

6868
std::vector<double> unconstrained_params_r;
6969
std::vector<double> row(draws.cols());
70+
auto start = std::chrono::steady_clock::now();
71+
7072
try {
7173
for (size_t i = 0; i < draws.rows(); ++i) {
7274
Eigen::Map<Eigen::VectorXd>(&row[0], draws.cols()) = draws.row(i);
@@ -85,6 +87,13 @@ int standalone_generate(const Model &model, const Eigen::MatrixXd &draws,
8587
logger.error(e.what());
8688
return error_codes::SOFTWARE;
8789
}
90+
auto end = std::chrono::steady_clock::now();
91+
double gq_delta_t
92+
= std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
93+
.count()
94+
/ 1000.0;
95+
writer.write_timing(gq_delta_t);
96+
8897
return error_codes::OK;
8998
}
9099

@@ -161,6 +170,7 @@ int standalone_generate(const Model &model, const int num_chains,
161170
std::stringstream msg;
162171
for (size_t slice_idx = r.begin(); slice_idx != r.end();
163172
++slice_idx) {
173+
auto start = std::chrono::steady_clock::now();
164174
for (size_t i = 0; i < draws[slice_idx].rows(); ++i) {
165175
if (error_any)
166176
return;
@@ -178,6 +188,13 @@ int standalone_generate(const Model &model, const int num_chains,
178188
writers[slice_idx].write_gq_values(model, rngs[slice_idx],
179189
unconstrained_params_r);
180190
}
191+
auto end = std::chrono::steady_clock::now();
192+
double gq_delta_t
193+
= std::chrono::duration_cast<std::chrono::milliseconds>(end
194+
- start)
195+
.count()
196+
/ 1000.0;
197+
writers[slice_idx].write_timing(gq_delta_t);
181198
}
182199
},
183200
tbb::simple_partitioner());

src/stan/services/util/gq_writer.hpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,28 @@ namespace stan {
1616
namespace services {
1717
namespace util {
1818

19+
namespace internal {
20+
21+
/**
22+
* Internal method
23+
*
24+
* Logs timing information
25+
*
26+
* @param[in] deltaT time in seconds
27+
*/
28+
template <typename F>
29+
void write_timing(double deltaT, F writer) {
30+
std::string title(" Elapsed Time: ");
31+
writer("");
32+
33+
std::stringstream ss1;
34+
ss1 << title << deltaT << " seconds (Generated Quantities)";
35+
writer(ss1.str());
36+
37+
writer("");
38+
}
39+
} // namespace internal
40+
1941
/**
2042
* gq_writer writes out
2143
*
@@ -128,6 +150,18 @@ class gq_writer {
128150
}
129151
sample_writer_(values);
130152
}
153+
154+
/**
155+
* Print timing information to all streams
156+
*
157+
* @param[in] deltaT time in seconds
158+
*/
159+
void write_timing(double deltaT) {
160+
internal::write_timing(
161+
deltaT, [this](const std::string& msg) { this->sample_writer_(msg); });
162+
internal::write_timing(
163+
deltaT, [this](const std::string& msg) { this->logger_.info(msg); });
164+
}
131165
};
132166

133167
} // namespace util

src/test/unit/services/sample/standalone_gqs_2390_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ TEST_F(ServicesStandaloneGQ4, genDraws_gq_test_vec_len_1) {
4848
sample_writer);
4949
EXPECT_EQ(return_code, stan::services::error_codes::OK);
5050
EXPECT_EQ(count_matches("y_est", sample_ss.str()), 5);
51-
EXPECT_EQ(count_matches("\n", sample_ss.str()), 1001);
51+
EXPECT_EQ(count_matches("\n", sample_ss.str()), 1004);
5252
match_csv_columns(multidim_csv.samples, sample_ss.str(), 1000, 0, 6);
5353
}

src/test/unit/services/sample/standalone_gqs_multidim_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,6 @@ TEST_F(ServicesStandaloneGQ2, genDraws_gq_test_multidim) {
5555
sample_writer);
5656
EXPECT_EQ(return_code, stan::services::error_codes::OK);
5757
EXPECT_EQ(count_matches("gq_ar_mat", sample_ss.str()), 120);
58-
EXPECT_EQ(count_matches("\n", sample_ss.str()), 1001);
58+
EXPECT_EQ(count_matches("\n", sample_ss.str()), 1004);
5959
match_csv_columns(multidim_csv.samples, sample_ss.str(), 1000, 120, 127);
6060
}

src/test/unit/services/sample/standalone_gqs_parallel_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ TEST_F(ServicesStandaloneGQ, genDraws_bernoulli) {
7474
for (int i = 0; i < num_chains; i++) {
7575
EXPECT_EQ(count_matches("mu", sample_ss[i].str()), 1);
7676
EXPECT_EQ(count_matches("y_rep", sample_ss[i].str()), 10);
77-
EXPECT_EQ(count_matches("\n", sample_ss[i].str()), 1001);
77+
EXPECT_EQ(count_matches("\n", sample_ss[i].str()), 1004);
7878
match_csv_columns(bern_csv.samples, sample_ss[i].str(), 1000, 1, 8);
7979
}
8080
}

src/test/unit/services/sample/standalone_gqs_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ TEST_F(ServicesStandaloneGQ, genDraws_bernoulli) {
5555
EXPECT_EQ(return_code, stan::services::error_codes::OK);
5656
EXPECT_EQ(count_matches("mu", sample_ss.str()), 1);
5757
EXPECT_EQ(count_matches("y_rep", sample_ss.str()), 10);
58-
EXPECT_EQ(count_matches("\n", sample_ss.str()), 1001);
58+
EXPECT_EQ(count_matches("\n", sample_ss.str()), 1004);
5959
match_csv_columns(bern_csv.samples, sample_ss.str(), 1000, 1, 8);
6060
}
6161

src/test/unit/services/util/gq_writer_test.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,18 @@ TEST_F(ServicesUtilGQWriter, t2) {
4646
EXPECT_EQ(count_matches("nan", sample_ss.str()), 0);
4747
}
4848

49+
TEST_F(ServicesUtilGQWriter, timing) {
50+
stan::callbacks::stream_writer sample_writer(sample_ss, "#");
51+
stan::callbacks::stream_logger logger(logger_ss, logger_ss, logger_ss,
52+
logger_ss, logger_ss);
53+
stan::services::util::gq_writer writer(sample_writer, logger, 2);
54+
writer.write_timing(4.31);
55+
// model test_gq.stan generates 4 values, 3 commas
56+
EXPECT_EQ(count_matches("4.31 seconds", logger_ss.str()), 1);
57+
EXPECT_EQ(count_matches("4.31 seconds", sample_ss.str()), 1)
58+
<< sample_ss.str();
59+
}
60+
4961
TEST_F(ServicesUtilGQWriter, TestExceptions) {
5062
stan::callbacks::stream_writer sample_writer(sample_ss, "");
5163
stan::callbacks::stream_logger logger(logger_ss, logger_ss, logger_ss,

0 commit comments

Comments
 (0)