Skip to content

Commit 105bfcc

Browse files
authored
Merge pull request #3310 from jachymb/jachymb_trace_dot
Implement trace_dot
2 parents 45d1de2 + 7157ac5 commit 105bfcc

10 files changed

Lines changed: 277 additions & 4 deletions

File tree

stan/math/fwd/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
#include <stan/math/fwd/fun/tcrossprod.hpp>
117117
#include <stan/math/fwd/fun/tgamma.hpp>
118118
#include <stan/math/fwd/fun/to_fvar.hpp>
119+
#include <stan/math/fwd/fun/trace_dot.hpp>
119120
#include <stan/math/fwd/fun/trace_quad_form.hpp>
120121
#include <stan/math/fwd/fun/trigamma.hpp>
121122
#include <stan/math/fwd/fun/trunc.hpp>

stan/math/fwd/fun/trace_dot.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef STAN_MATH_FWD_FUN_TRACE_DOT_HPP
2+
#define STAN_MATH_FWD_FUN_TRACE_DOT_HPP
3+
4+
#include <stan/math/fwd/core.hpp>
5+
#include <stan/math/fwd/fun/multiply.hpp>
6+
#include <stan/math/prim/err.hpp>
7+
#include <stan/math/prim/fun/trace.hpp>
8+
9+
namespace stan {
10+
namespace math {
11+
12+
/**
13+
* Compute the trace of the product of two matrices with
14+
* forward-mode autodiff support.
15+
*
16+
* @tparam EigMat1 A type either inheriting from `Eigen::DenseBase` or a
17+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
18+
* @tparam EigMat2 A type either inheriting from `Eigen::DenseBase` or a
19+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
20+
*
21+
* @param A first matrix (m x n)
22+
* @param B second matrix (n x m)
23+
* @return trace of A * B
24+
* @throw std::invalid_argument if A and B have incompatible dimensions
25+
*/
26+
template <typename EigMat1, typename EigMat2,
27+
require_all_eigen_t<EigMat1, EigMat2>* = nullptr,
28+
require_any_vt_fvar<EigMat1, EigMat2>* = nullptr>
29+
inline return_type_t<EigMat1, EigMat2> trace_dot(EigMat1&& A, EigMat2&& B) {
30+
check_size_match("trace_dot", "A.cols()", A.cols(), "B.rows()", B.rows());
31+
check_size_match("trace_dot", "A.rows()", A.rows(), "B.cols()", B.cols());
32+
return trace(multiply(std::forward<EigMat1>(A), std::forward<EigMat2>(B)));
33+
}
34+
35+
} // namespace math
36+
} // namespace stan
37+
#endif

stan/math/prim/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@
328328
#include <stan/math/prim/fun/to_row_vector.hpp>
329329
#include <stan/math/prim/fun/to_vector.hpp>
330330
#include <stan/math/prim/fun/trace.hpp>
331+
#include <stan/math/prim/fun/trace_dot.hpp>
331332
#include <stan/math/prim/fun/trace_gen_inv_quad_form_ldlt.hpp>
332333
#include <stan/math/prim/fun/trace_gen_quad_form.hpp>
333334
#include <stan/math/prim/fun/trace_inv_quad_form_ldlt.hpp>

stan/math/prim/fun/trace.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ namespace math {
1919
*/
2020
template <typename T, require_eigen_t<T>* = nullptr,
2121
require_not_st_var<T>* = nullptr>
22-
inline value_type_t<T> trace(const T& m) {
23-
return m.trace();
22+
inline auto trace(T&& m) {
23+
return make_holder([](auto&& m_) { return m_.trace(); }, std::forward<T>(m));
2424
}
2525

2626
} // namespace math

stan/math/prim/fun/trace_dot.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef STAN_MATH_PRIM_FUN_TRACE_DOT_HPP
2+
#define STAN_MATH_PRIM_FUN_TRACE_DOT_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/Eigen.hpp>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Compute the trace of the product of two matrices,
13+
* \f$ \text{tr}(A \cdot B) = \sum_{i,j} A_{ij} B_{ji} \f$.
14+
*
15+
* This is more efficient than computing the full product and
16+
* taking the trace, as it avoids forming the intermediate matrix.
17+
*
18+
* @tparam EigMat1 A type either inheriting from `Eigen::DenseBase` or a
19+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
20+
* @tparam EigMat2 A type either inheriting from `Eigen::DenseBase` or a
21+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
22+
*
23+
* @param A first matrix (m x n)
24+
* @param B second matrix (n x m)
25+
* @return trace of A * B
26+
* @throw std::invalid_argument if A and B have incompatible dimensions
27+
*/
28+
template <typename EigMat1, typename EigMat2,
29+
require_all_eigen_vt<std::is_arithmetic, EigMat1, EigMat2>* = nullptr>
30+
inline auto trace_dot(EigMat1&& A, EigMat2&& B) {
31+
check_size_match("trace_dot", "A.cols()", A.cols(), "B.rows()", B.rows());
32+
check_size_match("trace_dot", "A.rows()", A.rows(), "B.cols()", B.cols());
33+
return make_holder(
34+
[](auto&& A_, auto&& B_) {
35+
return A_.cwiseProduct(B_.transpose()).sum();
36+
},
37+
std::forward<EigMat1>(A), std::forward<EigMat2>(B));
38+
}
39+
40+
} // namespace math
41+
} // namespace stan
42+
43+
#endif

stan/math/rev/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@
172172
#include <stan/math/rev/fun/to_var_value.hpp>
173173
#include <stan/math/rev/fun/to_vector.hpp>
174174
#include <stan/math/rev/fun/trace.hpp>
175+
#include <stan/math/rev/fun/trace_dot.hpp>
175176
#include <stan/math/rev/fun/trace_gen_inv_quad_form_ldlt.hpp>
176177
#include <stan/math/rev/fun/trace_gen_quad_form.hpp>
177178
#include <stan/math/rev/fun/trace_inv_quad_form_ldlt.hpp>

stan/math/rev/fun/trace.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ namespace math {
2121
* @return Trace of the matrix.
2222
*/
2323
template <typename T, require_rev_matrix_t<T>* = nullptr>
24-
inline auto trace(const T& m) {
25-
arena_t<T> arena_m = m;
24+
inline auto trace(T&& m) {
25+
arena_t<T> arena_m(std::forward<T>(m));
2626

2727
return make_callback_var(arena_m.val_op().trace(),
2828
[arena_m](const auto& vi) mutable {

stan/math/rev/fun/trace_dot.hpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#ifndef STAN_MATH_REV_FUN_TRACE_DOT_HPP
2+
#define STAN_MATH_REV_FUN_TRACE_DOT_HPP
3+
4+
#include <stan/math/rev/meta.hpp>
5+
#include <stan/math/rev/core.hpp>
6+
#include <stan/math/rev/fun/value_of.hpp>
7+
#include <stan/math/prim/meta.hpp>
8+
#include <stan/math/prim/err.hpp>
9+
#include <stan/math/prim/fun/trace_dot.hpp>
10+
11+
namespace stan {
12+
namespace math {
13+
14+
/**
15+
* Compute the trace of the product of two matrices with autodiff support.
16+
*
17+
* \f$ \text{tr}(A \cdot B) = \sum_{i,j} A_{ij} B_{ji} \f$
18+
*
19+
* The gradients are:
20+
* \f$ \frac{\partial}{\partial A} \text{tr}(A B) = B^T \f$,
21+
* \f$ \frac{\partial}{\partial B} \text{tr}(A B) = A^T \f$.
22+
*
23+
* @tparam Mat1 A type either inheriting from `Eigen::DenseBase` or a
24+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
25+
* @tparam Mat2 A type either inheriting from `Eigen::DenseBase` or a
26+
* `var_value` with an inner type inheriting from `Eigen::DenseBase`
27+
*
28+
* @param A first matrix (m x n)
29+
* @param B second matrix (n x m)
30+
* @return trace of A * B
31+
* @throw std::invalid_argument if A and B have incompatible dimensions
32+
*/
33+
template <typename Mat1, typename Mat2,
34+
require_all_matrix_t<Mat1, Mat2>* = nullptr,
35+
require_any_rev_matrix_t<Mat1, Mat2>* = nullptr>
36+
inline var trace_dot(Mat1&& A, Mat2&& B) {
37+
check_size_match("trace_dot", "A.cols()", A.cols(), "B.rows()", B.rows());
38+
check_size_match("trace_dot", "A.rows()", A.rows(), "B.cols()", B.cols());
39+
if constexpr (is_autodiff_v<Mat1> && is_autodiff_v<Mat2>) {
40+
arena_t<Mat1> arena_A(std::forward<Mat1>(A));
41+
arena_t<Mat2> arena_B(std::forward<Mat2>(B));
42+
auto res_val = arena_A.val().cwiseProduct(arena_B.val().transpose()).sum();
43+
return make_callback_var(res_val, [arena_A, arena_B](auto&& res) mutable {
44+
if constexpr (is_var_matrix<Mat1>::value) {
45+
arena_A.adj().noalias() += res.adj() * arena_B.val().transpose();
46+
} else {
47+
arena_A.adj() += res.adj() * arena_B.val().transpose();
48+
}
49+
if constexpr (is_var_matrix<Mat2>::value) {
50+
arena_B.adj().noalias() += res.adj() * arena_A.val().transpose();
51+
} else {
52+
arena_B.adj() += res.adj() * arena_A.val().transpose();
53+
}
54+
});
55+
} else if constexpr (is_autodiff_v<Mat2>) {
56+
arena_t<Mat1> arena_A(std::forward<Mat1>(A));
57+
arena_t<Mat2> arena_B(std::forward<Mat2>(B));
58+
auto res_val = arena_A.cwiseProduct(arena_B.val().transpose()).sum();
59+
return make_callback_var(res_val, [arena_A, arena_B](auto&& res) mutable {
60+
if constexpr (is_var_matrix<Mat2>::value) {
61+
arena_B.adj().noalias() += res.adj() * arena_A.transpose();
62+
} else {
63+
arena_B.adj() += res.adj() * arena_A.transpose();
64+
}
65+
});
66+
} else {
67+
arena_t<Mat1> arena_A(std::forward<Mat1>(A));
68+
arena_t<Mat2> arena_B(std::forward<Mat2>(B));
69+
auto res_val = arena_A.val().cwiseProduct(arena_B.transpose()).sum();
70+
return make_callback_var(res_val, [arena_A, arena_B](auto&& res) mutable {
71+
if constexpr (is_var_matrix<Mat1>::value) {
72+
arena_A.adj().noalias() += res.adj() * arena_B.transpose();
73+
} else {
74+
arena_A.adj() += res.adj() * arena_B.transpose();
75+
}
76+
});
77+
}
78+
}
79+
80+
} // namespace math
81+
} // namespace stan
82+
#endif
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
3+
TEST(MathMixMatFun, traceDot) {
4+
auto f = [](const auto& x, const auto& y) {
5+
return stan::math::trace_dot(x, y);
6+
};
7+
8+
// 1x1
9+
Eigen::MatrixXd a11{{3}};
10+
Eigen::MatrixXd b11{{7}};
11+
stan::test::expect_ad(f, a11, b11);
12+
stan::test::expect_ad_matvar(f, a11, b11);
13+
14+
// 0x0
15+
Eigen::MatrixXd m00(0, 0);
16+
stan::test::expect_ad(f, m00, m00);
17+
stan::test::expect_ad_matvar(f, m00, m00);
18+
19+
// 2x2
20+
Eigen::MatrixXd a22{{1, 2}, {3, 4}};
21+
Eigen::MatrixXd b22{{5, 6}, {7, 8}};
22+
stan::test::expect_ad(f, a22, b22);
23+
stan::test::expect_ad_matvar(f, a22, b22);
24+
25+
// 2x3 times 3x2 (rectangular)
26+
Eigen::MatrixXd a23{{1, 2, 3}, {4, 5, 6}};
27+
Eigen::MatrixXd b32{{7, 8}, {9, 10}, {11, 12}};
28+
stan::test::expect_ad(f, a23, b32);
29+
stan::test::expect_ad_matvar(f, a23, b32);
30+
31+
// 3x2 times 2x3 (rectangular, transposed shape)
32+
stan::test::expect_ad(f, b32, a23);
33+
stan::test::expect_ad_matvar(f, b32, a23);
34+
35+
// 3x3
36+
Eigen::MatrixXd a33{{1, -2, 3}, {0.5, 7, -1}, {2, 0, 4}};
37+
Eigen::MatrixXd b33{{3, 1, -2}, {0, 5, 1}, {-1, 2, 6}};
38+
stan::test::expect_ad(f, a33, b33);
39+
stan::test::expect_ad_matvar(f, a33, b33);
40+
41+
// dimension mismatch: A.cols() != B.rows()
42+
stan::test::expect_ad(f, a22, b32);
43+
stan::test::expect_ad_matvar(f, a22, b32);
44+
45+
stan::test::expect_ad(f, a23, b22);
46+
stan::test::expect_ad_matvar(f, a23, b22);
47+
48+
// dimension mismatch: A.cols() == B.rows() but A.rows() != B.cols()
49+
stan::test::expect_ad(f, a23, a33);
50+
stan::test::expect_ad_matvar(f, a23, a33);
51+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include <stan/math/prim.hpp>
2+
#include <gtest/gtest.h>
3+
4+
TEST(MathMatrixPrim, trace_dot) {
5+
using stan::math::matrix_d;
6+
using stan::math::trace_dot;
7+
8+
matrix_d a{{1, 2, 3}, {4, 5, 6}};
9+
matrix_d b{{7, 8}, {9, 10}, {11, 12}};
10+
// trace(A * B) = trace([[58,64],[139,154]]) = 58 + 154 = 212
11+
EXPECT_FLOAT_EQ(212, trace_dot(a, b));
12+
}
13+
14+
TEST(MathMatrixPrim, trace_dot_square) {
15+
using stan::math::matrix_d;
16+
using stan::math::trace_dot;
17+
18+
matrix_d a{{1, 2}, {3, 4}};
19+
matrix_d b{{5, 6}, {7, 8}};
20+
21+
// trace(A * B) = trace([[19,22],[43,50]]) = 19 + 50 = 69
22+
EXPECT_FLOAT_EQ(69, trace_dot(a, b));
23+
}
24+
25+
TEST(MathMatrixPrim, trace_dot_1x1) {
26+
using stan::math::matrix_d;
27+
using stan::math::trace_dot;
28+
29+
matrix_d a{{3}};
30+
matrix_d b{{7}};
31+
32+
EXPECT_FLOAT_EQ(21, trace_dot(a, b));
33+
}
34+
35+
TEST(MathMatrixPrim, trace_dot_size_zero) {
36+
using stan::math::matrix_d;
37+
using stan::math::trace_dot;
38+
39+
matrix_d a00{};
40+
matrix_d b00{};
41+
EXPECT_FLOAT_EQ(0, trace_dot(a00, b00));
42+
}
43+
44+
TEST(MathMatrixPrim, trace_dot_dimension_mismatch) {
45+
using stan::math::matrix_d;
46+
using stan::math::trace_dot;
47+
48+
matrix_d a{{1, 2, 3}, {4, 5, 6}};
49+
matrix_d b{{1, 2, 3}, {4, 5, 6}};
50+
51+
// A.cols() != B.rows()
52+
EXPECT_THROW(trace_dot(a, b), std::invalid_argument);
53+
54+
// A.cols() == B.rows() but A.rows() != B.cols() (product not square)
55+
matrix_d c{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
56+
EXPECT_THROW(trace_dot(a, c), std::invalid_argument);
57+
}

0 commit comments

Comments
 (0)