diff --git a/stan/math/fwd/fun/log_softmax.hpp b/stan/math/fwd/fun/log_softmax.hpp index acaf71070cb..1cc1d08c448 100644 --- a/stan/math/fwd/fun/log_softmax.hpp +++ b/stan/math/fwd/fun/log_softmax.hpp @@ -5,53 +5,67 @@ #include #include #include +#include #include #include +#include namespace stan { namespace math { /** - * Return the log softmax of the specified vector or container of vectors. + * Return the log softmax of the specified row vector of `fvar` values. + * Delegates to the column-vector overload via transposition. * - * @tparam T Type of input vector or matrix. - * @param[in] x Unconstrained input vector. - * @return Softmax of the input. - * @throw std::domain_error If the input vector is size 0. + * @tparam RowVec Eigen row vector with `fvar` scalar + * @param x row vector to transform + * @return log softmax of the row vector */ -template * = nullptr> -inline auto log_softmax(T&& x) { - return apply_vector_unary::apply(std::forward(x), [](auto&& alpha) { - using T_alpha = decltype(alpha); - using T_fvar = value_type_t; - using T_fvar_inner = typename T_fvar::Scalar; - - auto&& alpha_ref = to_ref(std::forward(alpha)); - Eigen::Matrix alpha_t = alpha_ref.val(); - Eigen::Matrix softmax_alpha_t = softmax(alpha_t); - - Eigen::Matrix log_softmax_alpha(alpha_ref.size()); - log_softmax_alpha.val() = log_softmax(alpha_t); - log_softmax_alpha.d().setZero(); - - for (int m = 0; m < alpha_ref.size(); ++m) { - T_fvar_inner negative_alpha_m_d_times_softmax_alpha_t_m - = -alpha_ref.coeff(m).d_ * softmax_alpha_t(m); - for (int k = 0; k < alpha_ref.size(); ++k) { - if (m == k) { - log_softmax_alpha(k).d_ - += alpha_ref.coeff(m).d_ - + negative_alpha_m_d_times_softmax_alpha_t_m; - } else { - log_softmax_alpha(k).d_ += negative_alpha_m_d_times_softmax_alpha_t_m; - } - } - } +template * = nullptr, + require_t>>* = nullptr> +inline auto log_softmax(const RowVec& x) { + return log_softmax(x.transpose()).transpose().eval(); +} - return log_softmax_alpha; +/** + * Return the log softmax of each vector in a container of `fvar` values. + * + * @tparam T `std::vector` whose scalar type is `fvar` + * @param x container of vectors to transform + * @return container of log softmax results + */ +template * = nullptr> +inline auto log_softmax(T&& x) { + return apply_vector_unary::apply(std::forward(x), [](auto&& v) { + return log_softmax(std::forward(v)); }); } +/** + * Return the log softmax of the specified column vector of `fvar` values. + * + * @tparam ColVec Eigen column vector with `fvar` scalar + * @param x column vector to transform + * @return log softmax of the column vector + * @throw std::domain_error if the input size is 0 + */ +template * = nullptr> +inline auto log_softmax(const ColVec& x) { + using Eigen::Dynamic; + using Eigen::Matrix; + using T = typename value_type_t::Scalar; + check_nonzero_size("log_softmax", "x", x); + const auto& x_ref = to_ref(x); + const Matrix s = softmax(value_of(x_ref)); + const auto d_in = x_ref.d().eval(); + const T dot_sd = s.dot(d_in); + Matrix, Dynamic, 1> result(x_ref.size()); + result.val() = s.array().log().matrix(); + result.d() = (d_in.array() - dot_sd).matrix(); + return result; +} + } // namespace math } // namespace stan #endif diff --git a/stan/math/fwd/fun/softmax.hpp b/stan/math/fwd/fun/softmax.hpp index 3625332ddf2..39516804c8e 100644 --- a/stan/math/fwd/fun/softmax.hpp +++ b/stan/math/fwd/fun/softmax.hpp @@ -6,47 +6,63 @@ #include #include #include +#include namespace stan { namespace math { +/** + * Return the softmax of the specified row vector of `fvar` values. + * Delegates to the column-vector overload via transposition. + * + * @tparam RowVec Eigen row vector with `fvar` scalar + * @param x row vector to transform + * @return softmax of the row vector + */ +template * = nullptr, + require_t>>* = nullptr> +inline auto softmax(const RowVec& x) { + return softmax(x.transpose()).transpose().eval(); +} + +/** + * Return the softmax of each vector in a container of `fvar` values. + * + * @tparam T `std::vector` whose scalar type is `fvar` + * @param x container of vectors to transform + * @return container of softmax results + */ +template * = nullptr> +inline auto softmax(T&& x) { + return apply_vector_unary::apply(std::forward(x), [](auto&& v) { + return softmax(std::forward(v)); + }); +} + +/** + * Return the softmax of the specified column vector of `fvar` values. + * + * @tparam ColVec Eigen column vector with `fvar` scalar + * @param x column vector to transform + * @return softmax of the column vector + */ template * = nullptr> -inline auto softmax(const ColVec& alpha) { +inline auto softmax(const ColVec& x) { using Eigen::Dynamic; using Eigen::Matrix; using T = typename value_type_t::Scalar; - if (alpha.size() == 0) { + if (x.size() == 0) { return Matrix, Dynamic, 1>(); } - const auto& alpha_ref = to_ref(alpha); - - Matrix softmax_alpha_t = softmax(value_of(alpha_ref)); - - Matrix, Dynamic, 1> softmax_alpha(alpha.size()); - for (int k = 0; k < alpha.size(); ++k) { - softmax_alpha.coeffRef(k).val_ = softmax_alpha_t.coeff(k); - softmax_alpha.coeffRef(k).d_ = 0; - } - - for (int m = 0; m < alpha.size(); ++m) { - T negative_alpha_m_d_times_softmax_alpha_t_m - = -alpha_ref.coeff(m).d_ * softmax_alpha_t.coeff(m); - for (int k = 0; k < alpha.size(); ++k) { - if (m == k) { - softmax_alpha.coeffRef(k).d_ - += softmax_alpha_t.coeff(k) - * (alpha_ref.coeff(m).d_ - + negative_alpha_m_d_times_softmax_alpha_t_m); - } else { - softmax_alpha.coeffRef(k).d_ - += softmax_alpha_t.coeff(k) - * negative_alpha_m_d_times_softmax_alpha_t_m; - } - } - } - - return softmax_alpha; + const auto& x_ref = to_ref(x); + const Matrix s = softmax(value_of(x_ref)); + const auto d_in = x_ref.d().eval(); + const T dot_sd = s.dot(d_in); + Matrix, Dynamic, 1> result(x_ref.size()); + result.val() = s; + result.d() = (s.array() * (d_in.array() - dot_sd)).matrix(); + return result; } } // namespace math diff --git a/stan/math/prim/fun/log_softmax.hpp b/stan/math/prim/fun/log_softmax.hpp index 876d75a7f09..ac4393adfa8 100644 --- a/stan/math/prim/fun/log_softmax.hpp +++ b/stan/math/prim/fun/log_softmax.hpp @@ -13,7 +13,7 @@ namespace math { /** * Return the natural logarithm of the softmax of the specified - * vector. + * vector, or of each vector in a container. * * \f$ * \log \mbox{softmax}(y) @@ -23,26 +23,31 @@ namespace math { * * For the log softmax function, the entries in the Jacobian are * \f$ - * \frac{\partial}{\partial y_m} \mbox{softmax}(y)[k] + * \frac{\partial}{\partial y_m} \log\mbox{softmax}(y)[k] * = \left\{ * \begin{array}{ll} * 1 - \mbox{softmax}(y)[m] * & \mbox{ if } m = k, \mbox{ and} * \\[6pt] - * \mbox{softmax}(y)[m] + * -\mbox{softmax}(y)[m] * & \mbox{ if } m \neq k. * \end{array} * \right. * \f$ * - * @tparam Container type of input vector to transform - * @param[in] x vector to transform - * @return log unit simplex result of the softmax transform of the vector. + * @tparam Container type of input: an Eigen vector, `std::vector` of doubles, + * or nested container whose scalar type is arithmetic + * @param[in] x vector or container of vectors to transform + * @return log softmax of the input, preserving the container structure + * @throw std::domain_error if any input vector is empty */ template * = nullptr, - require_container_t* = nullptr> + require_container_t* = nullptr, + require_not_t>::value + && !is_eigen_vector>::value>>* = nullptr> inline auto log_softmax(Container&& x) { - check_nonzero_size("log_softmax", "v", x); + check_nonzero_size("log_softmax", "x", x); return make_holder( [](auto&& a) { return apply_vector_unary>::apply( diff --git a/stan/math/prim/fun/softmax.hpp b/stan/math/prim/fun/softmax.hpp index d3221f7ce72..3ce17dd8b32 100644 --- a/stan/math/prim/fun/softmax.hpp +++ b/stan/math/prim/fun/softmax.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace stan { @@ -38,20 +39,33 @@ namespace math { * \end{array} * \f$ * - * @tparam ColVec type of elements in the vector + * @tparam Vec type of the input vector * @param[in] v Vector to transform. * @return Unit simplex result of the softmax transform of the vector. */ -template * = nullptr> -inline plain_type_t softmax(const ColVec& v) { - using std::exp; +template * = nullptr> +inline plain_type_t softmax(const Vec& v) { if (v.size() == 0) { return v; } const auto& v_ref = to_ref(v); const auto theta = (v_ref.array() - v_ref.maxCoeff()).exp().eval(); - return theta.array() / theta.sum(); + return (theta / theta.sum()).matrix(); +} + +/** + * Return the softmax of each vector in an array. + * + * @tparam T `std::vector` whose scalar type is arithmetic + * @param[in] x Array of vectors to transform. + * @return Array of unit simplex results. + */ +template * = nullptr> +inline auto softmax(T&& x) { + return apply_vector_unary::apply(std::forward(x), [](auto&& v) { + return softmax(std::forward(v)); + }); } } // namespace math diff --git a/stan/math/rev/fun/log_softmax.hpp b/stan/math/rev/fun/log_softmax.hpp index 92650558b65..4dcfc21a63d 100644 --- a/stan/math/rev/fun/log_softmax.hpp +++ b/stan/math/rev/fun/log_softmax.hpp @@ -3,126 +3,51 @@ #include #include -#include -#include +#include #include #include #include -#include #include -#include -#include +#include namespace stan { namespace math { -namespace internal { - -class log_softmax_elt_vari : public vari { - private: - vari** alpha_; - const double* softmax_alpha_; - const int size_; // array sizes - const int idx_; // in in softmax output - - public: - log_softmax_elt_vari(double val, vari** alpha, const double* softmax_alpha, - int size, int idx) - : vari(val), - alpha_(alpha), - softmax_alpha_(softmax_alpha), - size_(size), - idx_(idx) {} - void chain() { - for (int m = 0; m < size_; ++m) { - if (m == idx_) { - alpha_[m]->adj_ += adj_ * (1 - softmax_alpha_[m]); - } else { - alpha_[m]->adj_ -= adj_ * softmax_alpha_[m]; - } - } - } -}; -} // namespace internal - /** - * Return the log softmax of the specified vector + * Return the log softmax of the specified vector or row vector. * - * @tparam T type of input + * @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar * @param x input - * @return softmax of the input + * @return log softmax of the input * @throw std::domain_error if the input size is 0 */ -template * = nullptr> -inline auto log_softmax(const T& x) { - const int a_size = x.size(); - - check_nonzero_size("log_softmax", "x", x); - - const auto& x_ref = to_ref(x); - - vari** x_vi_array - = ChainableStack::instance_->memalloc_.alloc_array(a_size); - Eigen::Map(x_vi_array, a_size) = x_ref.vi(); - - vector_d x_d = x_ref.val(); - - // fold logic of math::softmax() and math::log_softmax() - // to save computations - - vector_d diff = (x_d.array() - x_d.maxCoeff()); - vector_d softmax_x_d = diff.array().exp(); - double sum = softmax_x_d.sum(); - vector_d log_softmax_x_d = diff.array() - std::log(sum); - - // end fold - double* softmax_x_d_array - = ChainableStack::instance_->memalloc_.alloc_array(a_size); - Eigen::Map(softmax_x_d_array, a_size) = softmax_x_d.array() / sum; - - plain_type_t log_softmax_x(a_size); - for (int k = 0; k < a_size; ++k) { - log_softmax_x(k) = var(new internal::log_softmax_elt_vari( - log_softmax_x_d[k], x_vi_array, softmax_x_d_array, a_size, k)); - } - return log_softmax_x; -} - -/** - * Return the log softmax of the specified vector - * - * @tparam T type of input - * @param x input - * @return softmax of the input - * @throw std::domain_error if the input size is 0 - */ -template * = nullptr> -inline auto log_softmax(const T& x) { +template * = nullptr> +inline auto log_softmax(T&& x) { check_nonzero_size("log_softmax", "x", x); - - const auto& theta = (x.val().array() - x.val().maxCoeff()).eval(); - - return make_callback_var( - (theta.array() - log(theta.exp().sum())).matrix(), - [x](const auto& res) mutable { - x.adj().noalias() - += res.adj() - (res.adj().sum() * res.val().array().exp()).matrix(); - }); + auto x_arena = to_arena(std::forward(x)); + using return_t + = return_var_matrix_t, T>; + arena_t res = log_softmax(x_arena.val()); + reverse_pass_callback([x_arena, res]() mutable { + const auto& res_adj = to_ref(res.adj()); + x_arena.adj().array() + += res_adj.array() - res_adj.sum() * res.val().array().exp(); + }); + return return_t(res); } /** - * Return the log softmax of the specified `std::vector` or - * `std::vector` of containers. + * Return the log softmax of each vector in an array. * - * @tparam T type of input - * @param x input - * @return softmax of the input - * @throw std::domain_error if the input size is 0 + * @tparam T `std::vector` whose scalar type is `var` + * @param x array of vectors to transform + * @return array of log softmax results + * @throw std::domain_error if any element size is 0 */ template * = nullptr> inline auto log_softmax(T&& x) { - return apply_vector_unary::apply(std::forward(x), [](auto&& alpha) { - return log_softmax(std::forward(alpha)); + return apply_vector_unary::apply(std::forward(x), [](auto&& v) { + return log_softmax(std::forward(v)); }); } diff --git a/stan/math/rev/fun/softmax.hpp b/stan/math/rev/fun/softmax.hpp index a1bf786e826..ee3170024d2 100644 --- a/stan/math/rev/fun/softmax.hpp +++ b/stan/math/rev/fun/softmax.hpp @@ -3,44 +3,52 @@ #include #include -#include #include #include -#include +#include #include #include -#include -#include +#include namespace stan { namespace math { /** - * Return the softmax of the specified Eigen vector. Softmax is - * guaranteed to return a simplex. + * Return the softmax of the specified vector or row vector. * - * @param alpha Unconstrained input vector. - * @return Softmax of the input. - * @throw std::domain_error If the input vector is size 0. + * @tparam T a `var_value` or Eigen vector/row_vector with `var` scalar + * @param x input + * @return softmax of the input */ -template * = nullptr> -inline auto softmax(const Mat& alpha) { - using mat_plain = plain_type_t; - using ret_type = return_var_matrix_t; - if (alpha.size() == 0) { - return ret_type(alpha); +template * = nullptr> +inline auto softmax(T&& x) { + auto x_arena = to_arena(std::forward(x)); + using return_t + = return_var_matrix_t, T>; + if (x_arena.size() == 0) { + return return_t(x_arena); } - arena_t alpha_arena = alpha; - arena_t res_val = softmax(value_of(alpha_arena)); - arena_t res = res_val; - - reverse_pass_callback([res_val, res, alpha_arena]() mutable { + arena_t res = softmax(x_arena.val()); + reverse_pass_callback([x_arena, res]() mutable { + const auto& s = to_ref(res.val()); const auto& res_adj = to_ref(res.adj()); - alpha_arena.adj() - += -res_val * res_adj.dot(res_val) + res_val.cwiseProduct(res_adj); + x_arena.adj().array() += s.array() * (res_adj.array() - s.dot(res_adj)); }); + return return_t(res); +} - return ret_type(res); +/** + * Return the softmax of each vector in an array. + * + * @tparam T `std::vector` whose scalar type is `var` + * @param x array of vectors to transform + * @return array of softmax results + */ +template * = nullptr> +inline auto softmax(T&& x) { + return apply_vector_unary::apply(std::forward(x), [](auto&& v) { + return softmax(std::forward(v)); + }); } } // namespace math diff --git a/test/unit/math/mix/fun/softmax_test.cpp b/test/unit/math/mix/fun/softmax_test.cpp index bf748824173..248bd975376 100644 --- a/test/unit/math/mix/fun/softmax_test.cpp +++ b/test/unit/math/mix/fun/softmax_test.cpp @@ -9,6 +9,7 @@ TEST(MathMixMatFun, softmax) { tols.hessian_hessian_ = 1e-2; tols.hessian_fvar_hessian_ = 1e-2; + // Column vectors Eigen::VectorXd a(0); stan::test::expect_ad(tols, f, a); expect_ad_matvar(f, a); @@ -41,4 +42,54 @@ TEST(MathMixMatFun, softmax) { d4 << 0, 3, -1; stan::test::expect_ad(tols, f, d4); expect_ad_matvar(f, d4); + + // Row vectors + Eigen::RowVectorXd ra(0); + stan::test::expect_ad(tols, f, ra); + expect_ad_matvar(f, ra); + + Eigen::RowVectorXd rb(1); + rb << 0; + stan::test::expect_ad(tols, f, rb); + expect_ad_matvar(f, rb); + + Eigen::RowVectorXd rc(2); + rc << -1, 1; + stan::test::expect_ad(tols, f, rc); + expect_ad_matvar(f, rc); + + Eigen::RowVectorXd rd(3); + rd << -1, 1, 10; + stan::test::expect_ad(tols, f, rd); + expect_ad_matvar(f, rd); + + Eigen::RowVectorXd rd2(3); + rd2 << 0.5, -1, 3; + stan::test::expect_ad(tols, f, rd2); + expect_ad_matvar(f, rd2); + + // Arrays of vectors (array[] vector and array[] row_vector) + std::vector stvx0{a, a}; // error case + stan::test::expect_ad(tols, f, stvx0); + expect_ad_matvar(f, stvx0); + + std::vector stvx1{b, b}; + stan::test::expect_ad(tols, f, stvx1); + expect_ad_matvar(f, stvx1); + + std::vector stvx2{c, d}; + stan::test::expect_ad(tols, f, stvx2); + expect_ad_matvar(f, stvx2); + + std::vector strx0{ra, ra}; // error case + stan::test::expect_ad(tols, f, strx0); + expect_ad_matvar(f, strx0); + + std::vector strx1{rb, rb}; + stan::test::expect_ad(tols, f, strx1); + expect_ad_matvar(f, strx1); + + std::vector strx2{rc, rd}; + stan::test::expect_ad(tols, f, strx2); + expect_ad_matvar(f, strx2); } diff --git a/test/unit/math/prim/fun/log_softmax_test.cpp b/test/unit/math/prim/fun/log_softmax_test.cpp index 27f682a17ff..9649a20e13d 100644 --- a/test/unit/math/prim/fun/log_softmax_test.cpp +++ b/test/unit/math/prim/fun/log_softmax_test.cpp @@ -1,5 +1,6 @@ #include #include +#include #include inline void test_log_softmax( @@ -66,6 +67,23 @@ TEST(MathMatrixPrimMat, log_softmax) { // x3 << -1.0, 1.0, 10.0; // test_log_softmax(x3); } +TEST(MathMatrixPrimMat, log_softmax_neg_inf) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::log_softmax; + constexpr double neg_inf = -std::numeric_limits::infinity(); + + // -inf in a vector stays -inf in the output; the rest get the + // proper restricted log-softmax. + Matrix v(3); + v << neg_inf, 1.0, 2.0; + Matrix result = log_softmax(v); + const double lse_finite = std::log(exp(1.0) + exp(2.0)); + EXPECT_EQ(neg_inf, result[0]); + EXPECT_FLOAT_EQ(1.0 - lse_finite, result[1]); + EXPECT_FLOAT_EQ(2.0 - lse_finite, result[2]); +} + TEST(MathMatrixPrimMat, log_softmax_exception) { using stan::math::log_softmax; stan::math::vector_d v0; // size == 0 diff --git a/test/unit/math/prim/fun/softmax_test.cpp b/test/unit/math/prim/fun/softmax_test.cpp index 8e3c8a13328..7a38156d95b 100644 --- a/test/unit/math/prim/fun/softmax_test.cpp +++ b/test/unit/math/prim/fun/softmax_test.cpp @@ -1,5 +1,6 @@ #include #include +#include TEST(MathMatrixPrimMat, softmax) { using Eigen::Dynamic; @@ -28,3 +29,46 @@ TEST(MathMatrixPrimMat, softmax) { EXPECT_FLOAT_EQ(exp(1) / (exp(-1) + exp(1) + exp(10.0)), theta3[1]); EXPECT_FLOAT_EQ(exp(10) / (exp(-1) + exp(1) + exp(10.0)), theta3[2]); } + +TEST(MathMatrixPrimMat, softmax_neg_inf) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::softmax; + constexpr double neg_inf = -std::numeric_limits::infinity(); + + // -inf in a vector pins that component to exactly 0; the rest renormalize. + Matrix v(3); + v << neg_inf, 1.0, 2.0; + Matrix theta = softmax(v); + EXPECT_FLOAT_EQ(0.0, theta[0]); + EXPECT_FLOAT_EQ(exp(1.0) / (exp(1.0) + exp(2.0)), theta[1]); + EXPECT_FLOAT_EQ(exp(2.0) / (exp(1.0) + exp(2.0)), theta[2]); + EXPECT_FLOAT_EQ(1.0, theta.sum()); +} + +TEST(MathMatrixPrimMat, softmax_row_vector) { + using Eigen::Dynamic; + using Eigen::Matrix; + using stan::math::softmax; + + Matrix x(1); + x << 0.0; + Matrix theta = softmax(x); + EXPECT_EQ(1, theta.size()); + EXPECT_FLOAT_EQ(1.0, theta[0]); + + Matrix x2(2); + x2 << -1.0, 1.0; + Matrix theta2 = softmax(x2); + EXPECT_EQ(2, theta2.size()); + EXPECT_FLOAT_EQ(exp(-1) / (exp(-1) + exp(1)), theta2[0]); + EXPECT_FLOAT_EQ(exp(1) / (exp(-1) + exp(1)), theta2[1]); + + Matrix x3(3); + x3 << -1.0, 1.0, 10.0; + Matrix theta3 = softmax(x3); + EXPECT_EQ(3, theta3.size()); + EXPECT_FLOAT_EQ(exp(-1) / (exp(-1) + exp(1) + exp(10.0)), theta3[0]); + EXPECT_FLOAT_EQ(exp(1) / (exp(-1) + exp(1) + exp(10.0)), theta3[1]); + EXPECT_FLOAT_EQ(exp(10) / (exp(-1) + exp(1) + exp(10.0)), theta3[2]); +} diff --git a/test/unit/math/rev/fun/log_softmax_test.cpp b/test/unit/math/rev/fun/log_softmax_test.cpp new file mode 100644 index 00000000000..0008e372b4c --- /dev/null +++ b/test/unit/math/rev/fun/log_softmax_test.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include + +// Direct exercise of the var_value vector overload in +// stan/math/rev/fun/log_softmax.hpp for both row and column var_value vectors. + +TEST_F(AgradRev, log_softmax_var_value_col_vector) { + using stan::math::log_softmax; + using stan::math::sum; + using stan::math::var_value; + + Eigen::VectorXd v(3); + v << -1, 1, 2; + + var_value x(v); + auto y = log_softmax(x); + + EXPECT_EQ(3, y.val().rows()); + EXPECT_EQ(1, y.val().cols()); + + const double lse = std::log(std::exp(-1.0) + std::exp(1.0) + std::exp(2.0)); + EXPECT_FLOAT_EQ(-1.0 - lse, y.val()(0)); + EXPECT_FLOAT_EQ(1.0 - lse, y.val()(1)); + EXPECT_FLOAT_EQ(2.0 - lse, y.val()(2)); + + // d/dx[m] sum(y) = sum_k(delta_km - softmax[m]) = 1 - n * softmax[m] + sum(y).grad(); + const double denom = std::exp(-1.0) + std::exp(1.0) + std::exp(2.0); + EXPECT_FLOAT_EQ(1.0 - 3.0 * std::exp(-1.0) / denom, x.adj()(0)); + EXPECT_FLOAT_EQ(1.0 - 3.0 * std::exp(1.0) / denom, x.adj()(1)); + EXPECT_FLOAT_EQ(1.0 - 3.0 * std::exp(2.0) / denom, x.adj()(2)); +} + +TEST_F(AgradRev, log_softmax_var_value_row_vector) { + using stan::math::log_softmax; + using stan::math::sum; + using stan::math::var_value; + + Eigen::RowVectorXd v(3); + v << -1, 1, 2; + + var_value x(v); + auto y = log_softmax(x); + + // Output should preserve row shape. + EXPECT_EQ(1, y.val().rows()); + EXPECT_EQ(3, y.val().cols()); + + const double lse = std::log(std::exp(-1.0) + std::exp(1.0) + std::exp(2.0)); + EXPECT_FLOAT_EQ(-1.0 - lse, y.val()(0)); + EXPECT_FLOAT_EQ(1.0 - lse, y.val()(1)); + EXPECT_FLOAT_EQ(2.0 - lse, y.val()(2)); + + // Same gradient formula; one softmax over the entire row vector. + sum(y).grad(); + const double denom = std::exp(-1.0) + std::exp(1.0) + std::exp(2.0); + EXPECT_FLOAT_EQ(1.0 - 3.0 * std::exp(-1.0) / denom, x.adj()(0)); + EXPECT_FLOAT_EQ(1.0 - 3.0 * std::exp(1.0) / denom, x.adj()(1)); + EXPECT_FLOAT_EQ(1.0 - 3.0 * std::exp(2.0) / denom, x.adj()(2)); +}