Skip to content

Commit 5598ed1

Browse files
authored
Merge pull request #1641 from stan-dev/feature/parameter-pack-odes
Clicked my button for the day
2 parents 6d63a44 + bec1c2a commit 5598ed1

File tree

71 files changed

+8902
-2534
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+8902
-2534
lines changed

Jenkinsfile

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -248,31 +248,31 @@ pipeline {
248248
}
249249
post { always { retry(3) { deleteDir() } } }
250250
}
251-
stage('OpenCL tests async') {
252-
agent { label "gpu-async" }
253-
when {
254-
expression {
255-
runGpuAsync
256-
}
257-
}
258-
steps {
259-
deleteDir()
260-
unstash 'MathSetup'
261-
sh "echo CXX=${env.CXX} -Werror > make/local"
262-
sh "echo STAN_OPENCL=true>> make/local"
263-
sh "echo OPENCL_PLATFORM_ID=0>> make/local"
264-
sh "echo OPENCL_DEVICE_ID=${OPENCL_DEVICE_ID}>> make/local"
265-
sh "make -j${env.PARALLEL} test-headers"
266-
runTests("test/unit/math/opencl")
267-
runTests("test/unit/math/prim/fun/gp_exp_quad_cov_test")
268-
runTests("test/unit/math/prim/fun/mdivide_left_tri_test")
269-
runTests("test/unit/math/prim/fun/mdivide_right_tri_test")
270-
runTests("test/unit/math/prim/fun/multiply_test")
271-
runTests("test/unit/math/rev/fun/mdivide_left_tri_test")
272-
runTests("test/unit/math/rev/fun/multiply_test")
273-
}
274-
post { always { retry(3) { deleteDir() } } }
275-
}
251+
// stage('OpenCL tests async') {
252+
// agent { label "gpu-async" }
253+
// when {
254+
// expression {
255+
// runGpuAsync
256+
// }
257+
// }
258+
// steps {
259+
// deleteDir()
260+
// unstash 'MathSetup'
261+
// sh "echo CXX=${env.CXX} -Werror > make/local"
262+
// sh "echo STAN_OPENCL=true>> make/local"
263+
// sh "echo OPENCL_PLATFORM_ID=0>> make/local"
264+
// sh "echo OPENCL_DEVICE_ID=${OPENCL_DEVICE_ID}>> make/local"
265+
// sh "make -j${env.PARALLEL} test-headers"
266+
// runTests("test/unit/math/opencl")
267+
// runTests("test/unit/math/prim/fun/gp_exp_quad_cov_test")
268+
// runTests("test/unit/math/prim/fun/mdivide_left_tri_test")
269+
// runTests("test/unit/math/prim/fun/mdivide_right_tri_test")
270+
// runTests("test/unit/math/prim/fun/multiply_test")
271+
// runTests("test/unit/math/rev/fun/mdivide_left_tri_test")
272+
// runTests("test/unit/math/rev/fun/multiply_test")
273+
// }
274+
// post { always { retry(3) { deleteDir() } } }
275+
// }
276276
stage('Distribution tests') {
277277
agent { label "distribution-tests" }
278278
steps {

stan/math/prim/err.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <stan/math/prim/err/check_nonzero_size.hpp>
2929
#include <stan/math/prim/err/check_not_nan.hpp>
3030
#include <stan/math/prim/err/check_ordered.hpp>
31+
#include <stan/math/prim/err/check_sorted.hpp>
3132
#include <stan/math/prim/err/check_pos_definite.hpp>
3233
#include <stan/math/prim/err/check_pos_semidefinite.hpp>
3334
#include <stan/math/prim/err/check_positive.hpp>
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#ifndef STAN_MATH_PRIM_ERR_CHECK_SORTED_HPP
2+
#define STAN_MATH_PRIM_ERR_CHECK_SORTED_HPP
3+
4+
#include <stan/math/prim/err/throw_domain_error.hpp>
5+
#include <stan/math/prim/fun/Eigen.hpp>
6+
#include <sstream>
7+
#include <string>
8+
#include <vector>
9+
10+
namespace stan {
11+
namespace math {
12+
13+
/**
14+
* Check if the specified vector is sorted into increasing order (repeated
15+
* values are okay).
16+
* @tparam T_y Type of scalar
17+
* @param function Function name (for error messages)
18+
* @param name Variable name (for error messages)
19+
* @param y Vector to test
20+
* @throw <code>std::domain_error</code> if the vector elements are
21+
* not sorted, or if any element is <code>NaN</code>.
22+
*/
23+
template <typename T_y>
24+
void check_sorted(const char* function, const char* name,
25+
const Eigen::Matrix<T_y, Eigen::Dynamic, 1>& y) {
26+
using size_type = index_type_t<Eigen::Matrix<T_y, Eigen::Dynamic, 1>>;
27+
28+
for (size_type n = 1; n < y.size(); n++) {
29+
if (!(y[n] >= y[n - 1])) {
30+
std::ostringstream msg1;
31+
msg1 << "is not a valid sorted vector."
32+
<< " The element at " << stan::error_index::value + n << " is ";
33+
std::string msg1_str(msg1.str());
34+
std::ostringstream msg2;
35+
msg2 << ", but should be greater than or equal to the previous element, "
36+
<< y[n - 1];
37+
std::string msg2_str(msg2.str());
38+
throw_domain_error(function, name, y[n], msg1_str.c_str(),
39+
msg2_str.c_str());
40+
}
41+
}
42+
}
43+
44+
/**
45+
* Check if the specified vector is sorted into increasing order (repeated
46+
* values are okay).
47+
* @tparam T_y Type of scalar
48+
* @param function Function name (for error messages)
49+
* @param name Variable name (for error messages)
50+
* @param y <code>std::vector</code> to test
51+
* @throw <code>std::domain_error</code> if the vector elements are
52+
* not sorted, or if any element
53+
* is <code>NaN</code>.
54+
*/
55+
template <typename T_y>
56+
void check_sorted(const char* function, const char* name,
57+
const std::vector<T_y>& y) {
58+
for (size_t n = 1; n < y.size(); n++) {
59+
if (!(y[n] >= y[n - 1])) {
60+
std::ostringstream msg1;
61+
msg1 << "is not a valid sorted vector."
62+
<< " The element at " << stan::error_index::value + n << " is ";
63+
std::string msg1_str(msg1.str());
64+
std::ostringstream msg2;
65+
msg2 << ", but should be greater than or equal to the previous element, "
66+
<< y[n - 1];
67+
std::string msg2_str(msg2.str());
68+
throw_domain_error(function, name, y[n], msg1_str.c_str(),
69+
msg2_str.c_str());
70+
}
71+
}
72+
}
73+
74+
} // namespace math
75+
} // namespace stan
76+
#endif

stan/math/prim/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
#include <stan/math/prim/fun/elt_multiply.hpp>
8383
#include <stan/math/prim/fun/erf.hpp>
8484
#include <stan/math/prim/fun/erfc.hpp>
85+
#include <stan/math/prim/fun/eval.hpp>
8586
#include <stan/math/prim/fun/exp.hpp>
8687
#include <stan/math/prim/fun/exp2.hpp>
8788
#include <stan/math/prim/fun/expm1.hpp>

stan/math/prim/fun/eval.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#ifndef STAN_MATH_PRIM_FUN_EVAL_HPP
2+
#define STAN_MATH_PRIM_FUN_EVAL_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/fun/Eigen.hpp>
6+
7+
namespace stan {
8+
namespace math {
9+
10+
/**
11+
* Inputs which have a plain_type equal to the own time are forwarded
12+
* unmodified (for Eigen expressions these types are different)
13+
*
14+
* @tparam T Input type
15+
* @param[in] arg Input argument
16+
* @return Forwarded input argument
17+
**/
18+
template <typename T,
19+
require_same_t<std::decay_t<T>, plain_type_t<T>>* = nullptr>
20+
inline decltype(auto) eval(T&& arg) {
21+
return std::forward<T>(arg);
22+
}
23+
24+
/**
25+
* Inputs which have a plain_type different from their own type are
26+
* Eval'd (this catches Eigen expressions)
27+
*
28+
* @tparam T Input type
29+
* @param[in] arg Input argument
30+
* @return Eval'd argument
31+
**/
32+
template <typename T,
33+
require_not_same_t<std::decay_t<T>, plain_type_t<T>>* = nullptr>
34+
inline decltype(auto) eval(const T& arg) {
35+
return arg.eval();
36+
}
37+
38+
} // namespace math
39+
} // namespace stan
40+
41+
#endif

stan/math/prim/fun/value_of.hpp

Lines changed: 24 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -10,102 +10,48 @@ namespace stan {
1010
namespace math {
1111

1212
/**
13-
* Return the value of the specified scalar argument
14-
* converted to a double value.
13+
* Inputs that are arithmetic types or containers of airthmetric types
14+
* are returned from value_of unchanged
1515
*
16-
* <p>See the <code>primitive_value</code> function to
17-
* extract values without casting to <code>double</code>.
18-
*
19-
* <p>This function is meant to cover the primitive types. For
20-
* types requiring pass-by-reference, this template function
21-
* should be specialized.
22-
*
23-
* @tparam T type of scalar.
24-
* @param x scalar to convert to double
25-
* @return value of scalar cast to double
26-
*/
27-
template <typename T, require_arithmetic_t<T>* = nullptr>
28-
inline double value_of(const T x) {
29-
return static_cast<double>(x);
30-
}
31-
32-
/**
33-
* Return the specified argument.
34-
*
35-
* <p>See <code>value_of(T)</code> for a polymorphic
36-
* implementation using static casts.
37-
*
38-
* <p>This inline pass-through no-op should be compiled away.
39-
*
40-
* @param x value
41-
* @return input value
42-
*/
43-
template <>
44-
inline double value_of<double>(double x) {
45-
return x;
16+
* @tparam T Input type
17+
* @param[in] x Input argument
18+
* @return Forwarded input argument
19+
**/
20+
template <typename T, require_st_arithmetic<T>* = nullptr>
21+
inline decltype(auto) value_of(T&& x) {
22+
return std::forward<T>(x);
4623
}
4724

4825
/**
49-
* Return the specified argument.
50-
*
51-
* <p>See <code>value_of(T)</code> for a polymorphic
52-
* implementation using static casts.
26+
* For std::vectors of non-arithmetic types, return a std::vector composed
27+
* of value_of applied to each element.
5328
*
54-
* <p>This inline pass-through no-op should be compiled away.
55-
*
56-
* @param x value
57-
* @return input value
58-
*/
59-
inline int value_of(int x) { return x; }
60-
61-
/**
62-
* Convert a std::vector of type T to a std::vector of
63-
* child_type<T>::type.
64-
*
65-
* @tparam T Scalar type in std::vector
66-
* @param[in] x std::vector to be converted
29+
* @tparam T Input element type
30+
* @param[in] x Input std::vector
6731
* @return std::vector of values
6832
**/
69-
template <typename T, require_not_double_or_int_t<T>* = nullptr>
70-
inline std::vector<typename child_type<T>::type> value_of(
71-
const std::vector<T>& x) {
72-
size_t x_size = x.size();
73-
std::vector<typename child_type<T>::type> result(x_size);
74-
for (size_t i = 0; i < x_size; i++) {
75-
result[i] = value_of(x[i]);
33+
template <typename T, require_not_st_arithmetic<T>* = nullptr>
34+
inline auto value_of(const std::vector<T>& x) {
35+
std::vector<plain_type_t<decltype(value_of(std::declval<T>()))>> out;
36+
out.reserve(x.size());
37+
for (auto&& x_elem : x) {
38+
out.emplace_back(value_of(x_elem));
7639
}
77-
return result;
40+
return out;
7841
}
7942

8043
/**
81-
* Return the specified argument.
82-
*
83-
* <p>See <code>value_of(T)</code> for a polymorphic
84-
* implementation using static casts.
85-
*
86-
* <p>This inline pass-through no-op should be compiled away.
87-
*
88-
* @param x Specified std::vector.
89-
* @return Specified std::vector.
90-
*/
91-
template <typename Vec, require_std_vector_vt<is_double_or_int, Vec>* = nullptr>
92-
inline Vec value_of(Vec&& x) {
93-
return std::forward<Vec>(x);
94-
}
95-
96-
/**
97-
* Convert a matrix of type T to a matrix of doubles.
98-
*
99-
* T must implement value_of. See
100-
* test/math/fwd/fun/value_of.cpp for fvar and var usage.
44+
* For Eigen matrices and expressions of non-arithmetic types, return an
45+
*expression that represents the Eigen::Matrix resulting from applying value_of
46+
*elementwise
10147
*
10248
* @tparam EigMat type of the matrix
10349
*
10450
* @param[in] M Matrix to be converted
10551
* @return Matrix of values
10652
**/
10753
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
108-
require_not_vt_double_or_int<EigMat>* = nullptr>
54+
require_not_st_arithmetic<EigMat>* = nullptr>
10955
inline auto value_of(EigMat&& M) {
11056
return make_holder(
11157
[](auto& a) {
@@ -114,25 +60,6 @@ inline auto value_of(EigMat&& M) {
11460
std::forward<EigMat>(M));
11561
}
11662

117-
/**
118-
* Return the specified argument.
119-
*
120-
* <p>See <code>value_of(T)</code> for a polymorphic
121-
* implementation using static casts.
122-
*
123-
* <p>This inline pass-through no-op should be compiled away.
124-
*
125-
* @tparam EigMat type of the matrix
126-
*
127-
* @param x Specified matrix.
128-
* @return Specified matrix.
129-
*/
130-
template <typename EigMat,
131-
require_eigen_vt<is_double_or_int, EigMat>* = nullptr>
132-
inline EigMat value_of(EigMat&& x) {
133-
return std::forward<EigMat>(x);
134-
}
135-
13663
} // namespace math
13764
} // namespace stan
13865

stan/math/prim/functor.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
66
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
77
#include <stan/math/prim/functor/apply_vector_unary.hpp>
8-
#include <stan/math/prim/functor/coupled_ode_observer.hpp>
98
#include <stan/math/prim/functor/coupled_ode_system.hpp>
109
#include <stan/math/prim/functor/finite_diff_gradient.hpp>
1110
#include <stan/math/prim/functor/finite_diff_gradient_auto.hpp>
@@ -14,6 +13,9 @@
1413
#include <stan/math/prim/functor/finite_diff_hessian_helper.hpp>
1514
#include <stan/math/prim/functor/integrate_1d.hpp>
1615
#include <stan/math/prim/functor/integrate_ode_rk45.hpp>
16+
#include <stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp>
17+
#include <stan/math/prim/functor/ode_rk45.hpp>
18+
#include <stan/math/prim/functor/ode_store_sensitivities.hpp>
1719
#include <stan/math/prim/functor/map_rect.hpp>
1820
#include <stan/math/prim/functor/map_rect_combine.hpp>
1921
#include <stan/math/prim/functor/map_rect_concurrent.hpp>
@@ -24,4 +26,5 @@
2426
#include <stan/math/prim/functor/operands_and_partials.hpp>
2527
#include <stan/math/prim/functor/reduce_sum.hpp>
2628
#include <stan/math/prim/functor/reduce_sum_static.hpp>
29+
2730
#endif

stan/math/prim/functor/apply.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ constexpr decltype(auto) apply_impl(F&& f, Tuple&& t,
2626
return f(std::forward<decltype(std::get<I>(t))>(std::get<I>(t))...);
2727
}
2828
} // namespace internal
29+
2930
/*
3031
* Call the functor f with the tuple of arguments t, like:
3132
*

0 commit comments

Comments
 (0)