Skip to content

Commit 48c75fc

Browse files
Merge pull request #1780 from peterwicksstringfield/feature/elementwise_check
Feature/elementwise check
2 parents a0e4ba0 + 234d5e6 commit 48c75fc

5 files changed

Lines changed: 363 additions & 5 deletions

File tree

stan/math/opencl/matrix_cl.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -452,15 +452,14 @@ class matrix_cl<T, require_arithmetic_t<T>> {
452452
}
453453

454454
/**
455-
* Construct from a kernel generator expression. It evaluates the ixpression
455+
* Construct from a kernel generator expression. It evaluates the expression
456456
* into \c this.
457457
* @tparam Expr type of the expression
458458
* @param expression expression
459459
*/
460460
template <typename Expr,
461461
require_all_valid_expressions_and_none_scalar_t<Expr>* = nullptr>
462-
matrix_cl(const Expr& expresion); // NOLINT This constructor is intentionally
463-
// implicit
462+
matrix_cl(const Expr& expression); // NOLINT(runtime/explicit)
464463

465464
/** \ingroup opencl
466465
* Move assignment operator.
@@ -492,14 +491,14 @@ class matrix_cl<T, require_arithmetic_t<T>> {
492491
}
493492

494493
/**
495-
* Assignment of a kernel generator expression evaluates the ixpression into
494+
* Assignment of a kernel generator expression evaluates the expression into
496495
* \c this.
497496
* @tparam Expr type of the expression
498497
* @param expression expression
499498
*/
500499
template <typename Expr,
501500
require_all_valid_expressions_and_none_scalar_t<Expr>* = nullptr>
502-
matrix_cl<T>& operator=(const Expr& expresion);
501+
matrix_cl<T>& operator=(const Expr& expression);
503502

504503
private:
505504
/** \ingroup opencl

stan/math/prim/err.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include <stan/math/prim/err/constraint_tolerance.hpp>
4848
#include <stan/math/prim/err/domain_error.hpp>
4949
#include <stan/math/prim/err/domain_error_vec.hpp>
50+
#include <stan/math/prim/err/elementwise_check.hpp>
5051
#include <stan/math/prim/err/invalid_argument.hpp>
5152
#include <stan/math/prim/err/invalid_argument_vec.hpp>
5253
#include <stan/math/prim/err/is_cholesky_factor.hpp>
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
#ifndef STAN_MATH_PRIM_ERR_ELEMENTWISE_ERROR_CHECKER_HPP
2+
#define STAN_MATH_PRIM_ERR_ELEMENTWISE_ERROR_CHECKER_HPP
3+
4+
#include <stan/math/prim/err/throw_domain_error.hpp>
5+
#include <stan/math/prim/fun/get.hpp>
6+
#include <stan/math/prim/fun/size.hpp>
7+
#include <stan/math/prim/fun/value_of_rec.hpp>
8+
#include <stan/math/prim/meta/is_vector.hpp>
9+
#include <string>
10+
#include <sstream>
11+
#include <vector>
12+
13+
namespace stan {
14+
namespace math {
15+
namespace internal {
16+
17+
/** Apply an error check to a container, signal failure by throwing.
18+
* Apply a predicate like is_positive to the double underlying every scalar in a
19+
* container, throw an exception if the predicate fails for any double.
20+
* @tparam F type of predicate
21+
* @tparam E type of exception thrown
22+
*/
23+
template <typename F, typename E>
24+
class Checker {
25+
const F& is_good;
26+
const char* function;
27+
const char* name;
28+
const char* suffix;
29+
30+
/**
31+
* Throw an exception of type `E`.
32+
* The error message is the string inside the provided stringstream.
33+
* @param ss stringstream containing error message
34+
* @throws `E`
35+
*/
36+
void raise_error_ss(std::stringstream& ss) { throw E{ss.str()}; }
37+
38+
/**
39+
* Throw an exception of type `E`.
40+
* The error message is the concatenation of the string inside the provided
41+
* stringstream with all the provided messages.
42+
* @tparam M types of first message
43+
* @tparam Ms types of other messages
44+
* @param ss stringstream to accumulate error message in.
45+
* @param message a message to append to `ss`
46+
* @param messages more messages to append
47+
* @throws `E`
48+
*/
49+
template <typename M, typename... Ms>
50+
void raise_error_ss(std::stringstream& ss, const M& message,
51+
const Ms&... messages) {
52+
ss << message;
53+
raise_error_ss(ss, messages...);
54+
}
55+
56+
/**
57+
* Throw an exception of type `E`.
58+
* The error message is the concatenation of the provided messages.
59+
* @tparam Ms types of messages
60+
* @param messages a list of messages
61+
* @throws `E`
62+
*/
63+
template <typename... Ms>
64+
void raise_error(const Ms&... messages) {
65+
std::stringstream ss{};
66+
raise_error_ss(ss, messages...);
67+
}
68+
69+
public:
70+
/**
71+
* @param is_good predicate to check, must accept doubles and produce bools
72+
* @param function function name (for error messages)
73+
* @param name variable name (for error messages)
74+
* @param suffix message to print at end of error message
75+
*/
76+
Checker(const F& is_good, const char* function, const char* name,
77+
const char* suffix)
78+
: is_good(is_good), function(function), name(name), suffix(suffix) {}
79+
80+
/**
81+
* Check the scalar.
82+
* @tparam T type of scalar
83+
* @tparam Ms types of messages
84+
* @param x scalar
85+
* @param messages a list of messages to append to the error message
86+
* @throws `E` if the scalar fails the error check
87+
*/
88+
template <typename T, typename = require_stan_scalar_t<T>, typename... Ms>
89+
void check(const T& x, Ms... messages) {
90+
double xd = value_of_rec(x);
91+
if (!is_good(xd))
92+
raise_error(function, ": ", name, messages..., " is ", xd, suffix);
93+
}
94+
95+
/**
96+
* Check all the scalars inside the vector.
97+
* @tparam T type of vector
98+
* @tparam Ms types of messages
99+
* @param x vector
100+
* @param messages a list of messages to append to the error message
101+
* @throws `E` if any of the scalars fail the error check
102+
*/
103+
template <typename T, typename = require_vector_t<T>, typename = void,
104+
typename... Ms>
105+
void check(const T& x, Ms... messages) {
106+
for (size_t i = 0; i < stan::math::size(x); ++i)
107+
check(x[i], messages..., "[", i + 1, "]");
108+
}
109+
110+
/**
111+
* Check all the scalars inside the matrix.
112+
* @tparam Derived type of matrix
113+
* @tparam Ms types of messages
114+
* @param x matrix
115+
* @param messages a list of messages to append to the error message
116+
* @throws `E` if any of the scalars fail the error check
117+
*/
118+
template <typename Derived, typename... Ms>
119+
void check(const Eigen::DenseBase<Derived>& x, Ms... messages) {
120+
for (size_t n = 0; n < x.cols(); ++n)
121+
for (size_t m = 0; m < x.rows(); ++m)
122+
check(x(m, n), messages..., "[row=", m + 1, ", col=", n + 1, "]");
123+
}
124+
}; // namespace internal
125+
126+
/** Apply an error check to a container, signal failure with `false`.
127+
* Apply a predicate like is_positive to the double underlying every scalar in a
128+
* container, producing true if the predicate holds everywhere and `false` if it
129+
* fails anywhere.
130+
* @tparam F type of predicate
131+
*/
132+
template <typename F>
133+
class Iser {
134+
const F& is_good;
135+
136+
public:
137+
/**
138+
* @param is_good predicate to check, must accept doubles and produce bools
139+
*/
140+
explicit Iser(const F& is_good) : is_good(is_good) {}
141+
142+
/**
143+
* Check the scalar.
144+
* @tparam T type of scalar
145+
* @param x scalar
146+
* @return `false` if the scalar fails the error check
147+
*/
148+
template <typename T, typename = require_stan_scalar_t<T>>
149+
bool is(const T& x) {
150+
return is_good(value_of_rec(x));
151+
}
152+
153+
/**
154+
* Check all the scalars inside the container.
155+
* @tparam T type of scalar
156+
* @param x container
157+
* @return `false` if any of the scalars fail the error check
158+
*/
159+
template <typename T, typename = require_not_stan_scalar_t<T>,
160+
typename = void>
161+
bool is(const T& x) {
162+
for (size_t i = 0; i < stan::math::size(x); ++i)
163+
if (!is(stan::get(x, i)))
164+
return false;
165+
return true;
166+
}
167+
};
168+
169+
} // namespace internal
170+
171+
/**
172+
* Check that the predicate holds for the value of `x`, working elementwise on
173+
* containers. If `x` is a scalar, check the double underlying the scalar. If
174+
* `x` is a container, check each element inside `x`, recursively.
175+
* @tparam F type of predicate
176+
* @tparam T type of `x`
177+
* @param is_good predicate to check, must accept doubles and produce bools
178+
* @param function function name (for error messages)
179+
* @param name variable name (for error messages)
180+
* @param x variable to check, can be a scalar, a container of scalars, a
181+
* container of containers of scalars, etc
182+
* @param suffix message to print at end of error message
183+
* @throws `std::domain_error` if `is_good` returns `false` for the value
184+
* of any element in `x`
185+
*/
186+
template <typename F, typename T>
187+
inline void elementwise_check(const F& is_good, const char* function,
188+
const char* name, const T& x,
189+
const char* suffix) {
190+
internal::Checker<F, std::domain_error>{is_good, function, name, suffix}
191+
.check(x);
192+
}
193+
194+
/**
195+
* Check that the predicate holds for the value of `x`, working elementwise on
196+
* containers. If `x` is a scalar, check the double underlying the scalar. If
197+
* `x` is a container, check each element inside `x`, recursively.
198+
* @tparam F type of predicate
199+
* @tparam T type of `x`
200+
* @param is_good predicate to check, must accept doubles and produce bools
201+
* @param x variable to check, can be a scalar, a container of scalars, a
202+
* container of containers of scalars, etc
203+
* @return `false` if any of the scalars fail the error check
204+
*/
205+
template <typename F, typename T>
206+
inline bool elementwise_is(const F& is_good, const T& x) {
207+
return internal::Iser<F>{is_good}.is(x);
208+
}
209+
210+
} // namespace math
211+
} // namespace stan
212+
#endif

test/unit/math/prim/err/check_not_nan_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TEST(ErrorHandlingArr, CheckNotNanVectorized_one_indexed_message) {
5252
TEST(ErrorHandlingMatrix, checkNotNanEigenRow) {
5353
stan::math::vector_d y;
5454
y.resize(3);
55+
y << 1, 2, 3;
5556

5657
EXPECT_NO_THROW(stan::math::check_not_nan("checkNotNanEigenRow(%1)", "y", y));
5758
EXPECT_NO_THROW(stan::math::check_not_nan("checkNotNanEigenRow(%1)", "y", y));

0 commit comments

Comments
 (0)