Skip to content

Commit dd4c98d

Browse files
committed
refactor to support array-of-vectors instead of matrices + some other cleanup
1 parent 9b9c4a5 commit dd4c98d

11 files changed

Lines changed: 241 additions & 403 deletions

File tree

stan/math/fwd/fun/log_softmax.hpp

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,72 +5,67 @@
55
#include <stan/math/fwd/core.hpp>
66
#include <stan/math/fwd/meta.hpp>
77
#include <stan/math/fwd/fun/softmax.hpp>
8+
#include <stan/math/prim/err.hpp>
89
#include <stan/math/prim/fun/log_softmax.hpp>
910
#include <stan/math/prim/fun/to_ref.hpp>
11+
#include <stan/math/prim/functor/apply_vector_unary.hpp>
1012

1113
namespace stan {
1214
namespace math {
1315

1416
/**
15-
* Return the log softmax of the rows of the specified matrix.
16-
* Each row is transformed independently; the result has the same shape
17-
* as the input.
17+
* Return the log softmax of the specified row vector of `fvar` values.
18+
* Delegates to the column-vector overload via transposition.
1819
*
19-
* @tparam Mat type of input matrix (Eigen matrix with fvar scalar)
20-
* @param[in] m Matrix to transform row-wise.
21-
* @return Log-softmax applied row-wise.
22-
* @throw std::domain_error If the input matrix is size 0.
20+
* @tparam RowVec Eigen row vector with `fvar` scalar
21+
* @param x row vector to transform
22+
* @return log softmax of the row vector
2323
*/
24-
template <typename Mat, require_eigen_t<Mat>* = nullptr,
25-
require_not_eigen_vector_t<Mat>* = nullptr,
26-
require_t<is_fvar<value_type_t<Mat>>>* = nullptr>
27-
inline auto log_softmax(const Mat& m) {
28-
check_nonzero_size("log_softmax", "m", m);
29-
const auto& m_ref = to_ref(m);
30-
const auto val = m_ref.val().eval();
31-
const auto shifted
32-
= (val.array().colwise() - val.rowwise().maxCoeff().array()).eval();
33-
const auto exp_s = shifted.exp().eval();
34-
const auto row_sums = exp_s.rowwise().sum().eval();
35-
const auto lsm_val = (shifted.colwise() - row_sums.log()).matrix().eval();
36-
// softmax values needed for the tangent: d_in - softmax(x) ⊙ dot(softmax(x),
37-
// d_in)
38-
const auto s = (exp_s.colwise() / row_sums).eval();
39-
const auto d_in = m_ref.d().eval();
40-
const auto dots = (s.array() * d_in.array()).rowwise().sum().eval();
41-
plain_type_t<Mat> result(m_ref.rows(), m_ref.cols());
42-
result.val() = lsm_val;
43-
result.d() = (d_in.array().colwise() - dots.array()).matrix();
44-
return result;
45-
}
46-
4724
template <typename RowVec, require_eigen_row_vector_t<RowVec>* = nullptr,
4825
require_t<is_fvar<value_type_t<RowVec>>>* = nullptr>
4926
inline auto log_softmax(const RowVec& x) {
5027
return log_softmax(x.transpose()).transpose().eval();
5128
}
5229

53-
template <typename T, require_vector_st<is_fvar, T>* = nullptr,
54-
require_not_t<is_eigen_row_vector<std::decay_t<T>>>* = nullptr>
30+
/**
31+
* Return the log softmax of each vector in a container of `fvar` values.
32+
*
33+
* @tparam T `std::vector` whose scalar type is `fvar`
34+
* @param x container of vectors to transform
35+
* @return container of log softmax results
36+
*/
37+
template <typename T, require_std_vector_st<is_fvar, T>* = nullptr>
5538
inline auto log_softmax(T&& x) {
56-
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& alpha) {
57-
using T_alpha = std::decay_t<decltype(alpha)>;
58-
using T_fvar = value_type_t<T_alpha>;
59-
using T_inner = typename T_fvar::Scalar;
60-
61-
auto&& alpha_ref = to_ref(std::forward<decltype(alpha)>(alpha));
62-
const Eigen::Matrix<T_inner, -1, 1> val = alpha_ref.val();
63-
const Eigen::Matrix<T_inner, -1, 1> s = softmax(val);
64-
const auto d_in = alpha_ref.d().eval();
65-
const T_inner dot_sd = s.dot(d_in);
66-
67-
Eigen::Matrix<T_fvar, -1, 1> result(alpha_ref.size());
68-
result.val() = log_softmax(val);
69-
result.d() = (d_in.array() - dot_sd).matrix();
70-
return result;
39+
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) {
40+
return log_softmax(std::forward<decltype(v)>(v));
7141
});
7242
}
7343

44+
/**
45+
* Return the log softmax of the specified column vector of `fvar` values.
46+
*
47+
* @tparam ColVec Eigen column vector with `fvar` scalar
48+
* @param x column vector to transform
49+
* @return log softmax of the column vector
50+
* @throw std::domain_error if the input size is 0
51+
*/
52+
template <typename ColVec,
53+
require_eigen_col_vector_vt<is_fvar, ColVec>* = nullptr>
54+
inline auto log_softmax(const ColVec& x) {
55+
using Eigen::Dynamic;
56+
using Eigen::Matrix;
57+
using T = typename value_type_t<ColVec>::Scalar;
58+
check_nonzero_size("log_softmax", "x", x);
59+
const auto& x_ref = to_ref(x);
60+
const Matrix<T, Dynamic, 1> s = softmax(value_of(x_ref));
61+
const auto d_in = x_ref.d().eval();
62+
const T dot_sd = s.dot(d_in);
63+
Matrix<fvar<T>, Dynamic, 1> result(x_ref.size());
64+
result.val() = s.array().log().matrix();
65+
result.d() = (d_in.array() - dot_sd).matrix();
66+
return result;
67+
}
68+
7469
} // namespace math
7570
} // namespace stan
7671
#endif

stan/math/fwd/fun/softmax.hpp

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,60 @@
66
#include <stan/math/fwd/fun/value_of.hpp>
77
#include <stan/math/prim/fun/to_ref.hpp>
88
#include <stan/math/prim/fun/softmax.hpp>
9+
#include <stan/math/prim/functor/apply_vector_unary.hpp>
910

1011
namespace stan {
1112
namespace math {
1213

13-
template <typename Mat, require_eigen_t<Mat>* = nullptr,
14-
require_not_eigen_vector_t<Mat>* = nullptr,
15-
require_t<is_fvar<value_type_t<Mat>>>* = nullptr>
16-
inline auto softmax(const Mat& m) {
17-
const auto& m_ref = to_ref(m);
18-
const auto s = softmax(m_ref.val());
19-
const auto d_in = m_ref.d().eval();
20-
// d/dx softmax(x) applied to tangent: s ⊙ (d_in - s · d_in) (per row)
21-
const auto dots = (s.array() * d_in.array()).rowwise().sum().eval();
22-
plain_type_t<Mat> result(m_ref.rows(), m_ref.cols());
23-
result.val() = s;
24-
result.d() = (s.array() * (d_in.array().colwise() - dots.array())).matrix();
25-
return result;
26-
}
27-
14+
/**
15+
* Return the softmax of the specified row vector of `fvar` values.
16+
* Delegates to the column-vector overload via transposition.
17+
*
18+
* @tparam RowVec Eigen row vector with `fvar` scalar
19+
* @param x row vector to transform
20+
* @return softmax of the row vector
21+
*/
2822
template <typename RowVec, require_eigen_row_vector_t<RowVec>* = nullptr,
2923
require_t<is_fvar<value_type_t<RowVec>>>* = nullptr>
30-
inline auto softmax(const RowVec& alpha) {
31-
return softmax(alpha.transpose()).transpose().eval();
24+
inline auto softmax(const RowVec& x) {
25+
return softmax(x.transpose()).transpose().eval();
26+
}
27+
28+
/**
29+
* Return the softmax of each vector in a container of `fvar` values.
30+
*
31+
* @tparam T `std::vector` whose scalar type is `fvar`
32+
* @param x container of vectors to transform
33+
* @return container of softmax results
34+
*/
35+
template <typename T, require_std_vector_st<is_fvar, T>* = nullptr>
36+
inline auto softmax(T&& x) {
37+
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) {
38+
return softmax(std::forward<decltype(v)>(v));
39+
});
3240
}
3341

42+
/**
43+
* Return the softmax of the specified column vector of `fvar` values.
44+
*
45+
* @tparam ColVec Eigen column vector with `fvar` scalar
46+
* @param x column vector to transform
47+
* @return softmax of the column vector
48+
*/
3449
template <typename ColVec,
3550
require_eigen_col_vector_vt<is_fvar, ColVec>* = nullptr>
36-
inline auto softmax(const ColVec& alpha) {
51+
inline auto softmax(const ColVec& x) {
3752
using Eigen::Dynamic;
3853
using Eigen::Matrix;
3954
using T = typename value_type_t<ColVec>::Scalar;
40-
if (alpha.size() == 0) {
55+
if (x.size() == 0) {
4156
return Matrix<fvar<T>, Dynamic, 1>();
4257
}
43-
const auto& alpha_ref = to_ref(alpha);
44-
const Matrix<T, Dynamic, 1> s = softmax(value_of(alpha_ref));
45-
const auto d_in = alpha_ref.d().eval();
58+
const auto& x_ref = to_ref(x);
59+
const Matrix<T, Dynamic, 1> s = softmax(value_of(x_ref));
60+
const auto d_in = x_ref.d().eval();
4661
const T dot_sd = s.dot(d_in);
47-
Matrix<fvar<T>, Dynamic, 1> result(alpha.size());
62+
Matrix<fvar<T>, Dynamic, 1> result(x_ref.size());
4863
result.val() = s;
4964
result.d() = (s.array() * (d_in.array() - dot_sd)).matrix();
5065
return result;

stan/math/prim/fun/log_softmax.hpp

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace math {
1313

1414
/**
1515
* Return the natural logarithm of the softmax of the specified
16-
* vector.
16+
* vector, or of each vector in a container.
1717
*
1818
* \f$
1919
* \log \mbox{softmax}(y)
@@ -23,29 +23,31 @@ namespace math {
2323
*
2424
* For the log softmax function, the entries in the Jacobian are
2525
* \f$
26-
* \frac{\partial}{\partial y_m} \mbox{softmax}(y)[k]
26+
* \frac{\partial}{\partial y_m} \log\mbox{softmax}(y)[k]
2727
* = \left\{
2828
* \begin{array}{ll}
2929
* 1 - \mbox{softmax}(y)[m]
3030
* & \mbox{ if } m = k, \mbox{ and}
3131
* \\[6pt]
32-
* \mbox{softmax}(y)[m]
32+
* -\mbox{softmax}(y)[m]
3333
* & \mbox{ if } m \neq k.
3434
* \end{array}
3535
* \right.
3636
* \f$
3737
*
38-
* @tparam Container type of input vector to transform
39-
* @param[in] x vector to transform
40-
* @return log unit simplex result of the softmax transform of the vector.
38+
* @tparam Container type of input: an Eigen vector, `std::vector` of doubles,
39+
* or nested container whose scalar type is arithmetic
40+
* @param[in] x vector or container of vectors to transform
41+
* @return log softmax of the input, preserving the container structure
42+
* @throw std::domain_error if any input vector is empty
4143
*/
4244
template <typename Container, require_st_arithmetic<Container>* = nullptr,
4345
require_container_t<Container>* = nullptr,
4446
require_not_t<bool_constant<
4547
is_eigen<std::decay_t<Container>>::value
4648
&& !is_eigen_vector<std::decay_t<Container>>::value>>* = nullptr>
4749
inline auto log_softmax(Container&& x) {
48-
check_nonzero_size("log_softmax", "v", x);
50+
check_nonzero_size("log_softmax", "x", x);
4951
return make_holder(
5052
[](auto&& a) {
5153
return apply_vector_unary<ref_type_t<Container>>::apply(
@@ -55,26 +57,6 @@ inline auto log_softmax(Container&& x) {
5557
to_ref(std::forward<Container>(x)));
5658
}
5759

58-
/**
59-
* Return the log softmax of the rows of the specified matrix.
60-
* Each row is transformed independently; the result has the same shape
61-
* as the input.
62-
*
63-
* @tparam Mat type of input matrix
64-
* @param[in] m Matrix to transform row-wise.
65-
* @return Log-softmax applied row-wise.
66-
*/
67-
template <typename Mat, require_eigen_vt<std::is_arithmetic, Mat>* = nullptr,
68-
require_not_eigen_vector_t<Mat>* = nullptr>
69-
inline plain_type_t<Mat> log_softmax(const Mat& m) {
70-
check_nonzero_size("log_softmax", "m", m);
71-
const auto& m_ref = to_ref(m);
72-
const auto shifted
73-
= (m_ref.array().colwise() - m_ref.rowwise().maxCoeff().array()).eval();
74-
const auto exp_s = shifted.exp().eval();
75-
return (shifted.colwise() - exp_s.rowwise().sum().log()).matrix();
76-
}
77-
7860
} // namespace math
7961
} // namespace stan
8062
#endif

stan/math/prim/fun/softmax.hpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stan/math/prim/err.hpp>
55
#include <stan/math/prim/fun/Eigen.hpp>
66
#include <stan/math/prim/fun/to_ref.hpp>
7+
#include <stan/math/prim/functor/apply_vector_unary.hpp>
78
#include <cmath>
89

910
namespace stan {
@@ -38,7 +39,7 @@ namespace math {
3839
* \end{array}
3940
* \f$
4041
*
41-
* @tparam Vec type of elements in the vector
42+
* @tparam Vec type of the input vector
4243
* @param[in] v Vector to transform.
4344
* @return Unit simplex result of the softmax transform of the vector.
4445
*/
@@ -54,22 +55,17 @@ inline plain_type_t<Vec> softmax(const Vec& v) {
5455
}
5556

5657
/**
57-
* Return the softmax of the rows of the specified matrix.
58-
* Each row is transformed independently; the result is a row-stochastic
59-
* matrix whose rows each sum to one.
58+
* Return the softmax of each vector in an array.
6059
*
61-
* @tparam Mat type of input matrix
62-
* @param[in] m Matrix to transform row-wise.
63-
* @return Row-stochastic matrix result of applying softmax to each row.
60+
* @tparam T `std::vector` whose scalar type is arithmetic
61+
* @param[in] x Array of vectors to transform.
62+
* @return Array of unit simplex results.
6463
*/
65-
template <typename Mat, require_eigen_vt<std::is_arithmetic, Mat>* = nullptr,
66-
require_not_eigen_vector_t<Mat>* = nullptr>
67-
inline plain_type_t<Mat> softmax(const Mat& m) {
68-
const auto& m_ref = to_ref(m);
69-
const auto shifted
70-
= (m_ref.array().colwise() - m_ref.rowwise().maxCoeff().array()).eval();
71-
const auto exp_s = shifted.exp().eval();
72-
return (exp_s.colwise() / exp_s.rowwise().sum()).matrix();
64+
template <typename T, require_std_vector_st<std::is_arithmetic, T>* = nullptr>
65+
inline auto softmax(T&& x) {
66+
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) {
67+
return softmax(std::forward<decltype(v)>(v));
68+
});
7369
}
7470

7571
} // namespace math

0 commit comments

Comments
 (0)