1212https://en.wikipedia.org/wiki/Strassen_algorithm
1313"""
1414
15-
1615Matrix = list [list [int ]]
1716
1817def add (A : Matrix , B : Matrix ) -> Matrix :
1918 n = len (A )
2019 return [[A [i ][j ] + B [i ][j ] for j in range (n )] for i in range (n )]
2120
21+
2222def sub (A : Matrix , B : Matrix ) -> Matrix :
2323 n = len (A )
2424 return [[A [i ][j ] - B [i ][j ] for j in range (n )] for i in range (n )]
2525
26+
2627def naive_mul (A : Matrix , B : Matrix ) -> Matrix :
2728 n = len (A )
28- C = [[0 ]* n for _ in range (n )]
29+ C = [[0 ] * n for _ in range (n )]
2930 for i in range (n ):
3031 ai = A [i ]
3132 ci = C [i ]
@@ -36,23 +37,27 @@ def naive_mul(A: Matrix, B: Matrix) -> Matrix:
3637 ci [j ] += a_ik * bk [j ]
3738 return C
3839
40+
3941def next_power_of_two (n : int ) -> int :
4042 p = 1
4143 while p < n :
4244 p <<= 1
4345 return p
4446
47+
4548def pad_matrix (A : Matrix , size : int ) -> Matrix :
4649 n = len (A )
47- padded = [[0 ]* size for _ in range (size )]
50+ padded = [[0 ] * size for _ in range (size )]
4851 for i in range (n ):
4952 for j in range (len (A [0 ])):
5053 padded [i ][j ] = A [i ][j ]
5154 return padded
5255
56+
5357def unpad_matrix (A : Matrix , rows : int , cols : int ) -> Matrix :
5458 return [row [:cols ] for row in A [:rows ]]
5559
60+
5661def split (A : Matrix ) -> tuple :
5762 n = len (A )
5863 mid = n // 2
@@ -62,10 +67,11 @@ def split(A: Matrix) -> tuple:
6267 A22 = [[A [i ][j ] for j in range (mid , n )] for i in range (mid , n )]
6368 return A11 , A12 , A21 , A22
6469
70+
6571def join (C11 : Matrix , C12 : Matrix , C21 : Matrix , C22 : Matrix ) -> Matrix :
6672 n2 = len (C11 )
6773 n = n2 * 2
68- C = [[0 ]* n for _ in range (n )]
74+ C = [[0 ] * n for _ in range (n )]
6975 for i in range (n2 ):
7076 for j in range (n2 ):
7177 C [i ][j ] = C11 [i ][j ]
@@ -74,19 +80,21 @@ def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix:
7480 C [i + n2 ][j + n2 ] = C22 [i ][j ]
7581 return C
7682
83+
7784def strassen (A : Matrix , B : Matrix , threshold : int = 64 ) -> Matrix :
7885 """
7986 Multiply square matrices A and B using Strassen algorithm.
8087 threshold: below this size, uses naive multiplication (tweakable).
8188 """
82- assert len (A ) == len (A [0 ]) == len (B ) == len (B [0 ]), "Only square matrices supported in this implementation"
89+ assert len (A ) == len (A [0 ]) == len (B ) == len (B [0 ]), (
90+ "Only square matrices supported in this implementation"
91+ )
8392
8493 n_orig = len (A )
8594 if n_orig == 0 :
8695 return []
8796
88- m = next_power_of_two (n_orig )
89- if m != n_orig :
97+ if (m := next_power_of_two (n_orig )) != n_orig :
9098 A_pad = pad_matrix (A , m )
9199 B_pad = pad_matrix (B , m )
92100 else :
@@ -97,6 +105,7 @@ def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix:
97105 C = unpad_matrix (C_pad , n_orig , n_orig )
98106 return C
99107
108+
100109def _strassen_recursive (A : Matrix , B : Matrix , threshold : int ) -> Matrix :
101110 n = len (A )
102111 if n <= threshold :
@@ -122,17 +131,10 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
122131
123132 return join (C11 , C12 , C21 , C22 )
124133
134+
125135if __name__ == "__main__" :
126- A = [
127- [1 , 2 , 3 ],
128- [4 , 5 , 6 ],
129- [7 , 8 , 9 ]
130- ]
131- B = [
132- [9 , 8 , 7 ],
133- [6 , 5 , 4 ],
134- [3 , 2 , 1 ]
135- ]
136+ A = [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]]
137+ B = [[9 , 8 , 7 ], [6 , 5 , 4 ], [3 , 2 , 1 ]]
136138
137139 C = strassen (A , B , threshold = 1 )
138140 print ("A * B =" )
@@ -142,4 +144,4 @@ def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
142144 # verify against naive
143145 expected = naive_mul (A , B )
144146 assert C == expected , "Strassen result differs from naive multiplication!"
145- print ("Verified: result matches naive multiplication." )
147+ print ("Verified: result matches naive multiplication." )
0 commit comments