Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 66 additions & 36 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> {
return Ok(int_log as f64);
}
}
decimal_to_f64(value, scale).map(|v| v.log(base))
log_checked(decimal_to_f64(value, scale)?, base)
}

/// Calculate logarithm for Decimal64 values.
Expand All @@ -139,7 +139,7 @@ fn log_decimal64(value: i64, scale: i8, base: f64) -> Result<f64, ArrowError> {
return Ok(int_log as f64);
}
}
decimal_to_f64(value, scale).map(|v| v.log(base))
log_checked(decimal_to_f64(value, scale)?, base)
}

/// Calculate logarithm for Decimal128 values.
Expand All @@ -157,7 +157,7 @@ fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError>
return Ok(int_log as f64);
}
}
decimal_to_f64(value, scale).map(|v| v.log(base))
log_checked(decimal_to_f64(value, scale)?, base)
}

/// Convert a scaled decimal value to f64.
Expand All @@ -180,11 +180,22 @@ fn log_decimal256(value: i256, scale: i8, base: f64) -> Result<f64, ArrowError>
ArrowError::ComputeError(format!("Cannot convert {value} to f64"))
})?;
let scale_factor = 10f64.powi(scale as i32);
Ok((value_f64 / scale_factor).log(base))
log_checked(value_f64 / scale_factor, base)
}
}
}

#[inline]
fn log_checked<T: Float>(value: T, base: T) -> Result<T, ArrowError> {
if value < T::zero() {
Err(ArrowError::ComputeError(
"cannot take logarithm of a negative number".to_string(),
))
} else {
Ok(value.log(base))
}
}

impl ScalarUDFImpl for LogFunc {
fn name(&self) -> &str {
"log"
Expand Down Expand Up @@ -247,27 +258,24 @@ impl ScalarUDFImpl for LogFunc {
let value = value.to_array(args.number_rows)?;

let output: ArrayRef = match value.data_type() {
DataType::Float16 => {
calculate_binary_math::<Float16Type, Float16Type, Float16Type, _>(
&value,
&base,
|value, base| Ok(value.log(base)),
)?
}
DataType::Float32 => {
calculate_binary_math::<Float32Type, Float32Type, Float32Type, _>(
&value,
&base,
|value, base| Ok(value.log(base)),
)?
}
DataType::Float64 => {
calculate_binary_math::<Float64Type, Float64Type, Float64Type, _>(
&value,
&base,
|value, base| Ok(value.log(base)),
)?
}
DataType::Float16 => calculate_binary_math::<
Float16Type,
Float16Type,
Float16Type,
_,
>(&value, &base, log_checked)?,
DataType::Float32 => calculate_binary_math::<
Float32Type,
Float32Type,
Float32Type,
_,
>(&value, &base, log_checked)?,
DataType::Float64 => calculate_binary_math::<
Float64Type,
Float64Type,
Float64Type,
_,
>(&value, &base, log_checked)?,
DataType::Decimal32(_, scale) => {
calculate_binary_math::<Decimal32Type, Float64Type, Float64Type, _>(
&value,
Expand Down Expand Up @@ -465,6 +473,32 @@ mod tests {
result.expect_err("expected error");
}

#[test]
fn test_log_negative_value() {
let arg_fields = vec![
Field::new("b", DataType::Float64, false).into(),
Field::new("n", DataType::Float64, false).into(),
];
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(Float64Array::from(vec![10.0]))),
ColumnarValue::Array(Arc::new(Float64Array::from(vec![-1.0]))),
],
arg_fields,
number_rows: 1,
return_field: Field::new("f", DataType::Float64, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};

let error = LogFunc::new().invoke_with_args(args).unwrap_err();
assert!(
error
.to_string()
.contains("cannot take logarithm of a negative number"),
"{error}"
);
}

#[test]
fn test_log_scalar_f32_unary() {
let arg_field = Field::new("a", DataType::Float32, false).into();
Expand Down Expand Up @@ -675,7 +709,7 @@ mod tests {
2.0, 2.0, 3.0, 5.0, 5.0,
]))), // base
ColumnarValue::Array(Arc::new(Float64Array::from(vec![
8.0, 4.0, 81.0, 625.0, -123.0,
8.0, 4.0, 81.0, 625.0, 125.0,
]))), // num
],
arg_fields,
Expand All @@ -697,7 +731,7 @@ mod tests {
assert!((floats.value(1) - 2.0).abs() < 1e-10);
assert!((floats.value(2) - 4.0).abs() < 1e-10);
assert!((floats.value(3) - 4.0).abs() < 1e-10);
assert!(floats.value(4).is_nan());
assert!((floats.value(4) - 3.0).abs() < 1e-10);
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
Expand Down Expand Up @@ -957,13 +991,13 @@ mod tests {
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(
Decimal128Array::from(vec![10, 100, 1000, 10000, 12600, -123])
Decimal128Array::from(vec![10, 100, 1000, 10000, 12600])
.with_precision_and_scale(38, 0)
.unwrap(),
)), // num
],
arg_fields: vec![arg_field],
number_rows: 6,
number_rows: 5,
return_field: Field::new("f", DataType::Float64, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};
Expand All @@ -976,14 +1010,13 @@ mod tests {
let floats = as_float64_array(&arr)
.expect("failed to convert result to a Float64Array");

assert_eq!(floats.len(), 6);
assert_eq!(floats.len(), 5);
assert!((floats.value(0) - 1.0).abs() < 1e-10);
assert!((floats.value(1) - 2.0).abs() < 1e-10);
assert!((floats.value(2) - 3.0).abs() < 1e-10);
assert!((floats.value(3) - 4.0).abs() < 1e-10);
let expected = 12600_f64.log(10.0);
assert!((floats.value(4) - expected).abs() < 1e-10);
assert!(floats.value(5).is_nan());
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
Expand Down Expand Up @@ -1090,15 +1123,13 @@ mod tests {
Some(i256::from(12600)),
// Slightly lower than i128 max - can calculate
Some(i256::from_i128(i128::MAX) - i256::from(1000)),
// Give NaN for incorrect inputs, as in f64::log
Some(i256::from(-123)),
])
.with_precision_and_scale(DECIMAL256_MAX_PRECISION, 0)
.unwrap(),
)), // num
],
arg_fields: vec![arg_field],
number_rows: 7,
number_rows: 6,
return_field: Field::new("f", DataType::Float64, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};
Expand All @@ -1111,7 +1142,7 @@ mod tests {
let floats = as_float64_array(&arr)
.expect("failed to convert result to a Float64Array");

assert_eq!(floats.len(), 7);
assert_eq!(floats.len(), 6);
assert!((floats.value(0) - 1.0).abs() < 1e-10);
assert!((floats.value(1) - 2.0).abs() < 1e-10);
assert!((floats.value(2) - 3.0).abs() < 1e-10);
Expand All @@ -1120,7 +1151,6 @@ mod tests {
assert!((floats.value(4) - expected).abs() < 1e-10);
let expected = ((i128::MAX - 1000) as f64).log(10.0);
assert!((floats.value(5) - expected).abs() < 1e-10);
assert!(floats.value(6).is_nan());
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
Expand Down
17 changes: 14 additions & 3 deletions datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ fn validate_sqrt_input(value: f64) -> Result<()> {
}
}

fn validate_log_input(value: f64) -> Result<()> {
if value < 0.0 {
exec_err!("cannot take logarithm of a negative number")
} else {
Ok(())
}
}

// Create UDFs
make_udf_function!(abs::AbsFunc, abs);
make_math_unary_udf!(
Expand Down Expand Up @@ -163,23 +171,26 @@ make_math_unary_udf!(
ln,
super::ln_order,
super::bounds::unbounded_bounds,
super::get_ln_doc
super::get_ln_doc,
Some(super::validate_log_input)
);
make_math_unary_udf!(
Log2Func,
log2,
log2,
super::log2_order,
super::bounds::unbounded_bounds,
super::get_log2_doc
super::get_log2_doc,
Some(super::validate_log_input)
);
make_math_unary_udf!(
Log10Func,
log10,
log10,
super::log10_order,
super::bounds::unbounded_bounds,
super::get_log10_doc
super::get_log10_doc,
Some(super::validate_log_input)
);
make_udf_function!(nanvl::NanvlFunc, nanvl);
make_udf_function!(pi::PiFunc, pi);
Expand Down
69 changes: 44 additions & 25 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -600,14 +600,18 @@ select ln(0);
----
-Infinity

# ln scalar ops with negative edgecases
statement error cannot take logarithm of a negative number
select ln(-1.0);

# ln with columns (round is needed to normalize the outputs of different operating systems)
query RRR rowsort
select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from signed_integers;
select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from unsigned_integers;
----
0.69315 NaN 4.81218
0 4.60517 6.34036
0.69315 6.90776 4.81218
1.09861 9.21034 6.88551
1.38629 NULL NULL
NaN 4.60517 NaN
NaN 9.21034 NaN

## log

Expand Down Expand Up @@ -649,6 +653,13 @@ select log(0) a, log(1, 64) b;
----
-Infinity Infinity

# log scalar ops with negative edgecases
statement error cannot take logarithm of a negative number
select log(-1.0);

statement error cannot take logarithm of a negative number
select log(2, -1.0);

# log with columns #1
query RRR rowsort
select log(a, 64) a, log(b), log(10, b) from unsigned_integers;
Expand All @@ -660,12 +671,12 @@ Infinity 2 2

# log with columns #2
query RRR rowsort
select log(a, 64) a, log(b), log(10, b) from signed_integers;
select log(2, abs(b)) a, log(abs(b)), log(10, abs(b)) from signed_integers;
----
3 NULL NULL
6 NaN NaN
NaN 2 2
NaN 4 4
13.287712379549 4 4
6.643856189775 2 2
9.965784284662 3 3
NULL NULL NULL

# log overloaded base 10 float64 and float32 casting scalar
query RR rowsort
Expand All @@ -675,12 +686,12 @@ select log(arrow_cast(10, 'Float64')) a ,log(arrow_cast(100, 'Float32')) b;

# log overloaded base 10 float64 and float32 casting with columns
query RR rowsort
select log(arrow_cast(a, 'Float64')), log(arrow_cast(b, 'Float32')) from signed_integers;
select log(arrow_cast(abs(a), 'Float64')), log(arrow_cast(abs(b), 'Float32')) from signed_integers;
----
0.301029995664 NaN
0 2
0.301029995664 3
0.47712125472 4
0.602059991328 NULL
NaN 2
NaN 4

# log float64 and float32 casting scalar
query RR rowsort
Expand All @@ -690,12 +701,12 @@ select log(2,arrow_cast(8, 'Float64')) a, log(2,arrow_cast(16, 'Float32')) b;

# log float64 and float32 casting with columns
query RR rowsort
select log(2,arrow_cast(a, 'Float64')), log(4,arrow_cast(b, 'Float32')) from signed_integers;
select log(2,arrow_cast(abs(a), 'Float64')), log(4,arrow_cast(abs(b), 'Float32')) from signed_integers;
----
1 NaN
0 3.321928
1 4.982892
1.584962500721 6.643856
2 NULL
NaN 3.321928
NaN 6.643856


## log10
Expand Down Expand Up @@ -725,14 +736,18 @@ select log10(0);
----
-Infinity

# log10 scalar ops with negative edgecases
statement error cannot take logarithm of a negative number
select log10(-1.0);

# log10 with columns (round is needed to normalize the outputs of different operating systems)
query RRR rowsort
select round(log(a), 5), round(log(b), 5), round(log(c), 5) from signed_integers;
select round(log(a), 5), round(log(b), 5), round(log(c), 5) from unsigned_integers;
----
0.30103 NaN 2.08991
0 2 2.75358
0.30103 3 2.08991
0.47712 4 2.99034
0.60206 NULL NULL
NaN 2 NaN
NaN 4 NaN

## log2

Expand Down Expand Up @@ -761,14 +776,18 @@ select log2(0);
----
-Infinity

# log2 scalar ops with negative edgecases
statement error cannot take logarithm of a negative number
select log2(-1.0);

# log2 with columns (round is needed to normalize the outputs of different operating systems)
query RRR rowsort
select round(log2(a), 5), round(log2(b), 5), round(log2(c), 5) from signed_integers;
select round(log2(a), 5), round(log2(b), 5), round(log2(c), 5) from unsigned_integers;
----
1 NaN 6.94251
0 6.64386 9.1472
1 9.96578 6.94251
1.58496 13.28771 9.93369
2 NULL NULL
NaN 13.28771 NaN
NaN 6.64386 NaN

## nanvl

Expand Down
Loading