Skip to content

Commit 4c9e72a

Browse files
committed
Merge branch 'strassen_matrix_multiply' of https://github.com/sourav-625/Python into strassen_matrix_multiply
2 parents 981f0c9 + 7889f30 commit 4c9e72a

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

matrix/strassen_matrix_multiply.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,21 @@
1212
https://en.wikipedia.org/wiki/Strassen_algorithm
1313
"""
1414

15-
1615
Matrix = list[list[int]]
1716

1817
def add(A: Matrix, B: Matrix) -> Matrix:
1918
n = len(A)
2019
return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)]
2120

21+
2222
def sub(A: Matrix, B: Matrix) -> Matrix:
2323
n = len(A)
2424
return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)]
2525

26+
2627
def naive_mul(A: Matrix, B: Matrix) -> Matrix:
2728
n = len(A)
28-
C = [[0]*n for _ in range(n)]
29+
C = [[0] * n for _ in range(n)]
2930
for i in range(n):
3031
ai = A[i]
3132
ci = C[i]
@@ -36,23 +37,27 @@ def naive_mul(A: Matrix, B: Matrix) -> Matrix:
3637
ci[j] += a_ik * bk[j]
3738
return C
3839

40+
3941
def next_power_of_two(n: int) -> int:
4042
p = 1
4143
while p < n:
4244
p <<= 1
4345
return p
4446

47+
4548
def pad_matrix(A: Matrix, size: int) -> Matrix:
4649
n = len(A)
47-
padded = [[0]*size for _ in range(size)]
50+
padded = [[0] * size for _ in range(size)]
4851
for i in range(n):
4952
for j in range(len(A[0])):
5053
padded[i][j] = A[i][j]
5154
return padded
5255

56+
5357
def unpad_matrix(A: Matrix, rows: int, cols: int) -> Matrix:
5458
return [row[:cols] for row in A[:rows]]
5559

60+
5661
def split(A: Matrix) -> tuple:
5762
n = len(A)
5863
mid = n // 2
@@ -62,10 +67,11 @@ def split(A: Matrix) -> tuple:
6267
A22 = [[A[i][j] for j in range(mid, n)] for i in range(mid, n)]
6368
return A11, A12, A21, A22
6469

70+
6571
def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix:
6672
n2 = len(C11)
6773
n = n2 * 2
68-
C = [[0]*n for _ in range(n)]
74+
C = [[0] * n for _ in range(n)]
6975
for i in range(n2):
7076
for j in range(n2):
7177
C[i][j] = C11[i][j]
@@ -74,19 +80,21 @@ def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix:
7480
C[i + n2][j + n2] = C22[i][j]
7581
return C
7682

83+
7784
def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix:
7885
"""
7986
Multiply square matrices A and B using Strassen algorithm.
8087
threshold: below this size, uses naive multiplication (tweakable).
8188
"""
82-
assert len(A) == len(A[0]) == len(B) == len(B[0]), "Only square matrices supported in this implementation"
89+
assert len(A) == len(A[0]) == len(B) == len(B[0]), (
90+
"Only square matrices supported in this implementation"
91+
)
8392

8493
n_orig = len(A)
8594
if n_orig == 0:
8695
return []
8796

88-
m = next_power_of_two(n_orig)
89-
if m != n_orig:
97+
if (m := next_power_of_two(n_orig)) != n_orig:
9098
A_pad = pad_matrix(A, m)
9199
B_pad = pad_matrix(B, m)
92100
else:
@@ -97,6 +105,7 @@ def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix:
97105
C = unpad_matrix(C_pad, n_orig, n_orig)
98106
return C
99107

108+
100109
def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
101110
n = len(A)
102111
if n <= threshold:
@@ -122,17 +131,10 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
122131

123132
return join(C11, C12, C21, C22)
124133

134+
125135
if __name__ == "__main__":
126-
A = [
127-
[1, 2, 3],
128-
[4, 5, 6],
129-
[7, 8, 9]
130-
]
131-
B = [
132-
[9, 8, 7],
133-
[6, 5, 4],
134-
[3, 2, 1]
135-
]
136+
A = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
137+
B = [[9, 8, 7], [6, 5, 4], [3, 2, 1]]
136138

137139
C = strassen(A, B, threshold=1)
138140
print("A * B =")
@@ -142,4 +144,4 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
142144
# verify against naive
143145
expected = naive_mul(A, B)
144146
assert C == expected, "Strassen result differs from naive multiplication!"
145-
print("Verified: result matches naive multiplication.")
147+
print("Verified: result matches naive multiplication.")

0 commit comments

Comments
 (0)