11#ifndef AFMT_LINEAR
22#define AFMT_LINEAR
33
4+ #include < algorithm>
45#include < cassert>
56#include < vector>
67
@@ -24,14 +25,77 @@ struct Matrix {
2425 inline size_t m (void ) { return M[0 ].size (); }
2526};
2627
28+ // Strassen's algorithm for square matrices
29+ namespace detail {
30+ template <class T >
31+ Matrix<T> strassen (const Matrix<T> &A, const Matrix<T> &B, int threshold = 64 ) {
32+ size_t n = A.n ();
33+ if (n <= threshold) {
34+ Matrix<T> ans (n, n, T ());
35+ for (size_t i = 0 ; i < n; i++)
36+ for (size_t k = 0 ; k < n; k++) {
37+ const T &Aik = A.M [i][k];
38+ for (size_t j = 0 ; j < n; j++)
39+ ans[i][j] += Aik * B.M [k][j];
40+ }
41+ return ans;
42+ }
43+ size_t m = n / 2 ;
44+ auto sub = [&](const Matrix<T> &X, int r, int c) {
45+ Matrix<T> res (m, m);
46+ for (size_t i = 0 ; i < m; i++)
47+ for (size_t j = 0 ; j < m; j++)
48+ res[i][j] = X.M [i + r * m][j + c * m];
49+ return res;
50+ };
51+ auto add = [](const Matrix<T> &X, const Matrix<T> &Y) {
52+ Matrix<T> res (X.n (), X.m ());
53+ for (size_t i = 0 ; i < X.n (); i++)
54+ for (size_t j = 0 ; j < X.m (); j++)
55+ res[i][j] = X.M [i][j] + Y.M [i][j];
56+ return res;
57+ };
58+ auto subm = [](const Matrix<T> &X, const Matrix<T> &Y) {
59+ Matrix<T> res (X.n (), X.m ());
60+ for (size_t i = 0 ; i < X.n (); i++)
61+ for (size_t j = 0 ; j < X.m (); j++)
62+ res[i][j] = X.M [i][j] - Y.M [i][j];
63+ return res;
64+ };
65+ Matrix<T> A11 = sub (A, 0 , 0 ), A12 = sub (A, 0 , 1 ), A21 = sub (A, 1 , 0 ), A22 = sub (A, 1 , 1 );
66+ Matrix<T> B11 = sub (B, 0 , 0 ), B12 = sub (B, 0 , 1 ), B21 = sub (B, 1 , 0 ), B22 = sub (B, 1 , 1 );
67+ auto M1 = strassen (add (A11, A22), add (B11, B22), threshold);
68+ auto M2 = strassen (add (A21, A22), B11, threshold);
69+ auto M3 = strassen (A11, subm (B12, B22), threshold);
70+ auto M4 = strassen (A22, subm (B21, B11), threshold);
71+ auto M5 = strassen (add (A11, A12), B22, threshold);
72+ auto M6 = strassen (subm (A21, A11), add (B11, B12), threshold);
73+ auto M7 = strassen (subm (A12, A22), add (B21, B22), threshold);
74+ Matrix<T> ans (n, n);
75+ for (size_t i = 0 ; i < m; i++)
76+ for (size_t j = 0 ; j < m; j++) {
77+ ans[i][j] = M1[i][j] + M4[i][j] - M5[i][j] + M7[i][j];
78+ ans[i][j + m] = M3[i][j] + M5[i][j];
79+ ans[i + m][j] = M2[i][j] + M4[i][j];
80+ ans[i + m][j + m] = M1[i][j] - M2[i][j] + M3[i][j] + M6[i][j];
81+ }
82+ return ans;
83+ }
84+ }
85+
2786template <class T >
2887Matrix<T> operator *(Matrix<T> A, Matrix<T> B) {
88+ if (A.n () == A.m () && B.n () == B.m () && A.n () == B.n () && (A.n () & (A.n () - 1 )) == 0 ) {
89+ // Use Strassen for square matrices with size power of 2
90+ return detail::strassen (A, B);
91+ }
2992 assert (A.m () == B.n ());
3093 Matrix<T> ans (A.n (), B.m ());
31- for (size_t k = 0 ; k < A.m (); k++) {
32- for (size_t i = 0 ; i < A.n (); i++) {
94+ for (size_t i = 0 ; i < A.n (); i++) {
95+ for (size_t k = 0 ; k < A.m (); k++) {
96+ const T &Aik = A[i][k];
3397 for (size_t j = 0 ; j < B.m (); j++) {
34- ans[i][j] += A[i][k] * B[k][j];
98+ ans[i][j] += Aik * B[k][j];
3599 }
36100 }
37101 }
0 commit comments