@@ -11,6 +11,7 @@ use datafusion_common::tree_node::TreeNode;
1111use datafusion_common:: tree_node:: TreeNodeRecursion ;
1212use datafusion_expr:: Operator as DFOperator ;
1313use datafusion_functions:: core:: getfield:: GetFieldFunc ;
14+ use datafusion_functions:: string:: octet_length:: OctetLengthFunc ;
1415use datafusion_physical_expr:: PhysicalExpr ;
1516use datafusion_physical_expr:: ScalarFunctionExpr ;
1617use datafusion_physical_expr:: projection:: ProjectionExpr ;
@@ -24,6 +25,7 @@ use vortex::dtype::Nullability;
2425use vortex:: dtype:: arrow:: FromArrowType ;
2526use vortex:: expr:: Expression ;
2627use vortex:: expr:: and_collect;
28+ use vortex:: expr:: byte_length;
2729use vortex:: expr:: cast;
2830use vortex:: expr:: get_item;
2931use vortex:: expr:: is_not_null;
@@ -111,8 +113,28 @@ pub trait ExpressionConvertor: Send + Sync {
111113pub struct DefaultExpressionConvertor { }
112114
113115impl DefaultExpressionConvertor {
116+ /// Attempts to convert DataFusion's `octet_length` function to Vortex `byte_length`.
117+ fn try_convert_octet_length ( & self , scalar_fn : & ScalarFunctionExpr ) -> DFResult < Expression > {
118+ let [ input] = scalar_fn. args ( ) else {
119+ return Err ( exec_datafusion_err ! (
120+ "octet_length requires exactly one argument"
121+ ) ) ;
122+ } ;
123+
124+ let input = self . convert ( input. as_ref ( ) ) ?;
125+ let return_dtype =
126+ DType :: from_arrow ( ( scalar_fn. return_type ( ) , scalar_fn. nullable ( ) . into ( ) ) ) ;
127+ Ok ( cast ( byte_length ( input) , return_dtype) )
128+ }
129+
114130 /// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
115131 fn try_convert_scalar_function ( & self , scalar_fn : & ScalarFunctionExpr ) -> DFResult < Expression > {
132+ if let Some ( octet_length_fn) =
133+ ScalarFunctionExpr :: try_downcast_func :: < OctetLengthFunc > ( scalar_fn)
134+ {
135+ return self . try_convert_octet_length ( octet_length_fn) ;
136+ }
137+
116138 if let Some ( get_field_fn) = ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn)
117139 {
118140 // DataFusion's GetFieldFunc flattens nested field access into a single call
@@ -289,7 +311,7 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
289311 let r = projection_expr. expr . apply ( |node| {
290312 // We only pull column children of scalar functions that we can't push into the scan.
291313 if let Some ( scalar_fn_expr) = node. downcast_ref :: < ScalarFunctionExpr > ( )
292- && !can_scalar_fn_be_pushed_down ( scalar_fn_expr)
314+ && !can_scalar_fn_be_pushed_down ( scalar_fn_expr, input_schema )
293315 {
294316 scan_projection. extend (
295317 collect_columns ( node)
@@ -305,8 +327,8 @@ impl ExpressionConvertor for DefaultExpressionConvertor {
305327 // Vortex expects a perfect match so we don't push it down.
306328 if let Some ( binary_expr) = node. downcast_ref :: < df_expr:: BinaryExpr > ( )
307329 && binary_expr. op ( ) . is_numerical_operators ( )
308- && ( is_decimal ( & binary_expr. left ( ) . data_type ( input_schema) ?)
309- && is_decimal ( & binary_expr. right ( ) . data_type ( input_schema) ?) )
330+ && binary_expr. left ( ) . data_type ( input_schema) ?. is_decimal ( )
331+ && binary_expr. right ( ) . data_type ( input_schema) ?. is_decimal ( )
310332 {
311333 scan_projection. extend (
312334 collect_columns ( node)
@@ -430,7 +452,7 @@ fn can_be_pushed_down_impl(expr: &Arc<dyn PhysicalExpr>, schema: &Schema) -> boo
430452 . iter ( )
431453 . all ( |e| can_be_pushed_down_impl ( e, schema) )
432454 } else if let Some ( scalar_fn) = expr. downcast_ref :: < ScalarFunctionExpr > ( ) {
433- can_scalar_fn_be_pushed_down ( scalar_fn)
455+ can_scalar_fn_be_pushed_down ( scalar_fn, schema )
434456 } else if let Some ( case_expr) = expr. downcast_ref :: < df_expr:: CaseExpr > ( ) {
435457 can_case_be_pushed_down ( case_expr, schema)
436458 } else {
@@ -454,9 +476,10 @@ fn is_convertible_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
454476 || expr. downcast_ref :: < df_expr:: IsNullExpr > ( ) . is_some ( )
455477 || expr. downcast_ref :: < df_expr:: IsNotNullExpr > ( ) . is_some ( )
456478 || expr. downcast_ref :: < df_expr:: InListExpr > ( ) . is_some ( )
457- || expr
458- . downcast_ref :: < ScalarFunctionExpr > ( )
459- . is_some_and ( |sf| ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( sf) . is_some ( ) )
479+ || expr. downcast_ref :: < ScalarFunctionExpr > ( ) . is_some_and ( |sf| {
480+ ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( sf) . is_some ( )
481+ || ScalarFunctionExpr :: try_downcast_func :: < OctetLengthFunc > ( sf) . is_some ( )
482+ } )
460483}
461484
462485fn can_binary_be_pushed_down ( binary : & df_expr:: BinaryExpr , schema : & Schema ) -> bool {
@@ -502,20 +525,11 @@ fn supported_data_types(dt: &DataType) -> bool {
502525
503526 let is_supported = dt. is_null ( )
504527 || dt. is_numeric ( )
528+ || dt. is_binary ( )
529+ || dt. is_string ( )
505530 || matches ! (
506531 dt,
507- Boolean
508- | Utf8
509- | LargeUtf8
510- | Utf8View
511- | Binary
512- | LargeBinary
513- | BinaryView
514- | Date32
515- | Date64
516- | Timestamp ( _, _)
517- | Time32 ( _)
518- | Time64 ( _)
532+ Boolean | Date32 | Date64 | Timestamp ( _, _) | Time32 ( _) | Time64 ( _)
519533 ) ;
520534
521535 if !is_supported {
@@ -526,20 +540,30 @@ fn supported_data_types(dt: &DataType) -> bool {
526540}
527541
528542/// Checks if a scalar function can be pushed down.
529- /// Currently only GetFieldFunc is supported.
530- fn can_scalar_fn_be_pushed_down ( scalar_fn : & ScalarFunctionExpr ) -> bool {
531- ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) . is_some ( )
543+ /// Currently GetFieldFunc and OctetLengthFunc are supported.
544+ fn can_scalar_fn_be_pushed_down ( scalar_fn : & ScalarFunctionExpr , schema : & Schema ) -> bool {
545+ if ScalarFunctionExpr :: try_downcast_func :: < GetFieldFunc > ( scalar_fn) . is_some ( ) {
546+ return true ;
547+ }
548+
549+ ScalarFunctionExpr :: try_downcast_func :: < OctetLengthFunc > ( scalar_fn)
550+ . is_some_and ( |octet_length| can_octet_length_be_pushed_down ( octet_length, schema) )
532551}
533552
534- // TODO(adam): Replace with `DataType::is_decimal` once its released.
535- fn is_decimal ( dt : & DataType ) -> bool {
536- matches ! (
537- dt,
538- DataType :: Decimal32 ( _, _)
539- | DataType :: Decimal64 ( _, _)
540- | DataType :: Decimal128 ( _, _)
541- | DataType :: Decimal256 ( _, _)
542- )
553+ fn can_octet_length_be_pushed_down ( scalar_fn : & ScalarFunctionExpr , schema : & Schema ) -> bool {
554+ let [ input] = scalar_fn. args ( ) else {
555+ return false ;
556+ } ;
557+
558+ input. data_type ( schema) . as_ref ( ) . is_ok_and ( |data_type| {
559+ let dt = if let DataType :: Dictionary ( _, value_type) = data_type {
560+ value_type. as_ref ( )
561+ } else {
562+ data_type
563+ } ;
564+
565+ dt. is_binary ( ) || dt. is_string ( )
566+ } ) && can_be_pushed_down_impl ( input, schema)
543567}
544568
545569#[ cfg( test) ]
@@ -553,7 +577,9 @@ mod tests {
553577 use datafusion:: arrow:: array:: AsArray ;
554578 use datafusion:: arrow:: datatypes:: Int32Type ;
555579 use datafusion_common:: ScalarValue ;
580+ use datafusion_common:: config:: ConfigOptions ;
556581 use datafusion_expr:: Operator as DFOperator ;
582+ use datafusion_expr:: ScalarUDF ;
557583 use datafusion_physical_expr:: PhysicalExpr ;
558584 use datafusion_physical_plan:: expressions as df_expr;
559585 use insta:: assert_snapshot;
@@ -582,6 +608,18 @@ mod tests {
582608 ] )
583609 }
584610
611+ fn octet_length_expr ( input : Arc < dyn PhysicalExpr > , schema : & Schema ) -> Arc < dyn PhysicalExpr > {
612+ Arc :: new (
613+ ScalarFunctionExpr :: try_new (
614+ Arc :: new ( ScalarUDF :: from ( OctetLengthFunc :: new ( ) ) ) ,
615+ vec ! [ input] ,
616+ schema,
617+ Arc :: new ( ConfigOptions :: new ( ) ) ,
618+ )
619+ . unwrap ( ) ,
620+ )
621+ }
622+
585623 #[ test]
586624 fn test_make_vortex_predicate_empty ( ) {
587625 let expr_convertor = DefaultExpressionConvertor :: default ( ) ;
@@ -711,6 +749,23 @@ mod tests {
711749 ) ;
712750 }
713751
752+ #[ rstest]
753+ fn test_expr_from_df_octet_length ( test_schema : Schema ) {
754+ let expr = Arc :: new ( df_expr:: Column :: new ( "name" , 1 ) ) as Arc < dyn PhysicalExpr > ;
755+ let octet_length = octet_length_expr ( expr, & test_schema) ;
756+
757+ let result = DefaultExpressionConvertor :: default ( )
758+ . convert ( octet_length. as_ref ( ) )
759+ . unwrap ( ) ;
760+
761+ assert_snapshot ! ( result. display_tree( ) . to_string( ) , @r"
762+ vortex.cast(i32?)
763+ └── input: vortex.byte_length()
764+ └── input: vortex.get_item(name)
765+ └── input: vortex.root()
766+ " ) ;
767+ }
768+
714769 #[ rstest]
715770 // Supported types
716771 #[ case:: null( DataType :: Null , true ) ]
@@ -865,6 +920,28 @@ mod tests {
865920 assert ! ( !can_be_pushed_down_impl( & like_expr, & test_schema) ) ;
866921 }
867922
923+ #[ rstest]
924+ fn test_can_be_pushed_down_octet_length_supported ( test_schema : Schema ) {
925+ let expr = Arc :: new ( df_expr:: Column :: new ( "name" , 1 ) ) as Arc < dyn PhysicalExpr > ;
926+ let octet_length = octet_length_expr ( expr, & test_schema) ;
927+
928+ assert ! ( can_be_pushed_down_impl( & octet_length, & test_schema) ) ;
929+ }
930+
931+ #[ rstest]
932+ fn test_can_be_pushed_down_octet_length_unsupported_operand ( test_schema : Schema ) {
933+ let expr = Arc :: new ( df_expr:: Column :: new ( "unsupported_list" , 5 ) ) as Arc < dyn PhysicalExpr > ;
934+ let octet_length = Arc :: new ( ScalarFunctionExpr :: new (
935+ "octet_length" ,
936+ Arc :: new ( ScalarUDF :: from ( OctetLengthFunc :: new ( ) ) ) ,
937+ vec ! [ expr] ,
938+ Arc :: new ( Field :: new ( "octet_length" , DataType :: Int32 , true ) ) ,
939+ Arc :: new ( ConfigOptions :: new ( ) ) ,
940+ ) ) as Arc < dyn PhysicalExpr > ;
941+
942+ assert ! ( !can_be_pushed_down_impl( & octet_length, & test_schema) ) ;
943+ }
944+
868945 // https://github.com/vortex-data/vortex/issues/6211
869946 #[ tokio:: test]
870947 async fn test_cast_int_to_string ( ) -> anyhow:: Result < ( ) > {
0 commit comments