|
| 1 | +#ifndef STAN_MATH_PRIM_FUN_TRACE_DOT_HPP |
| 2 | +#define STAN_MATH_PRIM_FUN_TRACE_DOT_HPP |
| 3 | + |
| 4 | +#include <stan/math/prim/meta.hpp> |
| 5 | +#include <stan/math/prim/err.hpp> |
| 6 | +#include <stan/math/prim/fun/Eigen.hpp> |
| 7 | + |
| 8 | +namespace stan { |
| 9 | +namespace math { |
| 10 | + |
| 11 | +/** |
| 12 | + * Compute the trace of the product of two matrices, |
| 13 | + * \f$ \text{tr}(A \cdot B) = \sum_{i,j} A_{ij} B_{ji} \f$. |
| 14 | + * |
| 15 | + * This is more efficient than computing the full product and |
| 16 | + * taking the trace, as it avoids forming the intermediate matrix. |
| 17 | + * |
| 18 | + * @tparam EigMat1 A type either inheriting from `Eigen::DenseBase` or a |
| 19 | + * `var_value` with an inner type inheriting from `Eigen::DenseBase` |
| 20 | + * @tparam EigMat2 A type either inheriting from `Eigen::DenseBase` or a |
| 21 | + * `var_value` with an inner type inheriting from `Eigen::DenseBase` |
| 22 | + * |
| 23 | + * @param A first matrix (m x n) |
| 24 | + * @param B second matrix (n x m) |
| 25 | + * @return trace of A * B |
| 26 | + * @throw std::invalid_argument if A and B have incompatible dimensions |
| 27 | + */ |
| 28 | +template <typename EigMat1, typename EigMat2, |
| 29 | + require_all_eigen_vt<std::is_arithmetic, EigMat1, EigMat2>* = nullptr> |
| 30 | +inline auto trace_dot(EigMat1&& A, EigMat2&& B) { |
| 31 | + check_size_match("trace_dot", "A.cols()", A.cols(), "B.rows()", B.rows()); |
| 32 | + check_size_match("trace_dot", "A.rows()", A.rows(), "B.cols()", B.cols()); |
| 33 | + return make_holder( |
| 34 | + [](auto&& A_, auto&& B_) { |
| 35 | + return A_.cwiseProduct(B_.transpose()).sum(); |
| 36 | + }, |
| 37 | + std::forward<EigMat1>(A), std::forward<EigMat2>(B)); |
| 38 | +} |
| 39 | + |
| 40 | +} // namespace math |
| 41 | +} // namespace stan |
| 42 | + |
| 43 | +#endif |
0 commit comments