Skip to content

Commit 3a66a33

Browse files
authored
Merge pull request #1754 from bstatcomp/generalize_fun_cr_di
Generalize */fun starting with cr-d
2 parents 48c75fc + 44fdfd7 commit 3a66a33

26 files changed

Lines changed: 165 additions & 506 deletions

stan/math/fwd/fun.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@
1717
#include <stan/math/fwd/fun/ceil.hpp>
1818
#include <stan/math/fwd/fun/cos.hpp>
1919
#include <stan/math/fwd/fun/cosh.hpp>
20-
#include <stan/math/fwd/fun/crossprod.hpp>
2120
#include <stan/math/fwd/fun/determinant.hpp>
2221
#include <stan/math/fwd/fun/digamma.hpp>
23-
#include <stan/math/fwd/fun/divide.hpp>
24-
#include <stan/math/fwd/fun/dot_self.hpp>
2522
#include <stan/math/fwd/fun/Eigen_NumTraits.hpp>
2623
#include <stan/math/fwd/fun/erf.hpp>
2724
#include <stan/math/fwd/fun/erfc.hpp>
@@ -96,7 +93,6 @@
9693
#include <stan/math/fwd/fun/softmax.hpp>
9794
#include <stan/math/fwd/fun/sqrt.hpp>
9895
#include <stan/math/fwd/fun/square.hpp>
99-
#include <stan/math/fwd/fun/squared_distance.hpp>
10096
#include <stan/math/fwd/fun/sum.hpp>
10197
#include <stan/math/fwd/fun/tan.hpp>
10298
#include <stan/math/fwd/fun/tanh.hpp>

stan/math/fwd/fun/crossprod.hpp

Lines changed: 0 additions & 22 deletions
This file was deleted.

stan/math/fwd/fun/determinant.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
#ifndef STAN_MATH_FWD_FUN_DETERMINANT_HPP
22
#define STAN_MATH_FWD_FUN_DETERMINANT_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/err.hpp>
56
#include <stan/math/prim/fun/Eigen.hpp>
67
#include <stan/math/fwd/core.hpp>
78

89
namespace stan {
910
namespace math {
1011

11-
template <typename T, int R, int C>
12-
inline fvar<T> determinant(const Eigen::Matrix<fvar<T>, R, C>& m) {
12+
template <typename EigMat, require_eigen_vt<is_fvar, EigMat>* = nullptr>
13+
inline value_type_t<EigMat> determinant(const EigMat& m) {
1314
check_square("determinant", "m", m);
1415

15-
const T vals = m.val().determinant();
16-
return fvar<T>(vals, vals * (m.val().inverse() * m.d()).trace());
16+
const typename value_type_t<EigMat>::Scalar vals = m.val().determinant();
17+
return {vals, vals * (m.val().inverse() * m.d()).trace()};
1718
}
1819

1920
} // namespace math

stan/math/fwd/fun/divide.hpp

Lines changed: 0 additions & 56 deletions
This file was deleted.

stan/math/fwd/fun/dot_self.hpp

Lines changed: 0 additions & 20 deletions
This file was deleted.

stan/math/fwd/fun/squared_distance.hpp

Lines changed: 0 additions & 167 deletions
This file was deleted.

stan/math/fwd/fun/tcrossprod.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
#ifndef STAN_MATH_FWD_FUN_TCROSSPROD_HPP
22
#define STAN_MATH_FWD_FUN_TCROSSPROD_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/fun/Eigen.hpp>
56
#include <stan/math/prim/fun/transpose.hpp>
67
#include <stan/math/fwd/fun/multiply.hpp>
78

89
namespace stan {
910
namespace math {
1011

11-
template <typename T, int R, int C>
12-
inline Eigen::Matrix<fvar<T>, R, R> tcrossprod(
13-
const Eigen::Matrix<fvar<T>, R, C>& m) {
12+
template <typename EigMat, require_eigen_vt<is_fvar, EigMat>* = nullptr>
13+
inline Eigen::Matrix<value_type_t<EigMat>, EigMat::RowsAtCompileTime,
14+
EigMat::RowsAtCompileTime>
15+
tcrossprod(const EigMat& m) {
1416
if (m.rows() == 0) {
1517
return {};
1618
}
17-
return multiply(m, transpose(m));
19+
return multiply(m, m.transpose());
1820
}
1921

2022
} // namespace math

stan/math/fwd/fun/unit_vector_constrain.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
#define STAN_MATH_FWD_FUN_UNIT_VECTOR_CONSTRAIN_HPP
33

44
#include <stan/math/fwd/core.hpp>
5-
#include <stan/math/fwd/fun/divide.hpp>
6-
#include <stan/math/fwd/fun/dot_self.hpp>
75
#include <stan/math/fwd/fun/tcrossprod.hpp>
86
#include <stan/math/fwd/fun/sqrt.hpp>
97
#include <stan/math/prim/fun/divide.hpp>
8+
#include <stan/math/prim/fun/dot_self.hpp>
109
#include <stan/math/prim/fun/Eigen.hpp>
1110
#include <stan/math/prim/fun/inv.hpp>
1211
#include <stan/math/prim/fun/unit_vector_constrain.hpp>

stan/math/prim/fun/crossprod.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef STAN_MATH_PRIM_FUN_CROSSPROD_HPP
22
#define STAN_MATH_PRIM_FUN_CROSSPROD_HPP
33

4+
#include <stan/math/prim/meta.hpp>
45
#include <stan/math/prim/fun/typedefs.hpp>
56
#include <stan/math/prim/fun/tcrossprod.hpp>
67

@@ -11,11 +12,13 @@ namespace math {
1112
* Returns the result of pre-multiplying a matrix by its
1213
* own transpose.
1314
*
15+
* @tparam EigMat type of the matrix (must be derived from \c Eigen::MatrixBase)
1416
* @param M Matrix to multiply.
1517
* @return Transpose of M times M
1618
*/
17-
inline matrix_d crossprod(const matrix_d& M) {
18-
return tcrossprod(static_cast<matrix_d>(M.transpose()));
19+
template <typename EigMat, require_eigen_t<EigMat>* = nullptr>
20+
inline auto crossprod(const EigMat& M) {
21+
return tcrossprod(M.transpose());
1922
}
2023

2124
} // namespace math

0 commit comments

Comments
 (0)