@@ -323,7 +323,9 @@ impl FusedArrayTransformStream {
323323 for col_name in & self . array_columns {
324324 let idx = input_schema. index_of ( col_name) ?;
325325 let col = batch. column ( idx) ;
326- let output_array = self . apply_identity_transform ( col, col_name) ?;
326+
327+ let filtered_col = datafusion:: arrow:: compute:: filter ( col. as_ref ( ) , & bool_mask) ?;
328+ let output_array = self . apply_identity_transform ( & filtered_col, col_name) ?;
327329 output_columns. push ( output_array) ;
328330 }
329331 } else {
@@ -627,28 +629,35 @@ mod tests {
627629 use datafusion:: arrow:: datatypes:: { Field , Schema } ;
628630 use datafusion:: physical_plan:: test:: TestMemoryExec ;
629631
630- fn create_test_batch ( ) -> RecordBatch {
632+ fn create_test_batch (
633+ row0_lista : Vec < f64 > ,
634+ row0_listb : Vec < f64 > ,
635+ row1_lista : Vec < f64 > ,
636+ row1_listb : Vec < f64 >
637+ ) -> RecordBatch {
631638 let mut list_builder_a = ListBuilder :: new ( Float64Builder :: new ( ) ) ;
632639 let mut list_builder_b = ListBuilder :: new ( Float64Builder :: new ( ) ) ;
633640
634- // Row 0: [1.0, 2.0, 3.0], [10.0, 20.0, 30.0]
635- list_builder_a . values ( ) . append_value ( 1.0 ) ;
636- list_builder_a. values ( ) . append_value ( 2.0 ) ;
637- list_builder_a . values ( ) . append_value ( 3.0 ) ;
641+ // Row 0
642+ for val in row0_lista {
643+ list_builder_a. values ( ) . append_value ( val ) ;
644+ }
638645 list_builder_a. append ( true ) ;
639646
640- list_builder_b . values ( ) . append_value ( 10.0 ) ;
641- list_builder_b. values ( ) . append_value ( 20.0 ) ;
642- list_builder_b . values ( ) . append_value ( 30.0 ) ;
647+ for val in row0_listb {
648+ list_builder_b. values ( ) . append_value ( val ) ;
649+ }
643650 list_builder_b. append ( true ) ;
644651
645- // Row 1: [4.0, 5.0], [40.0, 50.0]
646- list_builder_a. values ( ) . append_value ( 4.0 ) ;
647- list_builder_a. values ( ) . append_value ( 5.0 ) ;
652+ // Row 1
653+ for val in row1_lista {
654+ list_builder_a. values ( ) . append_value ( val) ;
655+ }
648656 list_builder_a. append ( true ) ;
649657
650- list_builder_b. values ( ) . append_value ( 40.0 ) ;
651- list_builder_b. values ( ) . append_value ( 50.0 ) ;
658+ for val in row1_listb {
659+ list_builder_b. values ( ) . append_value ( val) ;
660+ }
652661 list_builder_b. append ( true ) ;
653662
654663 let arr_a = list_builder_a. finish ( ) ;
@@ -677,9 +686,18 @@ mod tests {
677686 . unwrap ( )
678687 }
679688
689+ macro_rules! create_test_batch {
690+ ( $row0_lista: expr, $row0_listb: expr, $row1_lista: expr, $row1_listb: expr) => {
691+ create_test_batch( $row0_lista, $row0_listb, $row1_lista, $row1_listb)
692+ } ;
693+ ( ) => {
694+ create_test_batch( vec![ 1.0 , 2.0 , 3.0 ] , vec![ 10.0 , 20.0 , 30.0 ] , vec![ 4.0 , 5.0 ] , vec![ 40.0 , 50.0 ] )
695+ } ;
696+ }
697+
680698 #[ tokio:: test]
681699 async fn test_identity_transform ( ) {
682- let batch = create_test_batch ( ) ;
700+ let batch = create_test_batch ! ( ) ;
683701 let schema = batch. schema ( ) ;
684702
685703 let mem_exec = TestMemoryExec :: try_new ( & [ vec ! [ batch. clone( ) ] ] , schema, None ) . unwrap ( ) ;
@@ -695,11 +713,41 @@ mod tests {
695713
696714 // Schema should have 2 fields: metadata + values_a_out
697715 assert_eq ! ( fused. schema( ) . fields( ) . len( ) , 2 ) ;
716+
717+ let ctx = Arc :: new ( TaskContext :: default ( ) ) ;
718+ let mut stream = fused. execute ( 0 , ctx) . unwrap ( ) ;
719+ let result_batch = stream. next ( ) . await . unwrap ( ) . unwrap ( ) ;
720+ assert_eq ! ( result_batch. num_rows( ) , 2 ) ;
721+ }
722+
723+ #[ tokio:: test]
724+ async fn test_identity_transform_with_empty_array ( ) {
725+ let batch = create_test_batch ! ( vec![ ] , vec![ ] , vec![ 4.0 , 5.0 ] , vec![ 40.0 , 50.0 ] ) ;
726+ let schema = batch. schema ( ) ;
727+
728+ let mem_exec = TestMemoryExec :: try_new ( & [ vec ! [ batch. clone( ) ] ] , schema, None ) . unwrap ( ) ;
729+
730+ let fused = FusedArrayTransformExec :: try_new (
731+ Arc :: new ( mem_exec) ,
732+ vec ! [ "values_a" . to_string( ) , "values_b" . to_string( ) ] ,
733+ vec ! [ "metadata" . to_string( ) ] ,
734+ vec ! [ "values_a_out" . to_string( ) , "values_b_out" . to_string( ) ] ,
735+ vec ! [ ] ,
736+ )
737+ . unwrap ( ) ;
738+
739+ // Schema should have 3 fields: metadata + values_a_out + values_b_out
740+ assert_eq ! ( fused. schema( ) . fields( ) . len( ) , 3 ) ;
741+ // row0 should be filtered out due to empty array, so only row1 remains
742+ let ctx = Arc :: new ( TaskContext :: default ( ) ) ;
743+ let mut stream = fused. execute ( 0 , ctx) . unwrap ( ) ;
744+ let result_batch = stream. next ( ) . await . unwrap ( ) . unwrap ( ) ;
745+ assert_eq ! ( result_batch. num_rows( ) , 1 ) ;
698746 }
699747
700748 #[ tokio:: test]
701749 async fn test_execution ( ) {
702- let batch = create_test_batch ( ) ;
750+ let batch = create_test_batch ! ( ) ;
703751 let schema = batch. schema ( ) ;
704752
705753 let mem_exec = TestMemoryExec :: try_new ( & [ vec ! [ batch. clone( ) ] ] , schema, None ) . unwrap ( ) ;
0 commit comments