Skip to content

Commit 9b9c4a5

Browse files
author
Jachym.Barvinek
committed
Merge branch 'jachymb_softmax' of github.com:jachymb/math into jachymb_softmax
2 parents 008a131 + 3271c41 commit 9b9c4a5

14 files changed

Lines changed: 288 additions & 15 deletions

File tree

stan/math/fwd/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
#include <stan/math/fwd/fun/tcrossprod.hpp>
117117
#include <stan/math/fwd/fun/tgamma.hpp>
118118
#include <stan/math/fwd/fun/to_fvar.hpp>
119+
#include <stan/math/fwd/fun/trace_dot.hpp>
119120
#include <stan/math/fwd/fun/trace_quad_form.hpp>
120121
#include <stan/math/fwd/fun/trigamma.hpp>
121122
#include <stan/math/fwd/fun/trunc.hpp>

stan/math/fwd/fun/log_softmax.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ inline auto log_softmax(const Mat& m) {
3333
const auto exp_s = shifted.exp().eval();
3434
const auto row_sums = exp_s.rowwise().sum().eval();
3535
const auto lsm_val = (shifted.colwise() - row_sums.log()).matrix().eval();
36-
// softmax values needed for the tangent: d_in - softmax(x) ⊙ dot(softmax(x), d_in)
36+
// softmax values needed for the tangent: d_in - softmax(x) ⊙ dot(softmax(x),
37+
// d_in)
3738
const auto s = (exp_s.colwise() / row_sums).eval();
3839
const auto d_in = m_ref.d().eval();
3940
const auto dots = (s.array() * d_in.array()).rowwise().sum().eval();

stan/math/fwd/fun/trace_dot.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef STAN_MATH_FWD_FUN_TRACE_DOT_HPP
2+
#define STAN_MATH_FWD_FUN_TRACE_DOT_HPP
3+
4+
#include <stan/math/fwd/core.hpp>
5+
#include <stan/math/fwd/fun/multiply.hpp>
6+
#include <stan/math/prim/err.hpp>
7+
#include <stan/math/prim/fun/trace.hpp>
8+
9+
namespace stan {
10+
namespace math {
11+
12+
/**
13+
* Compute the trace of the product of two matrices with
14+
* forward-mode autodiff support.
15+
*
16+
* @tparam EigMat1 A type either inheriting from `Eigen::DenseBase` or a
17+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
18+
* @tparam EigMat2 A type either inheriting from `Eigen::DenseBase` or a
19+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
20+
*
21+
* @param A first matrix (m x n)
22+
* @param B second matrix (n x m)
23+
* @return trace of A * B
24+
* @throw std::invalid_argument if A and B have incompatible dimensions
25+
*/
26+
template <typename EigMat1, typename EigMat2,
27+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
28+
require_any_vt_fvar<EigMat1, EigMat2>* = nullptr>
29+
inline return_type_t<EigMat1, EigMat2> trace_dot(EigMat1&& A, EigMat2&& B) {
30+
check_size_match("trace_dot", "A.cols()", A.cols(), "B.rows()", B.rows());
31+
check_size_match("trace_dot", "A.rows()", A.rows(), "B.cols()", B.cols());
32+
return trace(multiply(std::forward<EigMat1>(A), std::forward<EigMat2>(B)));
33+
}
34+
35+
} // namespace math
36+
} // namespace stan
37+
#endif

stan/math/prim/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@
328328
#include <stan/math/prim/fun/to_row_vector.hpp>
329329
#include <stan/math/prim/fun/to_vector.hpp>
330330
#include <stan/math/prim/fun/trace.hpp>
331+
#include <stan/math/prim/fun/trace_dot.hpp>
331332
#include <stan/math/prim/fun/trace_gen_inv_quad_form_ldlt.hpp>
332333
#include <stan/math/prim/fun/trace_gen_quad_form.hpp>
333334
#include <stan/math/prim/fun/trace_inv_quad_form_ldlt.hpp>

stan/math/prim/fun/log_softmax.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ namespace math {
4141
*/
4242
template <typename Container, require_st_arithmetic<Container>* = nullptr,
4343
require_container_t<Container>* = nullptr,
44-
require_not_t<bool_constant<is_eigen<std::decay_t<Container>>::value
45-
&& !is_eigen_vector<std::decay_t<
46-
Container>>::value>>* = nullptr>
44+
require_not_t<bool_constant<
45+
is_eigen<std::decay_t<Container>>::value
46+
&& !is_eigen_vector<std::decay_t<Container>>::value>>* = nullptr>
4747
inline auto log_softmax(Container&& x) {
4848
check_nonzero_size("log_softmax", "v", x);
4949
return make_holder(

stan/math/prim/fun/trace.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ namespace math {
1919
*/
2020
template <typename T, require_eigen_t<T>* = nullptr,
2121
require_not_st_var<T>* = nullptr>
22-
inline value_type_t<T> trace(const T& m) {
23-
return m.trace();
22+
inline auto trace(T&& m) {
23+
return make_holder([](auto&& m_) { return m_.trace(); }, std::forward<T>(m));
2424
}
2525

2626
} // namespace math

stan/math/prim/fun/trace_dot.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef STAN_MATH_PRIM_FUN_TRACE_DOT_HPP
2+
#define STAN_MATH_PRIM_FUN_TRACE_DOT_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/Eigen.hpp>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Compute the trace of the product of two matrices,
13+
* \f$ \text{tr}(A \cdot B) = \sum_{i,j} A_{ij} B_{ji} \f$.
14+
*
15+
* This is more efficient than computing the full product and
16+
* taking the trace, as it avoids forming the intermediate matrix.
17+
*
18+
* @tparam EigMat1 A type either inheriting from `Eigen::DenseBase` or a
19+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
20+
* @tparam EigMat2 A type either inheriting from `Eigen::DenseBase` or a
21+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
22+
*
23+
* @param A first matrix (m x n)
24+
* @param B second matrix (n x m)
25+
* @return trace of A * B
26+
* @throw std::invalid_argument if A and B have incompatible dimensions
27+
*/
28+
template <typename EigMat1, typename EigMat2,
29+
require_all_eigen_vt<std::is_arithmetic, EigMat1, EigMat2>* = nullptr>
30+
inline auto trace_dot(EigMat1&& A, EigMat2&& B) {
31+
check_size_match("trace_dot", "A.cols()", A.cols(), "B.rows()", B.rows());
32+
check_size_match("trace_dot", "A.rows()", A.rows(), "B.cols()", B.cols());
33+
return make_holder(
34+
[](auto&& A_, auto&& B_) {
35+
return A_.cwiseProduct(B_.transpose()).sum();
36+
},
37+
std::forward<EigMat1>(A), std::forward<EigMat2>(B));
38+
}
39+
40+
} // namespace math
41+
} // namespace stan
42+
43+
#endif

stan/math/rev/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@
172172
#include <stan/math/rev/fun/to_var_value.hpp>
173173
#include <stan/math/rev/fun/to_vector.hpp>
174174
#include <stan/math/rev/fun/trace.hpp>
175+
#include <stan/math/rev/fun/trace_dot.hpp>
175176
#include <stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp>
176177
#include <stan/math/rev/fun/trace_gen_quad_form.hpp>
177178
#include <stan/math/rev/fun/trace_inv_quad_form_ldlt.hpp>

stan/math/rev/fun/log_softmax.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ template <typename T, require_var_matrix_t<T>* = nullptr,
109109
inline auto log_softmax(const T& x) {
110110
check_nonzero_size("log_softmax", "x", x);
111111
return make_callback_var(
112-
log_softmax(x.val()).eval(),
113-
[x](const auto& res) mutable {
112+
log_softmax(x.val()).eval(), [x](const auto& res) mutable {
114113
// grad: g - sum(g) * softmax(x), where softmax(x) = exp(log_softmax(x))
115114
x.adj().noalias()
116115
+= res.adj() - (res.adj().sum() * res.val().array().exp()).matrix();
@@ -131,9 +130,9 @@ template <typename T, require_var_matrix_t<T>* = nullptr,
131130
inline auto log_softmax(const T& x) {
132131
check_nonzero_size("log_softmax", "x", x);
133132
return make_callback_var(
134-
Eigen::MatrixXd(log_softmax(x.val())),
135-
[x](const auto& res) mutable {
136-
// grad per row: g - softmax(x) * sum(g), softmax(x) = exp(log_softmax(x))
133+
Eigen::MatrixXd(log_softmax(x.val())), [x](const auto& res) mutable {
134+
// grad per row: g - softmax(x) * sum(g), softmax(x) =
135+
// exp(log_softmax(x))
137136
const auto row_sums = res.adj().rowwise().sum().eval();
138137
x.adj().noalias()
139138
+= res.adj()

stan/math/rev/fun/softmax.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ inline auto softmax(const Mat& m) {
7373
reverse_pass_callback([res_val, res, m_arena]() mutable {
7474
const auto& g = to_ref(res.adj());
7575
const auto dots = (res_val.array() * g.array()).rowwise().sum().eval();
76-
m_arena.adj() += (res_val.array() * (g.array().colwise() - dots.array()))
77-
.matrix();
76+
m_arena.adj()
77+
+= (res_val.array() * (g.array().colwise() - dots.array())).matrix();
7878
});
7979

8080
return ret_type(res);

0 commit comments

Comments
 (0)