Skip to content

Commit 6fb3ec0

Browse files
committed
Merge branch 'expression_test_framework' of https://github.com/bstatcomp/math into expression_test_framework
2 parents 0153c0e + 5f57686 commit 6fb3ec0

174 files changed

Lines changed: 975 additions & 783 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/FUNDING.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
github: stan-dev
2+
custom: https://mc-stan.org/support/

stan/math/fwd/fun/log_sum_exp.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,17 @@ inline fvar<T> log_sum_exp(const fvar<T>& x1, double x2) {
5353
*/
5454
template <typename T, require_container_st<is_fvar, T>* = nullptr>
5555
inline auto log_sum_exp(const T& x) {
56-
return apply_vector_unary<ref_type_t<T>>::reduce(to_ref(x), [&](const auto& v) {
57-
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
58-
using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
59-
mat_type vals = v.val();
60-
mat_type exp_vals = vals.array().exp();
56+
return apply_vector_unary<ref_type_t<T>>::reduce(
57+
to_ref(x), [&](const auto& v) {
58+
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;
59+
using mat_type = Eigen::Matrix<T_fvar_inner, -1, -1>;
60+
mat_type vals = v.val();
61+
mat_type exp_vals = vals.array().exp();
6162

62-
return fvar<T_fvar_inner>(
63-
log_sum_exp(vals), v.d().cwiseProduct(exp_vals).sum() / exp_vals.sum());
64-
});
63+
return fvar<T_fvar_inner>(
64+
log_sum_exp(vals),
65+
v.d().cwiseProduct(exp_vals).sum() / exp_vals.sum());
66+
});
6567
}
6668

6769
} // namespace math

stan/math/fwd/fun/quad_form_sym.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ inline promote_scalar_t<return_type_t<EigMat1, EigMat2>, EigMat2> quad_form_sym(
3333
check_multiplicable("quad_form_sym", "A", A, "B", B);
3434
check_symmetric("quad_form_sym", "A", A);
3535
const auto& B_ref = to_ref(B);
36-
promote_scalar_t<T_ret, EigMat2> ret(multiply(B_ref.transpose(), multiply(A, B_ref)));
36+
promote_scalar_t<T_ret, EigMat2> ret(
37+
multiply(B_ref.transpose(), multiply(A, B_ref)));
3738
return T_ret(0.5) * (ret + ret.transpose());
3839
}
3940

stan/math/fwd/functor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef STAN_MATH_FWD_FUNCTOR_HPP
22
#define STAN_MATH_FWD_FUNCTOR_HPP
33

4+
#include <stan/math/fwd/functor/apply_scalar_unary.hpp>
45
#include <stan/math/fwd/functor/gradient.hpp>
56
#include <stan/math/fwd/functor/hessian.hpp>
67
#include <stan/math/fwd/functor/jacobian.hpp>

stan/math/fwd/meta/apply_scalar_unary.hpp renamed to stan/math/fwd/functor/apply_scalar_unary.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#ifndef STAN_MATH_FWD_META_APPLY_SCALAR_UNARY_HPP
2-
#define STAN_MATH_FWD_META_APPLY_SCALAR_UNARY_HPP
1+
#ifndef STAN_MATH_FWD_FUNCTOR_APPLY_SCALAR_UNARY_HPP
2+
#define STAN_MATH_FWD_FUNCTOR_APPLY_SCALAR_UNARY_HPP
33

4-
#include <stan/math/prim/meta/apply_scalar_unary.hpp>
4+
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
55
#include <stan/math/fwd/core/fvar.hpp>
66

77
namespace stan {

stan/math/fwd/meta.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#ifndef STAN_MATH_FWD_META_HPP
22
#define STAN_MATH_FWD_META_HPP
33

4-
#include <stan/math/fwd/meta/apply_scalar_unary.hpp>
54
#include <stan/math/fwd/meta/is_fvar.hpp>
65
#include <stan/math/fwd/meta/partials_type.hpp>
76
#include <stan/math/fwd/meta/operands_and_partials.hpp>

stan/math/prim/eigen_plugins.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ struct val_Op{
5353
//Returns value from a vari*
5454
template<typename T = Scalar>
5555
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
56-
std::enable_if_t<std::is_pointer<T>::value, reverse_return_t<T>>
56+
std::enable_if_t<std::is_pointer<T>::value, const double&>
5757
operator()(T &v) const { return v->val_; }
5858

5959
//Returns value from a var
6060
template<typename T = Scalar>
6161
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
6262
std::enable_if_t<(!std::is_pointer<T>::value && !is_fvar<T>::value
63-
&& !std::is_arithmetic<T>::value), reverse_return_t<T>>
63+
&& !std::is_arithmetic<T>::value), const double&>
6464
operator()(T &v) const { return v.vi_->val_; }
6565

6666
//Returns value from an fvar

stan/math/prim/err/check_symmetric.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ namespace math {
2828
* main diagonal is <code>NaN</code>
2929
*/
3030
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
31-
inline void check_symmetric(
32-
const char* function, const char* name,
33-
const EigMat& y) {
31+
inline void check_symmetric(const char* function, const char* name,
32+
const EigMat& y) {
3433
check_square(function, name, y);
3534
using std::fabs;
3635

stan/math/prim/fun/Phi.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/prim/fun/erf.hpp>
88
#include <stan/math/prim/fun/erfc.hpp>
99
#include <stan/math/prim/fun/Phi.hpp>
10+
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
1011

1112
namespace stan {
1213
namespace math {

stan/math/prim/fun/Phi_approx.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/fun/inv_logit.hpp>
6+
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
67
#include <cmath>
78

89
namespace stan {

0 commit comments

Comments
 (0)