@@ -10,11 +10,14 @@ use datafusion::prelude::*;
1010pub const VERTEX_ID : & str = "id" ;
1111pub const EDGE_SRC : & str = "src" ;
1212pub const EDGE_DST : & str = "dst" ;
13+ pub const EDGE_COL : & str = "edge" ;
14+ pub const SRC_VERTEX : & str = "src_vertex" ;
15+ pub const DST_VERTEX : & str = "dst_vertex" ;
1316
1417#[ derive( Debug , Clone ) ]
1518pub struct GraphFrame {
16- vertices : DataFrame ,
17- edges : DataFrame ,
19+ pub vertices : DataFrame ,
20+ pub edges : DataFrame ,
1821}
1922
2023impl GraphFrame {
@@ -43,13 +46,107 @@ impl GraphFrame {
4346 ) ?;
4447 Ok ( df. select ( vec ! [ col( EDGE_SRC ) . alias( VERTEX_ID ) , col( "out_degree" ) ] ) ?)
4548 }
49+
50+ /// Generates a DataFrame containing "triplets" by combining information from edges and vertices.
51+ ///
52+ /// This method aggregates data about source vertices, edges, and destination vertices,
53+ /// producing a combined representation of these relationships as triplets.
54+ /// It constructs structured representations of edges and vertices, then performs
55+ /// joins to associate source and destination vertices with their respective edges.
56+ ///
57+ /// # Returns
58+ ///
59+ /// Returns a `Result<DataFrame>` which can either:
60+ /// - Contain the `DataFrame` representing the triplets (source vertex, edge, destination vertex).
61+ /// - Return an error if an operation (e.g., selection or join) fails during the process.
62+ ///
63+ /// Output `DataFrame` contains the following columns:
64+ /// - `SRC_VERTEX` - struct with all the columns of vertices, associated with a source of the triple
65+ /// - `EDGE_COL` - struct with all the columns of edges, associated with an edge
66+ /// - `DST_VERTEX` - struct with all the columns of vertices, associated with a destination of the triplet
67+ ///
68+ /// # Errors
69+ ///
70+ /// This method will return an error if:
71+ /// - Either the source vertices or destination vertices cannot be joined with edges due to schema mismatches.
72+ /// - Any selection or transformation process internally fails due to invalid queries.
73+ ///
74+ /// # Example
75+ ///
76+ /// ```
77+ /// use datafusion::dataframe;
78+ /// use graphframes_rs::{GraphFrame, VERTEX_ID, EDGE_SRC, EDGE_DST};
79+ /// let vertices = dataframe!(
80+ /// VERTEX_ID => vec![1i64, 2i64, 3i64],
81+ /// "attr" => vec!["a", "b", "c"]
82+ /// ).unwrap();
83+ /// let edges = dataframe!(
84+ /// EDGE_SRC => vec![1i64, 2i64, 3i64],
85+ /// EDGE_DST => vec![3i64, 1i64, 2i64],
86+ /// "attr" => vec!["d", "j", "h"]
87+ /// ).unwrap();
88+ ///
89+ /// let graph = GraphFrame { vertices, edges };
90+ /// let triplets = graph.triplets();
91+ /// ```
92+ /// // Assuming `edges_df` and `vertices_df` are initialized DataFrames for
93+ pub async fn triplets ( & self ) -> Result < DataFrame > {
94+ let edges_struct = self . edges . clone ( ) . select ( vec ! [
95+ col( EDGE_SRC ) ,
96+ col( EDGE_DST ) ,
97+ named_struct(
98+ self . edges
99+ . clone( )
100+ . schema( )
101+ . fields( )
102+ . iter( )
103+ . map( |field| field. name( ) )
104+ . flat_map( |name| vec![ lit( name) , col( name) ] )
105+ . collect( ) ,
106+ )
107+ . alias( EDGE_COL ) ,
108+ ] ) ?;
109+ let vertices_struct = self . vertices . clone ( ) . select ( vec ! [
110+ col( VERTEX_ID ) ,
111+ named_struct(
112+ self . vertices
113+ . clone( )
114+ . schema( )
115+ . fields( )
116+ . iter( )
117+ . map( |field| field. name( ) )
118+ . flat_map( |name| vec![ lit( name) , col( name) ] )
119+ . collect( ) ,
120+ )
121+ . alias( "_vertex_struct" ) ,
122+ ] ) ?;
123+ edges_struct
124+ . join_on (
125+ vertices_struct. clone ( ) . select ( vec ! [
126+ col( VERTEX_ID ) ,
127+ col( "_vertex_struct" ) . alias( SRC_VERTEX ) ,
128+ ] ) ?,
129+ JoinType :: Left ,
130+ vec ! [ col( EDGE_SRC ) . eq( col( VERTEX_ID ) ) ] ,
131+ ) ?
132+ . select ( vec ! [ col( SRC_VERTEX ) , col( EDGE_DST ) , col( EDGE_COL ) ] ) ?
133+ . join_on (
134+ vertices_struct. select ( vec ! [
135+ col( VERTEX_ID ) ,
136+ col( "_vertex_struct" ) . alias( DST_VERTEX ) ,
137+ ] ) ?,
138+ JoinType :: Left ,
139+ vec ! [ col( EDGE_DST ) . eq( col( VERTEX_ID ) ) ] ,
140+ ) ?
141+ . select ( vec ! [ col( SRC_VERTEX ) , col( EDGE_COL ) , col( DST_VERTEX ) ] )
142+ }
46143}
47144
48145#[ cfg( test) ]
49146mod tests {
50147 use super :: * ;
51148 use datafusion:: arrow:: array:: { Int64Array , RecordBatch , StringArray } ;
52- use datafusion:: arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
149+ use datafusion:: arrow:: datatypes:: { DataType , Field , Fields , Schema , SchemaRef } ;
53150 use std:: collections:: HashMap ;
54151 use std:: sync:: Arc ;
55152
@@ -183,4 +280,51 @@ mod tests {
183280
184281 Ok ( ( ) )
185282 }
283+
284+ #[ tokio:: test]
285+ async fn test_triplets ( ) -> Result < ( ) > {
286+ let vertices =
287+ dataframe ! ( VERTEX_ID => vec![ 1i64 , 2i64 , 3i64 ] , "attr" => vec![ "a" , "b" , "c" ] ) ?;
288+ let edges = dataframe ! ( EDGE_SRC => vec![ 1i64 , 2i64 , 3i64 ] , EDGE_DST => vec![ 3i64 , 1i64 , 2i64 ] , "attr" => vec![ "d" , "j" , "h" ] ) ?;
289+
290+ let graph = GraphFrame { vertices, edges } ;
291+ let triplets = graph. triplets ( ) . await ?;
292+
293+ // Check schema
294+ let schema = triplets. schema ( ) ;
295+ assert_eq ! ( schema. fields( ) . len( ) , 3 ) ;
296+ assert_eq ! ( schema. field( 0 ) . name( ) , SRC_VERTEX ) ;
297+ assert_eq ! ( schema. field( 1 ) . name( ) , EDGE_COL ) ;
298+ assert_eq ! ( schema. field( 2 ) . name( ) , DST_VERTEX ) ;
299+ assert ! (
300+ schema
301+ . field( 0 )
302+ . data_type( )
303+ . eq( & DataType :: Struct ( Fields :: from( vec![
304+ Field :: new( VERTEX_ID , DataType :: Int64 , true ) ,
305+ Field :: new( "attr" , DataType :: Utf8 , true )
306+ ] ) ) )
307+ ) ;
308+ assert ! (
309+ schema
310+ . field( 1 )
311+ . data_type( )
312+ . eq( & DataType :: Struct ( Fields :: from( vec![
313+ Field :: new( EDGE_SRC , DataType :: Int64 , true ) ,
314+ Field :: new( EDGE_DST , DataType :: Int64 , true ) ,
315+ Field :: new( "attr" , DataType :: Utf8 , true ) ,
316+ ] ) ) )
317+ ) ;
318+ assert ! (
319+ schema
320+ . field( 2 )
321+ . data_type( )
322+ . eq( & DataType :: Struct ( Fields :: from( vec![
323+ Field :: new( VERTEX_ID , DataType :: Int64 , true ) ,
324+ Field :: new( "attr" , DataType :: Utf8 , true )
325+ ] ) ) )
326+ ) ;
327+
328+ Ok ( ( ) )
329+ }
186330}
0 commit comments