@@ -23,7 +23,10 @@ use std::str::FromStr;
2323use std:: sync:: Arc ;
2424
2525use arrow_arith:: boolean:: { and, and_kleene, is_not_null, is_null, not, or, or_kleene} ;
26+ use arrow_array:: cast:: AsArray ;
27+ use arrow_array:: types:: { Float32Type , Float64Type } ;
2628use arrow_array:: { Array , ArrayRef , BooleanArray , Datum as ArrowDatum , RecordBatch , Scalar } ;
29+ use arrow_buffer:: BooleanBuffer ;
2730use arrow_cast:: cast:: cast;
2831use arrow_ord:: cmp:: { eq, gt, gt_eq, lt, lt_eq, neq} ;
2932use arrow_schema:: {
@@ -1509,6 +1512,35 @@ fn project_column(
15091512 }
15101513}
15111514
1515+ fn compute_is_nan ( array : & ArrayRef ) -> std:: result:: Result < BooleanArray , ArrowError > {
1516+ // Compute NaN over the contiguous values slice, then fold the null bitmap
1517+ // in with a single bitwise AND so that null slots become false.
1518+ let ( is_nan, nulls) = match array. data_type ( ) {
1519+ DataType :: Float32 => {
1520+ let arr = array. as_primitive :: < Float32Type > ( ) ;
1521+ (
1522+ BooleanBuffer :: from_iter ( arr. values ( ) . iter ( ) . map ( |v| v. is_nan ( ) ) ) ,
1523+ arr. nulls ( ) ,
1524+ )
1525+ }
1526+ DataType :: Float64 => {
1527+ let arr = array. as_primitive :: < Float64Type > ( ) ;
1528+ (
1529+ BooleanBuffer :: from_iter ( arr. values ( ) . iter ( ) . map ( |v| v. is_nan ( ) ) ) ,
1530+ arr. nulls ( ) ,
1531+ )
1532+ }
1533+ _ => unreachable ! ( "is_nan is only valid for float types" ) ,
1534+ } ;
1535+
1536+ let values = match nulls {
1537+ Some ( nulls) => & is_nan & nulls. inner ( ) ,
1538+ None => is_nan,
1539+ } ;
1540+
1541+ Ok ( BooleanArray :: new ( values, None ) )
1542+ }
1543+
15121544type PredicateResult =
15131545 dyn FnMut ( RecordBatch ) -> std:: result:: Result < BooleanArray , ArrowError > + Send + ' static ;
15141546
@@ -1591,8 +1623,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
15911623 reference : & BoundReference ,
15921624 _predicate : & BoundPredicate ,
15931625 ) -> Result < Box < PredicateResult > > {
1594- if self . bound_reference ( reference) ?. is_some ( ) {
1595- self . build_always_true ( )
1626+ if let Some ( idx) = self . bound_reference ( reference) ? {
1627+ Ok ( Box :: new ( move |batch| {
1628+ let column = project_column ( & batch, idx) ?;
1629+ compute_is_nan ( & column)
1630+ } ) )
15961631 } else {
15971632 // A missing column, treating it as null.
15981633 self . build_always_false ( )
@@ -1604,8 +1639,12 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
16041639 reference : & BoundReference ,
16051640 _predicate : & BoundPredicate ,
16061641 ) -> Result < Box < PredicateResult > > {
1607- if self . bound_reference ( reference) ?. is_some ( ) {
1608- self . build_always_false ( )
1642+ if let Some ( idx) = self . bound_reference ( reference) ? {
1643+ Ok ( Box :: new ( move |batch| {
1644+ let column = project_column ( & batch, idx) ?;
1645+ let is_nan = compute_is_nan ( & column) ?;
1646+ not ( & is_nan)
1647+ } ) )
16091648 } else {
16101649 // A missing column, treating it as null.
16111650 self . build_always_true ( )
@@ -2002,7 +2041,7 @@ mod tests {
20022041 use std:: sync:: Arc ;
20032042
20042043 use arrow_array:: cast:: AsArray ;
2005- use arrow_array:: { ArrayRef , LargeStringArray , RecordBatch , StringArray } ;
2044+ use arrow_array:: { Array , ArrayRef , BooleanArray , LargeStringArray , RecordBatch , StringArray } ;
20062045 use arrow_schema:: { DataType , Field , Schema as ArrowSchema , TimeUnit } ;
20072046 use futures:: TryStreamExt ;
20082047 use parquet:: arrow:: arrow_reader:: { RowSelection , RowSelector } ;
@@ -5464,4 +5503,81 @@ message schema {
54645503 ts_array. value( 0 )
54655504 ) ;
54665505 }
5506+
5507+ fn apply_predicate_to_batch (
5508+ predicate : Predicate ,
5509+ schema : SchemaRef ,
5510+ batch : RecordBatch ,
5511+ ) -> BooleanArray {
5512+ use super :: PredicateConverter ;
5513+
5514+ let bound = predicate. bind ( schema, true ) . unwrap ( ) ;
5515+
5516+ // Build a trivial Parquet schema with one float column at field id 4
5517+ let message_type = "
5518+ message schema {
5519+ optional float qux = 4;
5520+ }
5521+ " ;
5522+ let parquet_type = parse_message_type ( message_type) . expect ( "parse schema" ) ;
5523+ let parquet_schema = SchemaDescriptor :: new ( Arc :: new ( parquet_type) ) ;
5524+
5525+ let column_map = HashMap :: from ( [ ( 4i32 , 0usize ) ] ) ;
5526+ let column_indices = vec ! [ 0usize ] ;
5527+
5528+ let mut converter = PredicateConverter {
5529+ parquet_schema : & parquet_schema,
5530+ column_map : & column_map,
5531+ column_indices : & column_indices,
5532+ } ;
5533+
5534+ let mut predicate_fn = visit ( & mut converter, & bound) . unwrap ( ) ;
5535+ predicate_fn ( batch) . unwrap ( )
5536+ }
5537+
5538+ #[ test]
5539+ fn test_predicate_converter_nan ( ) {
5540+ use arrow_array:: Float32Array ;
5541+
5542+ let schema = table_schema_simple ( ) ;
5543+ let arrow_schema = Arc :: new ( ArrowSchema :: new ( vec ! [ Field :: new(
5544+ "qux" ,
5545+ DataType :: Float32 ,
5546+ true ,
5547+ ) ] ) ) ;
5548+ let values = vec ! [ Some ( 1.0f32 ) , Some ( f32 :: NAN ) , None , Some ( 0.0f32 ) ] ;
5549+
5550+ // is_nan: non-null-propagating per Java's implementation - NULL → false
5551+ let batch = RecordBatch :: try_new ( arrow_schema. clone ( ) , vec ! [ Arc :: new( Float32Array :: from(
5552+ values. clone( ) ,
5553+ ) ) ] )
5554+ . unwrap ( ) ;
5555+ let result =
5556+ apply_predicate_to_batch ( Reference :: new ( "qux" ) . is_nan ( ) , schema. clone ( ) , batch) ;
5557+ assert_eq ! (
5558+ [
5559+ result. value( 0 ) ,
5560+ result. value( 1 ) ,
5561+ result. value( 2 ) ,
5562+ result. value( 3 )
5563+ ] ,
5564+ [ false , true , false , false ]
5565+ ) ;
5566+ assert ! ( !result. is_null( 2 ) ) ;
5567+
5568+ // not_nan: non-null-propagating per Java's implementation - NULL → true
5569+ let batch =
5570+ RecordBatch :: try_new ( arrow_schema, vec ! [ Arc :: new( Float32Array :: from( values) ) ] ) . unwrap ( ) ;
5571+ let result = apply_predicate_to_batch ( Reference :: new ( "qux" ) . is_not_nan ( ) , schema, batch) ;
5572+ assert_eq ! (
5573+ [
5574+ result. value( 0 ) ,
5575+ result. value( 1 ) ,
5576+ result. value( 2 ) ,
5577+ result. value( 3 )
5578+ ] ,
5579+ [ true , false , true , true ]
5580+ ) ;
5581+ assert ! ( !result. is_null( 2 ) ) ;
5582+ }
54675583}
0 commit comments