Skip to content

Commit 0ee81d6

Browse files
authored
Merge pull request #1966 from andrjohns/feature/apply_binary_int
Allow binary vectorisation to take combinations of integer vectors
2 parents d48a487 + 6134a07 commit 0ee81d6

File tree

12 files changed

+649
-20
lines changed

12 files changed

+649
-20
lines changed

stan/math/fwd/fun/Eigen_NumTraits.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,32 @@ struct ScalarBinaryOpTraits<double, stan::math::fvar<T>, BinaryOp> {
8080
using ReturnType = stan::math::fvar<T>;
8181
};
8282

83+
/**
84+
* Traits specialization for Eigen binary operations for autodiff and
85+
* `int` arguments.
86+
*
87+
* @tparam T value and tangent type of autodiff variable
88+
* @tparam BinaryOp type of binary operation for which traits are
89+
* defined
90+
*/
91+
template <typename T, typename BinaryOp>
92+
struct ScalarBinaryOpTraits<stan::math::fvar<T>, int, BinaryOp> {
93+
using ReturnType = stan::math::fvar<T>;
94+
};
95+
96+
/**
97+
* Traits specialization for Eigen binary operations for `int` and
98+
* autodiff arguments.
99+
*
100+
* @tparam T value and tangent type of autodiff variable
101+
* @tparam BinaryOp type of binary operation for which traits are
102+
* defined
103+
*/
104+
template <typename T, typename BinaryOp>
105+
struct ScalarBinaryOpTraits<int, stan::math::fvar<T>, BinaryOp> {
106+
using ReturnType = stan::math::fvar<T>;
107+
};
108+
83109
/**
84110
* Traits specialization for Eigen binary operations for `double` and
85111
* complex autodiff arguments.
@@ -108,6 +134,32 @@ struct ScalarBinaryOpTraits<std::complex<stan::math::fvar<T>>, double,
108134
using ReturnType = std::complex<stan::math::fvar<T>>;
109135
};
110136

137+
/**
138+
* Traits specialization for Eigen binary operations for `int` and
139+
* complex autodiff arguments.
140+
*
141+
* @tparam T value and tangent type of autodiff variable
142+
* @tparam BinaryOp type of binary operation for which traits are
143+
* defined
144+
*/
145+
template <typename T, typename BinaryOp>
146+
struct ScalarBinaryOpTraits<int, std::complex<stan::math::fvar<T>>, BinaryOp> {
147+
using ReturnType = std::complex<stan::math::fvar<T>>;
148+
};
149+
150+
/**
151+
* Traits specialization for Eigen binary operations for complex
152+
* autodiff and `int` arguments.
153+
*
154+
* @tparam T value and tangent type of autodiff variable
155+
* @tparam BinaryOp type of binary operation for which traits are
156+
* defined
157+
*/
158+
template <typename T, typename BinaryOp>
159+
struct ScalarBinaryOpTraits<std::complex<stan::math::fvar<T>>, int, BinaryOp> {
160+
using ReturnType = std::complex<stan::math::fvar<T>>;
161+
};
162+
111163
/**
112164
* Traits specialization for Eigen binary operations for autodiff
113165
* and complex `double` arguments.

stan/math/prim/fun/Eigen.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,56 @@
2424
#include <Eigen/QR>
2525
#include <Eigen/src/Core/NumTraits.h>
2626

27+
namespace Eigen {
28+
29+
/**
30+
* Traits specialization for Eigen binary operations for `int`
31+
* and `double` arguments.
32+
*
33+
* @tparam BinaryOp type of binary operation for which traits are
34+
* defined
35+
*/
36+
template <typename BinaryOp>
37+
struct ScalarBinaryOpTraits<int, double, BinaryOp> {
38+
using ReturnType = double;
39+
};
40+
41+
/**
42+
* Traits specialization for Eigen binary operations for `double`
43+
* and `int` arguments.
44+
*
45+
* @tparam BinaryOp type of binary operation for which traits are
46+
* defined
47+
*/
48+
template <typename BinaryOp>
49+
struct ScalarBinaryOpTraits<double, int, BinaryOp> {
50+
using ReturnType = double;
51+
};
52+
53+
/**
54+
* Traits specialization for Eigen binary operations for `int`
55+
* and complex `double` arguments.
56+
*
57+
* @tparam BinaryOp type of binary operation for which traits are
58+
* defined
59+
*/
60+
template <typename BinaryOp>
61+
struct ScalarBinaryOpTraits<int, std::complex<double>, BinaryOp> {
62+
using ReturnType = std::complex<double>;
63+
};
64+
65+
/**
66+
* Traits specialization for Eigen binary operations for complex
67+
* `double` and `int` arguments.
68+
*
69+
* @tparam BinaryOp type of binary operation for which traits are
70+
* defined
71+
*/
72+
template <typename BinaryOp>
73+
struct ScalarBinaryOpTraits<std::complex<double>, int, BinaryOp> {
74+
using ReturnType = std::complex<double>;
75+
};
76+
77+
} // namespace Eigen
78+
2779
#endif

stan/math/prim/fun/bessel_first_kind.hpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
66
#include <boost/math/special_functions/bessel.hpp>
7+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
78

89
namespace stan {
910
namespace math {
@@ -34,12 +35,29 @@ namespace math {
3435
\f]
3536
*
3637
*/
37-
template <typename T2>
38+
template <typename T2, require_arithmetic_t<T2>* = nullptr>
3839
inline T2 bessel_first_kind(int v, const T2 z) {
3940
check_not_nan("bessel_first_kind", "z", z);
4041
return boost::math::cyl_bessel_j(v, z);
4142
}
4243

44+
/**
45+
* Enables the vectorised application of the bessel first kind function, when
46+
* the first and/or second arguments are containers.
47+
*
48+
* @tparam T1 type of first input
49+
* @tparam T2 type of second input
50+
* @param a First input
51+
* @param b Second input
52+
* @return Bessel first kind function applied to the two inputs.
53+
*/
54+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
55+
inline auto bessel_first_kind(const T1& a, const T2& b) {
56+
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
57+
return bessel_first_kind(c, d);
58+
});
59+
}
60+
4361
} // namespace math
4462
} // namespace stan
4563
#endif

stan/math/prim/fun/falling_factorial.hpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
66
#include <stan/math/prim/fun/boost_policy.hpp>
7+
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
78
#include <boost/math/special_functions/factorials.hpp>
89

910
namespace stan {
@@ -59,14 +60,31 @@ namespace math {
5960
\f]
6061
*
6162
*/
62-
template <typename T>
63+
template <typename T, require_arithmetic_t<T>* = nullptr>
6364
inline return_type_t<T> falling_factorial(const T& x, int n) {
6465
static const char* function = "falling_factorial";
6566
check_not_nan(function, "first argument", x);
6667
check_nonnegative(function, "second argument", n);
6768
return boost::math::falling_factorial(x, n, boost_policy_t<>());
6869
}
6970

71+
/**
72+
* Enables the vectorised application of the falling factorial function, when
73+
* the first and/or second arguments are containers.
74+
*
75+
* @tparam T1 type of first input
76+
* @tparam T2 type of second input
77+
* @param a First input
78+
* @param b Second input
79+
* @return Falling factorial function applied to the two inputs.
80+
*/
81+
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
82+
inline auto falling_factorial(const T1& a, const T2& b) {
83+
return apply_scalar_binary(a, b, [&](const auto& c, const auto& d) {
84+
return falling_factorial(c, d);
85+
});
86+
}
87+
7088
} // namespace math
7189
} // namespace stan
7290

stan/math/prim/functor/apply_scalar_binary.hpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <stan/math/prim/meta/is_container.hpp>
77
#include <stan/math/prim/meta/is_eigen.hpp>
88
#include <stan/math/prim/meta/require_generics.hpp>
9+
#include <stan/math/prim/fun/num_elements.hpp>
910
#include <vector>
1011

1112
namespace stan {
@@ -55,6 +56,116 @@ inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
5556
return x.binaryExpr(y, f).eval();
5657
}
5758

59+
/**
60+
* Specialisation for use with one Eigen vector (row or column) and
61+
* a one-dimensional std::vector of integer types
62+
*
63+
* @tparam T1 Type of first argument to which functor is applied.
64+
* @tparam T2 Type of second argument to which functor is applied.
65+
* @tparam F Type of functor to apply.
66+
* @param x Eigen input to which operation is applied.
67+
* @param y Integer std::vector input to which operation is applied.
68+
* @param f functor to apply to inputs.
69+
* @return Eigen object with result of applying functor to inputs.
70+
*/
71+
template <typename T1, typename T2, typename F,
72+
require_eigen_vector_vt<is_stan_scalar, T1>* = nullptr,
73+
require_std_vector_vt<std::is_integral, T2>* = nullptr>
74+
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
75+
check_matching_sizes("Binary function", "x", x, "y", y);
76+
using int_vec_t = promote_scalar_t<value_type_t<T2>, plain_type_t<T1>>;
77+
Eigen::Map<const int_vec_t> y_map(y.data(), y.size());
78+
return x.binaryExpr(y_map, f).eval();
79+
}
80+
81+
/**
82+
* Specialisation for use with a one-dimensional std::vector of integer types
83+
* and one Eigen vector (row or column).
84+
*
85+
* @tparam T1 Type of first argument to which functor is applied.
86+
* @tparam T2 Type of second argument to which functor is applied.
87+
* @tparam F Type of functor to apply.
88+
* @param x Integer std::vector input to which operation is applied.
89+
* @param y Eigen input to which operation is applied.
90+
* @param f functor to apply to inputs.
91+
* @return Eigen object with result of applying functor to inputs.
92+
*/
93+
template <typename T1, typename T2, typename F,
94+
require_std_vector_vt<std::is_integral, T1>* = nullptr,
95+
require_eigen_vector_vt<is_stan_scalar, T2>* = nullptr>
96+
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
97+
check_matching_sizes("Binary function", "x", x, "y", y);
98+
using int_vec_t = promote_scalar_t<value_type_t<T1>, plain_type_t<T2>>;
99+
Eigen::Map<const int_vec_t> x_map(x.data(), x.size());
100+
return x_map.binaryExpr(y, f).eval();
101+
}
102+
103+
/**
104+
* Specialisation for use with one Eigen matrix and
105+
* a two-dimensional std::vector of integer types
106+
*
107+
* @tparam T1 Type of first argument to which functor is applied.
108+
* @tparam T2 Type of second argument to which functor is applied.
109+
* @tparam F Type of functor to apply.
110+
* @param x Eigen matrix input to which operation is applied.
111+
* @param y Nested integer std::vector input to which operation is applied.
112+
* @param f functor to apply to inputs.
113+
* @return Eigen object with result of applying functor to inputs.
114+
*/
115+
template <typename T1, typename T2, typename F,
116+
require_eigen_matrix_vt<is_stan_scalar, T1>* = nullptr,
117+
require_std_vector_vt<is_std_vector, T2>* = nullptr,
118+
require_std_vector_st<std::is_integral, T2>* = nullptr>
119+
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
120+
if (num_elements(x) != num_elements(y)) {
121+
std::ostringstream msg;
122+
msg << "Inputs to vectorized binary function must match in"
123+
<< " size if one is not a scalar";
124+
throw std::invalid_argument(msg.str());
125+
}
126+
using return_st = decltype(f(x(0), y[0][0]));
127+
Eigen::Matrix<return_st, Eigen::Dynamic, Eigen::Dynamic> result(x.rows(),
128+
x.cols());
129+
for (size_t i = 0; i < y.size(); ++i) {
130+
result.row(i) = apply_scalar_binary(x.row(i).transpose(),
131+
as_column_vector_or_scalar(y[i]), f);
132+
}
133+
return result;
134+
}
135+
136+
/**
137+
* Specialisation for use with a two-dimensional std::vector of integer types
138+
* and one Eigen matrix.
139+
*
140+
* @tparam T1 Type of first argument to which functor is applied.
141+
* @tparam T2 Type of second argument to which functor is applied.
142+
* @tparam F Type of functor to apply.
143+
* @param x Nested integer std::vector input to which operation is applied.
144+
* @param y Eigen matrix input to which operation is applied.
145+
* @param f functor to apply to inputs.
146+
* @return Eigen object with result of applying functor to inputs.
147+
*/
148+
template <typename T1, typename T2, typename F,
149+
require_std_vector_vt<is_std_vector, T1>* = nullptr,
150+
require_std_vector_st<std::is_integral, T1>* = nullptr,
151+
require_eigen_matrix_vt<is_stan_scalar, T2>* = nullptr>
152+
inline auto apply_scalar_binary(const T1& x, const T2& y, const F& f) {
153+
if (num_elements(x) != num_elements(y)) {
154+
std::ostringstream msg;
155+
msg << "Inputs to vectorized binary function must match in"
156+
<< " size if one is not a scalar";
157+
throw std::invalid_argument(msg.str());
158+
}
159+
using return_st = decltype(f(x[0][0], y(0)));
160+
Eigen::Matrix<return_st, Eigen::Dynamic, Eigen::Dynamic> result(y.rows(),
161+
y.cols());
162+
for (size_t i = 0; i < x.size(); ++i) {
163+
result.row(i) = apply_scalar_binary(as_column_vector_or_scalar(x[i]),
164+
y.row(i).transpose(), f);
165+
}
166+
return result;
167+
}
168+
58169
/**
59170
* Specialisation for use when the first input is an Eigen type and the second
60171
* is a scalar. Eigen's unaryExpr framework is used for more efficient indexing

stan/math/rev/fun/Eigen_NumTraits.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,30 @@ struct ScalarBinaryOpTraits<double, stan::math::var, BinaryOp> {
105105
using ReturnType = stan::math::var;
106106
};
107107

108+
/**
109+
* Traits specialization for Eigen binary operations for reverse-mode
110+
* autodiff and `int` arguments.
111+
*
112+
* @tparam BinaryOp type of binary operation for which traits are
113+
* defined
114+
*/
115+
template <typename BinaryOp>
116+
struct ScalarBinaryOpTraits<stan::math::var, int, BinaryOp> {
117+
using ReturnType = stan::math::var;
118+
};
119+
120+
/**
121+
* Traits specialization for Eigen binary operations for `int` and
122+
* reverse-mode autodiff arguments.
123+
*
124+
* @tparam BinaryOp type of binary operation for which traits are
125+
* defined
126+
*/
127+
template <typename BinaryOp>
128+
struct ScalarBinaryOpTraits<int, stan::math::var, BinaryOp> {
129+
using ReturnType = stan::math::var;
130+
};
131+
108132
/**
109133
* Traits specialization for Eigen binary operations for reverse-mode
110134
autodiff
@@ -142,6 +166,30 @@ struct ScalarBinaryOpTraits<std::complex<stan::math::var>, double, BinaryOp> {
142166
using ReturnType = std::complex<stan::math::var>;
143167
};
144168

169+
/**
170+
* Traits specialization for Eigen binary operations for `int` and
171+
* complex autodiff arguments.
172+
*
173+
* @tparam BinaryOp type of binary operation for which traits are
174+
* defined
175+
*/
176+
template <typename BinaryOp>
177+
struct ScalarBinaryOpTraits<int, std::complex<stan::math::var>, BinaryOp> {
178+
using ReturnType = std::complex<stan::math::var>;
179+
};
180+
181+
/**
182+
* Traits specialization for Eigen binary operations for complex
183+
* autodiff and `int` arguments.
184+
*
185+
* @tparam BinaryOp type of binary operation for which traits are
186+
* defined
187+
*/
188+
template <typename BinaryOp>
189+
struct ScalarBinaryOpTraits<std::complex<stan::math::var>, int, BinaryOp> {
190+
using ReturnType = std::complex<stan::math::var>;
191+
};
192+
145193
/**
146194
* Traits specialization for Eigen binary operations for autodiff and
147195
* complex `double` arguments.

0 commit comments

Comments
 (0)