Skip to content

Commit 7eb7e44

Browse files
committed
change auto& to auto&& in make_holder function lambdas
1 parent ef592d9 commit 7eb7e44

13 files changed

Lines changed: 293 additions & 192 deletions

stan/math/opencl/prim/constraint/lb_constrain.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ template <typename T, typename L,
3131
require_all_kernel_expressions_t<L>* = nullptr>
3232
inline auto lb_constrain(T&& x, L&& lb) {
3333
return make_holder_cl(
34-
[](auto& x_, auto& lb_) {
34+
[](auto&& x_, auto& lb_) {
3535
return select(lb_ == NEGATIVE_INFTY, x_, lb_ + exp(x_));
3636
},
3737
std::forward<T>(x), std::forward<L>(lb));

stan/math/opencl/prim/constraint/ub_constrain.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ template <typename T, typename U,
3131
require_all_kernel_expressions_t<U>* = nullptr>
3232
inline auto ub_constrain(T&& x, U&& ub) {
3333
return make_holder_cl(
34-
[](auto& x_, auto& ub_) {
34+
[](auto&& x_, auto&& ub_) {
3535
return select(ub_ == INFTY, x_, ub_ - exp(x_));
3636
},
3737
std::forward<T>(x), std::forward<U>(ub));

stan/math/opencl/prim/symmetrize_from_lower_tri.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ template <typename T_x,
2121
inline auto symmetrize_from_lower_tri(T_x&& x) {
2222
check_square("symmetrize_from_lower_tri", "x", x);
2323
return make_holder_cl(
24-
[](auto& arg) {
24+
[](auto&& arg) {
2525
return select(row_index() < col_index(), transpose(arg), arg);
2626
},
2727
std::forward<T_x>(x));

stan/math/opencl/prim/symmetrize_from_upper_tri.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ template <typename T_x,
2121
inline auto symmetrize_from_upper_tri(T_x&& x) {
2222
check_square("symmetrize_from_upper_tri", "x", x);
2323
return make_holder_cl(
24-
[](auto& arg) {
24+
[](auto&& arg) {
2525
return select(col_index() < row_index(), transpose(arg), arg);
2626
},
2727
std::forward<T_x>(x));

stan/math/prim/fun/as_array_or_scalar.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ inline T as_array_or_scalar(T&& v) {
5454
template <typename T, typename = require_eigen_t<T>,
5555
require_not_eigen_array_t<T>* = nullptr>
5656
inline auto as_array_or_scalar(T&& v) {
57-
return make_holder([](auto& x) { return x.array(); }, std::forward<T>(v));
57+
return make_holder([](auto&& x) { return x.array(); }, std::forward<T>(v));
5858
}
5959

6060
/**
@@ -69,7 +69,7 @@ template <typename T, require_std_vector_t<T>* = nullptr,
6969
inline auto as_array_or_scalar(T&& v) {
7070
using T_map
7171
= Eigen::Map<const Eigen::Array<value_type_t<T>, Eigen::Dynamic, 1>>;
72-
return make_holder([](auto& x) { return T_map(x.data(), x.size()); },
72+
return make_holder([](auto&& x) { return T_map(x.data(), x.size()); },
7373
std::forward<T>(v));
7474
}
7575

stan/math/prim/fun/as_column_vector_or_scalar.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ inline T&& as_column_vector_or_scalar(T&& a) {
5757
template <typename T, require_eigen_row_vector_t<T>* = nullptr,
5858
require_not_eigen_col_vector_t<T>* = nullptr>
5959
inline auto as_column_vector_or_scalar(T&& a) {
60-
return make_holder([](auto& x) { return x.transpose(); }, std::forward<T>(a));
60+
return make_holder([](auto&& x) { return x.transpose(); }, std::forward<T>(a));
6161
}
6262

6363
/**
@@ -74,7 +74,7 @@ inline auto as_column_vector_or_scalar(T&& a) {
7474
= std::conditional_t<std::is_const<std::remove_reference_t<T>>::value,
7575
const plain_vector, plain_vector>;
7676
using T_map = Eigen::Map<optionally_const_vector>;
77-
return make_holder([](auto& x) { return T_map(x.data(), x.size()); },
77+
return make_holder([](auto&& x) { return T_map(x.data(), x.size()); },
7878
std::forward<T>(a));
7979
}
8080

stan/math/prim/fun/value_of.hpp

Lines changed: 114 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -10,100 +10,128 @@
1010

1111
namespace stan {
1212
namespace math {
13+
template <typename Tuple, require_tuple_t<Tuple>* = nullptr>
14+
inline auto value_of(Tuple&& tup);
15+
template <typename T, require_std_vector_t<T>* = nullptr,
16+
require_not_st_arithmetic<T>* = nullptr>
17+
inline auto value_of(const T& x);
18+
/**
19+
* Inputs that are arithmetic types or containers of airthmetric types
20+
* are returned from value_of unchanged
21+
*
22+
* @tparam T Input type
23+
* @param[in] x Input argument
24+
* @return Forwarded input argument
25+
**/
26+
template <typename T, require_st_arithmetic<T>* = nullptr>
27+
inline T value_of(T&& x) {
28+
return std::forward<T>(x);
29+
}
30+
31+
template <typename T, require_complex_t<T>* = nullptr,
32+
require_t<std::is_arithmetic<
33+
typename std::decay_t<T>::value_type>>* = nullptr>
34+
inline auto value_of(T&& x) {
35+
return std::forward<T>(x);
36+
}
37+
38+
template <
39+
typename T, require_complex_t<T>* = nullptr,
40+
require_not_arithmetic_t<typename std::decay_t<T>::value_type>* = nullptr>
41+
inline auto value_of(T&& x) {
42+
using inner_t = partials_type_t<typename std::decay_t<T>::value_type>;
43+
return std::complex<inner_t>{value_of(x.real()), value_of(x.imag())};
44+
}
1345

1446
/**
15-
* Return the value of the specified argument.
16-
* For types with a \ref base_type of double or int returns itself.
17-
* For types with a \ref base_type of \ref var or \ref fvar
18-
* the `value` member of their class is returned.
47+
* For std::vectors of non-arithmetic types, return a std::vector composed
48+
* of value_of applied to each element.
1949
*
20-
* So for `std::complex<fvar<var>>` this will return
21-
* a `std::complex<var>`. And for `std::vector<var>`
22-
* this will return a `std:vector<double>` whose
23-
* values are the `val_` members of the `var`s.
50+
* @tparam T Input element type
51+
* @param[in] x Input std::vector
52+
* @return std::vector of values
53+
**/
54+
template <typename T, require_std_vector_t<T>*,
55+
require_not_st_arithmetic<T>*>
56+
inline auto value_of(const T& x) {
57+
std::vector<plain_type_t<decltype(value_of(std::declval<value_type_t<T>>()))>>
58+
out;
59+
out.reserve(x.size());
60+
for (auto&& x_elem : x) {
61+
out.emplace_back(value_of(x_elem));
62+
}
63+
return out;
64+
}
65+
66+
/**
67+
* For Eigen matrices and expressions of non-arithmetic types, return an
68+
*expression that represents the Eigen::Matrix resulting from applying value_of
69+
*elementwise
2470
*
25-
* <p>See the <code>primitive_value</code> function to
26-
* extract values without casting to <code>double</code>.
27-
* @tparam T A container or scalar type
28-
* @param x The object whose values will be extracted.
29-
* @return An object whose \ref scalar_type
30-
*/
31-
template <typename T>
32-
inline constexpr decltype(auto) value_of(T&& x) {
33-
using val_t = std::decay_t<T>;
34-
if constexpr (is_tuple_v<val_t>) {
35-
return stan::math::apply(
36-
[](auto&&... args) {
37-
return partially_forward_as_tuple(
38-
value_of(std::forward<decltype(args)>(args))...);
39-
},
40-
std::forward<T>(x));
41-
} else {
42-
constexpr bool is_float_or_int
43-
= std::is_floating_point_v<val_t> || std::is_integral_v<val_t>;
44-
constexpr bool is_base_float_or_int
45-
= std::is_floating_point_v<
46-
base_type_t<val_t>> || std::is_integral_v<base_type_t<val_t>>;
47-
if constexpr (is_float_or_int) {
48-
return val_t{x};
49-
} else if constexpr (is_base_float_or_int && !is_eigen_v<val_t>) {
50-
if constexpr (std::is_rvalue_reference_v<T&&>) {
51-
return plain_type_t<T>(std::forward<T>(x));
52-
} else {
53-
return x;
54-
}
55-
} else if constexpr (is_complex<val_t>::value) {
56-
return std::complex<double>{value_of(x.real()), value_of(x.imag())};
57-
} else if constexpr (is_std_vector_v<val_t>) {
58-
std::vector<
59-
plain_type_t<decltype(value_of(std::declval<value_type_t<T>>()))>>
60-
ret;
61-
ret.reserve(x.size());
62-
for (auto&& x_i : x) {
63-
ret.push_back(value_of(std::forward<decltype(x_i)>(x_i)));
64-
}
65-
return ret;
66-
} else if constexpr (is_eigen_v<val_t>) {
67-
/**
68-
* Because of lifetimes of eigen expressions we have to account
69-
* for a few choices.
70-
* 1. If a base type of double
71-
* a. and it is an rvalue reference and not a holder
72-
* - Wrap it in a holder and forward the object
73-
* b. and it is an rvalue holder
74-
* - pass x to decayed holder
75-
* c. it is an rvalue
76-
* - pass x
77-
* 2. Any other value type that does not have a base type of double
78-
* - wrap it ina a holder with an unary expr to pull out the values
79-
*/
80-
if constexpr (is_base_float_or_int) {
81-
if constexpr (std::is_rvalue_reference_v<T&&> && !is_holder_v<val_t>) {
82-
return make_holder([](auto&& x_inner) { return x_inner; },
83-
std::forward<T>(x));
84-
} else if constexpr (is_holder_v<val_t>) {
85-
return std::decay_t<T>(std::forward<T>(x));
86-
} else {
87-
return x;
88-
}
89-
} else {
90-
return make_holder(
91-
[](auto& m) {
92-
return m.unaryExpr([](auto x_i) { return value_of(x_i); });
93-
},
94-
std::forward<T>(x));
95-
}
96-
} else if constexpr (is_var_v<val_t>) {
97-
return x.vi_->val_;
98-
} else if constexpr (is_fvar<val_t>::value) {
99-
return x.val();
100-
} else {
101-
static_assert(1, "Type not caught!");
71+
* @tparam EigMat type of the matrix
72+
*
73+
* @param[in] M Matrix to be converted
74+
* @return Matrix of values
75+
**/
76+
template <typename EigMat, require_eigen_dense_base_t<EigMat>* = nullptr,
77+
require_not_st_arithmetic<EigMat>* = nullptr>
78+
inline auto value_of(EigMat&& M) {
79+
return make_holder(
80+
[](auto&& a) {
81+
return a.unaryExpr([](const auto& scal) { return value_of(scal); });
82+
},
83+
std::forward<EigMat>(M));
84+
}
85+
86+
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
87+
require_not_st_arithmetic<EigMat>* = nullptr>
88+
inline auto value_of(EigMat&& M) {
89+
auto&& M_ref = to_ref(M);
90+
using scalar_t = decltype(value_of(std::declval<value_type_t<EigMat>>()));
91+
promote_scalar_t<scalar_t, plain_type_t<EigMat>> ret(M_ref.rows(),
92+
M_ref.cols());
93+
ret.reserve(M_ref.nonZeros());
94+
for (int k = 0; k < M_ref.outerSize(); ++k) {
95+
for (typename std::decay_t<EigMat>::InnerIterator it(M_ref, k); it; ++it) {
96+
ret.insert(it.row(), it.col()) = value_of(it.valueRef());
10297
}
10398
}
99+
ret.makeCompressed();
100+
return ret;
101+
}
102+
103+
/*
104+
* For Sparse Eigen matrices and expressions of non-arithmetic types, return an
105+
*expression that represents the Eigen::Matrix resulting from applying value_of
106+
*elementwise
107+
*
108+
* @tparam EigMat type of the matrix
109+
*
110+
* @param[in] M Matrix to be converted
111+
* @return Matrix of values
112+
*/
113+
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
114+
require_st_arithmetic<EigMat>* = nullptr>
115+
inline auto value_of(EigMat&& M) {
116+
return std::forward<EigMat>(M);
117+
}
118+
119+
/**
120+
* Converts a tuples elements scalar types from ad to their child type.
121+
* @tparam Tuple type of tuple
122+
* @param[in] tup tuple to be converted
123+
*/
124+
template <typename Tuple, require_tuple_t<Tuple>*>
125+
inline auto value_of(Tuple&& tup) {
126+
return stan::math::apply(
127+
[](auto&&... args) {
128+
return partially_forward_as_tuple(
129+
value_of(std::forward<decltype(args)>(args))...);
130+
},
131+
std::forward<Tuple>(tup));
104132
}
105133

106134
} // namespace math
107135
} // namespace stan
108136

109-
#endif
137+
#endif

0 commit comments

Comments
 (0)