Skip to content

Commit e5f00e2

Browse files
authored
Merge pull request #1798 from peterwicksstringfield/feature/elementwise_checks_part_2
feature/elementwise_check (2)
2 parents 0ee81d6 + 09336d0 commit e5f00e2

27 files changed

+325
-287
lines changed

stan/math/prim/err/check_finite.hpp

Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,27 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_FINITE_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/is_scal_finite.hpp>
6-
#include <stan/math/prim/err/throw_domain_error.hpp>
7-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
8-
#include <stan/math/prim/fun/Eigen.hpp>
9-
#include <stan/math/prim/fun/get.hpp>
10-
#include <stan/math/prim/fun/size.hpp>
11-
#include <stan/math/prim/fun/value_of.hpp>
12-
#include <stan/math/prim/fun/value_of_rec.hpp>
13-
#include <cmath>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
145

156
namespace stan {
167
namespace math {
178

18-
namespace internal {
19-
template <typename T_y, bool is_vec>
20-
struct finite {
21-
static void check(const char* function, const char* name, const T_y& y) {
22-
if (!is_scal_finite(y)) {
23-
throw_domain_error(function, name, y, "is ", ", but must be finite!");
24-
}
25-
}
26-
};
27-
28-
template <typename T_y>
29-
struct finite<T_y, true> {
30-
static void check(const char* function, const char* name, const T_y& y) {
31-
for (size_t n = 0; n < stan::math::size(y); n++) {
32-
if (!is_scal_finite(stan::get(y, n))) {
33-
throw_domain_error_vec(function, name, y, n, "is ",
34-
", but must be finite!");
35-
}
36-
}
37-
}
38-
};
39-
} // namespace internal
40-
419
/**
4210
* Check if <code>y</code> is finite.
4311
* This function is vectorized and will check each element of
4412
* <code>y</code>.
45-
* @tparam T_y Type of y
46-
* @param function Function name (for error messages)
47-
* @param name Variable name (for error messages)
48-
* @param y Variable to check
13+
* @tparam T_y type of y
14+
* @param function function name (for error messages)
15+
* @param name variable name (for error messages)
16+
* @param y variable to check
4917
* @throw <code>domain_error</code> if y is infinity, -infinity, or NaN
5018
*/
5119
template <typename T_y>
5220
inline void check_finite(const char* function, const char* name, const T_y& y) {
53-
internal::finite<T_y, is_vector_like<T_y>::value>::check(function, name, y);
21+
auto is_good = [](const auto& y) { return std::isfinite(y); };
22+
elementwise_check(is_good, function, name, y, ", but must be finite!");
5423
}
5524

56-
/**
57-
* Return <code>true</code> is the specified matrix is finite.
58-
*
59-
* @tparam T type of elements in the matrix
60-
* @tparam R number of rows, can be Eigen::Dynamic
61-
* @tparam C number of columns, can be Eigen::Dynamic
62-
*
63-
* @param function name of function (for error messages)
64-
* @param name variable name (for error messages)
65-
* @param y matrix to test
66-
* @return <code>true</code> if the matrix is finite
67-
**/
68-
namespace internal {
69-
template <typename T, int R, int C>
70-
struct finite<Eigen::Matrix<T, R, C>, true> {
71-
static void check(const char* function, const char* name,
72-
const Eigen::Matrix<T, R, C>& y) {
73-
if (!value_of(y).allFinite()) {
74-
for (int n = 0; n < y.size(); ++n) {
75-
if (!std::isfinite(value_of_rec(y(n)))) {
76-
throw_domain_error_vec(function, name, y, n, "is ",
77-
", but must be finite!");
78-
}
79-
}
80-
}
81-
}
82-
};
83-
84-
} // namespace internal
8525
} // namespace math
8626
} // namespace stan
8727

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,11 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_NONNEGATIVE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_NONNEGATIVE_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
7-
#include <stan/math/prim/fun/get.hpp>
8-
#include <stan/math/prim/fun/size.hpp>
9-
#include <type_traits>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
105

116
namespace stan {
127
namespace math {
138

14-
namespace internal {
15-
template <typename T_y, bool is_vec>
16-
struct nonnegative {
17-
static void check(const char* function, const char* name, const T_y& y) {
18-
// have to use not is_unsigned. is_signed will be false for
19-
// floating point types that have no unsigned versions.
20-
if (!std::is_unsigned<T_y>::value && !(y >= 0)) {
21-
throw_domain_error(function, name, y, "is ", ", but must be >= 0!");
22-
}
23-
}
24-
};
25-
26-
template <typename T_y>
27-
struct nonnegative<T_y, true> {
28-
static void check(const char* function, const char* name, const T_y& y) {
29-
for (size_t n = 0; n < stan::math::size(y); n++) {
30-
if (!std::is_unsigned<typename value_type<T_y>::type>::value
31-
&& !(stan::get(y, n) >= 0)) {
32-
throw_domain_error_vec(function, name, y, n, "is ",
33-
", but must be >= 0!");
34-
}
35-
}
36-
}
37-
};
38-
} // namespace internal
39-
409
/**
4110
* Check if <code>y</code> is non-negative.
4211
* This function is vectorized and will check each element of <code>y</code>.
@@ -50,9 +19,10 @@ struct nonnegative<T_y, true> {
5019
template <typename T_y>
5120
inline void check_nonnegative(const char* function, const char* name,
5221
const T_y& y) {
53-
internal::nonnegative<T_y, is_vector_like<T_y>::value>::check(function, name,
54-
y);
22+
auto is_good = [](const auto& y) { return y >= 0; };
23+
elementwise_check(is_good, function, name, y, ", but must be >= 0!");
5524
}
25+
5626
} // namespace math
5727
} // namespace stan
5828
#endif

stan/math/prim/err/check_not_nan.hpp

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,11 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_NOT_NAN_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_NOT_NAN_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
7-
#include <stan/math/prim/fun/get.hpp>
8-
#include <stan/math/prim/fun/is_nan.hpp>
9-
#include <stan/math/prim/fun/size.hpp>
10-
#include <stan/math/prim/fun/value_of_rec.hpp>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
115

126
namespace stan {
137
namespace math {
148

15-
namespace internal {
16-
template <typename T_y, bool is_vec>
17-
struct not_nan {
18-
static void check(const char* function, const char* name, const T_y& y) {
19-
if (is_nan(value_of_rec(y))) {
20-
throw_domain_error(function, name, y, "is ", ", but must not be nan!");
21-
}
22-
}
23-
};
24-
25-
template <typename T_y>
26-
struct not_nan<T_y, true> {
27-
static void check(const char* function, const char* name, const T_y& y) {
28-
for (size_t n = 0; n < stan::math::size(y); n++) {
29-
if (is_nan(value_of_rec(stan::get(y, n)))) {
30-
throw_domain_error_vec(function, name, y, n, "is ",
31-
", but must not be nan!");
32-
}
33-
}
34-
}
35-
};
36-
} // namespace internal
37-
389
/**
3910
* Check if <code>y</code> is not <code>NaN</code>.
4011
* This function is vectorized and will check each element of
@@ -49,7 +20,8 @@ struct not_nan<T_y, true> {
4920
template <typename T_y>
5021
inline void check_not_nan(const char* function, const char* name,
5122
const T_y& y) {
52-
internal::not_nan<T_y, is_vector_like<T_y>::value>::check(function, name, y);
23+
auto is_good = [](const auto& y) { return !std::isnan(y); };
24+
elementwise_check(is_good, function, name, y, ", but must not be nan!");
5325
}
5426

5527
} // namespace math

stan/math/prim/err/check_positive.hpp

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,14 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_POSITIVE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_POSITIVE_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/throw_domain_error.hpp>
6-
#include <stan/math/prim/err/throw_domain_error_vec.hpp>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
75
#include <stan/math/prim/err/invalid_argument.hpp>
8-
#include <stan/math/prim/fun/get.hpp>
9-
#include <stan/math/prim/fun/size.hpp>
10-
#include <type_traits>
116
#include <string>
7+
#include <sstream>
128

139
namespace stan {
1410
namespace math {
1511

16-
namespace {
17-
18-
template <typename T_y, bool is_vec>
19-
struct positive {
20-
static void check(const char* function, const char* name, const T_y& y) {
21-
// have to use not is_unsigned. is_signed will be false
22-
// floating point types that have no unsigned versions.
23-
if (!std::is_unsigned<T_y>::value && !(y > 0)) {
24-
throw_domain_error(function, name, y, "is ", ", but must be > 0!");
25-
}
26-
}
27-
};
28-
29-
template <typename T_y>
30-
struct positive<T_y, true> {
31-
static void check(const char* function, const char* name, const T_y& y) {
32-
for (size_t n = 0; n < stan::math::size(y); n++) {
33-
if (!std::is_unsigned<typename value_type<T_y>::type>::value
34-
&& !(stan::get(y, n) > 0)) {
35-
throw_domain_error_vec(function, name, y, n, "is ",
36-
", but must be > 0!");
37-
}
38-
}
39-
}
40-
};
41-
42-
} // namespace
43-
4412
/**
4513
* Check if <code>y</code> is positive.
4614
* This function is vectorized and will check each element of
@@ -55,7 +23,8 @@ struct positive<T_y, true> {
5523
template <typename T_y>
5624
inline void check_positive(const char* function, const char* name,
5725
const T_y& y) {
58-
positive<T_y, is_vector_like<T_y>::value>::check(function, name, y);
26+
auto is_good = [](const auto& y) { return y > 0; };
27+
elementwise_check(is_good, function, name, y, ", but must be > 0!");
5928
}
6029

6130
/**
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
#ifndef STAN_MATH_PRIM_ERR_CHECK_POSITIVE_FINITE_HPP
22
#define STAN_MATH_PRIM_ERR_CHECK_POSITIVE_FINITE_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/err/check_positive.hpp>
6-
#include <stan/math/prim/err/check_finite.hpp>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
75

86
namespace stan {
97
namespace math {
@@ -17,15 +15,16 @@ namespace math {
1715
* @param name Variable name (for error messages)
1816
* @param y Variable to check
1917
* @throw <code>domain_error</code> if any element of y is not positive or
20-
* if any element of y is NaN.
18+
* if any element of y is NaN or infinity.
2119
*/
20+
2221
template <typename T_y>
2322
inline void check_positive_finite(const char* function, const char* name,
2423
const T_y& y) {
25-
check_positive(function, name, y);
26-
check_finite(function, name, y);
24+
auto is_good = [](const auto& y) { return y > 0 && std::isfinite(y); };
25+
elementwise_check(is_good, function, name, y,
26+
", but must be positive and finite!");
2727
}
28-
2928
} // namespace math
3029
} // namespace stan
3130
#endif
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef STAN_MATH_PRIM_ERR_IS_NONNEGATIVE_HPP
2+
#define STAN_MATH_PRIM_ERR_IS_NONNEGATIVE_HPP
3+
4+
#include <stan/math/prim/err/elementwise_check.hpp>
5+
6+
namespace stan {
7+
namespace math {
8+
9+
/**
10+
* Return <code>true</code> if <code>y</code> is nonnegative.
11+
* This function is vectorized and will check each element of
12+
* <code>y</code>.
13+
* @tparam T_y Type of y
14+
* @param y Variable to check
15+
* @return <code>true</code> if every element of y is >=0.
16+
*/
17+
template <typename T_y>
18+
inline bool is_nonnegative(const T_y& y) {
19+
auto is_good = [](const auto& y) { return y >= 0; };
20+
return elementwise_is(is_good, y);
21+
}
22+
} // namespace math
23+
} // namespace stan
24+
#endif

stan/math/prim/err/is_not_nan.hpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
#ifndef STAN_MATH_PRIM_ERR_IS_NOT_NAN_HPP
22
#define STAN_MATH_PRIM_ERR_IS_NOT_NAN_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/fun/get.hpp>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
65
#include <stan/math/prim/fun/is_nan.hpp>
7-
#include <stan/math/prim/fun/size.hpp>
8-
#include <stan/math/prim/fun/value_of_rec.hpp>
96

107
namespace stan {
118
namespace math {
@@ -15,18 +12,14 @@ namespace math {
1512
* This function is vectorized and will check each element of
1613
* <code>y</code>. If no element is <code>NaN</code>, this
1714
* function will return <code>true</code>.
18-
* @tparam T_y Type of y
19-
* @param y Variable to check
15+
* @tparam T_y type of y
16+
* @param y variable to check
2017
* @return <code>true</code> if no element of y is NaN
2118
*/
2219
template <typename T_y>
2320
inline bool is_not_nan(const T_y& y) {
24-
for (size_t n = 0; n < stan::math::size(y); ++n) {
25-
if (is_nan(value_of_rec(stan::get(y, n)))) {
26-
return false;
27-
}
28-
}
29-
return true;
21+
auto is_good = [](const auto& y) { return !std::isnan(y); };
22+
return elementwise_is(is_good, y);
3023
}
3124

3225
} // namespace math

stan/math/prim/err/is_positive.hpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
#ifndef STAN_MATH_PRIM_ERR_IS_POSITIVE_HPP
22
#define STAN_MATH_PRIM_ERR_IS_POSITIVE_HPP
33

4-
#include <stan/math/prim/meta.hpp>
5-
#include <stan/math/prim/fun/get.hpp>
6-
#include <stan/math/prim/fun/size.hpp>
4+
#include <stan/math/prim/err/elementwise_check.hpp>
75

86
namespace stan {
97
namespace math {
@@ -14,16 +12,12 @@ namespace math {
1412
* <code>y</code>.
1513
* @tparam T_y Type of y
1614
* @param y Variable to check
17-
* @return <code>true</code> if vector contains only positive elements
15+
* @return <code>true</code> if y contains only positive elements
1816
*/
1917
template <typename T_y>
2018
inline bool is_positive(const T_y& y) {
21-
for (size_t n = 0; n < stan::math::size(y); ++n) {
22-
if (!(stan::get(y, n) > 0)) {
23-
return false;
24-
}
25-
}
26-
return true;
19+
auto is_good = [](const auto& y) { return y > 0; };
20+
return elementwise_is(is_good, y);
2721
}
2822

2923
/**

0 commit comments

Comments
 (0)