|
5 | 5 | #include <stan/math/fwd/core.hpp> |
6 | 6 | #include <stan/math/fwd/meta.hpp> |
7 | 7 | #include <stan/math/fwd/fun/softmax.hpp> |
| 8 | +#include <stan/math/prim/err.hpp> |
8 | 9 | #include <stan/math/prim/fun/log_softmax.hpp> |
9 | 10 | #include <stan/math/prim/fun/to_ref.hpp> |
| 11 | +#include <stan/math/prim/functor/apply_vector_unary.hpp> |
10 | 12 |
|
11 | 13 | namespace stan { |
12 | 14 | namespace math { |
13 | 15 |
|
14 | 16 | /** |
15 | | - * Return the log softmax of the rows of the specified matrix. |
16 | | - * Each row is transformed independently; the result has the same shape |
17 | | - * as the input. |
| 17 | + * Return the log softmax of the specified row vector of `fvar` values. |
| 18 | + * Delegates to the column-vector overload via transposition. |
18 | 19 | * |
19 | | - * @tparam Mat type of input matrix (Eigen matrix with fvar scalar) |
20 | | - * @param[in] m Matrix to transform row-wise. |
21 | | - * @return Log-softmax applied row-wise. |
22 | | - * @throw std::domain_error If the input matrix is size 0. |
| 20 | + * @tparam RowVec Eigen row vector with `fvar` scalar |
| 21 | + * @param x row vector to transform |
| 22 | + * @return log softmax of the row vector |
23 | 23 | */ |
24 | | -template <typename Mat, require_eigen_t<Mat>* = nullptr, |
25 | | - require_not_eigen_vector_t<Mat>* = nullptr, |
26 | | - require_t<is_fvar<value_type_t<Mat>>>* = nullptr> |
27 | | -inline auto log_softmax(const Mat& m) { |
28 | | - check_nonzero_size("log_softmax", "m", m); |
29 | | - const auto& m_ref = to_ref(m); |
30 | | - const auto val = m_ref.val().eval(); |
31 | | - const auto shifted |
32 | | - = (val.array().colwise() - val.rowwise().maxCoeff().array()).eval(); |
33 | | - const auto exp_s = shifted.exp().eval(); |
34 | | - const auto row_sums = exp_s.rowwise().sum().eval(); |
35 | | - const auto lsm_val = (shifted.colwise() - row_sums.log()).matrix().eval(); |
36 | | - // softmax values needed for the tangent: d_in - softmax(x) ⊙ dot(softmax(x), |
37 | | - // d_in) |
38 | | - const auto s = (exp_s.colwise() / row_sums).eval(); |
39 | | - const auto d_in = m_ref.d().eval(); |
40 | | - const auto dots = (s.array() * d_in.array()).rowwise().sum().eval(); |
41 | | - plain_type_t<Mat> result(m_ref.rows(), m_ref.cols()); |
42 | | - result.val() = lsm_val; |
43 | | - result.d() = (d_in.array().colwise() - dots.array()).matrix(); |
44 | | - return result; |
45 | | -} |
46 | | - |
47 | 24 | template <typename RowVec, require_eigen_row_vector_t<RowVec>* = nullptr, |
48 | 25 | require_t<is_fvar<value_type_t<RowVec>>>* = nullptr> |
49 | 26 | inline auto log_softmax(const RowVec& x) { |
50 | 27 | return log_softmax(x.transpose()).transpose().eval(); |
51 | 28 | } |
52 | 29 |
|
53 | | -template <typename T, require_vector_st<is_fvar, T>* = nullptr, |
54 | | - require_not_t<is_eigen_row_vector<std::decay_t<T>>>* = nullptr> |
| 30 | +/** |
| 31 | + * Return the log softmax of each vector in a container of `fvar` values. |
| 32 | + * |
| 33 | + * @tparam T `std::vector` whose scalar type is `fvar` |
| 34 | + * @param x container of vectors to transform |
| 35 | + * @return container of log softmax results |
| 36 | + */ |
| 37 | +template <typename T, require_std_vector_st<is_fvar, T>* = nullptr> |
55 | 38 | inline auto log_softmax(T&& x) { |
56 | | - return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& alpha) { |
57 | | - using T_alpha = std::decay_t<decltype(alpha)>; |
58 | | - using T_fvar = value_type_t<T_alpha>; |
59 | | - using T_inner = typename T_fvar::Scalar; |
60 | | - |
61 | | - auto&& alpha_ref = to_ref(std::forward<decltype(alpha)>(alpha)); |
62 | | - const Eigen::Matrix<T_inner, -1, 1> val = alpha_ref.val(); |
63 | | - const Eigen::Matrix<T_inner, -1, 1> s = softmax(val); |
64 | | - const auto d_in = alpha_ref.d().eval(); |
65 | | - const T_inner dot_sd = s.dot(d_in); |
66 | | - |
67 | | - Eigen::Matrix<T_fvar, -1, 1> result(alpha_ref.size()); |
68 | | - result.val() = log_softmax(val); |
69 | | - result.d() = (d_in.array() - dot_sd).matrix(); |
70 | | - return result; |
| 39 | + return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) { |
| 40 | + return log_softmax(std::forward<decltype(v)>(v)); |
71 | 41 | }); |
72 | 42 | } |
73 | 43 |
|
| 44 | +/** |
| 45 | + * Return the log softmax of the specified column vector of `fvar` values. |
| 46 | + * |
| 47 | + * @tparam ColVec Eigen column vector with `fvar` scalar |
| 48 | + * @param x column vector to transform |
| 49 | + * @return log softmax of the column vector |
| 50 | + * @throw std::domain_error if the input size is 0 |
| 51 | + */ |
| 52 | +template <typename ColVec, |
| 53 | + require_eigen_col_vector_vt<is_fvar, ColVec>* = nullptr> |
| 54 | +inline auto log_softmax(const ColVec& x) { |
| 55 | + using Eigen::Dynamic; |
| 56 | + using Eigen::Matrix; |
| 57 | + using T = typename value_type_t<ColVec>::Scalar; |
| 58 | + check_nonzero_size("log_softmax", "x", x); |
| 59 | + const auto& x_ref = to_ref(x); |
| 60 | + const Matrix<T, Dynamic, 1> s = softmax(value_of(x_ref)); |
| 61 | + const auto d_in = x_ref.d().eval(); |
| 62 | + const T dot_sd = s.dot(d_in); |
| 63 | + Matrix<fvar<T>, Dynamic, 1> result(x_ref.size()); |
| 64 | + result.val() = s.array().log().matrix(); |
| 65 | + result.d() = (d_in.array() - dot_sd).matrix(); |
| 66 | + return result; |
| 67 | +} |
| 68 | + |
74 | 69 | } // namespace math |
75 | 70 | } // namespace stan |
76 | 71 | #endif |
0 commit comments