Skip to content

Commit 6586aed

Browse files
committed
fix: reject negative ln inputs
1 parent 66f82af commit 6586aed

3 files changed

Lines changed: 147 additions & 15 deletions

File tree

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array};
21+
use arrow::datatypes::{DataType, Float32Type, Float64Type};
22+
use arrow::error::ArrowError;
23+
use datafusion_common::{Result, exec_err, utils::take_function_args};
24+
use datafusion_expr::interval_arithmetic::Interval;
25+
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
26+
use datafusion_expr::{
27+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28+
Volatility,
29+
};
30+
31+
use super::{bounds::unbounded_bounds, get_ln_doc, ln_order};
32+
33+
#[derive(Debug, PartialEq, Eq, Hash)]
34+
pub struct LnFunc {
35+
signature: Signature,
36+
}
37+
38+
impl Default for LnFunc {
39+
fn default() -> Self {
40+
Self::new()
41+
}
42+
}
43+
44+
impl LnFunc {
45+
pub fn new() -> Self {
46+
use DataType::*;
47+
Self {
48+
signature: Signature::uniform(
49+
1,
50+
vec![Float64, Float32],
51+
Volatility::Immutable,
52+
),
53+
}
54+
}
55+
}
56+
57+
impl ScalarUDFImpl for LnFunc {
58+
fn name(&self) -> &str {
59+
"ln"
60+
}
61+
62+
fn signature(&self) -> &Signature {
63+
&self.signature
64+
}
65+
66+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
67+
match &arg_types[0] {
68+
DataType::Float32 => Ok(DataType::Float32),
69+
_ => Ok(DataType::Float64),
70+
}
71+
}
72+
73+
fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
74+
ln_order(input)
75+
}
76+
77+
fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
78+
unbounded_bounds(inputs)
79+
}
80+
81+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
82+
let args = ColumnarValue::values_to_arrays(&args.args)?;
83+
let [arg] = take_function_args(self.name(), args)?;
84+
85+
let arr: ArrayRef = match arg.data_type() {
86+
DataType::Float64 => {
87+
let result: Float64Array = arg
88+
.as_primitive::<Float64Type>()
89+
.try_unary(checked_ln_f64)?;
90+
Arc::new(result) as ArrayRef
91+
}
92+
DataType::Float32 => {
93+
let result: Float32Array = arg
94+
.as_primitive::<Float32Type>()
95+
.try_unary(checked_ln_f32)?;
96+
Arc::new(result) as ArrayRef
97+
}
98+
other => {
99+
return exec_err!(
100+
"Unsupported data type {other:?} for function {}",
101+
self.name()
102+
);
103+
}
104+
};
105+
106+
Ok(ColumnarValue::Array(arr))
107+
}
108+
109+
fn documentation(&self) -> Option<&Documentation> {
110+
Some(get_ln_doc())
111+
}
112+
}
113+
114+
fn checked_ln_f64(value: f64) -> Result<f64, ArrowError> {
115+
if value < 0.0 {
116+
Err(ArrowError::ComputeError(
117+
"Cannot take logarithm of a negative number".to_string(),
118+
))
119+
} else {
120+
Ok(value.ln())
121+
}
122+
}
123+
124+
fn checked_ln_f32(value: f32) -> Result<f32, ArrowError> {
125+
if value < 0.0 {
126+
Err(ArrowError::ComputeError(
127+
"Cannot take logarithm of a negative number".to_string(),
128+
))
129+
} else {
130+
Ok(value.ln())
131+
}
132+
}

datafusion/functions/src/math/mod.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub mod floor;
3131
pub mod gcd;
3232
pub mod iszero;
3333
pub mod lcm;
34+
pub mod ln;
3435
pub mod log;
3536
pub mod monotonicity;
3637
pub mod nans;
@@ -148,14 +149,7 @@ make_udf_function!(gcd::GcdFunc, gcd);
148149
make_udf_function!(nans::IsNanFunc, isnan);
149150
make_udf_function!(iszero::IsZeroFunc, iszero);
150151
make_udf_function!(lcm::LcmFunc, lcm);
151-
make_math_unary_udf!(
152-
LnFunc,
153-
ln,
154-
ln,
155-
super::ln_order,
156-
super::bounds::unbounded_bounds,
157-
super::get_ln_doc
158-
);
152+
make_udf_function!(ln::LnFunc, ln);
159153
make_math_unary_udf!(
160154
Log2Func,
161155
log2,

datafusion/sqllogictest/test_files/scalar.slt

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -581,14 +581,20 @@ select ln(0);
581581
----
582582
-Infinity
583583

584-
# ln with columns (round is needed to normalize the outputs of different operating systems)
585-
query RRR rowsort
586-
select round(ln(a), 5), round(ln(b), 5), round(ln(c), 5) from signed_integers;
584+
# ln with positive column values (round is needed to normalize the outputs of different operating systems)
585+
query R rowsort
586+
select round(ln(a), 5) from signed_integers where a > 0;
587587
----
588-
0.69315 NaN 4.81218
589-
1.38629 NULL NULL
590-
NaN 4.60517 NaN
591-
NaN 9.21034 NaN
588+
0.69315
589+
1.38629
590+
591+
# ln rejects negative scalar inputs
592+
query error Arrow error: Compute error: Cannot take logarithm of a negative number
593+
select ln((-1.0)::float8);
594+
595+
# ln rejects negative column inputs
596+
query error Arrow error: Compute error: Cannot take logarithm of a negative number
597+
select round(ln(a), 5) from signed_integers;
592598

593599
## log
594600

0 commit comments

Comments
 (0)