Skip to content

Commit d1e6eb4

Browse files
authored
Refactor BinaryTypeCoercer to Handle Null Coercion Early and Avoid Redundant Checks (apache#16768)
Introduced get_result helper for arithmetic operations result type inference. Moved null coercion logic into signature to apply uniformly across arithmetic operators. Removed duplicated inline logic for arithmetic result type inference in signature_inner. Updated existing tests in expr_simplifier to use lit() for constructing null literals.
1 parent e1a1889 commit d1e6eb4

2 files changed

Lines changed: 69 additions & 46 deletions

File tree

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 68 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,57 @@ impl<'a> BinaryTypeCoercer<'a> {
124124

125125
/// Returns a [`Signature`] for applying `op` to arguments of type `lhs` and `rhs`
126126
fn signature(&'a self) -> Result<Signature> {
127+
if let Some(coerced) = null_coercion(self.lhs, self.rhs) {
128+
use Operator::*;
129+
// Special handling for arithmetic + null coercion:
130+
// For arithmetic operators on non-temporal types, we must handle the result type here using Arrow's numeric kernel.
131+
// This is because Arrow expects concrete numeric types, and this ensures the correct result type (e.g., for NULL + Int32, result is Int32).
132+
// For all other cases (including temporal arithmetic and non-arithmetic operators),
133+
// we can delegate to signature_inner(&coerced, &coerced), which handles the necessary logic for those operators.
134+
// In those cases, signature_inner is designed to work with the coerced type, even if it originated from a NULL.
135+
if matches!(self.op, Plus | Minus | Multiply | Divide | Modulo)
136+
&& !coerced.is_temporal()
137+
{
138+
let ret = self.get_result(&coerced, &coerced).map_err(|e| {
139+
plan_datafusion_err!(
140+
"Cannot get result type for arithmetic operation {coerced} {} {coerced}: {e}",
141+
self.op
142+
)
143+
})?;
144+
145+
return Ok(Signature {
146+
lhs: coerced.clone(),
147+
rhs: coerced,
148+
ret,
149+
});
150+
}
151+
return self.signature_inner(&coerced, &coerced);
152+
}
153+
self.signature_inner(self.lhs, self.rhs)
154+
}
155+
156+
/// Returns the result type for arithmetic operations
157+
fn get_result(
158+
&self,
159+
lhs: &DataType,
160+
rhs: &DataType,
161+
) -> arrow::error::Result<DataType> {
162+
use arrow::compute::kernels::numeric::*;
163+
let l = new_empty_array(lhs);
164+
let r = new_empty_array(rhs);
165+
166+
let result = match self.op {
167+
Operator::Plus => add_wrapping(&l, &r),
168+
Operator::Minus => sub_wrapping(&l, &r),
169+
Operator::Multiply => mul_wrapping(&l, &r),
170+
Operator::Divide => div(&l, &r),
171+
Operator::Modulo => rem(&l, &r),
172+
_ => unreachable!(),
173+
};
174+
result.map(|x| x.data_type().clone())
175+
}
176+
177+
fn signature_inner(&'a self, lhs: &DataType, rhs: &DataType) -> Result<Signature> {
127178
use arrow::datatypes::DataType::*;
128179
use Operator::*;
129180
let result = match self.op {
@@ -135,7 +186,7 @@ impl<'a> BinaryTypeCoercer<'a> {
135186
GtEq |
136187
IsDistinctFrom |
137188
IsNotDistinctFrom => {
138-
comparison_coercion(self.lhs, self.rhs).map(Signature::comparison).ok_or_else(|| {
189+
comparison_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
139190
plan_datafusion_err!(
140191
"Cannot infer common argument type for comparison operation {} {} {}",
141192
self.lhs,
@@ -144,7 +195,7 @@ impl<'a> BinaryTypeCoercer<'a> {
144195
)
145196
})
146197
}
147-
And | Or => if matches!((self.lhs, self.rhs), (Boolean | Null, Boolean | Null)) {
198+
And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) {
148199
// Logical binary boolean operators can only be evaluated for
149200
// boolean or null arguments.
150201
Ok(Signature::uniform(Boolean))
@@ -154,78 +205,62 @@ impl<'a> BinaryTypeCoercer<'a> {
154205
)
155206
}
156207
RegexMatch | RegexIMatch | RegexNotMatch | RegexNotIMatch => {
157-
regex_coercion(self.lhs, self.rhs).map(Signature::comparison).ok_or_else(|| {
208+
regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
158209
plan_datafusion_err!(
159210
"Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
160211
)
161212
})
162213
}
163214
LikeMatch | ILikeMatch | NotLikeMatch | NotILikeMatch => {
164-
regex_coercion(self.lhs, self.rhs).map(Signature::comparison).ok_or_else(|| {
215+
regex_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
165216
plan_datafusion_err!(
166217
"Cannot infer common argument type for regex operation {} {} {}", self.lhs, self.op, self.rhs
167218
)
168219
})
169220
}
170221
BitwiseAnd | BitwiseOr | BitwiseXor | BitwiseShiftRight | BitwiseShiftLeft => {
171-
bitwise_coercion(self.lhs, self.rhs).map(Signature::uniform).ok_or_else(|| {
222+
bitwise_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
172223
plan_datafusion_err!(
173224
"Cannot infer common type for bitwise operation {} {} {}", self.lhs, self.op, self.rhs
174225
)
175226
})
176227
}
177228
StringConcat => {
178-
string_concat_coercion(self.lhs, self.rhs).map(Signature::uniform).ok_or_else(|| {
229+
string_concat_coercion(lhs, rhs).map(Signature::uniform).ok_or_else(|| {
179230
plan_datafusion_err!(
180231
"Cannot infer common string type for string concat operation {} {} {}", self.lhs, self.op, self.rhs
181232
)
182233
})
183234
}
184235
AtArrow | ArrowAt => {
185236
// Array contains or search (similar to LIKE) operation
186-
array_coercion(self.lhs, self.rhs)
187-
.or_else(|| like_coercion(self.lhs, self.rhs)).map(Signature::comparison).ok_or_else(|| {
237+
array_coercion(lhs, rhs)
238+
.or_else(|| like_coercion(lhs, rhs)).map(Signature::comparison).ok_or_else(|| {
188239
plan_datafusion_err!(
189240
"Cannot infer common argument type for operation {} {} {}", self.lhs, self.op, self.rhs
190241
)
191242
})
192243
}
193244
AtAt => {
194245
// text search has similar signature to LIKE
195-
like_coercion(self.lhs, self.rhs).map(Signature::comparison).ok_or_else(|| {
246+
like_coercion(lhs, rhs).map(Signature::comparison).ok_or_else(|| {
196247
plan_datafusion_err!(
197248
"Cannot infer common argument type for AtAt operation {} {} {}", self.lhs, self.op, self.rhs
198249
)
199250
})
200251
}
201252
Plus | Minus | Multiply | Divide | Modulo => {
202-
let get_result = |lhs, rhs| {
203-
use arrow::compute::kernels::numeric::*;
204-
let l = new_empty_array(lhs);
205-
let r = new_empty_array(rhs);
206-
207-
let result = match self.op {
208-
Plus => add_wrapping(&l, &r),
209-
Minus => sub_wrapping(&l, &r),
210-
Multiply => mul_wrapping(&l, &r),
211-
Divide => div(&l, &r),
212-
Modulo => rem(&l, &r),
213-
_ => unreachable!(),
214-
};
215-
result.map(|x| x.data_type().clone())
216-
};
217-
218-
if let Ok(ret) = get_result(self.lhs, self.rhs) {
253+
if let Ok(ret) = self.get_result(lhs, rhs) {
219254
// Temporal arithmetic, e.g. Date32 + Interval
220255
Ok(Signature{
221-
lhs: self.lhs.clone(),
222-
rhs: self.rhs.clone(),
256+
lhs: lhs.clone(),
257+
rhs: rhs.clone(),
223258
ret,
224259
})
225-
} else if let Some(coerced) = temporal_coercion_strict_timezone(self.lhs, self.rhs) {
260+
} else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) {
226261
// Temporal arithmetic by first coercing to a common time representation
227262
// e.g. Date32 - Timestamp
228-
let ret = get_result(&coerced, &coerced).map_err(|e| {
263+
let ret = self.get_result(&coerced, &coerced).map_err(|e| {
229264
plan_datafusion_err!(
230265
"Cannot get result type for temporal operation {coerced} {} {coerced}: {e}", self.op
231266
)
@@ -235,9 +270,9 @@ impl<'a> BinaryTypeCoercer<'a> {
235270
rhs: coerced,
236271
ret,
237272
})
238-
} else if let Some((lhs, rhs)) = math_decimal_coercion(self.lhs, self.rhs) {
273+
} else if let Some((lhs, rhs)) = math_decimal_coercion(lhs, rhs) {
239274
// Decimal arithmetic, e.g. Decimal(10, 2) + Decimal(10, 0)
240-
let ret = get_result(&lhs, &rhs).map_err(|e| {
275+
let ret = self.get_result(&lhs, &rhs).map_err(|e| {
241276
plan_datafusion_err!(
242277
"Cannot get result type for decimal operation {} {} {}: {e}", self.lhs, self.op, self.rhs
243278
)
@@ -247,21 +282,9 @@ impl<'a> BinaryTypeCoercer<'a> {
247282
rhs,
248283
ret,
249284
})
250-
} else if let Some(numeric) = mathematics_numerical_coercion(self.lhs, self.rhs) {
285+
} else if let Some(numeric) = mathematics_numerical_coercion(lhs, rhs) {
251286
// Numeric arithmetic, e.g. Int32 + Int32
252287
Ok(Signature::uniform(numeric))
253-
} else if let Some(coerced) = null_coercion(self.lhs, self.rhs) {
254-
// One side is NULL, cast it to the other's type
255-
let ret = get_result(&coerced, &coerced).map_err(|e| {
256-
plan_datafusion_err!(
257-
"Cannot get result type for null arithmetic {coerced} {} {coerced}: {e}", self.op
258-
)
259-
})?;
260-
Ok(Signature {
261-
lhs: coerced.clone(),
262-
rhs: coerced,
263-
ret,
264-
})
265288
} else {
266289
plan_err!(
267290
"Cannot coerce arithmetic expression {} {} {} to valid types", self.lhs, self.op, self.rhs

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2432,7 +2432,7 @@ mod tests {
24322432

24332433
#[test]
24342434
fn test_simplify_multiply_by_null() {
2435-
let null = Expr::Literal(ScalarValue::Null, None);
2435+
let null = lit(ScalarValue::Null);
24362436
// A * null --> null
24372437
{
24382438
let expr = col("c2") * null.clone();

0 commit comments

Comments
 (0)