@@ -93,7 +93,7 @@ def solve(arr: NDArray, row: int, cols: set[int], cache: dict[str, int]) -> int:
9393
9494def solution (matrix_str : list [str ] = MATRIX_2 ) -> int :
9595 """
96- Takes list of strings matrix_str to parse the matrix and calculates the max sum.
96+ Takes list of strings ` matrix_str` to parse the matrix and calculates the max sum.
9797
9898 >>> solution(["1 2", "3 4"])
9999 5
@@ -102,15 +102,14 @@ def solution(matrix_str: list[str] = MATRIX_2) -> int:
102102 """
103103
104104 n = len (matrix_str )
105- arr = np .empty ((n , n ), dtype = np .int64 )
106- for i in range ( n ):
107- els = matrix_str [ i ]. strip (). split (" " )
108- for j in range ( len ( els ) ):
109- arr [i , j ] = int (els [ j ] )
105+ arr = np .empty (shape = (n , n ), dtype = np .int64 )
106+ for row , matrix_row_str in enumerate ( matrix_str ):
107+ matrix_row_list_str = matrix_row_str . split ()
108+ for col , elem_str in enumerate ( matrix_row_list_str ):
109+ arr [row , col ] = int (elem_str )
110110
111111 cache : dict [str , int ] = {}
112- ans = solve (arr , 0 , set (range (n )), cache )
113- return int (ans )
112+ return solve (arr = arr , row = 0 , cols = set (range (n )), cache = cache )
114113
115114
116115if __name__ == "__main__" :
0 commit comments