Skip to content

Commit 690fcea

Browse files
author
Andrew Johnson
committed
Allow for integer arrays with vectorisation
1 parent fcbd277 commit 690fcea

7 files changed

Lines changed: 292 additions & 37 deletions

File tree

stan/math/prim/functor/apply_scalar_binary.hpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/prim/meta/is_container.hpp>
77
#include <stan/math/prim/meta/is_eigen.hpp>
88
#include <stan/math/prim/meta/require_generics.hpp>
9+
#include <stan/math/prim/fun/num_elements.hpp>
910
#include <vector>
1011

1112
namespace stan {
@@ -55,6 +56,116 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
5556
return x.binaryExpr(y, f).eval();
5657
}
5758

59+
/**
60+
* Specialisation for use with one Eigen vector (row or column) and
61+
* a one-dimensional std::vector of integer types
62+
*
63+
* @tparam T1 Type of first argument to which functor is applied.
64+
* @tparam T2 Type of second argument to which functor is applied.
65+
* @tparam F Type of functor to apply.
66+
* @param x Eigen input to which operation is applied.
67+
* @param y Integer std::vector input to which operation is applied.
68+
* @param f functor to apply to inputs.
69+
* @return Eigen object with result of applying functor to inputs.
70+
*/
71+
template <typename T1, typename T2, typename F,
72+
require_eigen_vector_vt<is_stan_scalar, T1>* = nullptr,
73+
require_std_vector_vt<std::is_integral, T2>* = nullptr>
74+
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
75+
check_matching_sizes("Binary function", "x", x, "y", y);
76+
using int_vec_t = promote_scalar_t<value_type_t<T2>, plain_type_t<T1>>;
77+
Eigen::Map<const int_vec_t> y_map(y.data(), y.size());
78+
return x.binaryExpr(y_map, f).eval();
79+
}
80+
81+
/**
82+
* Specialisation for use with a one-dimensional std::vector of integer types
83+
* and one Eigen vector (row or column).
84+
*
85+
* @tparam T1 Type of first argument to which functor is applied.
86+
* @tparam T2 Type of second argument to which functor is applied.
87+
* @tparam F Type of functor to apply.
88+
* @param x Integer std::vector input to which operation is applied.
89+
* @param y Eigen input to which operation is applied.
90+
* @param f functor to apply to inputs.
91+
* @return Eigen object with result of applying functor to inputs.
92+
*/
93+
template <typename T1, typename T2, typename F,
94+
require_std_vector_vt<std::is_integral, T1>* = nullptr,
95+
require_eigen_vector_vt<is_stan_scalar, T2>* = nullptr>
96+
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
97+
check_matching_sizes("Binary function", "x", x, "y", y);
98+
using int_vec_t = promote_scalar_t<value_type_t<T1>, plain_type_t<T2>>;
99+
Eigen::Map<const int_vec_t> x_map(x.data(), x.size());
100+
return x_map.binaryExpr(y, f).eval();
101+
}
102+
103+
/**
104+
* Specialisation for use with one Eigen matrix and
105+
* a two-dimensional std::vector of integer types
106+
*
107+
* @tparam T1 Type of first argument to which functor is applied.
108+
* @tparam T2 Type of second argument to which functor is applied.
109+
* @tparam F Type of functor to apply.
110+
* @param x Eigen matrix input to which operation is applied.
111+
* @param y Nested integer std::vector input to which operation is applied.
112+
* @param f functor to apply to inputs.
113+
* @return Eigen object with result of applying functor to inputs.
114+
*/
115+
template <typename T1, typename T2, typename F,
116+
require_eigen_matrix_vt<is_stan_scalar, T1>* = nullptr,
117+
require_std_vector_vt<is_std_vector, T2>* = nullptr,
118+
require_std_vector_st<std::is_integral, T2>* = nullptr>
119+
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
120+
if (num_elements(x) != num_elements(y)) {
121+
std::ostringstream msg;
122+
msg << "Inputs to vectorized binary function must match in"
123+
<< " size if one is not a scalar";
124+
throw std::invalid_argument(msg.str());
125+
}
126+
using return_st = decltype(f(x(0), y[0][0]));
127+
Eigen::Matrix<return_st, Eigen::Dynamic, Eigen::Dynamic>
128+
result(x.rows(), x.cols());
129+
for (size_t i = 0; i < y.size(); ++i) {
130+
result.row(i) = apply_scalar_binary(x.row(i).transpose(),
131+
as_column_vector_or_scalar(y[i]), f);
132+
}
133+
return result;
134+
}
135+
136+
/**
137+
* Specialisation for use with a two-dimensional std::vector of integer types
138+
* and one Eigen matrix.
139+
*
140+
* @tparam T1 Type of first argument to which functor is applied.
141+
* @tparam T2 Type of second argument to which functor is applied.
142+
* @tparam F Type of functor to apply.
143+
* @param x Nested integer std::vector input to which operation is applied.
144+
* @param y Eigen matrix input to which operation is applied.
145+
* @param f functor to apply to inputs.
146+
* @return Eigen object with result of applying functor to inputs.
147+
*/
148+
template <typename T1, typename T2, typename F,
149+
require_std_vector_vt<is_std_vector, T1>* = nullptr,
150+
require_std_vector_st<std::is_integral, T1>* = nullptr,
151+
require_eigen_matrix_vt<is_stan_scalar, T2>* = nullptr>
152+
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
153+
if (num_elements(x) != num_elements(y)) {
154+
std::ostringstream msg;
155+
msg << "Inputs to vectorized binary function must match in"
156+
<< " size if one is not a scalar";
157+
throw std::invalid_argument(msg.str());
158+
}
159+
using return_st = decltype(f(x[0][0], y(0)));
160+
Eigen::Matrix<return_st, Eigen::Dynamic, Eigen::Dynamic>
161+
result(y.rows(), y.cols());
162+
for (size_t i = 0; i < x.size(); ++i) {
163+
result.row(i) = apply_scalar_binary(as_column_vector_or_scalar(x[i]),
164+
y.row(i).transpose(), f);
165+
}
166+
return result;
167+
}
168+
58169
/**
59170
* Specialisation for use when the first input is an Eigen type and the second
60171
* is a scalar. Eigen's unaryExpr framework is used for more efficient indexing

test/unit/math/mix/fun/bessel_first_kind_test.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ TEST(mathMixScalFun, besselFirstKind_vec) {
2020
return bessel_first_kind(x1, x2);
2121
};
2222

23-
Eigen::VectorXi in1(2);
24-
in1 << 3, 1;
23+
std::vector<int> std_in1{3, 1};
2524
Eigen::VectorXd in2(2);
2625
in2 << 0.5, 3.4;
27-
stan::test::expect_ad_vectorized_binary(f, in1, in2);
26+
stan::test::expect_ad_vectorized_binary(f, std_in1, in2);
27+
28+
std::vector<std::vector<int>> std_std_in1{std_in1, std_in1};
29+
Eigen::MatrixXd mat_in2 = in2.replicate(1, 2);
30+
stan::test::expect_ad_vectorized_binary(f, std_std_in1, mat_in2);
2831
}

test/unit/math/mix/fun/falling_factorial_test.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ TEST(mathMixScalFun, fallingFactorial_vec) {
2929

3030
Eigen::VectorXd in1(2);
3131
in1 << 0.5, 3.4;
32-
Eigen::VectorXi in2(2);
33-
in2 << 3, 1;
34-
stan::test::expect_ad_vectorized_binary(f, in1, in2);
32+
std::vector<int> std_in2{3, 1};
33+
stan::test::expect_ad_vectorized_binary(f, in1, std_in2);
34+
35+
36+
Eigen::MatrixXd mat_in1 = in1.replicate(1, 2);
37+
std::vector<std::vector<int>> std_std_in2{std_in2, std_in2};
38+
stan::test::expect_ad_vectorized_binary(f, mat_in1, std_std_in2);
3539
}

test/unit/math/prim/fun/bessel_first_kind_test.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ TEST(MathFunctions, bessel_first_kind_vec) {
2626
return bessel_first_kind(x1, x2);
2727
};
2828

29-
Eigen::VectorXi in1(3);
30-
in1 << 1, 3, 1;
29+
std::vector<int> std_in1{1, 3, 1};
3130
Eigen::VectorXd in2(3);
3231
in2 << -1.3, 0.7, 2.8;
33-
stan::test::binary_scalar_tester(f, in1, in2);
32+
stan::test::binary_scalar_tester(f, std_in1, in2);
33+
34+
Eigen::MatrixXd mat_in2 = in2.replicate(1, 3);
35+
std::vector<std::vector<int>> std_std_in1{std_in1, std_in1, std_in1};
36+
stan::test::binary_scalar_tester(f, std_std_in1, mat_in2);
37+
3438
}

test/unit/math/prim/fun/binary_scalar_tester.hpp

Lines changed: 133 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace test {
1717
* @param f functor to apply to inputs.
1818
*/
1919
template <typename F, typename T1, typename T2,
20-
require_all_not_std_vector_t<T1, T2>* = nullptr>
20+
require_all_not_vector_t<T1, T2>* = nullptr>
2121
void binary_scalar_tester_impl(const F& f, const T1& x, const T2& y) {
2222
auto vec_vec = f(x, y);
2323
auto vec_scal = f(x, y(0));
@@ -87,7 +87,7 @@ void binary_scalar_tester_impl(const F& f, const T1& x, const T2& y) {
8787
* @param f functor to apply to inputs.
8888
*/
8989
template <typename F, typename T1, typename T2,
90-
require_all_std_vector_t<T1, T2>* = nullptr>
90+
require_all_vector_t<T1, T2>* = nullptr>
9191
void binary_scalar_tester_impl(const F& f, const T1& x, const T2& y) {
9292
auto vec_vec = f(x, y);
9393
auto vec_scal = f(x, y[0]);
@@ -97,13 +97,13 @@ void binary_scalar_tester_impl(const F& f, const T1& x, const T2& y) {
9797
EXPECT_FLOAT_EQ(f(x[i], y[0]), vec_scal[i]);
9898
EXPECT_FLOAT_EQ(f(x[0], y[i]), scal_vec[i]);
9999
}
100-
T1 x_zero;
101-
T2 y_zero;
100+
plain_type_t<T1> x_zero;
101+
plain_type_t<T2> y_zero;
102102
EXPECT_THROW(f(x_zero, y), std::invalid_argument);
103103
EXPECT_THROW(f(x, y_zero), std::invalid_argument);
104104

105-
std::vector<T1> nest_x{x, x, x};
106-
std::vector<T2> nest_y{y, y, y};
105+
std::vector<plain_type_t<T1>> nest_x{x, x, x};
106+
std::vector<plain_type_t<T2>> nest_y{y, y, y};
107107
auto nestvec_nestvec = f(nest_x, nest_y);
108108
auto nestvec_scal = f(nest_x, y[0]);
109109
auto scal_nestvec = f(x[0], nest_y);
@@ -114,13 +114,13 @@ void binary_scalar_tester_impl(const F& f, const T1& x, const T2& y) {
114114
EXPECT_FLOAT_EQ(f(x[0], nest_y[i][j]), scal_nestvec[i][j]);
115115
}
116116
}
117-
std::vector<T1> nest_x_small{x, x};
118-
std::vector<T2> nest_y_small{y, y};
117+
std::vector<plain_type_t<T1>> nest_x_small{x, x};
118+
std::vector<plain_type_t<T2>> nest_y_small{y, y};
119119
EXPECT_THROW(f(nest_x, nest_y_small), std::invalid_argument);
120120
EXPECT_THROW(f(nest_x_small, nest_y), std::invalid_argument);
121121

122-
std::vector<std::vector<T1>> nest_nest_x{nest_x, nest_x, nest_x};
123-
std::vector<std::vector<T2>> nest_nest_y{nest_y, nest_y, nest_y};
122+
std::vector<std::vector<plain_type_t<T1>>> nest_nest_x{nest_x, nest_x, nest_x};
123+
std::vector<std::vector<plain_type_t<T2>>> nest_nest_y{nest_y, nest_y, nest_y};
124124
auto nestnestvec_nestnestvec = f(nest_nest_x, nest_nest_y);
125125
auto nestnestvec_scal = f(nest_nest_x, y[0]);
126126
auto scal_nestnestvec = f(x[0], nest_nest_y);
@@ -136,12 +136,129 @@ void binary_scalar_tester_impl(const F& f, const T1& x, const T2& y) {
136136
}
137137
}
138138
}
139-
std::vector<std::vector<T1>> nest_nest_x_small{nest_x, nest_x};
139+
std::vector<std::vector<plain_type_t<T1>>> nest_nest_x_small{nest_x, nest_x};
140+
std::vector<std::vector<plain_type_t<T2>>> nest_nest_y_small{nest_y, nest_y};
141+
EXPECT_THROW(f(nest_nest_x, nest_nest_y_small), std::invalid_argument);
142+
EXPECT_THROW(f(nest_nest_x_small, nest_nest_y), std::invalid_argument);
143+
}
144+
145+
146+
/**
147+
* Implementation function which checks that the binary vectorisation
148+
* framework returns the same value as the function with scalar inputs,
149+
* for all valid combinations of scalar/vector/nested vector.
150+
*
151+
* @tparam F Type of functor to apply.
152+
* @tparam T1 Type of first vector.
153+
* @tparam T2 Type of second vector.
154+
* @param x First vector input to which operation is applied.
155+
* @param y Second vector input to which operation is applied.
156+
* @param f functor to apply to inputs.
157+
*/
158+
template <typename F, typename T1, typename T2,
159+
typename T1_plain = plain_type_t<T1>,
160+
require_eigen_matrix_t<T1>* = nullptr,
161+
require_std_vector_t<T2>* = nullptr>
162+
void binary_scalar_tester_impl(const F& f, const T1& x, const T2& y) {
163+
auto vec_vec = f(x, y);
164+
for (int r = 0; r < x.rows(); ++r) {
165+
for(int c = 0; c < x.cols(); ++c) {
166+
EXPECT_FLOAT_EQ(f(x(r, c), y[r][c]), vec_vec(r, c));
167+
}
168+
}
169+
170+
T1_plain x_zero;
171+
T2 y_zero;
172+
EXPECT_THROW(f(x_zero, y), std::invalid_argument);
173+
EXPECT_THROW(f(x, y_zero), std::invalid_argument);
174+
175+
std::vector<T1_plain> nest_x{x, x, x};
176+
std::vector<T2> nest_y{y, y, y};
177+
auto nestvec_nestvec = f(nest_x, nest_y);
178+
for (int i = 0; i < 3; ++i) {
179+
for (int r = 0; r < x.rows(); ++r) {
180+
for(int c = 0; c < x.cols(); ++c) {
181+
EXPECT_FLOAT_EQ(f(nest_x[i](r, c), nest_y[i][r][c]), nestvec_nestvec[i](r, c));
182+
}
183+
}
184+
}
185+
std::vector<T1_plain> nest_x_small{x, x};
186+
std::vector<T2> nest_y_small{y, y};
187+
EXPECT_THROW(f(nest_x, nest_y_small), std::invalid_argument);
188+
EXPECT_THROW(f(nest_x_small, nest_y), std::invalid_argument);
189+
190+
std::vector<std::vector<T1_plain>> nest_nest_x{nest_x, nest_x, nest_x};
191+
std::vector<std::vector<T2>> nest_nest_y{nest_y, nest_y, nest_y};
192+
193+
auto nestnestvec_nestnestvec = f(nest_nest_x, nest_nest_y);
194+
for (int i = 0; i < 3; ++i) {
195+
for (int j = 0; j < 3; ++j) {
196+
for (int r = 0; r < x.rows(); ++r) {
197+
for(int c = 0; c < x.cols(); ++c) {
198+
EXPECT_FLOAT_EQ(f(nest_nest_x[i][j](r, c), nest_nest_y[i][j][r][c]),
199+
nestnestvec_nestnestvec[i][j](r, c));
200+
}
201+
}
202+
}
203+
}
204+
std::vector<std::vector<T1_plain>> nest_nest_x_small{nest_x, nest_x};
140205
std::vector<std::vector<T2>> nest_nest_y_small{nest_y, nest_y};
141206
EXPECT_THROW(f(nest_nest_x, nest_nest_y_small), std::invalid_argument);
142207
EXPECT_THROW(f(nest_nest_x_small, nest_nest_y), std::invalid_argument);
143208
}
144209

210+
template <typename F, typename T1, typename T2,
211+
typename T2_plain = plain_type_t<T2>,
212+
require_std_vector_t<T1>* = nullptr,
213+
require_eigen_matrix_t<T2>* = nullptr>
214+
void binary_scalar_tester_impl(const F& f, const T1& x, const T2& y) {
215+
auto vec_vec = f(x, y);
216+
for (int r = 0; r < y.rows(); ++r) {
217+
for(int c = 0; c < y.cols(); ++c) {
218+
EXPECT_FLOAT_EQ(f(x[r][c], y(r, c)), vec_vec(r, c));
219+
}
220+
}
221+
222+
T1 x_zero;
223+
T2_plain y_zero;
224+
EXPECT_THROW(f(x_zero, y), std::invalid_argument);
225+
EXPECT_THROW(f(x, y_zero), std::invalid_argument);
226+
227+
std::vector<T1> nest_x{x, x, x};
228+
std::vector<T2_plain> nest_y{y, y, y};
229+
auto nestvec_nestvec = f(nest_x, nest_y);
230+
for (int i = 0; i < 3; ++i) {
231+
for (int r = 0; r < y.rows(); ++r) {
232+
for(int c = 0; c < y.cols(); ++c) {
233+
EXPECT_FLOAT_EQ(f(nest_x[i][r][c], nest_y[i](r, c)), nestvec_nestvec[i](r, c));
234+
}
235+
}
236+
}
237+
std::vector<T1> nest_x_small{x, x};
238+
std::vector<T2_plain> nest_y_small{y, y};
239+
EXPECT_THROW(f(nest_x, nest_y_small), std::invalid_argument);
240+
EXPECT_THROW(f(nest_x_small, nest_y), std::invalid_argument);
241+
242+
std::vector<std::vector<T1>> nest_nest_x{nest_x, nest_x, nest_x};
243+
std::vector<std::vector<T2_plain>> nest_nest_y{nest_y, nest_y, nest_y};
244+
245+
auto nestnestvec_nestnestvec = f(nest_nest_x, nest_nest_y);
246+
for (int i = 0; i < 3; ++i) {
247+
for (int j = 0; j < 3; ++j) {
248+
for (int r = 0; r < y.rows(); ++r) {
249+
for(int c = 0; c < y.cols(); ++c) {
250+
EXPECT_FLOAT_EQ(f(nest_nest_x[i][j][r][c], nest_nest_y[i][j](r, c)),
251+
nestnestvec_nestnestvec[i][j](r, c));
252+
}
253+
}
254+
}
255+
}
256+
std::vector<std::vector<T1>> nest_nest_x_small{nest_x, nest_x};
257+
std::vector<std::vector<T2_plain>> nest_nest_y_small{nest_y, nest_y};
258+
EXPECT_THROW(f(nest_nest_x, nest_nest_y_small), std::invalid_argument);
259+
EXPECT_THROW(f(nest_nest_x_small, nest_nest_y), std::invalid_argument);
260+
}
261+
145262
/**
146263
* Testing framework for checking that the vectorisation of binary
147264
* functions returns the same results as the binary function with
@@ -172,5 +289,10 @@ void binary_scalar_tester(const F& f, const T1& x, const T2& y) {
172289
std::vector<typename T2::Scalar>(y.data(), y.data() + y.size()));
173290
}
174291

292+
template <typename F, typename T1, typename T2,
293+
require_any_std_vector_t<T1, T2>* = nullptr>
294+
void binary_scalar_tester(const F& f, const T1& x, const T2& y) {
295+
binary_scalar_tester_impl(f, x, y);
296+
}
175297
} // namespace test
176298
} // namespace stan

test/unit/math/prim/fun/falling_factorial_test.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ TEST(MathFunctions, falling_factorial_vec) {
2525

2626
Eigen::VectorXd in1(3);
2727
in1 << -1.3, 0.7, 2.8;
28-
Eigen::VectorXi in2(3);
29-
in2 << 1, 3, 1;
30-
stan::test::binary_scalar_tester(f, in1, in2);
28+
std::vector<int> std_in2{1, 3, 1};
29+
stan::test::binary_scalar_tester(f, in1, std_in2);
30+
31+
Eigen::MatrixXd mat_in1 = in1.replicate(1, 3);
32+
std::vector<std::vector<int>> std_std_in2{std_in2, std_in2, std_in2};
33+
stan::test::binary_scalar_tester(f, mat_in1, std_std_in2);
3134
}

0 commit comments

Comments
 (0)