Skip to content

Commit 185dc7d

Browse files
authored
Allow boolean vs integer comparison (#31)
* Allow comparison between boolean and int values * Allow comparison between boolean and int values * Add integer types * clippy+fmt+tests fix * tests fix * Fix incompatible types test * Fix slt test * Fix docs
1 parent 087ebef commit 185dc7d

8 files changed

Lines changed: 67 additions & 17 deletions

File tree

datafusion-examples/examples/dataframe.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ async fn main() -> Result<()> {
6666
write_out(&ctx).await?;
6767
register_aggregate_test_data("t1", &ctx).await?;
6868
register_aggregate_test_data("t2", &ctx).await?;
69-
where_scalar_subquery(&ctx).await?;
70-
where_in_subquery(&ctx).await?;
69+
Box::pin(where_scalar_subquery(&ctx)).await?;
70+
Box::pin(where_in_subquery(&ctx)).await?;
7171
where_exist_subquery(&ctx).await?;
7272
Ok(())
7373
}

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
734734
.or_else(|| binary_coercion(lhs_type, rhs_type))
735735
.or_else(|| struct_coercion(lhs_type, rhs_type))
736736
.or_else(|| map_coercion(lhs_type, rhs_type))
737+
.or_else(|| boolean_coercion(lhs_type, rhs_type))
737738
}
738739

739740
/// Similar to [`comparison_coercion`] but prefers numeric if compares with
@@ -1007,6 +1008,20 @@ fn map_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
10071008
}
10081009
}
10091010

1011+
/// Coercion rules for boolean types: If at least one argument is
1012+
/// a boolean type and both arguments can be coerced into a boolean type, coerce
1013+
/// to boolean type.
1014+
fn boolean_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
1015+
use arrow::datatypes::DataType::*;
1016+
match (lhs_type, rhs_type) {
1017+
(Boolean, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64)
1018+
| (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Boolean) => {
1019+
Some(Boolean)
1020+
}
1021+
_ => None,
1022+
}
1023+
}
1024+
10101025
/// Returns the output type of applying mathematics operations such as
10111026
/// `+` to arguments of `lhs_type` and `rhs_type`.
10121027
fn mathematics_numerical_coercion(
@@ -2434,6 +2449,32 @@ mod tests {
24342449
DataType::List(Arc::clone(&inner_field))
24352450
);
24362451

2452+
// boolean
2453+
let int_types = vec![
2454+
DataType::Int8,
2455+
DataType::Int16,
2456+
DataType::Int32,
2457+
DataType::Int64,
2458+
DataType::UInt8,
2459+
DataType::UInt16,
2460+
DataType::UInt32,
2461+
DataType::UInt64,
2462+
];
2463+
for int_type in int_types {
2464+
test_coercion_binary_rule!(
2465+
DataType::Boolean,
2466+
int_type,
2467+
Operator::Eq,
2468+
DataType::Boolean
2469+
);
2470+
test_coercion_binary_rule!(
2471+
int_type,
2472+
DataType::Boolean,
2473+
Operator::Eq,
2474+
DataType::Boolean
2475+
);
2476+
}
2477+
24372478
// Negative test: inner_timestamp_field and inner_field are not compatible because their inner types are not compatible
24382479
let inner_timestamp_field = Arc::new(Field::new_list_field(
24392480
DataType::Timestamp(TimeUnit::Microsecond, None),

datafusion/expr/src/logical_plan/display.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -797,12 +797,12 @@ mod tests {
797797
let pivot = Pivot {
798798
input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
799799
produce_one_row: false,
800-
schema: schema.clone(),
800+
schema: Arc::clone(&schema),
801801
})),
802802
aggregate_expr: Expr::Column(Column::from_name("sum_value")),
803803
pivot_column: Column::from_name("category"),
804804
pivot_values,
805-
schema: schema.clone(),
805+
schema: Arc::clone(&schema),
806806
value_subquery: None,
807807
default_on_null_expr: None,
808808
};
@@ -834,15 +834,15 @@ mod tests {
834834
let pivot = Pivot {
835835
input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
836836
produce_one_row: false,
837-
schema: schema.clone(),
837+
schema: Arc::clone(&schema),
838838
})),
839839
aggregate_expr: Expr::Column(Column::from_name("sum_value")),
840840
pivot_column: Column::from_name("category"),
841841
pivot_values: vec![],
842-
schema: schema.clone(),
842+
schema: Arc::clone(&schema),
843843
value_subquery: Some(Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
844844
produce_one_row: false,
845-
schema: schema.clone(),
845+
schema: Arc::clone(&schema),
846846
}))),
847847
default_on_null_expr: None,
848848
};

datafusion/functions/src/datetime/to_date.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ mod tests {
166166
use arrow::datatypes::DataType;
167167
use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type};
168168
use datafusion_common::ScalarValue;
169-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
169+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
170170
use std::sync::Arc;
171171

172172
#[test]

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,11 +1572,18 @@ mod test {
15721572
let expected = "Projection: a IS TRUE\n EmptyRelation";
15731573
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
15741574

1575-
let empty = empty_with_type(DataType::Int64);
1575+
let empty = empty_with_type(DataType::Float64);
15761576
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
15771577
let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, "");
15781578
let err = ret.unwrap_err().to_string();
1579-
assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}");
1579+
assert!(err.contains("Cannot infer common argument type for comparison operation Float64 IS DISTINCT FROM Boolean"), "{err}");
1580+
1581+
// integer
1582+
let expr = col("a").is_true();
1583+
let empty = empty_with_type(DataType::Int64);
1584+
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?);
1585+
let expected = "Projection: CAST(a AS Boolean) IS TRUE\n EmptyRelation";
1586+
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?;
15801587

15811588
// is not true
15821589
let expr = col("a").is_not_true();

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,7 @@ mod tests {
10861086

10871087
#[test]
10881088
fn case_test_incompatible() -> Result<()> {
1089-
// 1 then is int64
1089+
// 1 then is float64
10901090
// 2 then is boolean
10911091
let batch = case_test_batch()?;
10921092
let schema = batch.schema();
@@ -1098,7 +1098,7 @@ mod tests {
10981098
lit("foo"),
10991099
&batch.schema(),
11001100
)?;
1101-
let then1 = lit(123i32);
1101+
let then1 = lit(1.23f64);
11021102
let when2 = binary(
11031103
col("a", &schema)?,
11041104
Operator::Eq,

datafusion/sqllogictest/test_files/scalar.slt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1536,8 +1536,10 @@ SELECT not(true), not(false)
15361536
----
15371537
false true
15381538

1539-
query error type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean
1539+
query BB
15401540
SELECT not(1), not(0)
1541+
----
1542+
false true
15411543

15421544
query ?B
15431545
SELECT null, not(null)

docs/source/library-user-guide/adding-udfs.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,8 +1075,8 @@ use datafusion_expr::Expr;
10751075
pub struct EchoFunction {}
10761076

10771077
impl TableFunctionImpl for EchoFunction {
1078-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
1079-
let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else {
1078+
fn call(&self, exprs: &[(datafusion_expr::Expr, Option<std::string::String>)]) -> Result<Arc<dyn TableProvider>> {
1079+
let Some((Expr::Literal(ScalarValue::Int64(Some(value))), _)) = exprs.get(0) else {
10801080
return plan_err!("First argument must be an integer");
10811081
};
10821082

@@ -1116,8 +1116,8 @@ With the UDTF implemented, you can register it with the `SessionContext`:
11161116
# pub struct EchoFunction {}
11171117
#
11181118
# impl TableFunctionImpl for EchoFunction {
1119-
# fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
1120-
# let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else {
1119+
# fn call(&self, exprs: &[(datafusion_expr::Expr, Option<std::string::String>)]) -> Result<Arc<dyn TableProvider>> {
1120+
# let Some((Expr::Literal(ScalarValue::Int64(Some(value))), _)) = exprs.get(0) else {
11211121
# return plan_err!("First argument must be an integer");
11221122
# };
11231123
#

0 commit comments

Comments
 (0)