Skip to content

Commit 7bcb613

Browse files
xiedeyantuCopilot
andauthored
fix sqrt(-1.0::float8) should error, not return NaN (apache#22308)
## Which issue does this PR close? - Closes apache#22260. ## Rationale for this change DataFusion previously returned `NaN` for `sqrt` on negative floating-point inputs, for example `sqrt((-1.0)::float8)`. This differs from PostgreSQL semantics, which raise an error for square root of a negative number. This change makes `sqrt` return an execution error for out-of-domain negative inputs so its behavior is closer to PostgreSQL and avoids silently producing `NaN` for invalid inputs. ## What changes are included in this PR? - Updated the unary math UDF helper to support an optional validator callback for runtime input validation. - Switched `sqrt` to use a named validator helper instead of inline predicate and error-string arguments. - Added runtime validation for `sqrt` so negative inputs now raise `cannot take square root of a negative number`. - Updated sqllogictests for `sqrt`: - negative literal inputs now expect an error - negative column inputs now expect an error - positive column coverage was retained using in-domain inputs ## Are these changes tested? Yes. The change is covered by existing SQL logic tests and targeted validation runs: - `cargo test -p datafusion-functions sqrt` - `cargo test -p datafusion-sqllogictest --test sqllogictests scalar` ## Are there any user-facing changes? Yes. `sqrt` now raises an execution error for negative inputs instead of returning `NaN`. This changes user-visible query behavior to better align with PostgreSQL semantics. --------- Co-authored-by: Copilot <copilot@github.com>
1 parent 7ad8e2c commit 7bcb613

3 files changed

Lines changed: 66 additions & 16 deletions

File tree

datafusion/functions/src/macros.rs

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,17 @@ macro_rules! downcast_arg {
210210
/// $GET_DOC: the function to get the documentation of the UDF
211211
macro_rules! make_math_unary_udf {
212212
($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => {
213+
make_math_unary_udf!(
214+
$UDF,
215+
$NAME,
216+
$UNARY_FUNC,
217+
$OUTPUT_ORDERING,
218+
$EVALUATE_BOUNDS,
219+
$GET_DOC,
220+
None::<fn(f64) -> Result<()>>
221+
);
222+
};
223+
($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr, $VALIDATOR:expr) => {
213224
$crate::make_udf_function!($NAME::$UDF, $NAME);
214225

215226
mod $NAME {
@@ -218,6 +229,7 @@ macro_rules! make_math_unary_udf {
218229

219230
use arrow::array::{ArrayRef, AsArray};
220231
use arrow::datatypes::{DataType, Float32Type, Float64Type};
232+
use arrow::error::ArrowError;
221233
use datafusion_common::{Result, exec_err};
222234
use datafusion_expr::interval_arithmetic::Interval;
223235
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
@@ -280,16 +292,38 @@ macro_rules! make_math_unary_udf {
280292
) -> Result<ColumnarValue> {
281293
let args = ColumnarValue::values_to_arrays(&args.args)?;
282294
let arr: ArrayRef = match args[0].data_type() {
283-
DataType::Float64 => Arc::new(
284-
args[0]
295+
DataType::Float64 => {
296+
let values = args[0]
285297
.as_primitive::<Float64Type>()
286-
.unary::<_, Float64Type>(|x: f64| f64::$UNARY_FUNC(x)),
287-
) as ArrayRef,
288-
DataType::Float32 => Arc::new(
289-
args[0]
298+
.try_unary::<_, Float64Type, _>(
299+
|x: f64| -> std::result::Result<f64, ArrowError> {
300+
if let Some(validate) = $VALIDATOR {
301+
validate(x).map_err(|error| {
302+
ArrowError::ComputeError(error.to_string())
303+
})?;
304+
}
305+
306+
Ok(f64::$UNARY_FUNC(x))
307+
},
308+
)?;
309+
Arc::new(values) as ArrayRef
310+
}
311+
DataType::Float32 => {
312+
let values = args[0]
290313
.as_primitive::<Float32Type>()
291-
.unary::<_, Float32Type>(|x: f32| f32::$UNARY_FUNC(x)),
292-
) as ArrayRef,
314+
.try_unary::<_, Float32Type, _>(
315+
|x: f32| -> std::result::Result<f32, ArrowError> {
316+
if let Some(validate) = $VALIDATOR {
317+
validate(x as f64).map_err(|error| {
318+
ArrowError::ComputeError(error.to_string())
319+
})?;
320+
}
321+
322+
Ok(f32::$UNARY_FUNC(x))
323+
},
324+
)?;
325+
Arc::new(values) as ArrayRef
326+
}
293327
other => {
294328
return exec_err!(
295329
"Unsupported data type {other:?} for function {}",

datafusion/functions/src/math/mod.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//! "math" DataFusion functions
1919
2020
use crate::math::monotonicity::*;
21+
use datafusion_common::{Result, exec_err};
2122
use datafusion_expr::ScalarUDF;
2223
use std::sync::Arc;
2324

@@ -42,6 +43,14 @@ pub mod round;
4243
pub mod signum;
4344
pub mod trunc;
4445

46+
fn validate_sqrt_input(value: f64) -> Result<()> {
47+
if value < 0.0 {
48+
exec_err!("cannot take square root of a negative number")
49+
} else {
50+
Ok(())
51+
}
52+
}
53+
4554
// Create UDFs
4655
make_udf_function!(abs::AbsFunc, abs);
4756
make_math_unary_udf!(
@@ -208,7 +217,8 @@ make_math_unary_udf!(
208217
sqrt,
209218
super::sqrt_order,
210219
super::bounds::sqrt_bounds,
211-
super::get_sqrt_doc
220+
super::get_sqrt_doc,
221+
Some(super::validate_sqrt_input)
212222
);
213223
make_math_unary_udf!(
214224
TanFunc,

datafusion/sqllogictest/test_files/scalar.slt

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,12 +1107,16 @@ NULL
11071107

11081108
# sqrt with columns (round is needed to normalize the outputs of different operating systems)
11091109
query RRR rowsort
1110-
select round(sqrt(a), 5), round(sqrt(b), 5), round(sqrt(c), 5) from signed_integers;
1110+
select round(sqrt(abs(a)), 5), round(sqrt(abs(b)), 5), round(sqrt(abs(c)), 5) from signed_integers;
11111111
----
1112-
1.41421 NaN 11.09054
1112+
1 10 23.81176
1113+
1.41421 31.62278 11.09054
1114+
1.73205 100 31.27299
11131115
2 NULL NULL
1114-
NaN 10 NaN
1115-
NaN 100 NaN
1116+
1117+
# sqrt with negative column values should error
1118+
query error cannot take square root of a negative number
1119+
select round(sqrt(a), 5), round(sqrt(b), 5), round(sqrt(c), 5) from signed_integers;
11161120

11171121
# sqrt scalar fraction
11181122
query RR rowsort
@@ -1128,10 +1132,12 @@ select sqrt(cast(10e8 as double));
11281132

11291133

11301134
# sqrt scalar negative
1131-
query R rowsort
1135+
query error cannot take square root of a negative number
11321136
select sqrt(-1);
1133-
----
1134-
NaN
1137+
1138+
# sqrt scalar negative float8
1139+
query error cannot take square root of a negative number
1140+
select sqrt((-1.0)::float8);
11351141

11361142
## tan
11371143

0 commit comments

Comments
 (0)