Skip to content

Commit f1967ff

Browse files
author
Andrew Johnson
committed
Merge branch 'develop' into feature/scalar_binary_complex
2 parents 317970a + 2a7a857 commit f1967ff

166 files changed

Lines changed: 1937 additions & 1219 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

stan/math/fwd/core/operator_division.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ inline fvar<T> operator/(const fvar<T>& x1, const fvar<T>& x2) {
3131
* @param x2 second argument
3232
* @return first argument divided by second argument
3333
*/
34-
template <typename T, typename U, require_arithmetic_t<U>...>
34+
template <typename T, typename U, require_arithmetic_t<U>* = nullptr>
3535
inline fvar<T> operator/(const fvar<T>& x1, U x2) {
3636
return fvar<T>(x1.val_ / x2, x1.d_ / x2);
3737
}
@@ -44,7 +44,7 @@ inline fvar<T> operator/(const fvar<T>& x1, U x2) {
4444
* @param x2 second argument
4545
* @return first argument divided by second argument
4646
*/
47-
template <typename T, typename U, require_arithmetic_t<U>...>
47+
template <typename T, typename U, require_arithmetic_t<U>* = nullptr>
4848
inline fvar<T> operator/(U x1, const fvar<T>& x2) {
4949
return fvar<T>(x1 / x2.val_, -x1 * x2.d_ / (x2.val_ * x2.val_));
5050
}
@@ -54,7 +54,7 @@ inline std::complex<fvar<T>> operator/(const std::complex<fvar<T>>& x1,
5454
const std::complex<fvar<T>>& x2) {
5555
return internal::complex_divide(x1, x2);
5656
}
57-
template <typename T, typename U, require_arithmetic_t<U>...>
57+
template <typename T, typename U, require_arithmetic_t<U>* = nullptr>
5858
inline std::complex<fvar<T>> operator/(const std::complex<fvar<T>>& x1,
5959
const std::complex<U>& x2) {
6060
return internal::complex_divide(x1, x2);
@@ -64,17 +64,17 @@ inline std::complex<fvar<T>> operator/(const std::complex<fvar<T>>& x1,
6464
const fvar<T>& x2) {
6565
return internal::complex_divide(x1, x2);
6666
}
67-
template <typename T, typename U, require_arithmetic_t<U>...>
67+
template <typename T, typename U, require_arithmetic_t<U>* = nullptr>
6868
inline std::complex<fvar<T>> operator/(const std::complex<fvar<T>>& x1, U x2) {
6969
return internal::complex_divide(x1, x2);
7070
}
7171

72-
template <typename T, typename U, require_arithmetic_t<U>...>
72+
template <typename T, typename U, require_arithmetic_t<U>* = nullptr>
7373
inline std::complex<fvar<T>> operator/(const std::complex<U>& x1,
7474
const std::complex<fvar<T>>& x2) {
7575
return internal::complex_divide(x1, x2);
7676
}
77-
template <typename T, typename U, require_arithmetic_t<U>...>
77+
template <typename T, typename U, require_arithmetic_t<U>* = nullptr>
7878
inline std::complex<fvar<T>> operator/(const std::complex<U>& x1,
7979
const fvar<T>& x2) {
8080
return internal::complex_divide(x1, x2);
@@ -92,7 +92,7 @@ inline std::complex<fvar<T>> operator/(const fvar<T>& x1,
9292
return internal::complex_divide(x1, x2);
9393
}
9494

95-
template <typename T, typename U, require_arithmetic_t<U>...>
95+
template <typename T, typename U, require_arithmetic_t<U>* = nullptr>
9696
inline std::complex<fvar<T>> operator/(U x1, const std::complex<fvar<T>>& x2) {
9797
return internal::complex_divide(x1, x2);
9898
}

stan/math/fwd/fun/is_nan.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace math {
1818
* @param x Value to test.
1919
* @return <code>1</code> if the value is NaN and <code>0</code> otherwise.
2020
*/
21-
template <typename T, require_fvar_t<T>...>
21+
template <typename T, require_fvar_t<T>* = nullptr>
2222
inline bool is_nan(T&& x) {
2323
return is_nan(std::forward<decltype(x.val())>(x.val()));
2424
}

stan/math/fwd/fun/log_mix.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ inline fvar<T> log_mix(const fvar<T>& theta, const fvar<T>& lambda1,
114114
}
115115
}
116116

117-
template <typename T, typename P, require_all_arithmetic_t<P>...>
117+
template <typename T, typename P, require_all_arithmetic_t<P>* = nullptr>
118118
inline fvar<T> log_mix(const fvar<T>& theta, const fvar<T>& lambda1,
119119
P lambda2) {
120120
if (lambda1.val_ > lambda2) {
@@ -132,7 +132,7 @@ inline fvar<T> log_mix(const fvar<T>& theta, const fvar<T>& lambda1,
132132
}
133133
}
134134

135-
template <typename T, typename P, require_all_arithmetic_t<P>...>
135+
template <typename T, typename P, require_all_arithmetic_t<P>* = nullptr>
136136
inline fvar<T> log_mix(const fvar<T>& theta, P lambda1,
137137
const fvar<T>& lambda2) {
138138
if (lambda1 > lambda2.val_) {
@@ -150,7 +150,7 @@ inline fvar<T> log_mix(const fvar<T>& theta, P lambda1,
150150
}
151151
}
152152

153-
template <typename T, typename P, require_all_arithmetic_t<P>...>
153+
template <typename T, typename P, require_all_arithmetic_t<P>* = nullptr>
154154
inline fvar<T> log_mix(P theta, const fvar<T>& lambda1,
155155
const fvar<T>& lambda2) {
156156
if (lambda1.val_ > lambda2.val_) {
@@ -169,7 +169,7 @@ inline fvar<T> log_mix(P theta, const fvar<T>& lambda1,
169169
}
170170

171171
template <typename T, typename P1, typename P2,
172-
require_all_arithmetic_t<P1, P2>...>
172+
require_all_arithmetic_t<P1, P2>* = nullptr>
173173
inline fvar<T> log_mix(const fvar<T>& theta, P1 lambda1, P2 lambda2) {
174174
if (lambda1 > lambda2) {
175175
fvar<T> partial_deriv_array[1];
@@ -185,7 +185,7 @@ inline fvar<T> log_mix(const fvar<T>& theta, P1 lambda1, P2 lambda2) {
185185
}
186186

187187
template <typename T, typename P1, typename P2,
188-
require_all_arithmetic_t<P1, P2>...>
188+
require_all_arithmetic_t<P1, P2>* = nullptr>
189189
inline fvar<T> log_mix(P1 theta, const fvar<T>& lambda1, P2 lambda2) {
190190
if (lambda1.val_ > lambda2) {
191191
fvar<T> partial_deriv_array[1];
@@ -201,7 +201,7 @@ inline fvar<T> log_mix(P1 theta, const fvar<T>& lambda1, P2 lambda2) {
201201
}
202202

203203
template <typename T, typename P1, typename P2,
204-
require_all_arithmetic_t<P1, P2>...>
204+
require_all_arithmetic_t<P1, P2>* = nullptr>
205205
inline fvar<T> log_mix(P1 theta, P2 lambda1, const fvar<T>& lambda2) {
206206
if (lambda1 > lambda2.val_) {
207207
fvar<T> partial_deriv_array[1];

stan/math/fwd/fun/log_sum_exp.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ inline fvar<T> log_sum_exp(const fvar<T>& x1, double x2) {
5050
* @param[in] x Matrix of specified values.
5151
* @return The log of the sum of the exponentiated vector values.
5252
*/
53-
template <typename T, require_container_st<is_fvar, T>...>
53+
template <typename T, require_container_st<is_fvar, T>* = nullptr>
5454
inline auto log_sum_exp(const T& x) {
5555
return apply_vector_unary<T>::reduce(x, [&](const auto& v) {
5656
using T_fvar_inner = typename value_type_t<decltype(v)>::Scalar;

stan/math/opencl/copy.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace math {
4141
* @return matrix_cl with a copy of the data in the source matrix
4242
*/
4343
template <typename Mat, typename Mat_scalar = scalar_type_t<Mat>,
44-
require_eigen_vt<std::is_arithmetic, Mat>...>
44+
require_eigen_vt<std::is_arithmetic, Mat>* = nullptr>
4545
inline matrix_cl<Mat_scalar> to_matrix_cl(Mat&& src) {
4646
return matrix_cl<Mat_scalar>(std::forward<Mat>(src));
4747
}
@@ -60,7 +60,7 @@ inline matrix_cl<Mat_scalar> to_matrix_cl(Mat&& src) {
6060
* @return matrix_cl with a copy of the data in the source matrix
6161
*/
6262
template <typename Vec, typename Vec_scalar = scalar_type_t<Vec>,
63-
require_std_vector_vt<std::is_arithmetic, Vec>...>
63+
require_std_vector_vt<std::is_arithmetic, Vec>* = nullptr>
6464
inline matrix_cl<Vec_scalar> to_matrix_cl(Vec&& src) {
6565
return matrix_cl<Vec_scalar>(std::forward<Vec>(src));
6666
}
@@ -158,7 +158,7 @@ inline std::vector<T> packed_copy(const matrix_cl<T>& src) {
158158
*/
159159
template <matrix_cl_view matrix_view, typename Vec,
160160
typename Vec_scalar = scalar_type_t<Vec>,
161-
require_vector_vt<std::is_arithmetic, Vec>...>
161+
require_vector_vt<std::is_arithmetic, Vec>* = nullptr>
162162
inline matrix_cl<Vec_scalar> packed_copy(Vec&& src, int rows) {
163163
const int packed_size = rows * (rows + 1) / 2;
164164
check_size_match("copy (packed std::vector -> OpenCL)", "src.size()",

stan/math/opencl/kernel_cl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ inline void assign_event(const cl::Event& e,
112112
helper.set(e, m);
113113
}
114114

115-
template <typename T, require_same_t<T, cl::Event>...>
115+
template <typename T, require_same_t<T, cl::Event>* = nullptr>
116116
inline void assign_events(const T&) {}
117117

118118
/** \ingroup kernel_executor_opencl

stan/math/opencl/kernel_generator/holder_cl.hpp

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -74,35 +74,6 @@ auto holder_cl(T&& a, Ptrs*... ptrs) {
7474
}
7575

7676
namespace internal {
77-
/**
78-
* Handles single element (moving rvalues to heap) for construction of
79-
* `holder_cl` from a functor. For lvalues just sets the `res` pointer.
80-
* @tparam T type of the element
81-
* @param a element to handle
82-
* @param res resulting pointer to element
83-
* @return tuple of pointer allocated on heap (empty).
84-
*/
85-
template <typename T>
86-
auto holder_cl_handle_element(const T& a, const T*& res) {
87-
res = &a;
88-
return std::make_tuple();
89-
}
90-
91-
/**
92-
* Handles single element (moving rvalues to heap) for construction of
93-
* `holder_cl` from a functor. Rvalue is moved to heap and the pointer to heap
94-
* memory is assigned to res and returned in a tuple.
95-
* @tparam T type of the element
96-
* @param a element to handle
97-
* @param res resulting pointer to element
98-
* @return tuple of pointer allocated on heap (containing single pointer).
99-
*/
100-
template <typename T>
101-
auto holder_cl_handle_element(std::remove_reference_t<T>&& a, const T*& res) {
102-
res = new T(std::move(a));
103-
return std::make_tuple(res);
104-
}
105-
10677
/**
10778
* Second step in implementation of construction `holder_cl` from a functor.
10879
* @tparam T type of the result expression
@@ -114,8 +85,8 @@ auto holder_cl_handle_element(std::remove_reference_t<T>&& a, const T*& res) {
11485
* @return `holder_cl` referencing given expression
11586
*/
11687
template <typename T, std::size_t... Is, typename... Args>
117-
auto make_holder_cl_impl2(T&& expr, std::index_sequence<Is...>,
118-
const std::tuple<Args*...>& ptrs) {
88+
auto make_holder_cl_impl_step2(T&& expr, std::index_sequence<Is...>,
89+
const std::tuple<Args*...>& ptrs) {
11990
return holder_cl(std::forward<T>(expr), std::get<Is>(ptrs)...);
12091
}
12192

@@ -129,12 +100,12 @@ auto make_holder_cl_impl2(T&& expr, std::index_sequence<Is...>,
129100
* @return `holder_cl` referencing given expression
130101
*/
131102
template <typename T, std::size_t... Is, typename... Args>
132-
auto make_holder_cl_impl(const T& func, std::index_sequence<Is...>,
133-
Args&&... args) {
134-
std::tuple<const std::remove_reference_t<Args>*...> res;
103+
auto make_holder_cl_impl_step1(const T& func, std::index_sequence<Is...>,
104+
Args&&... args) {
105+
std::tuple<std::remove_reference_t<Args>*...> res;
135106
auto ptrs = std::tuple_cat(
136-
holder_cl_handle_element(std::forward<Args>(args), std::get<Is>(res))...);
137-
return make_holder_cl_impl2(
107+
holder_handle_element(std::forward<Args>(args), std::get<Is>(res))...);
108+
return make_holder_cl_impl_step2(
138109
func(*std::get<Is>(res)...),
139110
std::make_index_sequence<std::tuple_size<decltype(ptrs)>::value>(), ptrs);
140111
}
@@ -156,7 +127,7 @@ template <typename T, typename... Args,
156127
decltype(std::declval<T>()(std::declval<Args&>()...)),
157128
Args...>* = nullptr>
158129
auto make_holder_cl(const T& func, Args&&... args) {
159-
return internal::make_holder_cl_impl(
130+
return internal::make_holder_cl_impl_step1(
160131
func, std::make_index_sequence<sizeof...(Args)>(),
161132
std::forward<Args>(args)...);
162133
}

stan/math/opencl/matrix_cl.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ class matrix_cl<T, require_arithmetic_t<T>> {
233233
* @throw <code>std::system_error</code> if the memory on the device could not
234234
* be allocated
235235
*/
236-
template <typename Vec, require_std_vector_vt<is_eigen, Vec>...,
237-
require_st_same<Vec, T>...>
236+
template <typename Vec, require_std_vector_vt<is_eigen, Vec>* = nullptr,
237+
require_st_same<Vec, T>* = nullptr>
238238
explicit matrix_cl(Vec&& A) try : rows_(A.empty() ? 0 : A[0].size()),
239239
cols_(A.size()) {
240240
if (this->size() == 0) {
@@ -305,7 +305,8 @@ class matrix_cl<T, require_arithmetic_t<T>> {
305305
* @throw <code>std::system_error</code> if the memory on the device could not
306306
* be allocated
307307
*/
308-
template <typename Mat, require_eigen_t<Mat>..., require_vt_same<Mat, T>...>
308+
template <typename Mat, require_eigen_t<Mat>* = nullptr,
309+
require_vt_same<Mat, T>* = nullptr>
309310
explicit matrix_cl(Mat&& A,
310311
matrix_cl_view partial_view = matrix_cl_view::Entire)
311312
: rows_(A.rows()), cols_(A.cols()), view_(partial_view) {
@@ -360,8 +361,8 @@ class matrix_cl<T, require_arithmetic_t<T>> {
360361
* @throw <code>std::system_error</code> if the memory on the device could not
361362
* be allocated
362363
*/
363-
template <typename Vec, require_std_vector_t<Vec>...,
364-
require_vt_same<Vec, T>...>
364+
template <typename Vec, require_std_vector_t<Vec>* = nullptr,
365+
require_vt_same<Vec, T>* = nullptr>
365366
explicit matrix_cl(Vec&& A,
366367
matrix_cl_view partial_view = matrix_cl_view::Entire)
367368
: matrix_cl(std::forward<Vec>(A), A.size(), 1) {}
@@ -384,8 +385,8 @@ class matrix_cl<T, require_arithmetic_t<T>> {
384385
* @throw <code>std::system_error</code> if the memory on the device could not
385386
* be allocated
386387
*/
387-
template <typename Vec, require_std_vector_t<Vec>...,
388-
require_vt_same<Vec, T>...>
388+
template <typename Vec, require_std_vector_t<Vec>* = nullptr,
389+
require_vt_same<Vec, T>* = nullptr>
389390
explicit matrix_cl(Vec&& A, const int& R, const int& C,
390391
matrix_cl_view partial_view = matrix_cl_view::Entire)
391392
: rows_(R), cols_(C), view_(partial_view) {
@@ -409,7 +410,7 @@ class matrix_cl<T, require_arithmetic_t<T>> {
409410
* @throw <code>std::system_error</code> if the memory on the device could not
410411
* be allocated
411412
*/
412-
template <typename U, require_same_t<T, U>...>
413+
template <typename U, require_same_t<T, U>* = nullptr>
413414
explicit matrix_cl(const U* A, const int& R, const int& C,
414415
matrix_cl_view partial_view = matrix_cl_view::Entire)
415416
: rows_(R), cols_(C), view_(partial_view) {

stan/math/opencl/prim/bernoulli_logit_glm_lpmf.hpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ return_type_t<T_alpha, T_beta> bernoulli_logit_glm_lpmf(
4949
const T_alpha& alpha, const T_beta& beta) {
5050
static const char* function = "bernoulli_logit_glm_lpmf(OpenCL)";
5151
using T_partials_return = partials_return_t<T_alpha, T_beta>;
52-
52+
using T_alpha_ref = ref_type_if_t<!is_constant<T_alpha>::value, T_alpha>;
53+
using T_beta_ref = ref_type_if_t<!is_constant<T_beta>::value, T_beta>;
5354
using Eigen::Dynamic;
5455
using Eigen::Matrix;
5556

@@ -71,15 +72,12 @@ return_type_t<T_alpha, T_beta> bernoulli_logit_glm_lpmf(
7172
if (N == 0) {
7273
return 0;
7374
}
74-
7575
if (!include_summand<propto, T_alpha, T_beta>::value) {
7676
return 0;
7777
}
7878

79-
T_partials_return logp(0);
80-
81-
const auto& beta_ref = to_ref_if<!is_constant<T_beta>::value>(beta);
82-
const auto& alpha_ref = to_ref_if<!is_constant<T_alpha>::value>(alpha);
79+
T_beta_ref beta_ref = beta;
80+
T_alpha_ref alpha_ref = alpha;
8381

8482
const auto& beta_val = value_of_rec(beta_ref);
8583
const auto& alpha_val = value_of_rec(alpha_ref);
@@ -112,7 +110,6 @@ return_type_t<T_alpha, T_beta> bernoulli_logit_glm_lpmf(
112110
(exp_m_ytheta_expr + 1))));
113111

114112
const int wgs = logp_expr.rows();
115-
116113
matrix_cl<double> logp_cl(wgs, 1);
117114
constexpr bool need_theta_derivative
118115
= !is_constant_all<T_beta, T_alpha>::value;
@@ -126,8 +123,7 @@ return_type_t<T_alpha, T_beta> bernoulli_logit_glm_lpmf(
126123
logp_expr, calc_if<need_theta_derivative>(theta_derivative_expr),
127124
calc_if<need_theta_derivative_sum>(colwise_sum(theta_derivative_expr)));
128125

129-
logp += sum(from_matrix_cl<Eigen::Dynamic, 1>(logp_cl));
130-
126+
T_partials_return logp = sum(from_matrix_cl<Eigen::Dynamic, 1>(logp_cl));
131127
if (!std::isfinite(logp)) {
132128
check_bounded(function, "Vector of dependent variables",
133129
from_matrix_cl(y_cl), 0, 1);
@@ -137,8 +133,8 @@ return_type_t<T_alpha, T_beta> bernoulli_logit_glm_lpmf(
137133
from_matrix_cl(x_cl));
138134
}
139135

140-
operands_and_partials<decltype(alpha_ref), decltype(beta_ref)> ops_partials(
141-
alpha_ref, beta_ref);
136+
operands_and_partials<T_alpha_ref, T_beta_ref> ops_partials(alpha_ref,
137+
beta_ref);
142138
// Compute the necessary derivatives.
143139
if (!is_constant_all<T_alpha>::value) {
144140
if (is_alpha_vector) {

stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ return_type_t<T_alpha, T_beta> categorical_logit_glm_lpmf(
4545
const matrix_cl<int>& y_cl, const matrix_cl<double>& x_cl,
4646
const T_alpha& alpha, const T_beta& beta) {
4747
using T_partials_return = partials_return_t<T_alpha, T_beta>;
48-
static const char* function = "categorical_logit_glm_lpmf";
49-
48+
using T_alpha_ref = ref_type_if_t<!is_constant<T_alpha>::value, T_alpha>;
49+
using T_beta_ref = ref_type_if_t<!is_constant<T_beta>::value, T_beta>;
5050
using Eigen::Array;
5151
using Eigen::Dynamic;
5252
using Eigen::Matrix;
@@ -55,6 +55,7 @@ return_type_t<T_alpha, T_beta> categorical_logit_glm_lpmf(
5555
const size_t N_attributes = x_cl.cols();
5656
const size_t N_classes = beta.cols();
5757

58+
static const char* function = "categorical_logit_glm_lpmf";
5859
if (y_cl.size() != 1) {
5960
check_size_match(function, "x.rows()", N_instances, "y.size()",
6061
y_cl.size());
@@ -66,13 +67,12 @@ return_type_t<T_alpha, T_beta> categorical_logit_glm_lpmf(
6667
if (N_instances == 0 || N_classes <= 1) {
6768
return 0;
6869
}
69-
7070
if (!include_summand<propto, T_alpha, T_beta>::value) {
7171
return 0;
7272
}
7373

74-
const auto& alpha_ref = to_ref_if<!is_constant<T_alpha>::value>(alpha);
75-
const auto& beta_ref = to_ref_if<!is_constant<T_beta>::value>(beta);
74+
T_alpha_ref alpha_ref = alpha;
75+
T_beta_ref beta_ref = beta;
7676

7777
const auto& alpha_val = value_of_rec(alpha_ref);
7878
const auto& beta_val = value_of_rec(beta_ref);
@@ -118,8 +118,8 @@ return_type_t<T_alpha, T_beta> categorical_logit_glm_lpmf(
118118
from_matrix_cl(x_cl));
119119
}
120120

121-
operands_and_partials<decltype(alpha_ref), decltype(beta_ref)> ops_partials(
122-
alpha_ref, beta_ref);
121+
operands_and_partials<T_alpha_ref, T_beta_ref> ops_partials(alpha_ref,
122+
beta_ref);
123123
if (!is_constant_all<T_alpha>::value) {
124124
ops_partials.edge1_.partials_
125125
= from_matrix_cl(alpha_derivative_cl).colwise().sum();

0 commit comments

Comments
 (0)