|
65 | 65 |
|
66 | 66 |
|
67 | 67 | def solve( |
68 | | - arr: NDArray, row_ind: int, include_set: set[int], cache: dict[str, int] |
| 68 | + arr: NDArray, row: int, cols: set[int], cache: dict[str, int] |
69 | 69 | ) -> int: |
70 | 70 | """ |
71 | | - finds the max sum for array arr starting with row number row_ind, and with columns |
72 | | - included in include_set. cache is used for caching intermediate results. |
| 71 | + Finds the max sum for array `arr` starting with row index `row`, and with columns |
| 72 | + included in `cols`. `cache` is used for caching intermediate results. |
73 | 73 |
|
74 | | - >>> solve(np.array([[1, 2], [3, 4]]), 0, {0, 1}, {}) |
| 74 | + >>> solve(arr=np.array([[1, 2], [3, 4]]), row=0, cols={0, 1}, cache={}) |
75 | 75 | np.int64(5) |
76 | 76 | """ |
77 | 77 |
|
78 | 78 | cache_id = f"{row_ind}, {sorted(include_set)}" |
79 | 79 | if cache_id in cache: |
80 | 80 | return cache[cache_id] |
81 | | - if row_ind == len(arr): |
| 81 | + |
| 82 | + if row == len(arr): |
82 | 83 | return 0 |
83 | | - sub_max = 0 |
84 | | - for i in include_set: |
85 | | - new_set = include_set - {i} |
86 | | - sub_max = max( |
87 | | - sub_max, arr[row_ind, i] + solve(arr, row_ind + 1, new_set, cache) |
| 84 | + |
| 85 | + max_sum = 0 |
| 86 | + for col in cols: |
| 87 | + new_cols = cols - {col} |
| 88 | + max_sum = max( |
| 89 | + max_sum, arr[row, col] + solve(arr=arr, row=row + 1, cols=new_cols, cache=cache) |
88 | 90 | ) |
89 | | - cache[cache_id] = sub_max |
90 | | - return sub_max |
| 91 | + cache[cache_id] = max_sum |
| 92 | + return max_sum |
91 | 93 |
|
92 | 94 |
|
93 | 95 | def solution(matrix_str: list[str] = MATRIX_2) -> int: |
|
0 commit comments