|
1 | | -from collections import deque |
2 | | - |
| 1 | +from functools import cache |
| 2 | +from typing import AbstractSet |
3 | 3 | import networkx as nx |
4 | 4 | from tqdm import tqdm |
5 | 5 |
|
|
10 | 10 | def main(debug: bool) -> None: |
11 | 11 | input_data = load_data(debug) |
12 | 12 |
|
13 | | - source_to_targets = {} |
| 13 | + graph = parse_input(input_data) |
| 14 | + print(graph) |
| 15 | + |
| 16 | + result_part1 = solve_part1(digraph=graph, start="you", end="out") |
| 17 | + result_part2 = solve_part2( |
| 18 | + digraph=graph, start="svr", end="out", via={"fft", "dac"} |
| 19 | + ) |
| 20 | + |
| 21 | + submit_or_print(result_part1, result_part2, debug) |
| 22 | + |
| 23 | + |
| 24 | +def parse_input(input_data: str) -> nx.DiGraph[str]: |
| 25 | + graph: nx.DiGraph[str] = nx.DiGraph() |
14 | 26 | for line in input_data.splitlines(): |
15 | 27 | s = line.split(":") |
16 | 28 | source = s[0].strip() |
17 | 29 | targets = set(s[1].strip().split(" ")) |
18 | | - source_to_targets[source] = targets |
| 30 | + graph.add_edges_from((source, target) for target in targets) |
| 31 | + return graph |
19 | 32 |
|
20 | | - graph = nx.DiGraph() |
21 | | - for source, targets in source_to_targets.items(): |
22 | | - for target in targets: |
23 | | - graph.add_edge(source, target) |
24 | 33 |
|
25 | | - # result_part1 = paths_count(graph, "you", "out") |
| 34 | +def solve_part1(digraph: nx.DiGraph[str], start: str, end: str) -> int: |
| 35 | + return sum(tqdm(1 for _ in nx.all_simple_paths(digraph, start, end))) |
26 | 36 |
|
27 | | - source = "svr" |
28 | | - target = "out" |
29 | 37 |
|
30 | | - graphs = deque([graph]) |
31 | | - while graphs: |
32 | | - g = graphs.pop() |
33 | | - cut_nodes = nx.minimum_node_cut(g, source, target) |
34 | | - print(g, len(cut_nodes), cut_nodes) |
35 | | - if len(cut_nodes) < 10: |
36 | | - g.remove_nodes_from(cut_nodes) |
37 | | - graphs.extend( |
38 | | - [ |
39 | | - g.subgraph(c).copy() |
40 | | - for c in nx.connected_components(g.to_undirected()) |
41 | | - ] |
| 38 | +def solve_part2( |
| 39 | + digraph: nx.DiGraph[str], start: str, end: str, via: AbstractSet[str] |
| 40 | +) -> int: |
| 41 | + @cache |
| 42 | + def path_count(start: str, end: str, visited: frozenset[str]) -> int: |
| 43 | + if start == end and visited == via: |
| 44 | + return 1 |
| 45 | + total = 0 |
| 46 | + for child in digraph.successors(start): |
| 47 | + total += path_count( |
| 48 | + start=child, end=end, visited=frozenset(visited | ({start} & via)) |
42 | 49 | ) |
| 50 | + return total |
43 | 51 |
|
44 | | - # |
45 | | - # inter_fft = paths_count(graph, "fft", "dac") |
46 | | - # inter_dac = paths_count(graph, "dac", "fft") |
47 | | - # print(inter_fft, inter_dac) |
48 | | - # |
49 | | - # fft = paths_count(graph, "svr", "fft") |
50 | | - # dac = paths_count(graph, "svr", "dac") |
51 | | - # print(fft, dac) |
52 | | - # |
53 | | - # fft_out = paths_count(graph, "fft", "out") |
54 | | - # dac_out = paths_count(graph, "dac", "out") |
55 | | - # print(fft_out, dac_out) |
56 | | - |
57 | | - result_part1 = None |
58 | | - result_part2 = None |
59 | | - |
60 | | - # nx.draw_networkx(graph) |
61 | | - # plt.show() |
62 | | - |
63 | | - submit_or_print(result_part1, result_part2, debug) |
64 | | - |
65 | | - |
66 | | -def paths_count(graph, start, end) -> int: |
67 | | - return sum(tqdm(1 for _ in nx.all_simple_paths(graph, start, end))) |
| 52 | + return path_count(start=start, end=end, visited=frozenset()) |
68 | 53 |
|
69 | 54 |
|
70 | 55 | if __name__ == "__main__": |
71 | 56 | debug_mode = True |
72 | | - debug_mode = False |
| 57 | + # debug_mode = False |
73 | 58 | main(debug_mode) |
0 commit comments