@@ -26,6 +26,19 @@ impl Direction {
2626 Direction :: W => Direction :: N ,
2727 }
2828 }
29+
30+ pub fn mirror ( & self ) -> Self {
31+ match self {
32+ Direction :: N => Direction :: S ,
33+ Direction :: E => Direction :: W ,
34+ Direction :: S => Direction :: N ,
35+ Direction :: W => Direction :: E ,
36+ }
37+ }
38+
39+ pub fn all ( ) -> Vec < Self > {
40+ vec ! [ Direction :: N , Direction :: E , Direction :: S , Direction :: W ]
41+ }
2942}
3043
3144#[ derive( Debug , Clone , Copy , PartialEq , Eq , Hash ) ]
@@ -114,97 +127,62 @@ impl Grid {
114127 } ) . collect ( )
115128 }
116129
117- fn n_valid_neighbors ( & self , pos : & Point ) -> usize {
118- [ Direction :: N , Direction :: E , Direction :: S , Direction :: W ] . iter ( ) . filter ( |& dir| {
119- let new_pos = pos. step ( * dir) ;
120- match new_pos {
121- Some ( new_pos) => !self . walls . contains ( & new_pos) ,
122- None => false ,
123- }
124- } ) . count ( )
125- }
126-
127- pub fn shortest_path ( & self ) -> Option < usize > {
128- let mut visited = HashSet :: new ( ) ;
129- let mut heap = BinaryHeap :: from ( [ State { pos : self . start , dir : Direction :: E , cost : 0 } ] ) ;
130+ pub fn shortest_distances ( & self , start_point : Point , start_direction : Direction ) -> HashMap < ( Point , Direction ) , usize > {
131+ let mut distance= HashMap :: new ( ) ;
132+ let mut heap = BinaryHeap :: from ( [ State { pos : start_point, dir : start_direction, cost : 0 } ] ) ;
130133
131134 while let Some ( current) = heap. pop ( ) {
132- if current. pos == self . end {
133- return Some ( current. cost ) ;
134- }
135-
136- if !visited. insert ( ( current. pos , current. dir ) ) {
135+ if distance. contains_key ( & ( current. pos , current. dir ) ) {
137136 continue ;
138137 }
139138
139+ distance. insert ( ( current. pos , current. dir ) , current. cost ) ;
140+
140141 for child in self . new_states ( & current) {
141142 heap. push ( child ) ;
142143 }
143144 }
144- None
145+ distance
145146 }
146147
147- pub fn shortest_path_with_chain ( & self , start : State ) -> ( usize , HashSet < State > ) {
148- let mut visited = HashSet :: new ( ) ;
149- let mut heap = BinaryHeap :: from ( [ start] ) ;
150- let mut parent = HashMap :: new ( ) ;
151-
152- let mut chain = HashSet :: new ( ) ;
153- let mut cost = 0 ;
154- while let Some ( current) = heap. pop ( ) {
155- if current. pos == self . end {
156- let mut current = current;
157- cost = current. cost ;
158- while let Some ( parent) = parent. get ( & current) {
159- chain. insert ( current) ;
160- current = * parent;
161- }
162- chain. insert ( current) ;
163- break
164- }
165-
166- if !visited. insert ( ( current. pos , current. dir ) ) {
167- continue ;
168- }
148+ pub fn shortest_path ( & self ) -> Option < usize > {
149+ let distances = self . shortest_distances ( self . start , Direction :: E ) ;
169150
170- for child in self . new_states ( & current) {
171- heap. push ( child ) ;
172- parent. insert ( child, current) ;
151+ Direction :: all ( ) . iter ( ) . filter_map ( |& dir| {
152+ match distances. get ( & ( self . end , dir) ) {
153+ Some ( & distance) => Some ( distance) ,
154+ None => None ,
173155 }
174- }
175- ( cost, chain)
156+ } ) . min ( )
176157 }
177158
178- pub fn count_all_shortest_paths ( & self ) -> usize {
179- let start_state = State { pos : self . start , dir : Direction :: E , cost : 0 } ;
180- let ( cost, mut chain) = self . shortest_path_with_chain ( start_state) ;
181- println ! ( "Cost: {}" , cost) ;
182- let mut visited = chain. clone ( ) ;
183- while chain. len ( ) != 0 {
184- let current = chain. iter ( ) . next ( ) . cloned ( ) . unwrap ( ) ;
185- chain. remove ( & current) ;
159+ // Based on idea by @jenuk
160+ pub fn points_on_shortest_path ( & self ) -> usize {
161+ let forward_distances = self . shortest_distances ( self . start , Direction :: E ) ;
186162
187- if self . n_valid_neighbors ( & current. pos ) <= 2 {
188- continue ;
163+ let end_direction = Direction :: all ( ) . iter ( ) . filter_map ( |& dir| {
164+ match forward_distances. get ( & ( self . end , dir) ) {
165+ Some ( & distance) => Some ( ( dir, distance) ) ,
166+ None => None ,
189167 }
190-
191- for new_state in self . new_states ( & current) {
192- if visited. contains ( & new_state) {
193- continue ;
194- }
195- let ( new_cost, new_chain) = self . shortest_path_with_chain ( new_state) ;
196- if new_cost == cost {
197- chain = chain. union ( & new_chain) . cloned ( ) . collect ( ) ;
198- visited = visited. union ( & new_chain) . cloned ( ) . collect ( ) ;
199- }
168+ } ) . min_by_key ( |& ( _, distance) | distance) . unwrap ( ) . 0 ;
169+
170+ let backward_distances = self . shortest_distances ( self . end , end_direction. mirror ( ) ) ;
171+
172+ HashSet :: < Point > :: from_iter ( forward_distances. iter ( ) . filter_map ( |( ( point, direction) , & distance) | {
173+ match backward_distances. get ( & ( * point, direction. mirror ( ) ) ) {
174+ Some ( & backward_distance) => { if distance + backward_distance == forward_distances[ & ( self . end , end_direction) ] {
175+ Some ( * point)
176+ } else {
177+ None
178+ }
179+ } ,
180+ None => None ,
200181 }
201- }
202- let visited_points: HashSet < Point > = HashSet :: from_iter ( visited. iter ( ) . map ( |state| state. pos ) ) ;
203- self . print_maze ( & visited_points) ;
204- visited_points. len ( )
205-
182+ } ) ) . len ( )
206183 }
207184
185+ #[ allow( dead_code) ]
208186 fn print_maze ( & self , visited : & HashSet < Point > ) {
209187 let ( n_rows, n_cols) = self . walls . iter ( ) . fold ( ( 0 , 0 ) , |( max_row, max_col) , point| {
210188 ( max_row. max ( point. row +1 ) , max_col. max ( point. col +1 ) )
@@ -237,7 +215,5 @@ pub fn task01(input: &str) -> String {
237215
238216pub fn task02 ( input : & str ) -> String {
239217 let grid = Grid :: from_input ( input) ;
240- // grid.count_tiles_on_shortest_paths_dfs().unwrap().to_string()
241- println ! ( "Caution super inefficient. Takes up to 3 min on my machine." ) ;
242- grid. count_all_shortest_paths ( ) . to_string ( )
218+ grid. points_on_shortest_path ( ) . to_string ( )
243219}
0 commit comments