Skip to content

Commit 7c55ef7

Browse files
authored
Merge pull request #2439 from stan-dev/bugfix/issue-2438
multi_normal_cholesky with optimization
2 parents c09e004 + 5b1c5f9 commit 7c55ef7

1 file changed

Lines changed: 180 additions & 41 deletions

File tree

stan/math/prim/prob/multi_normal_cholesky_lpdf.hpp

Lines changed: 180 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <stan/math/prim/fun/log.hpp>
1010
#include <stan/math/prim/fun/max_size_mvt.hpp>
1111
#include <stan/math/prim/fun/mdivide_left_tri.hpp>
12+
#include <stan/math/prim/fun/mdivide_right_tri.hpp>
1213
#include <stan/math/prim/fun/size_mvt.hpp>
1314
#include <stan/math/prim/fun/sum.hpp>
1415
#include <stan/math/prim/fun/transpose.hpp>
@@ -23,6 +24,8 @@ namespace math {
2324
* a Cholesky factor L of the variance matrix.
2425
* Sigma = LL', a square, semi-positive definite matrix.
2526
*
27+
* This version of the function is vectorized on y and mu.
28+
*
2629
* Analytic expressions taken from
2730
* http://qwone.com/~jason/writing/multivariateNormal.pdf
2831
* written by Jason D. M. Rennie.
@@ -39,6 +42,7 @@ namespace math {
3942
* @tparam T_covar Type of scale.
4043
*/
4144
template <bool propto, typename T_y, typename T_loc, typename T_covar,
45+
require_any_not_vector_vt<is_stan_scalar, T_y, T_loc>* = nullptr,
4246
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
4347
T_y, T_loc, T_covar>* = nullptr>
4448
return_type_t<T_y, T_loc, T_covar> multi_normal_cholesky_lpdf(
@@ -72,27 +76,26 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_cholesky_lpdf(
7276

7377
const int size_y = y_vec[0].size();
7478
const int size_mu = mu_vec[0].size();
75-
if (likely(size_vec > 1)) {
76-
// check size consistency of all random variables y
77-
for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {
78-
check_size_match(function,
79-
"Size of one of the vectors of "
80-
"the random variable",
81-
y_vec[i].size(),
82-
"Size of the first vector of the "
83-
"random variable",
84-
size_y);
85-
}
86-
// check size consistency of all means mu
87-
for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) {
88-
check_size_match(function,
89-
"Size of one of the vectors of "
90-
"the location variable",
91-
mu_vec[i].size(),
92-
"Size of the first vector of the "
93-
"location variable",
94-
size_mu);
95-
}
79+
80+
// check size consistency of all random variables y
81+
for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {
82+
check_size_match(function,
83+
"Size of one of the vectors of "
84+
"the random variable",
85+
y_vec[i].size(),
86+
"Size of the first vector of the "
87+
"random variable",
88+
size_y);
89+
}
90+
// check size consistency of all means mu
91+
for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) {
92+
check_size_match(function,
93+
"Size of one of the vectors of "
94+
"the location variable",
95+
mu_vec[i].size(),
96+
"Size of the first vector of the "
97+
"location variable",
98+
size_mu);
9699
}
97100

98101
check_size_match(function, "Size of random variable", size_y,
@@ -119,40 +122,176 @@ return_type_t<T_y, T_loc, T_covar> multi_normal_cholesky_lpdf(
119122
logp += NEG_LOG_SQRT_TWO_PI * size_y * size_vec;
120123
}
121124

122-
const matrix_partials_t inv_L_dbl
123-
= mdivide_left_tri<Eigen::Lower>(value_of(L_ref));
124-
125125
if (include_summand<propto, T_y, T_loc, T_covar_elem>::value) {
126+
Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>
127+
y_val_minus_mu_val(size_y, size_vec);
128+
126129
for (size_t i = 0; i < size_vec; i++) {
127130
decltype(auto) y_val = as_value_column_vector_or_scalar(y_vec[i]);
128131
decltype(auto) mu_val = as_value_column_vector_or_scalar(mu_vec[i]);
132+
y_val_minus_mu_val.col(i) = y_val - mu_val;
133+
}
134+
135+
matrix_partials_t half;
136+
matrix_partials_t scaled_diff;
137+
138+
// If the covariance is not autodiff, we can avoid computing a matrix
139+
// inverse
140+
if (is_constant<T_covar_elem>::value) {
141+
matrix_partials_t L_val = value_of(L_ref);
129142

130-
const row_vector_partials_t half
131-
= (inv_L_dbl.template triangularView<Eigen::Lower>()
132-
* (y_val - mu_val).template cast<T_partials_return>())
133-
.transpose();
134-
const vector_partials_t scaled_diff
135-
= (half * inv_L_dbl.template triangularView<Eigen::Lower>())
136-
.transpose();
143+
half = mdivide_left_tri<Eigen::Lower>(L_val, y_val_minus_mu_val)
144+
.transpose();
145+
146+
scaled_diff = mdivide_right_tri<Eigen::Lower>(half, L_val).transpose();
147+
148+
if (include_summand<propto>::value) {
149+
logp -= sum(log(L_val.diagonal())) * size_vec;
150+
}
151+
} else {
152+
matrix_partials_t inv_L_val
153+
= mdivide_left_tri<Eigen::Lower>(value_of(L_ref));
137154

138-
logp -= 0.5 * dot_self(half);
155+
half = (inv_L_val.template triangularView<Eigen::Lower>()
156+
* y_val_minus_mu_val)
157+
.transpose();
139158

159+
scaled_diff = (half * inv_L_val.template triangularView<Eigen::Lower>())
160+
.transpose();
161+
162+
logp += sum(log(inv_L_val.diagonal())) * size_vec;
163+
ops_partials.edge3_.partials_ -= size_vec * inv_L_val.transpose();
164+
165+
for (size_t i = 0; i < size_vec; i++) {
166+
ops_partials.edge3_.partials_vec_[i]
167+
+= scaled_diff.col(i) * half.row(i);
168+
}
169+
}
170+
171+
logp -= 0.5 * sum(columns_dot_self(half));
172+
173+
for (size_t i = 0; i < size_vec; i++) {
140174
if (!is_constant_all<T_y>::value) {
141-
ops_partials.edge1_.partials_vec_[i] -= scaled_diff;
175+
ops_partials.edge1_.partials_vec_[i] -= scaled_diff.col(i);
142176
}
143177
if (!is_constant_all<T_loc>::value) {
144-
ops_partials.edge2_.partials_vec_[i] += scaled_diff;
145-
}
146-
if (!is_constant_all<T_covar>::value) {
147-
ops_partials.edge3_.partials_ += scaled_diff * half;
178+
ops_partials.edge2_.partials_vec_[i] += scaled_diff.col(i);
148179
}
149180
}
150181
}
151182

152-
if (include_summand<propto, T_covar_elem>::value) {
153-
logp += sum(log(inv_L_dbl.diagonal())) * size_vec;
154-
if (!is_constant_all<T_covar>::value) {
155-
ops_partials.edge3_.partials_ -= size_vec * inv_L_dbl.transpose();
183+
return ops_partials.build(logp);
184+
}
185+
186+
/** \ingroup multivar_dists
187+
* The log of the multivariate normal density for the given y, mu, and
188+
* a Cholesky factor L of the variance matrix.
189+
* Sigma = LL', a square, semi-positive definite matrix.
190+
*
191+
* Analytic expressions taken from
192+
* http://qwone.com/~jason/writing/multivariateNormal.pdf
193+
* written by Jason D. M. Rennie.
194+
*
195+
* @param y A scalar vector
196+
* @param mu The mean vector of the multivariate normal distribution.
197+
* @param L The Cholesky decomposition of a variance matrix
198+
* of the multivariate normal distribution
199+
* @return The log of the multivariate normal density.
200+
* @throw std::domain_error if LL' is not square, not symmetric,
201+
* or not semi-positive definite.
202+
* @tparam T_y Type of scalar.
203+
* @tparam T_loc Type of location.
204+
* @tparam T_covar Type of scale.
205+
*/
206+
template <bool propto, typename T_y, typename T_loc, typename T_covar,
207+
require_all_vector_vt<is_stan_scalar, T_y, T_loc>* = nullptr,
208+
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
209+
T_y, T_loc, T_covar>* = nullptr>
210+
return_type_t<T_y, T_loc, T_covar> multi_normal_cholesky_lpdf(
211+
const T_y& y, const T_loc& mu, const T_covar& L) {
212+
static const char* function = "multi_normal_cholesky_lpdf";
213+
using T_covar_elem = typename scalar_type<T_covar>::type;
214+
using T_return = return_type_t<T_y, T_loc, T_covar>;
215+
using T_partials_return = partials_return_t<T_y, T_loc, T_covar>;
216+
using matrix_partials_t
217+
= Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>;
218+
using vector_partials_t = Eigen::Matrix<T_partials_return, Eigen::Dynamic, 1>;
219+
using row_vector_partials_t
220+
= Eigen::Matrix<T_partials_return, 1, Eigen::Dynamic>;
221+
using T_y_ref = ref_type_t<T_y>;
222+
using T_mu_ref = ref_type_t<T_loc>;
223+
using T_L_ref = ref_type_t<T_covar>;
224+
225+
T_y_ref y_ref = y;
226+
T_mu_ref mu_ref = mu;
227+
T_L_ref L_ref = L;
228+
decltype(auto) y_val = as_value_column_vector_or_scalar(y_ref);
229+
decltype(auto) mu_val = as_value_column_vector_or_scalar(mu_ref);
230+
231+
const int size_y = y_ref.size();
232+
const int size_mu = mu_ref.size();
233+
234+
check_size_match(function, "Size of random variable", size_y,
235+
"size of location parameter", size_mu);
236+
check_size_match(function, "Size of random variable", size_y,
237+
"rows of covariance parameter", L.rows());
238+
check_size_match(function, "Size of random variable", size_y,
239+
"columns of covariance parameter", L.cols());
240+
241+
check_finite(function, "Location parameter", mu_val);
242+
check_not_nan(function, "Random variable", y_val);
243+
244+
if (unlikely(size_y == 0)) {
245+
return T_return(0);
246+
}
247+
248+
operands_and_partials<T_y_ref, T_mu_ref, T_L_ref> ops_partials(y_ref, mu_ref,
249+
L_ref);
250+
251+
T_partials_return logp(0);
252+
if (include_summand<propto>::value) {
253+
logp += NEG_LOG_SQRT_TWO_PI * size_y;
254+
}
255+
256+
if (include_summand<propto, T_y, T_loc, T_covar_elem>::value) {
257+
row_vector_partials_t half;
258+
vector_partials_t scaled_diff;
259+
260+
// If the covariance is not autodiff, we can avoid computing a matrix
261+
// inverse
262+
if (is_constant<T_covar_elem>::value) {
263+
matrix_partials_t L_val = value_of(L_ref);
264+
265+
half = mdivide_left_tri<Eigen::Lower>(L_val, y_val - mu_val).transpose();
266+
267+
scaled_diff = mdivide_right_tri<Eigen::Lower>(half, L_val).transpose();
268+
269+
if (include_summand<propto>::value) {
270+
logp -= sum(log(L_val.diagonal()));
271+
}
272+
} else {
273+
matrix_partials_t inv_L_val
274+
= mdivide_left_tri<Eigen::Lower>(value_of(L_ref));
275+
276+
half = (inv_L_val.template triangularView<Eigen::Lower>()
277+
* (y_val - mu_val).template cast<T_partials_return>())
278+
.transpose();
279+
280+
scaled_diff = (half * inv_L_val.template triangularView<Eigen::Lower>())
281+
.transpose();
282+
283+
logp += sum(log(inv_L_val.diagonal()));
284+
ops_partials.edge3_.partials_
285+
+= scaled_diff * half - inv_L_val.transpose();
286+
}
287+
288+
logp -= 0.5 * sum(dot_self(half));
289+
290+
if (!is_constant_all<T_y>::value) {
291+
ops_partials.edge1_.partials_ -= scaled_diff;
292+
}
293+
if (!is_constant_all<T_loc>::value) {
294+
ops_partials.edge2_.partials_ += scaled_diff;
156295
}
157296
}
158297

0 commit comments

Comments
 (0)