diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 71528b4d16bf0..1be40b4055f40 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -309,6 +309,142 @@ macro_rules! make_math_unary_udf { }; } +/// Macro to create a unary logarithm UDF that rejects negative inputs. +/// +/// A unary logarithm function takes an argument of type Float32 or Float64, +/// applies a unary floating function to the argument, and returns a value of the same type. +/// +/// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` +/// $NAME: the name of the function +/// $UNARY_FUNC: the unary function to apply to the argument +/// $OUTPUT_ORDERING: the output ordering calculation method of the function +/// $GET_DOC: the function to get the documentation of the UDF +macro_rules! make_checked_log_unary_udf { + ($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { + $crate::make_udf_function!($NAME::$UDF, $NAME); + + mod $NAME { + + use std::sync::Arc; + + use arrow::array::{Array, ArrayRef, AsArray, Float32Array, Float64Array}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use arrow::error::ArrowError; + use datafusion_common::{Result, exec_err}; + use datafusion_expr::interval_arithmetic::Interval; + use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; + use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, + }; + + #[derive(Debug, PartialEq, Eq, Hash)] + pub struct $UDF { + signature: Signature, + } + + impl $UDF { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } + } + + impl ScalarUDFImpl for $UDF { + fn name(&self) -> &str { + stringify!($NAME) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float32 => Ok(DataType::Float32), + // For other types (possible values float64/null/int), use Float64 + _ => Ok(DataType::Float64), + } + } + + fn output_ordering( + &self, + input: &[ExprProperties], + ) -> Result { + $OUTPUT_ORDERING(input) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + $EVALUATE_BOUNDS(inputs) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => { + let result: Float64Array = args[0] + .as_primitive::() + .try_unary(checked_log_f64)?; + Arc::new(result) as ArrayRef + } + DataType::Float32 => { + let result: Float32Array = args[0] + .as_primitive::() + .try_unary(checked_log_f32)?; + Arc::new(result) as ArrayRef + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some($GET_DOC()) + } + } + + fn checked_log_f64(value: f64) -> Result { + if value < 0.0 { + Err(negative_log_error()) + } else { + Ok(f64::$UNARY_FUNC(value)) + } + } + + fn checked_log_f32(value: f32) -> Result { + if value < 0.0 { + Err(negative_log_error()) + } else { + Ok(f32::$UNARY_FUNC(value)) + } + } + + fn negative_log_error() -> ArrowError { + ArrowError::ComputeError( + "Cannot take logarithm of a negative number".to_string(), + ) + } + } + }; +} + /// Macro to create a binary math UDF. /// /// A binary math function takes two arguments of types Float32 or Float64, diff --git a/datafusion/functions/src/math/ln.rs b/datafusion/functions/src/math/ln.rs new file mode 100644 index 0000000000000..24079d27181cc --- /dev/null +++ b/datafusion/functions/src/math/ln.rs @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use arrow::error::ArrowError; +use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; + +use super::{bounds::unbounded_bounds, get_ln_doc, ln_order}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct LnFunc { + signature: Signature, +} + +impl Default for LnFunc { + fn default() -> Self { + Self::new() + } +} + +impl LnFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LnFunc { + fn name(&self) -> &str { + "ln" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + ln_order(input) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + unbounded_bounds(inputs) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [arg] = take_function_args(self.name(), args)?; + + let arr: ArrayRef = match arg.data_type() { + DataType::Float64 => { + let result: Float64Array = arg + .as_primitive::() + .try_unary(checked_ln_f64)?; + Arc::new(result) as ArrayRef + } + DataType::Float32 => { + let result: Float32Array = arg + .as_primitive::() + .try_unary(checked_ln_f32)?; + Arc::new(result) as ArrayRef + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ln_doc()) + } +} + +fn checked_ln_f64(value: f64) -> Result { + if value < 0.0 { + Err(ArrowError::ComputeError( + "Cannot take logarithm of a negative number".to_string(), + )) + } else { + Ok(value.ln()) + } +} + +fn checked_ln_f32(value: f32) -> Result { + if value < 0.0 { + Err(ArrowError::ComputeError( + "Cannot take logarithm of a negative number".to_string(), + )) + } else { + Ok(value.ln()) + } +} diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index ac94f78e0c723..39785522a6e21 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -106,10 +106,27 @@ fn is_valid_integer_base(base: f64) -> bool { base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64 } +#[inline] +fn negative_log_error() -> ArrowError { + ArrowError::ComputeError("Cannot take logarithm of a negative number".to_string()) +} + +#[inline] +fn checked_log(value: T, base: T) -> Result { + if value < T::zero() { + Err(negative_log_error()) + } else { + Ok(value.log(base)) + } +} + /// Calculate logarithm for Decimal32 values. /// For integer bases >= 2 with zero scale, return an exact integer log when the /// value is a perfect power of the base. Otherwise falls back to f64 computation. fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { + if value < 0 { + return Err(negative_log_error()); + } if scale == 0 && is_valid_integer_base(base) && let Ok(unscaled) = u32::try_from(value) @@ -128,6 +145,9 @@ fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { /// For integer bases >= 2 with zero scale, return an exact integer log when the /// value is a perfect power of the base. Otherwise falls back to f64 computation. fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { + if value < 0 { + return Err(negative_log_error()); + } if scale == 0 && is_valid_integer_base(base) && let Ok(unscaled) = u64::try_from(value) @@ -146,6 +166,9 @@ fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { /// For integer bases >= 2 with zero scale, return an exact integer log when the /// value is a perfect power of the base. Otherwise falls back to f64 computation. fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { + if value < 0 { + return Err(negative_log_error()); + } if scale == 0 && is_valid_integer_base(base) && let Ok(unscaled) = u128::try_from(value) @@ -171,6 +194,9 @@ fn decimal_to_f64(value: T, scale: i8) -> Result Result { + if value < i256::ZERO { + return Err(negative_log_error()); + } // Try to convert to i128 for the optimized path match value.to_i128() { Some(v) => log_decimal128(v, scale, base), @@ -247,27 +273,24 @@ impl ScalarUDFImpl for LogFunc { let value = value.to_array(args.number_rows)?; let output: ArrayRef = match value.data_type() { - DataType::Float16 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } - DataType::Float32 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } - DataType::Float64 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } + DataType::Float16 => calculate_binary_math::< + Float16Type, + Float16Type, + Float16Type, + _, + >(&value, &base, checked_log)?, + DataType::Float32 => calculate_binary_math::< + Float32Type, + Float32Type, + Float32Type, + _, + >(&value, &base, checked_log)?, + DataType::Float64 => calculate_binary_math::< + Float64Type, + Float64Type, + Float64Type, + _, + >(&value, &base, checked_log)?, DataType::Decimal32(_, scale) => { calculate_binary_math::( &value, @@ -672,14 +695,14 @@ mod tests { let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 2.0, 2.0, 3.0, 5.0, 5.0, + 2.0, 2.0, 3.0, 5.0, ]))), // base ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 8.0, 4.0, 81.0, 625.0, -123.0, + 8.0, 4.0, 81.0, 625.0, ]))), // num ], arg_fields, - number_rows: 5, + number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; @@ -692,12 +715,11 @@ mod tests { let floats = as_float64_array(&arr) .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 5); + assert_eq!(floats.len(), 4); assert!((floats.value(0) - 3.0).abs() < 1e-10); assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 4.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); - assert!(floats.value(4).is_nan()); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") @@ -705,6 +727,26 @@ mod tests { } } + #[test] + fn test_log_f64_rejects_negative_input() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("a", DataType::Float64, false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0]))), // base + ColumnarValue::Array(Arc::new(Float64Array::from(vec![-123.0]))), // num + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + assert_negative_log_error(LogFunc::new().invoke_with_args(args)); + } + #[test] fn test_log_f32() { let arg_fields = vec![ @@ -957,13 +999,13 @@ mod tests { let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new( - Decimal128Array::from(vec![10, 100, 1000, 10000, 12600, -123]) + Decimal128Array::from(vec![10, 100, 1000, 10000, 12600]) .with_precision_and_scale(38, 0) .unwrap(), )), // num ], arg_fields: vec![arg_field], - number_rows: 6, + number_rows: 5, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; @@ -976,14 +1018,13 @@ mod tests { let floats = as_float64_array(&arr) .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 6); + assert_eq!(floats.len(), 5); assert!((floats.value(0) - 1.0).abs() < 1e-10); assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 3.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); let expected = 12600_f64.log(10.0); assert!((floats.value(4) - expected).abs() < 1e-10); - assert!(floats.value(5).is_nan()); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") @@ -991,6 +1032,11 @@ mod tests { } } + #[test] + fn test_log_decimal128_rejects_negative_input() { + assert_negative_log_error(log_decimal128(-123, 0, 10.0)); + } + #[test] fn test_log_decimal128_base_decimal() { // Base stays 2 despite scaling @@ -1090,15 +1136,13 @@ mod tests { Some(i256::from(12600)), // Slightly lower than i128 max - can calculate Some(i256::from_i128(i128::MAX) - i256::from(1000)), - // Give NaN for incorrect inputs, as in f64::log - Some(i256::from(-123)), ]) .with_precision_and_scale(DECIMAL256_MAX_PRECISION, 0) .unwrap(), )), // num ], arg_fields: vec![arg_field], - number_rows: 7, + number_rows: 6, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; @@ -1111,7 +1155,7 @@ mod tests { let floats = as_float64_array(&arr) .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 7); + assert_eq!(floats.len(), 6); assert!((floats.value(0) - 1.0).abs() < 1e-10); assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 3.0).abs() < 1e-10); @@ -1120,7 +1164,6 @@ mod tests { assert!((floats.value(4) - expected).abs() < 1e-10); let expected = ((i128::MAX - 1000) as f64).log(10.0); assert!((floats.value(5) - expected).abs() < 1e-10); - assert!(floats.value(6).is_nan()); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") @@ -1128,6 +1171,11 @@ mod tests { } } + #[test] + fn test_log_decimal256_rejects_negative_input() { + assert_negative_log_error(log_decimal256(i256::from(-123), 0, 10.0)); + } + #[test] fn test_log_decimal128_invalid_base() { // Invalid base (-2.0) should return NaN, matching f64::log behavior @@ -1200,4 +1248,16 @@ mod tests { } } } + + fn assert_negative_log_error( + result: std::result::Result, + ) { + let error = result.expect_err("expected negative logarithm error"); + assert!( + error + .to_string() + .contains("Cannot take logarithm of a negative number"), + "{error}" + ); + } } diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 610e773d68fd0..de9518d5aa023 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -31,6 +31,7 @@ pub mod floor; pub mod gcd; pub mod iszero; pub mod lcm; +pub mod ln; pub mod log; pub mod monotonicity; pub mod nans; @@ -148,15 +149,8 @@ make_udf_function!(gcd::GcdFunc, gcd); make_udf_function!(nans::IsNanFunc, isnan); make_udf_function!(iszero::IsZeroFunc, iszero); make_udf_function!(lcm::LcmFunc, lcm); -make_math_unary_udf!( - LnFunc, - ln, - ln, - super::ln_order, - super::bounds::unbounded_bounds, - super::get_ln_doc -); -make_math_unary_udf!( +make_udf_function!(ln::LnFunc, ln); +make_checked_log_unary_udf!( Log2Func, log2, log2, @@ -164,7 +158,7 @@ make_math_unary_udf!( super::bounds::unbounded_bounds, super::get_log2_doc ); -make_math_unary_udf!( +make_checked_log_unary_udf!( Log10Func, log10, log10, diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 89ae30e3c047b..64eb9e77a41ac 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -581,14 +581,26 @@ select ln(0); ---- -Infinity -# ln with columns (round is needed to normalize the outputs of different operating systems) +# ln with positive column values (round is needed to normalize the outputs of different operating systems) query RRR rowsort -select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from signed_integers; +select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from unsigned_integers; ---- -0.69315 NaN 4.81218 +0 4.60517 6.34036 +0.69315 6.90776 4.81218 +1.09861 9.21034 6.88551 1.38629 NULL NULL -NaN 4.60517 NaN -NaN 9.21034 NaN + +# ln rejects negative scalar inputs +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select ln((-1.0)::float8); + +# ln rejects negative Float32 scalar inputs +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select ln((-1.0)::float4); + +# ln rejects negative column inputs +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select round(ln(a), 5) from signed_integers; ## log @@ -630,6 +642,14 @@ select log(0) a, log(1, 64) b; ---- -Infinity Infinity +# log rejects negative scalar inputs +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select log((-1.0)::float8); + +# log rejects negative Float32 scalar inputs +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select log((-1.0)::float4); + # log with columns #1 query RRR rowsort select log(a, 64) a, log(b), log(10, b) from unsigned_integers; @@ -640,13 +660,8 @@ select log(a, 64) a, log(b), log(10, b) from unsigned_integers; Infinity 2 2 # log with columns #2 -query RRR rowsort +query error Arrow error: Compute error: Cannot take logarithm of a negative number select log(a, 64) a, log(b), log(10, b) from signed_integers; ----- -3 NULL NULL -6 NaN NaN -NaN 2 2 -NaN 4 4 # log overloaded base 10 float64 and float32 casting scalar query RR rowsort @@ -655,13 +670,8 @@ select log(arrow_cast(10, 'Float64')) a ,log(arrow_cast(100, 'Float32')) b; 1 2 # log overloaded base 10 float64 and float32 casting with columns -query RR rowsort +query error Arrow error: Compute error: Cannot take logarithm of a negative number select log(arrow_cast(a, 'Float64')), log(arrow_cast(b, 'Float32')) from signed_integers; ----- -0.301029995664 NaN -0.602059991328 NULL -NaN 2 -NaN 4 # log float64 and float32 casting scalar query RR rowsort @@ -670,13 +680,8 @@ select log(2,arrow_cast(8, 'Float64')) a, log(2,arrow_cast(16, 'Float32')) b; 3 4 # log float64 and float32 casting with columns -query RR rowsort +query error Arrow error: Compute error: Cannot take logarithm of a negative number select log(2,arrow_cast(a, 'Float64')), log(4,arrow_cast(b, 'Float32')) from signed_integers; ----- -1 NaN -2 NULL -NaN 3.321928 -NaN 6.643856 ## log10 @@ -706,14 +711,17 @@ select log10(0); ---- -Infinity +# log10 rejects negative scalar inputs +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select log10((-1.0)::float8); + +# log10 rejects negative Float32 scalar inputs +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select log10((-1.0)::float4); + # log10 with columns (round is needed to normalize the outputs of different operating systems) -query RRR rowsort -select round(log(a), 5), round(log(b), 5), round(log(c), 5) from signed_integers; ----- -0.30103 NaN 2.08991 -0.60206 NULL NULL -NaN 2 NaN -NaN 4 NaN +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select round(log10(a), 5), round(log10(b), 5), round(log10(c), 5) from signed_integers; ## log2 @@ -742,14 +750,17 @@ select log2(0); ---- -Infinity +# log2 rejects negative scalar inputs +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select log2((-1.0)::float8); + +# log2 rejects negative Float32 scalar inputs +query error Arrow error: Compute error: Cannot take logarithm of a negative number +select log2((-1.0)::float4); + # log2 with columns (round is needed to normalize the outputs of different operating systems) -query RRR rowsort +query error Arrow error: Compute error: Cannot take logarithm of a negative number select round(log2(a), 5), round(log2(b), 5), round(log2(c), 5) from signed_integers; ----- -1 NaN 6.94251 -2 NULL NULL -NaN 13.28771 NaN -NaN 6.64386 NaN ## nanvl