diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 2ca2ed1b572be..4eee16f845b82 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -121,7 +121,7 @@ fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { return Ok(int_log as f64); } } - decimal_to_f64(value, scale).map(|v| v.log(base)) + log_checked(decimal_to_f64(value, scale)?, base) } /// Calculate logarithm for Decimal64 values. @@ -139,7 +139,7 @@ fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { return Ok(int_log as f64); } } - decimal_to_f64(value, scale).map(|v| v.log(base)) + log_checked(decimal_to_f64(value, scale)?, base) } /// Calculate logarithm for Decimal128 values. @@ -157,7 +157,7 @@ fn log_decimal128(value: i128, scale: i8, base: f64) -> Result return Ok(int_log as f64); } } - decimal_to_f64(value, scale).map(|v| v.log(base)) + log_checked(decimal_to_f64(value, scale)?, base) } /// Convert a scaled decimal value to f64. @@ -180,11 +180,22 @@ fn log_decimal256(value: i256, scale: i8, base: f64) -> Result ArrowError::ComputeError(format!("Cannot convert {value} to f64")) })?; let scale_factor = 10f64.powi(scale as i32); - Ok((value_f64 / scale_factor).log(base)) + log_checked(value_f64 / scale_factor, base) } } } +#[inline] +fn log_checked(value: T, base: T) -> Result { + if value < T::zero() { + Err(ArrowError::ComputeError( + "cannot take logarithm of a negative number".to_string(), + )) + } else { + Ok(value.log(base)) + } +} + impl ScalarUDFImpl for LogFunc { fn name(&self) -> &str { "log" @@ -247,27 +258,24 @@ impl ScalarUDFImpl for LogFunc { let value = value.to_array(args.number_rows)?; let output: ArrayRef = match value.data_type() { - DataType::Float16 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } - DataType::Float32 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } - DataType::Float64 => { - calculate_binary_math::( - &value, - &base, - |value, base| Ok(value.log(base)), - )? - } + DataType::Float16 => calculate_binary_math::< + Float16Type, + Float16Type, + Float16Type, + _, + >(&value, &base, log_checked)?, + DataType::Float32 => calculate_binary_math::< + Float32Type, + Float32Type, + Float32Type, + _, + >(&value, &base, log_checked)?, + DataType::Float64 => calculate_binary_math::< + Float64Type, + Float64Type, + Float64Type, + _, + >(&value, &base, log_checked)?, DataType::Decimal32(_, scale) => { calculate_binary_math::( &value, @@ -465,6 +473,32 @@ mod tests { result.expect_err("expected error"); } + #[test] + fn test_log_negative_value() { + let arg_fields = vec![ + Field::new("b", DataType::Float64, false).into(), + Field::new("n", DataType::Float64, false).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![10.0]))), + ColumnarValue::Array(Arc::new(Float64Array::from(vec![-1.0]))), + ], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let error = LogFunc::new().invoke_with_args(args).unwrap_err(); + assert!( + error + .to_string() + .contains("cannot take logarithm of a negative number"), + "{error}" + ); + } + #[test] fn test_log_scalar_f32_unary() { let arg_field = Field::new("a", DataType::Float32, false).into(); @@ -675,7 +709,7 @@ mod tests { 2.0, 2.0, 3.0, 5.0, 5.0, ]))), // base ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 8.0, 4.0, 81.0, 625.0, -123.0, + 8.0, 4.0, 81.0, 625.0, 125.0, ]))), // num ], arg_fields, @@ -697,7 +731,7 @@ mod tests { assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 4.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); - assert!(floats.value(4).is_nan()); + assert!((floats.value(4) - 3.0).abs() < 1e-10); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") @@ -957,13 +991,13 @@ mod tests { let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new( - Decimal128Array::from(vec![10, 100, 1000, 10000, 12600, -123]) + Decimal128Array::from(vec![10, 100, 1000, 10000, 12600]) .with_precision_and_scale(38, 0) .unwrap(), )), // num ], arg_fields: vec![arg_field], - number_rows: 6, + number_rows: 5, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; @@ -976,14 +1010,13 @@ mod tests { let floats = as_float64_array(&arr) .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 6); + assert_eq!(floats.len(), 5); assert!((floats.value(0) - 1.0).abs() < 1e-10); assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 3.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); let expected = 12600_f64.log(10.0); assert!((floats.value(4) - expected).abs() < 1e-10); - assert!(floats.value(5).is_nan()); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") @@ -1090,15 +1123,13 @@ mod tests { Some(i256::from(12600)), // Slightly lower than i128 max - can calculate Some(i256::from_i128(i128::MAX) - i256::from(1000)), - // Give NaN for incorrect inputs, as in f64::log - Some(i256::from(-123)), ]) .with_precision_and_scale(DECIMAL256_MAX_PRECISION, 0) .unwrap(), )), // num ], arg_fields: vec![arg_field], - number_rows: 7, + number_rows: 6, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; @@ -1111,7 +1142,7 @@ mod tests { let floats = as_float64_array(&arr) .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 7); + assert_eq!(floats.len(), 6); assert!((floats.value(0) - 1.0).abs() < 1e-10); assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 3.0).abs() < 1e-10); @@ -1120,7 +1151,6 @@ mod tests { assert!((floats.value(4) - expected).abs() < 1e-10); let expected = ((i128::MAX - 1000) as f64).log(10.0); assert!((floats.value(5) - expected).abs() < 1e-10); - assert!(floats.value(6).is_nan()); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 1754ccb43488a..4a58c146c9d6e 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -51,6 +51,14 @@ fn validate_sqrt_input(value: f64) -> Result<()> { } } +fn validate_log_input(value: f64) -> Result<()> { + if value < 0.0 { + exec_err!("cannot take logarithm of a negative number") + } else { + Ok(()) + } +} + // Create UDFs make_udf_function!(abs::AbsFunc, abs); make_math_unary_udf!( @@ -163,7 +171,8 @@ make_math_unary_udf!( ln, super::ln_order, super::bounds::unbounded_bounds, - super::get_ln_doc + super::get_ln_doc, + Some(super::validate_log_input) ); make_math_unary_udf!( Log2Func, @@ -171,7 +180,8 @@ make_math_unary_udf!( log2, super::log2_order, super::bounds::unbounded_bounds, - super::get_log2_doc + super::get_log2_doc, + Some(super::validate_log_input) ); make_math_unary_udf!( Log10Func, @@ -179,7 +189,8 @@ make_math_unary_udf!( log10, super::log10_order, super::bounds::unbounded_bounds, - super::get_log10_doc + super::get_log10_doc, + Some(super::validate_log_input) ); make_udf_function!(nanvl::NanvlFunc, nanvl); make_udf_function!(pi::PiFunc, pi); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 9dbf8f16d85ab..eccdb522af96f 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -600,14 +600,18 @@ select ln(0); ---- -Infinity +# ln scalar ops with negative edgecases +statement error cannot take logarithm of a negative number +select ln(-1.0); + # ln with columns (round is needed to normalize the outputs of different operating systems) query RRR rowsort -select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from signed_integers; +select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from unsigned_integers; ---- -0.69315 NaN 4.81218 +0 4.60517 6.34036 +0.69315 6.90776 4.81218 +1.09861 9.21034 6.88551 1.38629 NULL NULL -NaN 4.60517 NaN -NaN 9.21034 NaN ## log @@ -649,6 +653,13 @@ select log(0) a, log(1, 64) b; ---- -Infinity Infinity +# log scalar ops with negative edgecases +statement error cannot take logarithm of a negative number +select log(-1.0); + +statement error cannot take logarithm of a negative number +select log(2, -1.0); + # log with columns #1 query RRR rowsort select log(a, 64) a, log(b), log(10, b) from unsigned_integers; @@ -660,12 +671,12 @@ Infinity 2 2 # log with columns #2 query RRR rowsort -select log(a, 64) a, log(b), log(10, b) from signed_integers; +select log(2, abs(b)) a, log(abs(b)), log(10, abs(b)) from signed_integers; ---- -3 NULL NULL -6 NaN NaN -NaN 2 2 -NaN 4 4 +13.287712379549 4 4 +6.643856189775 2 2 +9.965784284662 3 3 +NULL NULL NULL # log overloaded base 10 float64 and float32 casting scalar query RR rowsort @@ -675,12 +686,12 @@ select log(arrow_cast(10, 'Float64')) a ,log(arrow_cast(100, 'Float32')) b; # log overloaded base 10 float64 and float32 casting with columns query RR rowsort -select log(arrow_cast(a, 'Float64')), log(arrow_cast(b, 'Float32')) from signed_integers; +select log(arrow_cast(abs(a), 'Float64')), log(arrow_cast(abs(b), 'Float32')) from signed_integers; ---- -0.301029995664 NaN +0 2 +0.301029995664 3 +0.47712125472 4 0.602059991328 NULL -NaN 2 -NaN 4 # log float64 and float32 casting scalar query RR rowsort @@ -690,12 +701,12 @@ select log(2,arrow_cast(8, 'Float64')) a, log(2,arrow_cast(16, 'Float32')) b; # log float64 and float32 casting with columns query RR rowsort -select log(2,arrow_cast(a, 'Float64')), log(4,arrow_cast(b, 'Float32')) from signed_integers; +select log(2,arrow_cast(abs(a), 'Float64')), log(4,arrow_cast(abs(b), 'Float32')) from signed_integers; ---- -1 NaN +0 3.321928 +1 4.982892 +1.584962500721 6.643856 2 NULL -NaN 3.321928 -NaN 6.643856 ## log10 @@ -725,14 +736,18 @@ select log10(0); ---- -Infinity +# log10 scalar ops with negative edgecases +statement error cannot take logarithm of a negative number +select log10(-1.0); + # log10 with columns (round is needed to normalize the outputs of different operating systems) query RRR rowsort -select round(log(a), 5), round(log(b), 5), round(log(c), 5) from signed_integers; +select round(log(a), 5), round(log(b), 5), round(log(c), 5) from unsigned_integers; ---- -0.30103 NaN 2.08991 +0 2 2.75358 +0.30103 3 2.08991 +0.47712 4 2.99034 0.60206 NULL NULL -NaN 2 NaN -NaN 4 NaN ## log2 @@ -761,14 +776,18 @@ select log2(0); ---- -Infinity +# log2 scalar ops with negative edgecases +statement error cannot take logarithm of a negative number +select log2(-1.0); + # log2 with columns (round is needed to normalize the outputs of different operating systems) query RRR rowsort -select round(log2(a), 5), round(log2(b), 5), round(log2(c), 5) from signed_integers; +select round(log2(a), 5), round(log2(b), 5), round(log2(c), 5) from unsigned_integers; ---- -1 NaN 6.94251 +0 6.64386 9.1472 +1 9.96578 6.94251 +1.58496 13.28771 9.93369 2 NULL NULL -NaN 13.28771 NaN -NaN 6.64386 NaN ## nanvl