Skip to content

Commit e51700b

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feature/parameter-pack-odes
2 parents a3f438b + e5f00e2 commit e51700b

39 files changed

+975
-346
lines changed

stan/math/fwd/fun/Eigen_NumTraits.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,32 @@ struct ScalarBinaryOpTraits<double, stan::math::fvar<T>, BinaryOp> {
8080
using ReturnType = stan::math::fvar<T>;
8181
};
8282

83+
/**
84+
* Traits specialization for Eigen binary operations for autodiff and
85+
* `int` arguments.
86+
*
87+
* @tparam T value and tangent type of autodiff variable
88+
* @tparam BinaryOp type of binary operation for which traits are
89+
* defined
90+
*/
91+
template <typename T, typename BinaryOp>
92+
struct ScalarBinaryOpTraits<stan::math::fvar<T>, int, BinaryOp> {
93+
using ReturnType = stan::math::fvar<T>;
94+
};
95+
96+
/**
97+
* Traits specialization for Eigen binary operations for `int` and
98+
* autodiff arguments.
99+
*
100+
* @tparam T value and tangent type of autodiff variable
101+
* @tparam BinaryOp type of binary operation for which traits are
102+
* defined
103+
*/
104+
template <typename T, typename BinaryOp>
105+
struct ScalarBinaryOpTraits<int, stan::math::fvar<T>, BinaryOp> {
106+
using ReturnType = stan::math::fvar<T>;
107+
};
108+
83109
/**
84110
* Traits specialization for Eigen binary operations for `double` and
85111
* complex autodiff arguments.
@@ -108,6 +134,32 @@ struct ScalarBinaryOpTraits<std::complex<stan::math::fvar<T>>, double,
108134
using ReturnType = std::complex<stan::math::fvar<T>>;
109135
};
110136

137+
/**
138+
* Traits specialization for Eigen binary operations for `int` and
139+
* complex autodiff arguments.
140+
*
141+
* @tparam T value and tangent type of autodiff variable
142+
* @tparam BinaryOp type of binary operation for which traits are
143+
* defined
144+
*/
145+
template <typename T, typename BinaryOp>
146+
struct ScalarBinaryOpTraits<int, std::complex<stan::math::fvar<T>>, BinaryOp> {
147+
using ReturnType = std::complex<stan::math::fvar<T>>;
148+
};
149+
150+
/**
151+
* Traits specialization for Eigen binary operations for complex
152+
* autodiff and `int` arguments.
153+
*
154+
* @tparam T value and tangent type of autodiff variable
155+
* @tparam BinaryOp type of binary operation for which traits are
156+
* defined
157+
*/
158+
template <typename T, typename BinaryOp>
159+
struct ScalarBinaryOpTraits<std::complex<stan::math::fvar<T>>, int, BinaryOp> {
160+
using ReturnType = std::complex<stan::math::fvar<T>>;
161+
};
162+
111163
/**
112164
* Traits specialization for Eigen binary operations for autodiff
113165
* and complex `double` arguments.

stan/math/prim/err/check_finite.hpp

Lines changed: 8 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,25 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/is_scal_finite.hpp>
6-
#include <stan/math/prim/err/throw_domain_error.hpp>
7-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
8-
#include <stan/math/prim/fun/Eigen.hpp>
9-
#include <stan/math/prim/fun/get.hpp>
10-
#include <stan/math/prim/fun/size.hpp>
11-
#include <stan/math/prim/fun/value_of.hpp>
12-
#include <stan/math/prim/fun/value_of_rec.hpp>
13-
#include <cmath>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
145

156
namespace stan {
167
namespace math {
17-
namespace internal {
18-
template <typename T_y>
19-
bool is_finite(const T_y& y) {
20-
return is_scal_finite(y);
21-
}
22-
23-
template <typename T_y, int R, int C>
24-
bool is_finite(const Eigen::Matrix<T_y, R, C>& y) {
25-
bool all = true;
26-
for (size_t n = 0; n < y.size(); ++n) {
27-
all &= is_finite(y(n));
28-
}
29-
return all;
30-
}
31-
32-
template <typename T_y>
33-
bool is_finite(const std::vector<T_y>& y) {
34-
bool all = true;
35-
for (size_t n = 0; n < stan::math::size(y); ++n) {
36-
all &= is_finite(y[n]);
37-
}
38-
return all;
39-
}
40-
} // namespace internal
418

429
/**
4310
* Check if <code>y</code> is finite.
4411
* This function is vectorized and will check each element of
4512
* <code>y</code>.
46-
* @tparam T_y Type of y
47-
* @param function Function name (for error messages)
48-
* @param name Variable name (for error messages)
49-
* @param y Variable to check
13+
* @tparam T_y type of y
14+
* @param function function name (for error messages)
15+
* @param name variable name (for error messages)
16+
* @param y variable to check
5017
* @throw <code>domain_error</code> if y is infinity, -infinity, or NaN
5118
*/
52-
template <typename T_y, require_stan_scalar_t<T_y>* = nullptr>
19+
template <typename T_y>
5320
inline void check_finite(const char* function, const char* name, const T_y& y) {
54-
if (!internal::is_finite(y)) {
55-
throw_domain_error(function, name, y, "is ", ", but must be finite!");
56-
}
57-
}
58-
59-
/**
60-
* Return <code>true</code> if all values in the std::vector are finite.
61-
*
62-
* @tparam T_y type of elements in the std::vector
63-
*
64-
* @param function name of function (for error messages)
65-
* @param name variable name (for error messages)
66-
* @param y std::vector to test
67-
* @return <code>true</code> if all values are finite
68-
**/
69-
template <typename T_y, require_stan_scalar_t<T_y>* = nullptr>
70-
inline void check_finite(const char* function, const char* name,
71-
const std::vector<T_y>& y) {
72-
for (size_t n = 0; n < stan::math::size(y); n++) {
73-
if (!internal::is_finite(stan::get(y, n))) {
74-
throw_domain_error_vec(function, name, y, n, "is ",
75-
", but must be finite!");
76-
}
77-
}
78-
}
79-
80-
/**
81-
* Return <code>true</code> is the specified matrix is finite.
82-
*
83-
* @tparam Derived Eigen derived type
84-
*
85-
* @param function name of function (for error messages)
86-
* @param name variable name (for error messages)
87-
* @param y matrix to test
88-
* @return <code>true</code> if the matrix is finite
89-
**/
90-
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
91-
inline void check_finite(const char* function, const char* name,
92-
const EigMat& y) {
93-
if (!value_of(y).allFinite()) {
94-
for (int n = 0; n < y.size(); ++n) {
95-
if (!std::isfinite(value_of_rec(y(n)))) {
96-
throw_domain_error_vec(function, name, y, n, "is ",
97-
", but must be finite!");
98-
}
99-
}
100-
}
101-
}
102-
103-
/**
104-
* Return <code>true</code> if all values in the std::vector are finite.
105-
*
106-
* @tparam T_y type of elements in the std::vector
107-
*
108-
* @param function name of function (for error messages)
109-
* @param name variable name (for error messages)
110-
* @param y std::vector to test
111-
* @return <code>true</code> if all values are finite
112-
**/
113-
template <typename T_y, require_not_stan_scalar_t<T_y>* = nullptr>
114-
inline void check_finite(const char* function, const char* name,
115-
const std::vector<T_y>& y) {
116-
for (size_t n = 0; n < stan::math::size(y); n++) {
117-
if (!internal::is_finite(stan::get(y, n))) {
118-
throw_domain_error(function, name, "", "", "is not finite!");
119-
}
120-
}
21+
auto is_good = [](const auto& y) { return std::isfinite(y); };
22+
elementwise_check(is_good, function, name, y, ", but must be finite!");
12123
}
12224

12325
} // namespace math
Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,11 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_NONNEGATIVE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_NONNEGATIVE_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
7-
#include <stan/math/prim/fun/get.hpp>
8-
#include <stan/math/prim/fun/size.hpp>
9-
#include <type_traits>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
105

116
namespace stan {
127
namespace math {
138

14-
namespace internal {
15-
template <typename T_y, bool is_vec>
16-
struct nonnegative {
17-
static void check(const char* function, const char* name, const T_y& y) {
18-
// have to use not is_unsigned. is_signed will be false for
19-
// floating point types that have no unsigned versions.
20-
if (!std::is_unsigned<T_y>::value && !(y >= 0)) {
21-
throw_domain_error(function, name, y, "is ", ", but must be >= 0!");
22-
}
23-
}
24-
};
25-
26-
template <typename T_y>
27-
struct nonnegative<T_y, true> {
28-
static void check(const char* function, const char* name, const T_y& y) {
29-
for (size_t n = 0; n < stan::math::size(y); n++) {
30-
if (!std::is_unsigned<typename value_type<T_y>::type>::value
31-
&& !(stan::get(y, n) >= 0)) {
32-
throw_domain_error_vec(function, name, y, n, "is ",
33-
", but must be >= 0!");
34-
}
35-
}
36-
}
37-
};
38-
} // namespace internal
39-
409
/**
4110
* Check if <code>y</code> is non-negative.
4211
* This function is vectorized and will check each element of <code>y</code>.
@@ -50,9 +19,10 @@ struct nonnegative<T_y, true> {
5019
template <typename T_y>
5120
inline void check_nonnegative(const char* function, const char* name,
5221
const T_y& y) {
53-
internal::nonnegative<T_y, is_vector_like<T_y>::value>::check(function, name,
54-
y);
22+
auto is_good = [](const auto& y) { return y >= 0; };
23+
elementwise_check(is_good, function, name, y, ", but must be >= 0!");
5524
}
25+
5626
} // namespace math
5727
} // namespace stan
5828
#endif

stan/math/prim/err/check_not_nan.hpp

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,11 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_NOT_NAN_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_NOT_NAN_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
7-
#include <stan/math/prim/fun/get.hpp>
8-
#include <stan/math/prim/fun/is_nan.hpp>
9-
#include <stan/math/prim/fun/size.hpp>
10-
#include <stan/math/prim/fun/value_of_rec.hpp>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
115

126
namespace stan {
137
namespace math {
148

15-
namespace internal {
16-
template <typename T_y, bool is_vec>
17-
struct not_nan {
18-
static void check(const char* function, const char* name, const T_y& y) {
19-
if (is_nan(value_of_rec(y))) {
20-
throw_domain_error(function, name, y, "is ", ", but must not be nan!");
21-
}
22-
}
23-
};
24-
25-
template <typename T_y>
26-
struct not_nan<T_y, true> {
27-
static void check(const char* function, const char* name, const T_y& y) {
28-
for (size_t n = 0; n < stan::math::size(y); n++) {
29-
if (is_nan(value_of_rec(stan::get(y, n)))) {
30-
throw_domain_error_vec(function, name, y, n, "is ",
31-
", but must not be nan!");
32-
}
33-
}
34-
}
35-
};
36-
} // namespace internal
37-
389
/**
3910
* Check if <code>y</code> is not <code>NaN</code>.
4011
* This function is vectorized and will check each element of
@@ -49,7 +20,8 @@ struct not_nan<T_y, true> {
4920
template <typename T_y>
5021
inline void check_not_nan(const char* function, const char* name,
5122
const T_y& y) {
52-
internal::not_nan<T_y, is_vector_like<T_y>::value>::check(function, name, y);
23+
auto is_good = [](const auto& y) { return !std::isnan(y); };
24+
elementwise_check(is_good, function, name, y, ", but must not be nan!");
5325
}
5426

5527
} // namespace math

stan/math/prim/err/check_positive.hpp

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,14 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_POSITIVE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_POSITIVE_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
75
#include <stan/math/prim/err/invalid_argument.hpp>
8-
#include <stan/math/prim/fun/get.hpp>
9-
#include <stan/math/prim/fun/size.hpp>
10-
#include <type_traits>
116
#include <string>
7+
#include <sstream>
128

139
namespace stan {
1410
namespace math {
1511

16-
namespace {
17-
18-
template <typename T_y, bool is_vec>
19-
struct positive {
20-
static void check(const char* function, const char* name, const T_y& y) {
21-
// have to use not is_unsigned. is_signed will be false
22-
// floating point types that have no unsigned versions.
23-
if (!std::is_unsigned<T_y>::value && !(y > 0)) {
24-
throw_domain_error(function, name, y, "is ", ", but must be > 0!");
25-
}
26-
}
27-
};
28-
29-
template <typename T_y>
30-
struct positive<T_y, true> {
31-
static void check(const char* function, const char* name, const T_y& y) {
32-
for (size_t n = 0; n < stan::math::size(y); n++) {
33-
if (!std::is_unsigned<typename value_type<T_y>::type>::value
34-
&& !(stan::get(y, n) > 0)) {
35-
throw_domain_error_vec(function, name, y, n, "is ",
36-
", but must be > 0!");
37-
}
38-
}
39-
}
40-
};
41-
42-
} // namespace
43-
4412
/**
4513
* Check if <code>y</code> is positive.
4614
* This function is vectorized and will check each element of
@@ -55,7 +23,8 @@ struct positive<T_y, true> {
5523
template <typename T_y>
5624
inline void check_positive(const char* function, const char* name,
5725
const T_y& y) {
58-
positive<T_y, is_vector_like<T_y>::value>::check(function, name, y);
26+
auto is_good = [](const auto& y) { return y > 0; };
27+
elementwise_check(is_good, function, name, y, ", but must be > 0!");
5928
}
6029

6130
/**
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_POSITIVE_FINITE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_POSITIVE_FINITE_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/check_positive.hpp>
6-
#include <stan/math/prim/err/check_finite.hpp>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
75

86
namespace stan {
97
namespace math {
@@ -17,15 +15,16 @@ namespace math {
1715
* @param name Variable name (for error messages)
1816
* @param y Variable to check
1917
* @throw <code>domain_error</code> if any element of y is not positive or
20-
* if any element of y is NaN.
18+
* if any element of y is NaN or infinity.
2119
*/
20+
2221
template <typename T_y>
2322
inline void check_positive_finite(const char* function, const char* name,
2423
const T_y& y) {
25-
check_positive(function, name, y);
26-
check_finite(function, name, y);
24+
auto is_good = [](const auto& y) { return y > 0 && std::isfinite(y); };
25+
elementwise_check(is_good, function, name, y,
26+
", but must be positive and finite!");
2727
}
28-
2928
} // namespace math
3029
} // namespace stan
3130
#endif

0 commit comments

Comments
 (0)