Skip to content

Commit ea059ec

Browse files
author
Alexandre Hoffmann
committed
refactor: removed vector and only use Row/Col Matrices. Added a rebracketing module. STill uncertain on how to use it. TODO: the computing the matrices product sould take into acount if the matrix are unit matrices and the types used
1 parent d3680cd commit ea059ec

17 files changed

Lines changed: 531 additions & 24 deletions

include/FSLinalg/BasicLinalg.hpp

Lines changed: 0 additions & 1 deletion
This file was deleted.

include/FSLinalg/Matrix.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010
#include <FSLinalg/Matrix/VectorCross.hpp>
1111
#include <FSLinalg/Matrix/StripSymbolsAndEvalMatrix.hpp>
1212
#include <FSLinalg/Matrix/UnitMatrix.hpp>
13+
#include <FSLinalg/Matrix/MatrixProductAnalyzer.hpp>
14+
#include <FSLinalg/Matrix/MatrixProductChain.hpp>
1315

1416
#include <FSLinalg/Matrix/MatrixBase_impl.hpp>
1517
#include <FSLinalg/Matrix/MatrixConj_impl.hpp>
1618
#include <FSLinalg/Matrix/MatrixProduct_impl.hpp>
1719
#include <FSLinalg/Matrix/Matrix_impl.hpp>
1820
#include <FSLinalg/Matrix/MatrixTransposed_impl.hpp>
1921
#include <FSLinalg/Matrix/VectorCross_impl.hpp>
22+
#include <FSLinalg/Matrix/MatrixProductAnalyzer_impl.hpp>
23+
#include <FSLinalg/Matrix/MatrixProductChain_impl.hpp>

include/FSLinalg/Matrix/Matrix.hpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ class Matrix : public MatrixBase< Matrix<T, Nrows, Ncols> >
3333
using Self = Matrix<T, Nrows, Ncols>;
3434
FSLINALG_DEFINE_MATRIX
3535

36+
template<class Dst>
37+
struct CanBeAlisaedTo : std::bool_constant<
38+
IsMatrix<Dst>::value
39+
and Base::nRows == Dst::nRows
40+
and Base::nCols == Dst::nCols
41+
and std::is_same<Scalar, typename Dst::Scalar>::value > {};
42+
3643
static constexpr bool isScalarComplex = IsComplexScalar<Scalar>::value;
3744
static constexpr bool isVector = isRowVector or isColVector;
3845

@@ -46,11 +53,11 @@ class Matrix : public MatrixBase< Matrix<T, Nrows, Ncols> >
4653

4754
Matrix(const Matrix& other) : m_data(other.m_data) {}
4855

49-
template<class Expr> Matrix(const MatrixBase<Expr>& expr) requires(IsConstructibleFrom<Expr>::value) { expr.assignTo(1., *this, std::false_type{}); }
56+
template<class Expr> Matrix(const MatrixBase<Expr>& expr) requires(IsConstructibleFrom<Expr>::value) { expr.assignTo(Scalar(1), *this, std::false_type{}); }
5057

51-
template<class Expr> Matrix& operator= (const MatrixBase<Expr>& expr) requires(IsConstructibleFrom<Expr>::value) { expr.assignTo (1., *this, std::true_type{}); return *this; }
52-
template<class Expr> Matrix& operator+=(const MatrixBase<Expr>& expr) requires(IsConstructibleFrom<Expr>::value) { expr.increment (1., *this, std::true_type{}); return *this; }
53-
template<class Expr> Matrix& operator-=(const MatrixBase<Expr>& expr) requires(IsConstructibleFrom<Expr>::value) { expr.decrement (1., *this, std::true_type{}); return *this; }
58+
template<class Expr> Matrix& operator= (const MatrixBase<Expr>& expr) requires(IsConstructibleFrom<Expr>::value) { expr.assignTo (Scalar(1), *this, std::true_type{}); return *this; }
59+
template<class Expr> Matrix& operator+=(const MatrixBase<Expr>& expr) requires(IsConstructibleFrom<Expr>::value) { expr.increment (Scalar(1), *this, std::true_type{}); return *this; }
60+
template<class Expr> Matrix& operator-=(const MatrixBase<Expr>& expr) requires(IsConstructibleFrom<Expr>::value) { expr.decrement (Scalar(1), *this, std::true_type{}); return *this; }
5461

5562
Matrix& operator*=(const RealScalar& alpha) requires(isScalarComplex) { for (Size i=0; i!=size; ++i) { m_data[i] *= alpha; } return *this; }
5663
Matrix& operator/=(const RealScalar& alpha) requires(isScalarComplex) { for (Size i=0; i!=size; ++i) { m_data[i] /= alpha; } return *this; }
@@ -64,8 +71,8 @@ class Matrix : public MatrixBase< Matrix<T, Nrows, Ncols> >
6471
const_ReturnType getImpl(const Size i, const Size j) const { return m_data[i*nCols + j]; }
6572
ReturnType getImpl(const Size i, const Size j) { return m_data[i*nCols + j]; }
6673

67-
template<class Dst> bool isAliasedToImpl(const MatrixBase<Dst>& dst) const requires(nRows == Dst::nRows and nCols == Dst::nCols) { return std::addressof(dst.derived()) == this; }
68-
template<class Dst> constexpr bool isAliasedToImpl(const MatrixBase<Dst>& ) const requires(nRows != Dst::nRows or nCols != Dst::nCols) { return false; }
74+
template<class Dst> bool isAliasedToImpl(const MatrixBase<Dst>& dst) const requires( CanBeAlisaedTo<Dst>::value) { return std::addressof(dst.derived()) == this; }
75+
template<class Dst> constexpr bool isAliasedToImpl(const MatrixBase<Dst>& ) const requires(not CanBeAlisaedTo<Dst>::value) { return false; }
6976
private:
7077
std::array<Scalar, size> m_data;
7178
};

include/FSLinalg/Matrix/MatrixBase.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ class MatrixBase : public CRTPBase<Derived>
2525

2626
template<class Dst>
2727
struct IsConvertibleTo : std::bool_constant<
28-
std::is_base_of<MatrixBase<Dst>, Dst>::value
28+
std::is_base_of<MatrixBase<Dst>, Dst>::value
2929
and Dst::hasWriteRandomAccess
3030
and std::is_convertible<Scalar, typename Dst::Scalar>::value> {};
3131

3232
template<class Src>
3333
struct IsConstructibleFrom : std::bool_constant<
34-
std::is_base_of<MatrixBase<Src>, Src>::value
34+
std::is_base_of<MatrixBase<Src>, Src>::value
3535
and DerivedTraits::hasWriteRandomAccess
3636
and std::is_convertible<typename Src::Scalar, Scalar>::value> {};
3737

include/FSLinalg/Matrix/MatrixProduct.hpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
1-
#ifndef FSLINALG_MATRIX_MATRIX_PRODUCT_HPP
2-
#define FSLINALG_MATRIX_MATRIX_PRODUCT_HPP
1+
#ifndef FSLINALG_MATRIX_PRODUCT_HPP
2+
#define FSLINALG_MATRIX_PRODUCT_HPP
33

44
#include <FSLinalg/Matrix/MatrixBase.hpp>
55
#include <FSLinalg/Matrix/MatrixTransposed.hpp>
66
#include <FSLinalg/Matrix/StripSymbolsAndEvalMatrix.hpp>
7+
#include <FSLinalg/Matrix/MatrixProductAnalyzer.hpp>
78

89
namespace FSLinalg
910
{
1011

12+
namespace detail
13+
{
14+
15+
template<class Expr> struct MatrixProductAnalyzerImpl;
16+
17+
} // namespace detail
18+
1119
template<class Lhs, class Rhs> class MatrixProduct;
1220

1321
template<class Lhs, class Rhs>
@@ -19,7 +27,7 @@ struct MatrixTraits< MatrixProduct<Lhs, Rhs> >
1927
using Scalar = decltype(std::declval<typename Lhs::Scalar>() * std::declval<typename Rhs::Scalar>());
2028
using Size = std::common_type_t<typename Lhs::Size, typename Rhs::Size>;
2129

22-
static constexpr bool hasReadRandomAccess = Lhs::isRowVector and Rhs::isColVector and Lhs::hasFlatRandomAccess and Rhs::hasFlatRandomAccess;
30+
static constexpr bool hasReadRandomAccess = (Lhs::isRowVector and Lhs::hasFlatRandomAccess) or (Rhs::isColVector and Rhs::hasFlatRandomAccess);
2331
static constexpr bool hasWriteRandomAccess = false;
2432
static constexpr bool hasFlatRandomAccess = false;
2533
static constexpr bool causesAliasingIssues = true;
@@ -36,10 +44,15 @@ class MatrixProduct : public MatrixBase< MatrixProduct<Lhs,Rhs> >
3644
using Self = MatrixProduct<Lhs,Rhs>;
3745
FSLINALG_DEFINE_MATRIX
3846

47+
using OptimallyBracketedSelf = MatrixProductAnalyzer<Self>;
48+
49+
friend struct detail::MatrixProductAnalyzerImpl< Self >;
50+
3951
MatrixProduct(const MatrixBase<Lhs>& lhs, const MatrixBase<Rhs>& rhs) : m_lhs(lhs.derived()), m_rhs(rhs.derived()) {}
4052

4153
static constexpr bool createTemporaryLhs = StripSymbolsAndEvalMatrix<Lhs>::createsTemporary;
4254
static constexpr bool createTemporaryRhs = StripSymbolsAndEvalMatrix<Rhs>::createsTemporary;
55+
static constexpr bool isOptimallyBracked = std::is_same<Self, OptimallyBracketedSelf>::value;
4356

4457
/**
4558
* @brief Only awailable when multiplying row-vector and col-vector
@@ -86,4 +99,4 @@ MatrixProduct< Lhs,MatrixTransposed<Rhs> > outer(const MatrixBase<Lhs>& lhs, con
8699

87100
} // namespace FSLinalg
88101

89-
#endif // FSLINALG_MATRIX_MATRIX_PRODUCT_HPP
102+
#endif // FSLINALG_MATRIX_PRODUCT_HPP
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#ifndef FSLINALG_MATRIX_PRODUCT_TRAITS_HPP
2+
#define FSLINALG_MATRIX_PRODUCT_TRAITS_HPP
3+
4+
#include <cstddef>
5+
#include <type_traits>
6+
7+
#include <FSLinalg/Matrix/MatrixProductChain.hpp>
8+
9+
namespace FSLinalg
10+
{
11+
12+
template<class Lhs, class Rhs> class MatrixProduct;
13+
14+
namespace detail
15+
{
16+
17+
template<class Expr>
18+
struct MatrixProductAnalyzerImpl
19+
{
20+
static constexpr size_t length = 1;
21+
22+
template<size_t n> requires(n == 0)
23+
using NthMatrix = Expr;
24+
25+
template<size_t n>
26+
static constexpr const NthMatrix<n>& getMatrix(const Expr& expr, std::integral_constant<size_t, n>) { return expr; }
27+
28+
template<size_t idx, size_t chainLenP1> requires (idx+1 <chainLenP1)
29+
static constexpr void fillDims(std::integral_constant<size_t, idx>, std::array<size_t, chainLenP1>& dims) { dims[idx] = Expr::nRows; dims[idx+1] = Expr::nCols; }
30+
};
31+
32+
template<class Lhs, class Rhs>
33+
struct MatrixProductAnalyzerImpl< MatrixProduct<Lhs, Rhs> >
34+
{
35+
static constexpr size_t lhsLength = MatrixProductAnalyzerImpl<Lhs>::length;
36+
static constexpr size_t rhsLength = MatrixProductAnalyzerImpl<Rhs>::length;
37+
static constexpr size_t length = lhsLength + rhsLength;
38+
39+
template<size_t n, class Enable = void>
40+
struct NthMatrixHelper;
41+
42+
template<size_t n>
43+
struct NthMatrixHelper<n, std::enable_if_t<n < lhsLength>>
44+
{
45+
using Type = typename MatrixProductAnalyzerImpl<Lhs>::template NthMatrix<n>;
46+
};
47+
48+
template<size_t n>
49+
struct NthMatrixHelper<n, std::enable_if_t<lhsLength <= n and n < length>>
50+
{
51+
using Type = typename MatrixProductAnalyzerImpl<Rhs>::template NthMatrix<n - lhsLength>;
52+
};
53+
54+
template<size_t n>
55+
using NthMatrix = typename NthMatrixHelper<n>::Type;
56+
57+
template<size_t n>
58+
static constexpr const NthMatrix<n>& getMatrix(const MatrixProduct<Lhs, Rhs>& expr, std::integral_constant<size_t, n>);
59+
60+
template<size_t idx, size_t chainLenP1> requires (idx+1 <chainLenP1)
61+
static constexpr void fillDims(std::integral_constant<size_t, idx>, std::array<size_t, chainLenP1>& dims);
62+
};
63+
64+
} // namespace detail
65+
66+
template<class Expr>
67+
struct MatrixProductAnalyzer
68+
{
69+
static_assert(IsMatrix<Expr>::value, "Expr must be a matrix");
70+
71+
using Impl = detail::MatrixProductAnalyzerImpl<Expr>;
72+
using DimArray = std::array<size_t, Impl::length+1>;
73+
74+
template<size_t n> using NthMatrix = typename Impl::template NthMatrix<n>;
75+
76+
template<size_t n>
77+
static const NthMatrix<n>& getMatrix(const Expr& expr, std::integral_constant<size_t, n> ic) { return Impl::getMatrix(expr, ic); }
78+
79+
static constexpr size_t getLength() { return Impl::length; }
80+
81+
static constexpr DimArray getDims();
82+
83+
// we use an external template class as to not recompute everything for every product
84+
// if the optimal splitting for an chain with the same dims has already been computed
85+
// we can re-use it.
86+
static constexpr size_t getMinalCost() { return MatrixProductChain<getDims()>::template minCostAndSplit<0, getLength()>.first; }
87+
static constexpr size_t getOptimalSplit() { return MatrixProductChain<getDims()>::template minCostAndSplit<0, getLength()>.second; }
88+
private:
89+
template<size_t start, size_t end> requires(start <= end and end <= getLength())
90+
struct OptimalBracketingHelper
91+
{
92+
static constexpr size_t split = MatrixProductChain<getDims()>::template minCostAndSplit<start, end>.second;
93+
94+
static_assert(start <= split and split+1 < end+1, "invalid split");
95+
96+
using LhsBracketing = OptimalBracketingHelper<start, split>;
97+
using RhsBracketing = OptimalBracketingHelper<split, end>;
98+
99+
using Lhs = typename LhsBracketing::Type;
100+
using Rhs = typename RhsBracketing::Type;
101+
102+
using Type = MatrixProduct<Lhs, Rhs>;
103+
using ReBracketType = Type;
104+
105+
static ReBracketType reBracket(const Expr& expr) { return ReBracketType(LhsBracketing::reBracket(expr), RhsBracketing::reBracket(expr)); }
106+
};
107+
108+
template<size_t idx> requires(idx < getLength())
109+
struct OptimalBracketingHelper<idx,idx+1>
110+
{
111+
using Type = NthMatrix<idx>;
112+
using ReBracketType = const Type&;
113+
114+
static ReBracketType reBracket(const Expr& expr) { return MatrixProductAnalyzer<Expr>::getMatrix(expr, std::integral_constant<size_t, idx>{}); }
115+
};
116+
public:
117+
using OptimalBracketing = typename OptimalBracketingHelper<0, getLength()>::Type;
118+
using ReBracketType = typename OptimalBracketingHelper<0, getLength()>::ReBracketType;
119+
120+
static ReBracketType reBracket(const Expr& expr) { return OptimalBracketingHelper<0, getLength()>::reBracket(expr); }
121+
};
122+
123+
} // namespace FSLinalg
124+
125+
#endif // FSLINALG_MATRIX_PRODUCT_TRAITS_HPP
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#ifndef FSLINALG_MATRIX_PRODUCT_TRAITS_IMPL_HPP
2+
#define FSLINALG_MATRIX_PRODUCT_TRAITS_IMPL_HPP
3+
4+
#include <FSLinalg/Matrix/MatrixProductAnalyzer.hpp>
5+
6+
namespace FSLinalg
7+
{
8+
namespace detail
9+
{
10+
11+
template<class Lhs, class Rhs> template<size_t idx, size_t chainLenP1> requires (idx+1 <chainLenP1)
12+
constexpr void MatrixProductAnalyzerImpl< MatrixProduct<Lhs, Rhs> >::fillDims(std::integral_constant<size_t, idx> ic, std::array<size_t, chainLenP1>& dims)
13+
{
14+
if constexpr (IsMatrixProduct<Lhs>::value and IsMatrixProduct<Rhs>::value)
15+
{
16+
MatrixProductAnalyzerImpl<Lhs>::fillDims(ic, dims);
17+
MatrixProductAnalyzerImpl<Rhs>::fillDims(std::integral_constant<size_t, idx + lhsLength>{}, dims);
18+
}
19+
else if constexpr (IsMatrixProduct<Lhs>::value)
20+
{
21+
MatrixProductAnalyzerImpl<Lhs>::fillDims(ic, dims);
22+
dims[idx + lhsLength + 1] = Rhs::nCols;
23+
}
24+
else if constexpr (IsMatrixProduct<Rhs>::value)
25+
{
26+
dims[idx] = Lhs::nRows;
27+
MatrixProductAnalyzerImpl<Rhs>::fillDims(std::integral_constant<size_t, idx + 1>{}, dims);
28+
}
29+
else
30+
{
31+
dims[idx] = Lhs::nRows;
32+
dims[idx + 1] = Lhs::nCols;
33+
dims[idx + 2] = Rhs::nCols;
34+
}
35+
}
36+
37+
template<class Lhs, class Rhs> template<size_t n>
38+
constexpr auto MatrixProductAnalyzerImpl< MatrixProduct<Lhs, Rhs> >::getMatrix(const MatrixProduct<Lhs, Rhs>& expr, std::integral_constant<size_t, n> ic) -> const NthMatrix<n>&
39+
{
40+
if constexpr (n < lhsLength)
41+
{
42+
return MatrixProductAnalyzerImpl<Lhs>::getMatrix(expr.m_lhs, ic);
43+
}
44+
else
45+
{
46+
return MatrixProductAnalyzerImpl<Rhs>::getMatrix(expr.m_rhs, std::integral_constant<size_t, n - lhsLength>{});
47+
}
48+
}
49+
50+
} // namespace detail
51+
52+
template<class Expr>
53+
constexpr auto MatrixProductAnalyzer<Expr>::getDims() -> DimArray
54+
{
55+
DimArray dims;
56+
Impl::fillDims(std::integral_constant<size_t, 0>{}, dims);
57+
return dims;
58+
}
59+
60+
} // namespace FSLinalg
61+
62+
#endif // FSLINALG_MATRIX_PRODUCT_TRAITS_IMPL_HPP
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef FSLINALG_MATRIX_PRODUCT_CHAIN_HPP
2+
#define FSLINALG_MATRIX_PRODUCT_CHAIN_HPP
3+
4+
#include <array>
5+
6+
namespace FSLinalg
7+
{
8+
9+
template<std::array dims>
10+
struct MatrixProductChain
11+
{
12+
template<size_t i, size_t j>
13+
static constexpr std::pair<size_t, size_t> minMulCostAndSplitRec(std::integral_constant<size_t, i>, std::integral_constant<size_t, j>);
14+
15+
template<size_t i, size_t j>
16+
static constexpr std::pair<size_t, size_t> minCostAndSplit = minMulCostAndSplitRec(std::integral_constant<size_t, i>{}, std::integral_constant<size_t, j>{});
17+
};
18+
19+
} // namespace FSLinalg
20+
21+
#endif // FSLINALG_MATRIX_PRODUCT_CHAIN_HPP
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#ifndef FSLINALG_MATRIX_PRODUCT_CHAIN_IMPL_HPP
2+
#define FSLINALG_MATRIX_PRODUCT_CHAIN_IMPL_HPP
3+
4+
#include <FSLinalg/Matrix/MatrixProductChain.hpp>
5+
#include <FSLinalg/misc/StaticFor.hpp>
6+
7+
namespace FSLinalg
8+
{
9+
10+
template<std::array dims> template<size_t i, size_t j>
11+
constexpr std::pair<size_t, size_t> MatrixProductChain<dims>::minMulCostAndSplitRec(std::integral_constant<size_t, i>, std::integral_constant<size_t, j>)
12+
{
13+
if constexpr (i + 1 == j)
14+
{
15+
return std::make_pair(0u, i + 1);
16+
}
17+
else
18+
{
19+
size_t minCost = std::numeric_limits<size_t>::max();
20+
size_t optSpliting = i + 1;
21+
22+
misc::StaticFor<i + 1, j>::run([&minCost, &optSpliting](const auto k) -> void
23+
{
24+
constexpr size_t curr = minCostAndSplit<i, k>.first + minCostAndSplit<k, j>.first + dims[i]*dims[k]*dims[j];
25+
if (curr < minCost)
26+
{
27+
minCost = curr;
28+
optSpliting = k;
29+
}
30+
});
31+
32+
return std::make_pair(minCost, optSpliting);
33+
}
34+
}
35+
36+
} //namespace FSLinalg
37+
38+
#endif // FSLINALG_MATRIX_PRODUCT_CHAIN_IMPL_HPP

include/FSLinalg/Matrix/MatrixProduct_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef FSLINALG_MATRIX_MATRIX_PRODUCT_IMPL_HPP
2-
#define FSLINALG_MATRIX_MATRIX_PRODUCT_IMPL_HPP
1+
#ifndef FSLINALG_MATRIX_PRODUCT_IMPL_HPP
2+
#define FSLINALG_MATRIX_PRODUCT_IMPL_HPP
33

44
#include <FSLinalg/Matrix/MatrixProduct.hpp>
55
#include <FSLinalg/BasicLinalg/GeneralMatrixMatrixProduct.hpp>
@@ -86,4 +86,4 @@ void MatrixProduct<Lhs,Rhs>::decrementImpl(const Alpha& alpha, MatrixBase<Dst>&
8686

8787
} // namespace FSLinalg
8888

89-
#endif // FSLINALG_MATRIX_MATRIX_PRODUCT_IMPL_HPP
89+
#endif // FSLINALG_MATRIX_PRODUCT_IMPL_HPP

0 commit comments

Comments
 (0)