Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions stan/math/fwd/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,67 @@
#include <stan/math/fwd/core.hpp>
#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/fun/softmax.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/log_softmax.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>

namespace stan {
namespace math {

/**
* Return the log softmax of the specified vector or container of vectors.
* Return the log softmax of the specified row vector of `fvar` values.
* Delegates to the column-vector overload via transposition.
*
* @tparam T Type of input vector or matrix.
* @param[in] x Unconstrained input vector.
* @return Softmax of the input.
* @throw std::domain_error If the input vector is size 0.
* @tparam RowVec Eigen row vector with `fvar` scalar
* @param x row vector to transform
* @return log softmax of the row vector
*/
template <typename T, require_vector_st<is_fvar, T>* = nullptr>
inline auto log_softmax(T&& x) {
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& alpha) {
using T_alpha = decltype(alpha);
using T_fvar = value_type_t<T_alpha>;
using T_fvar_inner = typename T_fvar::Scalar;

auto&& alpha_ref = to_ref(std::forward<decltype(alpha)>(alpha));
Eigen::Matrix<T_fvar_inner, -1, 1> alpha_t = alpha_ref.val();
Eigen::Matrix<T_fvar_inner, -1, 1> softmax_alpha_t = softmax(alpha_t);

Eigen::Matrix<T_fvar, -1, 1> log_softmax_alpha(alpha_ref.size());
log_softmax_alpha.val() = log_softmax(alpha_t);
log_softmax_alpha.d().setZero();

for (int m = 0; m < alpha_ref.size(); ++m) {
T_fvar_inner negative_alpha_m_d_times_softmax_alpha_t_m
= -alpha_ref.coeff(m).d_ * softmax_alpha_t(m);
for (int k = 0; k < alpha_ref.size(); ++k) {
if (m == k) {
log_softmax_alpha(k).d_
+= alpha_ref.coeff(m).d_
+ negative_alpha_m_d_times_softmax_alpha_t_m;
} else {
log_softmax_alpha(k).d_ += negative_alpha_m_d_times_softmax_alpha_t_m;
}
}
}
template <typename RowVec, require_eigen_row_vector_t<RowVec>* = nullptr,
require_t<is_fvar<value_type_t<RowVec>>>* = nullptr>
inline auto log_softmax(const RowVec& x) {
return log_softmax(x.transpose()).transpose().eval();
}

return log_softmax_alpha;
/**
* Return the log softmax of each vector in a container of `fvar` values.
*
* @tparam T `std::vector` whose scalar type is `fvar`
* @param x container of vectors to transform
* @return container of log softmax results
*/
template <typename T, require_std_vector_st<is_fvar, T>* = nullptr>
inline auto log_softmax(T&& x) {
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) {
return log_softmax(std::forward<decltype(v)>(v));
});
}

/**
* Return the log softmax of the specified column vector of `fvar` values.
*
* @tparam ColVec Eigen column vector with `fvar` scalar
* @param x column vector to transform
* @return log softmax of the column vector
* @throw std::domain_error if the input size is 0
*/
template <typename ColVec,
require_eigen_col_vector_vt<is_fvar, ColVec>* = nullptr>
inline auto log_softmax(const ColVec& x) {
using Eigen::Dynamic;
using Eigen::Matrix;
using T = typename value_type_t<ColVec>::Scalar;
check_nonzero_size("log_softmax", "x", x);
const auto& x_ref = to_ref(x);
const Matrix<T, Dynamic, 1> s = softmax(value_of(x_ref));
const auto d_in = x_ref.d().eval();
const T dot_sd = s.dot(d_in);
Matrix<fvar<T>, Dynamic, 1> result(x_ref.size());
result.val() = s.array().log().matrix();
result.d() = (d_in.array() - dot_sd).matrix();
return result;
}

} // namespace math
} // namespace stan
#endif
76 changes: 46 additions & 30 deletions stan/math/fwd/fun/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,63 @@
#include <stan/math/fwd/fun/value_of.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/softmax.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>

namespace stan {
namespace math {

/**
* Return the softmax of the specified row vector of `fvar` values.
* Delegates to the column-vector overload via transposition.
*
* @tparam RowVec Eigen row vector with `fvar` scalar
* @param x row vector to transform
* @return softmax of the row vector
*/
template <typename RowVec, require_eigen_row_vector_t<RowVec>* = nullptr,
require_t<is_fvar<value_type_t<RowVec>>>* = nullptr>
inline auto softmax(const RowVec& x) {
return softmax(x.transpose()).transpose().eval();
}

/**
* Return the softmax of each vector in a container of `fvar` values.
*
* @tparam T `std::vector` whose scalar type is `fvar`
* @param x container of vectors to transform
* @return container of softmax results
*/
template <typename T, require_std_vector_st<is_fvar, T>* = nullptr>
inline auto softmax(T&& x) {
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) {
return softmax(std::forward<decltype(v)>(v));
});
}

/**
* Return the softmax of the specified column vector of `fvar` values.
*
* @tparam ColVec Eigen column vector with `fvar` scalar
* @param x column vector to transform
* @return softmax of the column vector
*/
template <typename ColVec,
require_eigen_col_vector_vt<is_fvar, ColVec>* = nullptr>
inline auto softmax(const ColVec& alpha) {
inline auto softmax(const ColVec& x) {
using Eigen::Dynamic;
using Eigen::Matrix;
using T = typename value_type_t<ColVec>::Scalar;
if (alpha.size() == 0) {
if (x.size() == 0) {
return Matrix<fvar<T>, Dynamic, 1>();
}
const auto& alpha_ref = to_ref(alpha);

Matrix<T, Dynamic, 1> softmax_alpha_t = softmax(value_of(alpha_ref));

Matrix<fvar<T>, Dynamic, 1> softmax_alpha(alpha.size());
for (int k = 0; k < alpha.size(); ++k) {
softmax_alpha.coeffRef(k).val_ = softmax_alpha_t.coeff(k);
softmax_alpha.coeffRef(k).d_ = 0;
}

for (int m = 0; m < alpha.size(); ++m) {
T negative_alpha_m_d_times_softmax_alpha_t_m
= -alpha_ref.coeff(m).d_ * softmax_alpha_t.coeff(m);
for (int k = 0; k < alpha.size(); ++k) {
if (m == k) {
softmax_alpha.coeffRef(k).d_
+= softmax_alpha_t.coeff(k)
* (alpha_ref.coeff(m).d_
+ negative_alpha_m_d_times_softmax_alpha_t_m);
} else {
softmax_alpha.coeffRef(k).d_
+= softmax_alpha_t.coeff(k)
* negative_alpha_m_d_times_softmax_alpha_t_m;
}
}
}

return softmax_alpha;
const auto& x_ref = to_ref(x);
const Matrix<T, Dynamic, 1> s = softmax(value_of(x_ref));
const auto d_in = x_ref.d().eval();
const T dot_sd = s.dot(d_in);
Matrix<fvar<T>, Dynamic, 1> result(x_ref.size());
result.val() = s;
result.d() = (s.array() * (d_in.array() - dot_sd)).matrix();
return result;
}

} // namespace math
Expand Down
21 changes: 13 additions & 8 deletions stan/math/prim/fun/log_softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace math {

/**
* Return the natural logarithm of the softmax of the specified
* vector.
* vector, or of each vector in a container.
*
* \f$
* \log \mbox{softmax}(y)
Expand All @@ -23,26 +23,31 @@ namespace math {
*
* For the log softmax function, the entries in the Jacobian are
* \f$
* \frac{\partial}{\partial y_m} \mbox{softmax}(y)[k]
* \frac{\partial}{\partial y_m} \log\mbox{softmax}(y)[k]
* = \left\{
* \begin{array}{ll}
* 1 - \mbox{softmax}(y)[m]
* & \mbox{ if } m = k, \mbox{ and}
* \\[6pt]
* \mbox{softmax}(y)[m]
* -\mbox{softmax}(y)[m]
* & \mbox{ if } m \neq k.
* \end{array}
* \right.
* \f$
*
* @tparam Container type of input vector to transform
* @param[in] x vector to transform
* @return log unit simplex result of the softmax transform of the vector.
* @tparam Container type of input: an Eigen vector, `std::vector` of doubles,
* or nested container whose scalar type is arithmetic
* @param[in] x vector or container of vectors to transform
* @return log softmax of the input, preserving the container structure
* @throw std::domain_error if any input vector is empty
*/
template <typename Container, require_st_arithmetic<Container>* = nullptr,
require_container_t<Container>* = nullptr>
require_container_t<Container>* = nullptr,
require_not_t<bool_constant<
is_eigen<std::decay_t<Container>>::value
&& !is_eigen_vector<std::decay_t<Container>>::value>>* = nullptr>
inline auto log_softmax(Container&& x) {
check_nonzero_size("log_softmax", "v", x);
check_nonzero_size("log_softmax", "x", x);
return make_holder(
[](auto&& a) {
return apply_vector_unary<ref_type_t<Container>>::apply(
Expand Down
26 changes: 20 additions & 6 deletions stan/math/prim/fun/softmax.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/functor/apply_vector_unary.hpp>
#include <cmath>

namespace stan {
Expand Down Expand Up @@ -38,20 +39,33 @@ namespace math {
* \end{array}
* \f$
*
* @tparam ColVec type of elements in the vector
* @tparam Vec type of the input vector
* @param[in] v Vector to transform.
* @return Unit simplex result of the softmax transform of the vector.
*/
template <typename ColVec,
require_eigen_col_vector_vt<std::is_arithmetic, ColVec>* = nullptr>
inline plain_type_t<ColVec> softmax(const ColVec& v) {
using std::exp;
template <typename Vec,
require_eigen_vector_vt<std::is_arithmetic, Vec>* = nullptr>
inline plain_type_t<Vec> softmax(const Vec& v) {
if (v.size() == 0) {
return v;
}
const auto& v_ref = to_ref(v);
const auto theta = (v_ref.array() - v_ref.maxCoeff()).exp().eval();
return theta.array() / theta.sum();
return (theta / theta.sum()).matrix();
}

/**
* Return the softmax of each vector in an array.
*
* @tparam T `std::vector` whose scalar type is arithmetic
* @param[in] x Array of vectors to transform.
* @return Array of unit simplex results.
*/
template <typename T, require_std_vector_st<std::is_arithmetic, T>* = nullptr>
inline auto softmax(T&& x) {
return apply_vector_unary<T>::apply(std::forward<T>(x), [](auto&& v) {
return softmax(std::forward<decltype(v)>(v));
});
}

} // namespace math
Expand Down
Loading
Loading