@@ -96,15 +96,18 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {
9696
9797#[ cfg( test) ]
9898mod tests {
99- use crate :: logical_plan:: producer:: { DefaultSubstraitProducer , SubstraitProducer } ;
99+ use crate :: logical_plan:: producer:: {
100+ DefaultSubstraitProducer , SubstraitProducer , to_substrait_type,
101+ } ;
100102 use datafusion:: arrow:: datatypes:: { DataType , Field , Schema } ;
101103 use datafusion:: common:: { JoinConstraint , JoinType , NullEquality } ;
102104 use datafusion:: execution:: SessionStateBuilder ;
103105 use datafusion:: logical_expr:: utils:: conjunction;
104106 use datafusion:: logical_expr:: { Join , col, table_scan} ;
105107 use std:: sync:: Arc ;
108+ use substrait:: proto:: expression:: { RexType , ScalarFunction } ;
106109 use substrait:: proto:: rel:: RelType ;
107- use substrait:: proto:: { JoinRel , Rel , join_rel} ;
110+ use substrait:: proto:: { Expression , JoinRel , Rel , join_rel} ;
108111
109112 #[ test]
110113 fn test_from_join ( ) -> datafusion:: common:: Result < ( ) > {
@@ -139,6 +142,9 @@ mod tests {
139142 col( "t1.c" ) . gt( col( "t2.c" ) ) ,
140143 ] )
141144 . unwrap ( ) ;
145+ let expected_join_expression =
146+ producer. handle_expr ( & expected_join_expr, & in_join_schema) ?;
147+
142148 assert_eq ! (
143149 join_expr,
144150 Box :: new( Rel {
@@ -147,14 +153,25 @@ mod tests {
147153 left: Some ( producer. handle_plan( & left_scan) ?) ,
148154 right: Some ( producer. handle_plan( & right_scan) ?) ,
149155 r#type: join_rel:: JoinType :: Inner as i32 ,
150- expression: Some ( Box :: new(
151- producer. handle_expr( & expected_join_expr, & in_join_schema) ?
152- ) ) ,
156+ expression: Some ( Box :: new( expected_join_expression. clone( ) ) ) ,
153157 post_join_filter: None ,
154158 advanced_extension: None ,
155159 } ) ) )
156160 } )
157161 ) ;
162+
163+ // Check that the join_expression has the expected output_type
164+ if let Expression {
165+ rex_type : Some ( RexType :: ScalarFunction ( ScalarFunction { output_type, .. } ) ) ,
166+ } = expected_join_expression
167+ {
168+ let expected_type =
169+ to_substrait_type ( & mut producer, & DataType :: Boolean , false ) ?;
170+ assert_eq ! ( output_type, Some ( expected_type) ) ;
171+ } else {
172+ panic ! ( "Substrait ScalarFunction expected" )
173+ }
174+
158175 Ok ( ( ) )
159176 }
160177}
0 commit comments