Skip to content

Commit 38b002c

Browse files
author
Alexandre Hoffmann
committed
refactoring
1 parent e911c05 commit 38b002c

56 files changed

Lines changed: 1138 additions & 1433 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

demo/demo_Matrix.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,31 @@
1-
#include <FSLinalg/Vector.hpp>
21
#include <FSLinalg/Matrix.hpp>
32

43
int main()
54
{
6-
FSLinalg::UnitVector<4> e0(0);
7-
FSLinalg::RealVector<3> a({2, 3, 0});
5+
FSLinalg::UnitRowVector<4> e0(0);
6+
FSLinalg::RealRowVector<3> a({2, 3, 0});
87

98
const auto expr1 = 0.5*FSLinalg::outer(e0, a)*a;
109

11-
fmt::print("expr = {}\n", FSLinalg::RealVector<4>(expr1));
12-
fmt::print("expr.createTemporaryMatrix = {}\n", expr1.createTemporaryMatrix);
13-
fmt::print("expr.createTemporaryVector = {}\n", expr1.createTemporaryVector);
10+
fmt::print("expr = {}\n", FSLinalg::RealRowVector<4>(expr1));
11+
fmt::print("expr.createTemporaryMatrix = {}\n", expr1.createTemporaryLhs);
12+
fmt::print("expr.createTemporaryVector = {}\n", expr1.createTemporaryRhs);
1413

1514
const auto expr2 = 0.5*FSLinalg::transpose(FSLinalg::outer(e0, a))*e0;
1615

17-
fmt::print("expr = {}\n", FSLinalg::RealVector<3>(expr2));
18-
fmt::print("expr.createTemporaryMatrix = {}\n", expr2.createTemporaryMatrix);
19-
fmt::print("expr.createTemporaryVector = {}\n", expr2.createTemporaryVector);
16+
fmt::print("expr = {}\n", FSLinalg::RealRowVector<3>(expr2));
17+
fmt::print("expr.createTemporaryMatrix = {}\n", expr2.createTemporaryLhs);
18+
fmt::print("expr.createTemporaryVector = {}\n", expr2.createTemporaryRhs);
2019

2120
FSLinalg::RealMatrix<4, 3> A(FSLinalg::outer(e0, a));
2221

2322
fmt::print("{}\n", A);
2423

2524
const auto expr3 = 0.5*FSLinalg::transpose(A)*e0;
2625

27-
fmt::print("expr = {}\n", FSLinalg::RealVector<3>(expr3));
28-
fmt::print("expr.createTemporaryMatrix = {}\n", expr3.createTemporaryMatrix);
29-
fmt::print("expr.createTemporaryVector = {}\n", expr3.createTemporaryVector);
26+
fmt::print("expr = {}\n", FSLinalg::RealRowVector<3>(expr3));
27+
fmt::print("expr.createTemporaryMatrix = {}\n", expr3.createTemporaryLhs);
28+
fmt::print("expr.createTemporaryVector = {}\n", expr3.createTemporaryRhs);
3029

3130
return EXIT_SUCCESS;
3231
}

demo/demo_Vector.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
#include <FSLinalg/Vector.hpp>
1+
#include <FSLinalg/Matrix.hpp>
22

33
int main()
44
{
5-
FSLinalg::UnitVector<3> e0(0);
6-
FSLinalg::RealVector<3> one(1);
7-
FSLinalg::RealVector<3> a({2, 3, 0});
5+
FSLinalg::UnitRowVector<3> e0(0);
6+
FSLinalg::RealRowVector<3> one(1);
7+
FSLinalg::RealRowVector<3> a({2, 3, 0});
88

99
const auto expr = a + 3.*e0 - one/2.;
1010

1111
fmt::print("{} + 3*{} - {}/2 = {}\n", a, e0, one, expr);
1212
fmt::print("expr is aliased to a = {}\n", expr.isAliasedTo(a));
1313

1414
const auto expr2 = -2.*FSLinalg::cross(a, one);
15-
FSLinalg::RealVector<3> c = expr2;
15+
FSLinalg::RealRowVector<3> c = expr2;
1616

1717
fmt::print("cross({}, {}) = {}\n", a, one, c);
1818
fmt::print("expr2 is aliased to a = {}\n", expr2.isAliasedTo(a));
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#ifndef FSLINALG_BASIC_LINALG_GENERAL_MATRIX_MATRIX_PRODUCT_HPP
2+
#define FSLINALG_BASIC_LINALG_GENERAL_MATRIX_MATRIX_PRODUCT_HPP
3+
4+
#include <FSLinalg/Scalar.hpp>
5+
#include <FSLinalg/Matrix.hpp>
6+
7+
namespace FSLinalg
8+
{
9+
namespace BasicLinalg
10+
{
11+
12+
template<bool transposeA, bool conjugateA, unsigned int nRowsA, unsigned int nColsA, bool transposeB, bool conjugateB, unsigned int nRowsB, unsigned int nColsB, bool incrDst>
13+
struct GeneralMatrixMatrixProduct
14+
{
15+
using Size = unsigned int;
16+
17+
static constexpr Size nRowsOpA = (not transposeA) ? nRowsA : nColsA;
18+
static constexpr Size nColsOpA = (not transposeA) ? nColsA : nRowsA;
19+
static constexpr Size nRowsOpB = (not transposeB) ? nRowsB : nColsB;
20+
static constexpr Size nColsOpB = (not transposeB) ? nColsB : nRowsB;
21+
static constexpr Size nRowsY = nRowsOpA;
22+
static constexpr Size nColsY = nColsOpB;
23+
24+
static_assert(nColsOpA == nRowsOpB, "Matrices sizes must match");
25+
26+
template<Scalar_concept ScalarAlpha, Scalar_concept ScalarA, Scalar_concept ScalarB, Scalar_concept ScalarY>
27+
static void run(const ScalarAlpha& alpha, const Matrix<ScalarA,nRowsA,nColsA>& A, const Matrix<ScalarB,nRowsB,nColsB>& B, Matrix<ScalarY,nRowsY,nColsY>& Y);
28+
29+
template<Scalar_concept ScalarAlpha, Scalar_concept ScalarA, Scalar_concept ScalarY>
30+
static void run(const ScalarAlpha& alpha, const Matrix<ScalarA,nRowsA,nColsA>& A, const UnitMatrix<nRowsB,nColsB>& B, Matrix<ScalarY,nRowsY,nColsY>& Y);
31+
32+
template<Scalar_concept ScalarAlpha, Scalar_concept ScalarB, Scalar_concept ScalarY>
33+
static void run(const ScalarAlpha& alpha, const UnitMatrix<nRowsA,nColsA>& A, const Matrix<ScalarB,nRowsB,nColsB>& B, Matrix<ScalarY,nRowsY,nColsY>& Y);
34+
35+
template<Scalar_concept ScalarAlpha, Scalar_concept ScalarY>
36+
static void run(const ScalarAlpha& alpha, const UnitMatrix<nRowsA,nColsA>& A, const UnitMatrix<nRowsB,nColsB>& B, Matrix<ScalarY,nRowsY,nColsY>& Y);
37+
};
38+
39+
} // namespace BasicLinalg
40+
} // namespace FSLinalg
41+
42+
#include <FSLinalg/BasicLinalg/GeneralMatrixMatrixProduct_impl.hpp>
43+
44+
#endif // FSLINALG_BASIC_LINALG_GENERAL_MATRIX_MATRIX_PRODUCT_HPP
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#ifndef FSLINALG_BASIC_LINALG_GENERAL_MATRIX_MATRIX_PRODUCT_IMPL_HPP
2+
#define FSLINALG_BASIC_LINALG_GENERAL_MATRIX_MATRIX_PRODUCT_IMPL_HPP
3+
4+
#include <FSLinalg/BasicLinalg/GeneralMatrixMatrixProduct.hpp>
5+
#include <FSLinalg/BasicLinalg/TripleProduct.hpp>
6+
#include <FSLinalg/BasicLinalg/Product.hpp>
7+
8+
namespace FSLinalg
9+
{
10+
namespace BasicLinalg
11+
{
12+
13+
template<bool transposeA, bool conjugateA, unsigned int nRowsA, unsigned int nColsA, bool transposeB, bool conjugateB, unsigned int nRowsB, unsigned int nColsB, bool incrDst>
14+
template<Scalar_concept ScalarAlpha, Scalar_concept ScalarA, Scalar_concept ScalarB, Scalar_concept ScalarY>
15+
void GeneralMatrixMatrixProduct<transposeA,conjugateA,nRowsA,nColsA,transposeB,conjugateB,nRowsB,nColsB,incrDst>::run(
16+
const ScalarAlpha& alpha,
17+
const Matrix<ScalarA,nRowsA,nColsA>& A,
18+
const Matrix<ScalarB,nRowsB,nColsB>& B,
19+
Matrix<ScalarY,nRowsY,nColsY>& Y)
20+
{
21+
constexpr Size A_iStride = (not transposeA) ? nColsA : 1;
22+
constexpr Size A_kStride = (not transposeA) ? 1 : nColsA;
23+
constexpr Size B_kStride = (not transposeB) ? nColsB : 1;
24+
constexpr Size B_jStride = (not transposeB) ? 1 : nColsB;
25+
26+
constexpr TripleProduct<false, conjugateA, conjugateB> prod;
27+
28+
if constexpr (not incrDst)
29+
{
30+
for (Size i=0; i!=nRowsY*nColsY; ++i) { Y[i] = 0; }
31+
}
32+
33+
for (Size i=0; i!=nRowsY; ++i)
34+
{
35+
for (Size k=0; k !=nColsOpA; ++k)
36+
{
37+
for (Size j=0; j!=nColsY; ++j)
38+
{
39+
Y(i,j) += prod(alpha, A[i*A_iStride + k*A_kStride], B[k*B_kStride + j*B_jStride]);
40+
}
41+
}
42+
}
43+
}
44+
45+
template<bool transposeA, bool conjugateA, unsigned int nRowsA, unsigned int nColsA, bool transposeB, bool conjugateB, unsigned int nRowsB, unsigned int nColsB, bool incrDst>
46+
template<Scalar_concept ScalarAlpha, Scalar_concept ScalarA, Scalar_concept ScalarY>
47+
void GeneralMatrixMatrixProduct<transposeA,conjugateA,nRowsA,nColsA,transposeB,conjugateB,nRowsB,nColsB,incrDst>::run(
48+
const ScalarAlpha& alpha,
49+
const Matrix<ScalarA,nRowsA,nColsA>& A,
50+
const UnitMatrix<nRowsB,nColsB>& B,
51+
Matrix<ScalarY,nRowsY,nColsY>& Y)
52+
{
53+
constexpr Size A_iStride = (not transposeA) ? nColsA : 1;
54+
constexpr Size A_kStride = (not transposeA) ? 1 : nColsA;
55+
56+
constexpr Product<false, conjugateA> prod;
57+
58+
if constexpr (not incrDst)
59+
{
60+
for (Size i=0; i!=nRowsY*nColsY; ++i) { Y[i] = 0; }
61+
}
62+
63+
const Size k = (not transposeB) ? B.getId().i : B.getId().j;
64+
const Size j = (not transposeB) ? B.getId().j : B.getId().i;
65+
66+
for (Size i=0; i!=nRowsY; ++i)
67+
{
68+
Y(i,j) += prod(alpha, A[i*A_iStride + k*A_kStride]);
69+
}
70+
}
71+
72+
template<bool transposeA, bool conjugateA, unsigned int nRowsA, unsigned int nColsA, bool transposeB, bool conjugateB, unsigned int nRowsB, unsigned int nColsB, bool incrDst>
73+
template<Scalar_concept ScalarAlpha, Scalar_concept ScalarB, Scalar_concept ScalarY>
74+
void GeneralMatrixMatrixProduct<transposeA,conjugateA,nRowsA,nColsA,transposeB,conjugateB,nRowsB,nColsB,incrDst>::run(
75+
const ScalarAlpha& alpha,
76+
const UnitMatrix<nRowsA,nColsA>& A,
77+
const Matrix<ScalarB,nRowsB,nColsB>& B,
78+
Matrix<ScalarY,nRowsY,nColsY>& Y)
79+
{
80+
constexpr Size B_kStride = (not transposeB) ? nColsB : 1;
81+
constexpr Size B_jStride = (not transposeB) ? 1 : nColsB;
82+
83+
constexpr Product<false, conjugateB> prod;
84+
85+
if constexpr (not incrDst)
86+
{
87+
for (Size i=0; i!=nRowsY*nColsY; ++i) { Y[i] = 0; }
88+
}
89+
90+
const Size i = (not transposeA) ? A.getId().i : A.getId().j;
91+
const Size k = (not transposeA) ? A.getId().j : A.getId().i;
92+
93+
94+
for (Size j=0; j!=nColsY; ++j)
95+
{
96+
Y(i,j) += prod(alpha, B[k*B_kStride + j*B_jStride]);
97+
}
98+
}
99+
100+
template<bool transposeA, bool conjugateA, unsigned int nRowsA, unsigned int nColsA, bool transposeB, bool conjugateB, unsigned int nRowsB, unsigned int nColsB, bool incrDst>
101+
template<Scalar_concept ScalarAlpha, Scalar_concept ScalarY>
102+
void GeneralMatrixMatrixProduct<transposeA,conjugateA,nRowsA,nColsA,transposeB,conjugateB,nRowsB,nColsB,incrDst>::run(
103+
const ScalarAlpha& alpha,
104+
const UnitMatrix<nRowsA,nColsA>& A,
105+
const UnitMatrix<nRowsB,nColsB>& B,
106+
Matrix<ScalarY,nRowsY,nColsY>& Y)
107+
{
108+
if constexpr (not incrDst)
109+
{
110+
for (Size i=0; i!=nRowsY*nColsY; ++i) { Y[i] = 0; }
111+
}
112+
113+
const Size i = (not transposeA) ? A.getId().i : A.getId().j;
114+
const Size j = (not transposeB) ? B.getId().j : B.getId().i;
115+
const Size k1 = (not transposeA) ? A.getId().j : A.getId().i;
116+
const Size k2 = (not transposeB) ? B.getId().i : B.getId().j;
117+
118+
Y(i,j) += alpha*(k1 == k2);
119+
}
120+
121+
} // namespace BasicLinalg
122+
} // namespace FSLinalg
123+
124+
#endif // FSLINALG_BASIC_LINALG_GENERAL_MATRIX_MATRIX_PRODUCT_IMPL_HPP

include/FSLinalg/BasicLinalg/GeneralMatrixVectorProduct.hpp

Lines changed: 0 additions & 94 deletions
This file was deleted.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef FSLINALG_INNER_PRODUCT_HPP
2+
#define FSLINALG_INNER_PRODUCT_HPP
3+
4+
#include <FSLinalg/Vector/VectorBase.hpp>
5+
#include <FSLinalg/Matrix/MatrixBase.hpp>
6+
7+
namespace FSLinalg
8+
{
9+
10+
template<class Lhs, class Rhs>
11+
using InnerProductScalar = decltype(conj(std::declval<typename Lhs::Scalar>()) * std::declval<typename Rhs::Scalar>());
12+
13+
template<class Lhs, class Rhs> InnerProductScalar<Lhs,Rhs> inner(const MatrixBase<Lhs>& base_lhs, const MatrixBase<Rhs>& base_rhs);
14+
15+
16+
} // namespace FSLinalg
17+
18+
#include <FSLinalg/BasicLinalg/InnerProduct_impl.hpp>
19+
20+
#endif // FSLINALG_INNER_PRODUCT_HPP

0 commit comments

Comments
 (0)