|
1 | 1 | #ifndef STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP |
2 | 2 | #define STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP |
3 | 3 |
|
4 | | -#include <stan/math/prim/meta.hpp> |
5 | | -#include <stan/math/prim/err/is_scal_finite.hpp> |
6 | | -#include <stan/math/prim/err/throw_domain_error.hpp> |
7 | | -#include <stan/math/prim/err/throw_domain_error_vec.hpp> |
8 | | -#include <stan/math/prim/fun/Eigen.hpp> |
9 | | -#include <stan/math/prim/fun/get.hpp> |
10 | | -#include <stan/math/prim/fun/size.hpp> |
11 | | -#include <stan/math/prim/fun/value_of.hpp> |
12 | | -#include <stan/math/prim/fun/value_of_rec.hpp> |
13 | | -#include <cmath> |
| 4 | +#include <stan/math/prim/err/elementwise_check.hpp> |
14 | 5 |
|
15 | 6 | namespace stan { |
16 | 7 | namespace math { |
17 | | -namespace internal { |
18 | | -template <typename T_y> |
19 | | -bool is_finite(const T_y& y) { |
20 | | - return is_scal_finite(y); |
21 | | -} |
22 | | - |
23 | | -template <typename T_y, int R, int C> |
24 | | -bool is_finite(const Eigen::Matrix<T_y, R, C>& y) { |
25 | | - bool all = true; |
26 | | - for (size_t n = 0; n < y.size(); ++n) { |
27 | | - all &= is_finite(y(n)); |
28 | | - } |
29 | | - return all; |
30 | | -} |
31 | | - |
32 | | -template <typename T_y> |
33 | | -bool is_finite(const std::vector<T_y>& y) { |
34 | | - bool all = true; |
35 | | - for (size_t n = 0; n < stan::math::size(y); ++n) { |
36 | | - all &= is_finite(y[n]); |
37 | | - } |
38 | | - return all; |
39 | | -} |
40 | | -} // namespace internal |
41 | 8 |
|
42 | 9 | /** |
43 | 10 | * Check if <code>y</code> is finite. |
44 | 11 | * This function is vectorized and will check each element of |
45 | 12 | * <code>y</code>. |
46 | | - * @tparam T_y Type of y |
47 | | - * @param function Function name (for error messages) |
48 | | - * @param name Variable name (for error messages) |
49 | | - * @param y Variable to check |
| 13 | + * @tparam T_y type of y |
| 14 | + * @param function function name (for error messages) |
| 15 | + * @param name variable name (for error messages) |
| 16 | + * @param y variable to check |
50 | 17 | * @throw <code>domain_error</code> if y is infinity, -infinity, or NaN |
51 | 18 | */ |
52 | | -template <typename T_y, require_stan_scalar_t<T_y>* = nullptr> |
| 19 | +template <typename T_y> |
53 | 20 | inline void check_finite(const char* function, const char* name, const T_y& y) { |
54 | | - if (!internal::is_finite(y)) { |
55 | | - throw_domain_error(function, name, y, "is ", ", but must be finite!"); |
56 | | - } |
57 | | -} |
58 | | - |
59 | | -/** |
60 | | - * Return <code>true</code> if all values in the std::vector are finite. |
61 | | - * |
62 | | - * @tparam T_y type of elements in the std::vector |
63 | | - * |
64 | | - * @param function name of function (for error messages) |
65 | | - * @param name variable name (for error messages) |
66 | | - * @param y std::vector to test |
67 | | - * @return <code>true</code> if all values are finite |
68 | | - **/ |
69 | | -template <typename T_y, require_stan_scalar_t<T_y>* = nullptr> |
70 | | -inline void check_finite(const char* function, const char* name, |
71 | | - const std::vector<T_y>& y) { |
72 | | - for (size_t n = 0; n < stan::math::size(y); n++) { |
73 | | - if (!internal::is_finite(stan::get(y, n))) { |
74 | | - throw_domain_error_vec(function, name, y, n, "is ", |
75 | | - ", but must be finite!"); |
76 | | - } |
77 | | - } |
78 | | -} |
79 | | - |
80 | | -/** |
81 | | - * Return <code>true</code> is the specified matrix is finite. |
82 | | - * |
83 | | - * @tparam Derived Eigen derived type |
84 | | - * |
85 | | - * @param function name of function (for error messages) |
86 | | - * @param name variable name (for error messages) |
87 | | - * @param y matrix to test |
88 | | - * @return <code>true</code> if the matrix is finite |
89 | | - **/ |
90 | | -template <typename EigMat, require_eigen_t<EigMat>* = nullptr> |
91 | | -inline void check_finite(const char* function, const char* name, |
92 | | - const EigMat& y) { |
93 | | - if (!value_of(y).allFinite()) { |
94 | | - for (int n = 0; n < y.size(); ++n) { |
95 | | - if (!std::isfinite(value_of_rec(y(n)))) { |
96 | | - throw_domain_error_vec(function, name, y, n, "is ", |
97 | | - ", but must be finite!"); |
98 | | - } |
99 | | - } |
100 | | - } |
101 | | -} |
102 | | - |
103 | | -/** |
104 | | - * Return <code>true</code> if all values in the std::vector are finite. |
105 | | - * |
106 | | - * @tparam T_y type of elements in the std::vector |
107 | | - * |
108 | | - * @param function name of function (for error messages) |
109 | | - * @param name variable name (for error messages) |
110 | | - * @param y std::vector to test |
111 | | - * @return <code>true</code> if all values are finite |
112 | | - **/ |
113 | | -template <typename T_y, require_not_stan_scalar_t<T_y>* = nullptr> |
114 | | -inline void check_finite(const char* function, const char* name, |
115 | | - const std::vector<T_y>& y) { |
116 | | - for (size_t n = 0; n < stan::math::size(y); n++) { |
117 | | - if (!internal::is_finite(stan::get(y, n))) { |
118 | | - throw_domain_error(function, name, "", "", "is not finite!"); |
119 | | - } |
120 | | - } |
| 21 | + auto is_good = [](const auto& y) { return std::isfinite(y); }; |
| 22 | + elementwise_check(is_good, function, name, y, ", but must be finite!"); |
121 | 23 | } |
122 | 24 |
|
123 | 25 | } // namespace math |
|
0 commit comments