Skip to content

Commit 605867b

Browse files
Update sol1.py
1 parent 639e922 commit 605867b

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

project_euler/problem_345/sol1.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,29 +65,31 @@
6565

6666

6767
def solve(
68-
arr: NDArray, row_ind: int, include_set: set[int], cache: dict[str, int]
68+
arr: NDArray, row: int, cols: set[int], cache: dict[str, int]
6969
) -> int:
7070
"""
71-
finds the max sum for array arr starting with row number row_ind, and with columns
72-
included in include_set. cache is used for caching intermediate results.
71+
Finds the max sum for array `arr` starting with row index `row`, and with columns
72+
included in `cols`. `cache` is used for caching intermediate results.
7373
74-
>>> solve(np.array([[1, 2], [3, 4]]), 0, {0, 1}, {})
74+
>>> solve(arr=np.array([[1, 2], [3, 4]]), row=0, cols={0, 1}, cache={})
7575
np.int64(5)
7676
"""
7777

7878
cache_id = f"{row_ind}, {sorted(include_set)}"
7979
if cache_id in cache:
8080
return cache[cache_id]
81-
if row_ind == len(arr):
81+
82+
if row == len(arr):
8283
return 0
83-
sub_max = 0
84-
for i in include_set:
85-
new_set = include_set - {i}
86-
sub_max = max(
87-
sub_max, arr[row_ind, i] + solve(arr, row_ind + 1, new_set, cache)
84+
85+
max_sum = 0
86+
for col in cols:
87+
new_cols = cols - {col}
88+
max_sum = max(
89+
max_sum, arr[row, col] + solve(arr=arr, row=row + 1, cols=new_cols, cache=cache)
8890
)
89-
cache[cache_id] = sub_max
90-
return sub_max
91+
cache[cache_id] = max_sum
92+
return max_sum
9193

9294

9395
def solution(matrix_str: list[str] = MATRIX_2) -> int:

0 commit comments

Comments
 (0)