@@ -323,7 +323,9 @@ impl FusedArrayTransformStream {
323323 for col_name in & self . array_columns {
324324 let idx = input_schema. index_of ( col_name) ?;
325325 let col = batch. column ( idx) ;
326- let output_array = self . apply_identity_transform ( col, col_name) ?;
326+
327+ let filtered_col = datafusion:: arrow:: compute:: filter ( col. as_ref ( ) , & bool_mask) ?;
328+ let output_array = self . apply_identity_transform ( & filtered_col, col_name) ?;
327329 output_columns. push ( output_array) ;
328330 }
329331 } else {
@@ -627,28 +629,35 @@ mod tests {
627629 use datafusion:: arrow:: datatypes:: { Field , Schema } ;
628630 use datafusion:: physical_plan:: test:: TestMemoryExec ;
629631
630- fn create_test_batch ( ) -> RecordBatch {
632+ fn create_test_batch (
633+ row0_lista : Vec < f64 > ,
634+ row0_listb : Vec < f64 > ,
635+ row1_lista : Vec < f64 > ,
636+ row1_listb : Vec < f64 > ,
637+ ) -> RecordBatch {
631638 let mut list_builder_a = ListBuilder :: new ( Float64Builder :: new ( ) ) ;
632639 let mut list_builder_b = ListBuilder :: new ( Float64Builder :: new ( ) ) ;
633640
634- // Row 0: [1.0, 2.0, 3.0], [10.0, 20.0, 30.0]
635- list_builder_a . values ( ) . append_value ( 1.0 ) ;
636- list_builder_a. values ( ) . append_value ( 2.0 ) ;
637- list_builder_a . values ( ) . append_value ( 3.0 ) ;
641+ // Row 0
642+ for val in row0_lista {
643+ list_builder_a. values ( ) . append_value ( val ) ;
644+ }
638645 list_builder_a. append ( true ) ;
639646
640- list_builder_b . values ( ) . append_value ( 10.0 ) ;
641- list_builder_b. values ( ) . append_value ( 20.0 ) ;
642- list_builder_b . values ( ) . append_value ( 30.0 ) ;
647+ for val in row0_listb {
648+ list_builder_b. values ( ) . append_value ( val ) ;
649+ }
643650 list_builder_b. append ( true ) ;
644651
645- // Row 1: [4.0, 5.0], [40.0, 50.0]
646- list_builder_a. values ( ) . append_value ( 4.0 ) ;
647- list_builder_a. values ( ) . append_value ( 5.0 ) ;
652+ // Row 1
653+ for val in row1_lista {
654+ list_builder_a. values ( ) . append_value ( val) ;
655+ }
648656 list_builder_a. append ( true ) ;
649657
650- list_builder_b. values ( ) . append_value ( 40.0 ) ;
651- list_builder_b. values ( ) . append_value ( 50.0 ) ;
658+ for val in row1_listb {
659+ list_builder_b. values ( ) . append_value ( val) ;
660+ }
652661 list_builder_b. append ( true ) ;
653662
654663 let arr_a = list_builder_a. finish ( ) ;
@@ -677,9 +686,23 @@ mod tests {
677686 . unwrap ( )
678687 }
679688
689+ macro_rules! create_test_batch {
690+ ( $row0_lista: expr, $row0_listb: expr, $row1_lista: expr, $row1_listb: expr) => {
691+ create_test_batch( $row0_lista, $row0_listb, $row1_lista, $row1_listb)
692+ } ;
693+ ( ) => {
694+ create_test_batch(
695+ vec![ 1.0 , 2.0 , 3.0 ] ,
696+ vec![ 10.0 , 20.0 , 30.0 ] ,
697+ vec![ 4.0 , 5.0 ] ,
698+ vec![ 40.0 , 50.0 ] ,
699+ )
700+ } ;
701+ }
702+
680703 #[ tokio:: test]
681704 async fn test_identity_transform ( ) {
682- let batch = create_test_batch ( ) ;
705+ let batch = create_test_batch ! ( ) ;
683706 let schema = batch. schema ( ) ;
684707
685708 let mem_exec = TestMemoryExec :: try_new ( & [ vec ! [ batch. clone( ) ] ] , schema, None ) . unwrap ( ) ;
@@ -695,11 +718,41 @@ mod tests {
695718
696719 // Schema should have 2 fields: metadata + values_a_out
697720 assert_eq ! ( fused. schema( ) . fields( ) . len( ) , 2 ) ;
721+
722+ let ctx = Arc :: new ( TaskContext :: default ( ) ) ;
723+ let mut stream = fused. execute ( 0 , ctx) . unwrap ( ) ;
724+ let result_batch = stream. next ( ) . await . unwrap ( ) . unwrap ( ) ;
725+ assert_eq ! ( result_batch. num_rows( ) , 2 ) ;
726+ }
727+
728+ #[ tokio:: test]
729+ async fn test_identity_transform_with_empty_array ( ) {
730+ let batch = create_test_batch ! ( vec![ ] , vec![ ] , vec![ 4.0 , 5.0 ] , vec![ 40.0 , 50.0 ] ) ;
731+ let schema = batch. schema ( ) ;
732+
733+ let mem_exec = TestMemoryExec :: try_new ( & [ vec ! [ batch. clone( ) ] ] , schema, None ) . unwrap ( ) ;
734+
735+ let fused = FusedArrayTransformExec :: try_new (
736+ Arc :: new ( mem_exec) ,
737+ vec ! [ "values_a" . to_string( ) , "values_b" . to_string( ) ] ,
738+ vec ! [ "metadata" . to_string( ) ] ,
739+ vec ! [ "values_a_out" . to_string( ) , "values_b_out" . to_string( ) ] ,
740+ vec ! [ ] ,
741+ )
742+ . unwrap ( ) ;
743+
744+ // Schema should have 3 fields: metadata + values_a_out + values_b_out
745+ assert_eq ! ( fused. schema( ) . fields( ) . len( ) , 3 ) ;
746+ // row0 should be filtered out due to empty array, so only row1 remains
747+ let ctx = Arc :: new ( TaskContext :: default ( ) ) ;
748+ let mut stream = fused. execute ( 0 , ctx) . unwrap ( ) ;
749+ let result_batch = stream. next ( ) . await . unwrap ( ) . unwrap ( ) ;
750+ assert_eq ! ( result_batch. num_rows( ) , 1 ) ;
698751 }
699752
700753 #[ tokio:: test]
701754 async fn test_execution ( ) {
702- let batch = create_test_batch ( ) ;
755+ let batch = create_test_batch ! ( ) ;
703756 let schema = batch. schema ( ) ;
704757
705758 let mem_exec = TestMemoryExec :: try_new ( & [ vec ! [ batch. clone( ) ] ] , schema, None ) . unwrap ( ) ;
0 commit comments