@@ -29,58 +29,32 @@ use std::collections::HashMap;
2929use std:: sync:: Arc ;
3030
3131/// Convert a DF Expr into a Substrait ExtendedExpressions message
32+ ///
33+ /// The schema needs to contain all of the fields that are referenced in the expression.
34+ /// It is ok if the schema has more fields than are required. However, we cannot currently
35+ /// convert all field types (e.g. extension types, FSL) and if these fields are present then
36+ /// the conversion will fail.
37+ ///
38+ /// As a result, it may be a good idea for now to remove those types from the schema before
39+ /// calling this function.
3240pub fn encode_substrait ( expr : Expr , schema : Arc < ArrowSchema > ) -> Result < Vec < u8 > > {
33- use datafusion:: logical_expr:: { builder:: LogicalTableSource , logical_plan, LogicalPlan } ;
34- use datafusion_substrait:: substrait:: proto:: { plan_rel, ExpressionReference , NamedStruct } ;
35-
36- let table_source = Arc :: new ( LogicalTableSource :: new ( schema. clone ( ) ) ) ;
41+ use arrow_schema:: Field ;
42+ use datafusion:: logical_expr:: ExprSchemable ;
43+ use datafusion_common:: DFSchema ;
3744
38- // DF doesn't handled ExtendedExpressions and so we need to create
39- // a dummy plan with a single filter node
40- let plan = LogicalPlan :: Filter ( logical_plan:: Filter :: try_new (
41- expr,
42- Arc :: new ( LogicalPlan :: TableScan ( logical_plan:: TableScan :: try_new (
43- "dummy" ,
44- table_source,
45- None ,
46- vec ! [ ] ,
47- None ,
48- ) ?) ) ,
49- ) ?) ;
45+ let ctx = SessionContext :: new ( ) ;
5046
51- let session_context = SessionContext :: new ( ) ;
52-
53- let substrait_plan = datafusion_substrait:: logical_plan:: producer:: to_substrait_plan (
54- & plan,
55- & session_context. state ( ) ,
47+ let df_schema = Arc :: new ( DFSchema :: try_from ( schema) ?) ;
48+ let output_type = expr. get_type ( & df_schema) ?;
49+ // Nullability doesn't matter
50+ let output_field = Field :: new ( "output" , output_type, /*nullable=*/ true ) ;
51+ let extended_expr = datafusion_substrait:: logical_plan:: producer:: to_substrait_extended_expr (
52+ & [ ( & expr, & output_field) ] ,
53+ & df_schema,
54+ & ctx. state ( ) ,
5655 ) ?;
5756
58- if let Some ( plan_rel:: RelType :: Root ( root) ) = & substrait_plan. relations [ 0 ] . rel_type {
59- if let Some ( rel:: RelType :: Filter ( filt) ) = & root. input . as_ref ( ) . unwrap ( ) . rel_type {
60- let expr = filt. condition . as_ref ( ) . unwrap ( ) . clone ( ) ;
61- let schema = NamedStruct {
62- names : schema. fields ( ) . iter ( ) . map ( |f| f. name ( ) . clone ( ) ) . collect ( ) ,
63- r#struct : None ,
64- } ;
65- let envelope = ExtendedExpression {
66- advanced_extensions : substrait_plan. advanced_extensions . clone ( ) ,
67- base_schema : Some ( schema) ,
68- expected_type_urls : substrait_plan. expected_type_urls . clone ( ) ,
69- extension_uris : substrait_plan. extension_uris . clone ( ) ,
70- extensions : substrait_plan. extensions . clone ( ) ,
71- referred_expr : vec ! [ ExpressionReference {
72- output_names: vec![ ] ,
73- expr_type: Some ( ExprType :: Expression ( * expr) ) ,
74- } ] ,
75- version : substrait_plan. version . clone ( ) ,
76- } ;
77- Ok ( envelope. encode_to_vec ( ) )
78- } else {
79- unreachable ! ( )
80- }
81- } else {
82- unreachable ! ( )
83- }
57+ Ok ( extended_expr. encode_to_vec ( ) )
8458}
8559
8660fn count_fields ( dtype : & Type ) -> usize {
@@ -425,7 +399,7 @@ mod tests {
425399 helpers:: { literals:: literal, schema:: SchemaInfo } ,
426400 } ;
427401
428- use crate :: substrait:: parse_substrait;
402+ use crate :: substrait:: { encode_substrait , parse_substrait} ;
429403
430404 #[ tokio:: test]
431405 async fn test_substrait_conversion ( ) {
@@ -462,4 +436,21 @@ mod tests {
462436 } ) ;
463437 assert_eq ! ( df_expr, expected) ;
464438 }
439+
440+ #[ tokio:: test]
441+ async fn test_expr_substrait_roundtrip ( ) {
442+ let schema = arrow_schema:: Schema :: new ( vec ! [ Field :: new( "x" , DataType :: Int32 , true ) ] ) ;
443+ let expr = Expr :: BinaryExpr ( BinaryExpr {
444+ left : Box :: new ( Expr :: Column ( Column :: new_unqualified ( "x" ) ) ) ,
445+ op : Operator :: Lt ,
446+ right : Box :: new ( Expr :: Literal ( ScalarValue :: Int32 ( Some ( 0 ) ) ) ) ,
447+ } ) ;
448+
449+ let bytes = encode_substrait ( expr. clone ( ) , Arc :: new ( schema. clone ( ) ) ) . unwrap ( ) ;
450+
451+ let decoded = parse_substrait ( bytes. as_slice ( ) , Arc :: new ( schema. clone ( ) ) )
452+ . await
453+ . unwrap ( ) ;
454+ assert_eq ! ( decoded, expr) ;
455+ }
465456}
0 commit comments