1- use crate :: pregel:: { PREGEL_MSG , pregel_dst} ;
1+ use crate :: pregel:: { MessageDirection , PREGEL_MSG , pregel_dst, pregel_src } ;
22use crate :: { GraphFrame , VERTEX_ID } ;
33use arrow:: compute:: min;
44use datafusion:: arrow;
@@ -114,6 +114,8 @@ pub struct ShortestPathsBuilder<'a> {
114114 max_iterations : usize ,
115115 /// Interval at which to checkpoint the computation state
116116 checkpoint_interval : usize ,
117+ /// Apply a reversal shortest paths algorithm to get distances from landmarks to node
118+ reversed : bool ,
117119}
118120
119121impl < ' a > ShortestPathsBuilder < ' a > {
@@ -130,9 +132,20 @@ impl<'a> ShortestPathsBuilder<'a> {
130132 landmarks : sorted_landmarks,
131133 max_iterations : i32:: MAX as usize ,
132134 checkpoint_interval : 1 ,
135+ reversed : false ,
133136 }
134137 }
135138
139+ /// Sets the direction for shortest paths' computation.
140+ ///
141+ /// # Arguments
142+ /// * `reversed` - If true, computes shortest paths from landmarks to vertices.
143+ /// If false, computes the shortest paths from vertices to landmarks.
144+ pub fn reversed ( mut self , reversed : bool ) -> Self {
145+ self . reversed = reversed;
146+ self
147+ }
148+
136149 /// Sets the maximum number of iterations for the algorithm.
137150 ///
138151 /// # Arguments
@@ -242,7 +255,11 @@ impl<'a> ShortestPathsBuilder<'a> {
242255 . iter ( )
243256 . flat_map ( |lm| {
244257 let col_name = lm. to_string ( ) ;
245- let d_col = pregel_dst ( DISTANCES ) . field ( col_name. clone ( ) ) ;
258+ let d_col = if self . reversed {
259+ pregel_src ( DISTANCES ) . field ( col_name. clone ( ) )
260+ } else {
261+ pregel_dst ( DISTANCES ) . field ( col_name. clone ( ) )
262+ } ;
246263 vec ! [
247264 lit( col_name) ,
248265 when( d_col. clone( ) . lt( lit( i32 :: MAX ) ) , d_col + lit( 1i32 ) )
@@ -266,7 +283,14 @@ impl<'a> ShortestPathsBuilder<'a> {
266283 update_participating. clone ( ) ,
267284 )
268285 // Add a message
269- . add_message ( message_expr, crate :: pregel:: MessageDirection :: DstToSrc )
286+ . add_message (
287+ message_expr,
288+ if self . reversed {
289+ MessageDirection :: SrcToDst
290+ } else {
291+ MessageDirection :: DstToSrc
292+ } ,
293+ )
270294 // Set aggregate expression
271295 . with_aggregate_expr ( aggregate_expr_udaf. call ( vec ! [ col( PREGEL_MSG ) ] ) )
272296 // Set voting condition
@@ -300,6 +324,7 @@ impl GraphFrame {
300324#[ cfg( test) ]
301325mod tests {
302326 use super :: * ;
327+ use crate :: tests:: create_ldbc_test_graph;
303328 use datafusion:: arrow:: array:: { Int64Array , RecordBatch } ;
304329 use datafusion:: arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
305330 use datafusion:: prelude:: SessionContext ;
@@ -448,4 +473,64 @@ mod tests {
448473 ) ;
449474 Ok ( ( ) )
450475 }
476+
477+ async fn get_ldbc_bfs_results ( dataset : & str ) -> Result < DataFrame > {
478+ let ctx = SessionContext :: new ( ) ;
479+ let manifest_dir = env ! ( "CARGO_MANIFEST_DIR" ) ;
480+ let expected_pr_schema = Schema :: new ( vec ! [
481+ Field :: new( "vertex_id" , DataType :: Int64 , false ) ,
482+ Field :: new( "expected_distance" , DataType :: Int64 , false ) ,
483+ ] ) ;
484+ let expected_sp_path = format ! (
485+ "{}/testing/data/ldbc/{}/{}-BFS.csv" ,
486+ manifest_dir, dataset, dataset
487+ ) ;
488+ let expected_sp = ctx
489+ . read_csv (
490+ & expected_sp_path,
491+ CsvReadOptions :: new ( )
492+ . delimiter ( b' ' )
493+ . has_header ( false )
494+ . schema ( & expected_pr_schema) ,
495+ )
496+ . await ?;
497+ Ok ( expected_sp)
498+ }
499+
500+ #[ tokio:: test]
501+ async fn test_ldbc ( ) -> Result < ( ) > {
502+ let expected_distances = get_ldbc_bfs_results ( "test-bfs-directed" ) . await ?;
503+ let graph = create_ldbc_test_graph ( "test-bfs-directed" ) . await ?;
504+
505+ let results = graph
506+ . shortest_paths ( vec ! [ 1 ] )
507+ . reversed ( true ) // In LDBC the task is formulated as find distance from the root
508+ . checkpoint_interval ( 1 )
509+ . run ( )
510+ . await ?;
511+ let diff = results
512+ . join (
513+ expected_distances,
514+ JoinType :: Left ,
515+ & [ VERTEX_ID ] ,
516+ & [ "vertex_id" ] ,
517+ None ,
518+ ) ?
519+ . select ( vec ! [
520+ col( VERTEX_ID ) ,
521+ col( DISTANCES ) . field( "1" ) . alias( "got_distance" ) ,
522+ when(
523+ col( "expected_distance" ) . eq( lit( 9223372036854775807i64 ) ) ,
524+ lit( i32 :: MAX as i64 ) ,
525+ )
526+ . otherwise( col( "expected_distance" ) )
527+ . unwrap( )
528+ . alias( "expected_distance" ) ,
529+ ] ) ?
530+ . filter ( col ( "got_distance" ) . not_eq ( col ( "expected_distance" ) ) ) ?;
531+
532+ assert_eq ! ( diff. count( ) . await ?, 0 ) ;
533+
534+ Ok ( ( ) )
535+ }
451536}
0 commit comments