Skip to content

Commit 4639bc4

Browse files
committed
simplify value_of and value_of_rec to just be one function with if constexpr dispatch
1 parent 4c18336 commit 4639bc4

2 files changed

Lines changed: 64 additions & 59 deletions

File tree

stan/math/prim/fun/value_of.hpp

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,43 +32,46 @@ template <typename T>
3232
inline constexpr decltype(auto) value_of(T&& x) {
3333
using val_t = std::decay_t<T>;
3434
// ints are cast to doubles, types with base double are passed along
35-
if constexpr (std::is_integral_v<val_t> || std::is_floating_point_v<val_t>) {
36-
return val_t{x};
37-
} else if constexpr (std::is_floating_point_v<base_type_t<
38-
val_t>> || std::is_integral_v<base_type_t<val_t>>) {
39-
if constexpr (std::is_rvalue_reference_v<T&&>) {
40-
return plain_type_t<T>(std::forward<T>(x));
41-
} else {
42-
return x;
43-
}
44-
} else if constexpr (is_complex<val_t>::value) {
45-
return std::complex<double>{value_of(x.real()), value_of(x.imag())};
46-
} else if constexpr (is_std_vector_v<val_t>) {
47-
std::vector<
48-
plain_type_t<decltype(value_of(std::declval<value_type_t<T>>()))>>
49-
ret;
50-
ret.reserve(x.size());
51-
for (auto&& x_i : x) {
52-
ret.push_back(value_of(std::forward<decltype(x_i)>(x_i)));
53-
}
54-
return ret;
55-
} else if constexpr (is_eigen_v<val_t>) {
56-
return make_holder(
57-
[](auto& m) {
58-
return m.unaryExpr([](auto x_i) { return value_of(x_i); });
59-
},
60-
std::forward<T>(x));
61-
} else if constexpr (is_tuple_v<val_t>) {
35+
if constexpr (is_tuple_v<val_t>) {
6236
return stan::math::apply(
6337
[](auto&&... args) {
6438
return partially_forward_as_tuple(
6539
value_of(std::forward<decltype(args)>(args))...);
6640
},
6741
std::forward<T>(x));
68-
} else if constexpr (is_var_v<val_t>) {
69-
return x.vi_->val_;
70-
} else if constexpr (is_fvar<val_t>::value) {
71-
return x.val();
42+
} else {
43+
constexpr bool is_base_float_or_int = std::is_floating_point_v<base_type_t<val_t>> || std::is_integral_v<base_type_t<val_t>>;
44+
if constexpr (std::is_integral_v<val_t> || std::is_floating_point_v<val_t>) {
45+
return val_t{x};
46+
} else if constexpr (is_base_float_or_int) {
47+
if constexpr (std::is_rvalue_reference_v<T&&>) {
48+
return plain_type_t<T>(std::forward<T>(x));
49+
} else {
50+
return x;
51+
}
52+
} else if constexpr (is_complex<val_t>::value) {
53+
return std::complex<double>{value_of(x.real()), value_of(x.imag())};
54+
} else if constexpr (is_std_vector_v<val_t>) {
55+
std::vector<
56+
plain_type_t<decltype(value_of(std::declval<value_type_t<T>>()))>>
57+
ret;
58+
ret.reserve(x.size());
59+
for (auto&& x_i : x) {
60+
ret.push_back(value_of(std::forward<decltype(x_i)>(x_i)));
61+
}
62+
return ret;
63+
} else if constexpr (is_eigen_v<val_t>) {
64+
return make_holder(
65+
[](auto& m) {
66+
return m.unaryExpr([](auto x_i) { return value_of(x_i); });
67+
},
68+
std::forward<T>(x));
69+
} else if constexpr (is_var_v<val_t>) {
70+
return x.vi_->val_;
71+
} else if constexpr (is_fvar<val_t>::value) {
72+
return x.val();
73+
}
74+
7275
}
7376
}
7477

stan/math/prim/fun/value_of_rec.hpp

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,40 +29,42 @@ template <typename T>
2929
inline constexpr decltype(auto) value_of_rec(T&& x) {
3030
using val_t = std::decay_t<T>;
3131
// ints are cast to doubles, types with base double are passed along
32-
if constexpr (std::is_integral_v<val_t> || std::is_floating_point_v<val_t>) {
33-
return static_cast<double>(x);
34-
} else if constexpr (std::is_floating_point_v<base_type_t<val_t>>) {
35-
if constexpr (std::is_rvalue_reference_v<T&&>) {
36-
return plain_type_t<T>(std::forward<T>(x));
37-
} else {
38-
return x;
39-
}
40-
} else if constexpr (is_complex<val_t>::value) {
41-
return std::complex<double>{value_of_rec(x.real()), value_of_rec(x.imag())};
42-
} else if constexpr (is_std_vector_v<val_t>) {
43-
promote_scalar_t<double, val_t> ret;
44-
ret.reserve(x.size());
45-
for (auto&& x_i : x) {
46-
ret.push_back(value_of_rec(std::forward<decltype(x_i)>(x_i)));
47-
}
48-
return ret;
49-
} else if constexpr (is_eigen_v<val_t>) {
50-
return make_holder(
51-
[](auto& m) {
52-
return m.unaryExpr([](auto x_i) { return value_of_rec(x_i); });
53-
},
54-
std::forward<T>(x));
55-
} else if constexpr (is_tuple_v<val_t>) {
32+
if constexpr (is_tuple_v<val_t>) {
5633
return stan::math::apply(
5734
[](auto&&... args) {
5835
return partially_forward_as_tuple(
5936
value_of_rec(std::forward<decltype(args)>(args))...);
6037
},
6138
std::forward<T>(x));
62-
} else if constexpr (is_var_v<val_t>) {
63-
return x.vi_->val_;
64-
} else if constexpr (is_fvar<val_t>::value) {
65-
return value_of_rec(x.val());
39+
} else {
40+
if constexpr (std::is_integral_v<val_t> || std::is_floating_point_v<val_t>) {
41+
return static_cast<double>(x);
42+
} else if constexpr (std::is_floating_point_v<base_type_t<val_t>>) {
43+
if constexpr (std::is_rvalue_reference_v<T&&>) {
44+
return plain_type_t<T>(std::forward<T>(x));
45+
} else {
46+
return x;
47+
}
48+
} else if constexpr (is_complex<val_t>::value) {
49+
return std::complex<double>{value_of_rec(x.real()), value_of_rec(x.imag())};
50+
} else if constexpr (is_std_vector_v<val_t>) {
51+
promote_scalar_t<double, val_t> ret;
52+
ret.reserve(x.size());
53+
for (auto&& x_i : x) {
54+
ret.push_back(value_of_rec(std::forward<decltype(x_i)>(x_i)));
55+
}
56+
return ret;
57+
} else if constexpr (is_eigen_v<val_t>) {
58+
return make_holder(
59+
[](auto& m) {
60+
return m.unaryExpr([](auto x_i) { return value_of_rec(x_i); });
61+
},
62+
std::forward<T>(x));
63+
} else if constexpr (is_var_v<val_t>) {
64+
return x.vi_->val_;
65+
} else if constexpr (is_fvar<val_t>::value) {
66+
return value_of_rec(x.val());
67+
}
6668
}
6769
}
6870

0 commit comments

Comments
 (0)