Skip to content

Commit 9b1a5c3

Browse files
committed
AI written strassen algorithm implementation
1 parent eb6186c commit 9b1a5c3

2 files changed

Lines changed: 68 additions & 4 deletions

File tree

src/alfred/all

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "algorithm/utils.hpp"
66

77
#include "data_structure/appear-statistics.hpp"
8-
#include "data_structure/binary-trie.hpp"
8+
// #include "data_structure/binary-trie.hpp" // Temporary remove trie.
99
#include "data_structure/chtholly.hpp"
1010
#include "data_structure/discretization.hpp"
1111
#include "data_structure/dsu/cancel-dsu.hpp"

src/alfred/math/linear.hpp

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef AFMT_LINEAR
22
#define AFMT_LINEAR
33

4+
#include <algorithm>
45
#include <cassert>
56
#include <vector>
67

@@ -24,14 +25,77 @@ struct Matrix {
2425
inline size_t m(void) { return M[0].size(); }
2526
};
2627

28+
// Strassen's algorithm for square matrices
29+
namespace detail {
30+
template <class T>
31+
Matrix<T> strassen(const Matrix<T> &A, const Matrix<T> &B, int threshold = 64) {
32+
size_t n = A.n();
33+
if (n <= threshold) {
34+
Matrix<T> ans(n, n, T());
35+
for (size_t i = 0; i < n; i++)
36+
for (size_t k = 0; k < n; k++) {
37+
const T &Aik = A.M[i][k];
38+
for (size_t j = 0; j < n; j++)
39+
ans[i][j] += Aik * B.M[k][j];
40+
}
41+
return ans;
42+
}
43+
size_t m = n / 2;
44+
auto sub = [&](const Matrix<T> &X, int r, int c) {
45+
Matrix<T> res(m, m);
46+
for (size_t i = 0; i < m; i++)
47+
for (size_t j = 0; j < m; j++)
48+
res[i][j] = X.M[i + r * m][j + c * m];
49+
return res;
50+
};
51+
auto add = [](const Matrix<T> &X, const Matrix<T> &Y) {
52+
Matrix<T> res(X.n(), X.m());
53+
for (size_t i = 0; i < X.n(); i++)
54+
for (size_t j = 0; j < X.m(); j++)
55+
res[i][j] = X.M[i][j] + Y.M[i][j];
56+
return res;
57+
};
58+
auto subm = [](const Matrix<T> &X, const Matrix<T> &Y) {
59+
Matrix<T> res(X.n(), X.m());
60+
for (size_t i = 0; i < X.n(); i++)
61+
for (size_t j = 0; j < X.m(); j++)
62+
res[i][j] = X.M[i][j] - Y.M[i][j];
63+
return res;
64+
};
65+
Matrix<T> A11 = sub(A, 0, 0), A12 = sub(A, 0, 1), A21 = sub(A, 1, 0), A22 = sub(A, 1, 1);
66+
Matrix<T> B11 = sub(B, 0, 0), B12 = sub(B, 0, 1), B21 = sub(B, 1, 0), B22 = sub(B, 1, 1);
67+
auto M1 = strassen(add(A11, A22), add(B11, B22), threshold);
68+
auto M2 = strassen(add(A21, A22), B11, threshold);
69+
auto M3 = strassen(A11, subm(B12, B22), threshold);
70+
auto M4 = strassen(A22, subm(B21, B11), threshold);
71+
auto M5 = strassen(add(A11, A12), B22, threshold);
72+
auto M6 = strassen(subm(A21, A11), add(B11, B12), threshold);
73+
auto M7 = strassen(subm(A12, A22), add(B21, B22), threshold);
74+
Matrix<T> ans(n, n);
75+
for (size_t i = 0; i < m; i++)
76+
for (size_t j = 0; j < m; j++) {
77+
ans[i][j] = M1[i][j] + M4[i][j] - M5[i][j] + M7[i][j];
78+
ans[i][j + m] = M3[i][j] + M5[i][j];
79+
ans[i + m][j] = M2[i][j] + M4[i][j];
80+
ans[i + m][j + m] = M1[i][j] - M2[i][j] + M3[i][j] + M6[i][j];
81+
}
82+
return ans;
83+
}
84+
}
85+
2786
template <class T>
2887
Matrix<T> operator*(Matrix<T> A, Matrix<T> B) {
88+
if (A.n() == A.m() && B.n() == B.m() && A.n() == B.n() && (A.n() & (A.n() - 1)) == 0) {
89+
// Use Strassen for square matrices with size power of 2
90+
return detail::strassen(A, B);
91+
}
2992
assert(A.m() == B.n());
3093
Matrix<T> ans(A.n(), B.m());
31-
for (size_t k = 0; k < A.m(); k++) {
32-
for (size_t i = 0; i < A.n(); i++) {
94+
for (size_t i = 0; i < A.n(); i++) {
95+
for (size_t k = 0; k < A.m(); k++) {
96+
const T &Aik = A[i][k];
3397
for (size_t j = 0; j < B.m(); j++) {
34-
ans[i][j] += A[i][k] * B[k][j];
98+
ans[i][j] += Aik * B[k][j];
3599
}
36100
}
37101
}

0 commit comments

Comments
 (0)