Skip to content

Commit c3f3b7a

Browse files
authored
fix: Avoid precision loss for atan2 with integer args (#22516)
## Which issue does this PR close? - Closes #22514. ## Rationale for this change `atan2` defined two input signatures: `(Float32, Float32)` and `(Float64, Float64)` (in that order). That meant that integer inputs were coerced into `Float32` values, which lead to incorrect results: `atan2(1, 1000000)` resulted in less precision than `atan2(1.0, 1000000.0)`; the results for the former were also inconsistent with the behavior of `atan2` in Postgres and DuckDB. Fix this by only using the `Float32` path when given two` Float32` inputs; for other inputs, we should use `Float64`. This avoids rounding for large integer inputs (`Float32` has only 24 mantissa bits, so larger integers would get rounded). ## What changes are included in this PR? * Fix `atan2` signature to only take the `Float32` code path for two `Float32` inputs * Update SLT, add new SLT test ## Are these changes tested? Yes, new test added. ## Are there any user-facing changes? Yes: the return type and semantics of `atan2` in some circumstances has changed. `atan2` will now only be computed in `Float32` when passed two `Float32` values. In all other cases, the computation will be done in `Float64` and a `Float64` value will be returned.
1 parent e292f33 commit c3f3b7a

4 files changed

Lines changed: 50 additions & 28 deletions

File tree

datafusion/functions/src/macros.rs

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,10 @@ macro_rules! make_math_unary_udf {
245245

246246
impl $UDF {
247247
pub fn new() -> Self {
248-
use DataType::*;
249248
Self {
250249
signature: Signature::uniform(
251250
1,
252-
vec![Float64, Float32],
251+
vec![DataType::Float64, DataType::Float32],
253252
Volatility::Immutable,
254253
),
255254
}
@@ -270,7 +269,6 @@ macro_rules! make_math_unary_udf {
270269

271270
match arg_type {
272271
DataType::Float32 => Ok(DataType::Float32),
273-
// For other types (possible values float64/null/int), use Float64
274272
_ => Ok(DataType::Float64),
275273
}
276274
}
@@ -345,8 +343,12 @@ macro_rules! make_math_unary_udf {
345343

346344
/// Macro to create a binary math UDF.
347345
///
348-
/// A binary math function takes two arguments of types Float32 or Float64,
349-
/// applies a binary floating function to the argument, and returns a value of the same type.
346+
/// A binary math function takes two numeric arguments. When both arguments are
347+
/// Float32 the function is evaluated in single precision and returns Float32.
348+
/// Any other combination of numeric (or null) argument types is coerced to
349+
/// Float64 and returns Float64; in particular integers are widened to Float64
350+
/// rather than Float32 so that values needing more than 24 bits of mantissa are
351+
/// not silently rounded.
350352
///
351353
/// $UDF: the name of the UDF struct that implements `ScalarUDFImpl`
352354
/// $NAME: the name of the function
@@ -365,7 +367,6 @@ macro_rules! make_math_binary_udf {
365367
use arrow::datatypes::{DataType, Float32Type, Float64Type};
366368
use datafusion_common::utils::take_function_args;
367369
use datafusion_common::{Result, ScalarValue, internal_err};
368-
use datafusion_expr::TypeSignature;
369370
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
370371
use datafusion_expr::{
371372
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl,
@@ -379,13 +380,18 @@ macro_rules! make_math_binary_udf {
379380

380381
impl $UDF {
381382
pub fn new() -> Self {
382-
use DataType::*;
383383
Self {
384-
signature: Signature::one_of(
385-
vec![
386-
TypeSignature::Exact(vec![Float32, Float32]),
387-
TypeSignature::Exact(vec![Float64, Float64]),
388-
],
384+
// Float64 is listed first so that integer (and other
385+
// non-float) arguments coerce to Float64 rather than
386+
// Float32; genuine Float32 arguments still match
387+
// exactly and stay in single precision. Coercing
388+
// integers to Float64 matters for correctness: Float32
389+
// has only a 24-bit mantissa, so widening a large
390+
// integer to Float32 would round it before the function
391+
// is ever applied.
392+
signature: Signature::uniform(
393+
2,
394+
vec![DataType::Float64, DataType::Float32],
389395
Volatility::Immutable,
390396
),
391397
}
@@ -402,11 +408,8 @@ macro_rules! make_math_binary_udf {
402408
}
403409

404410
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
405-
let arg_type = &arg_types[0];
406-
407-
match arg_type {
408-
DataType::Float32 => Ok(DataType::Float32),
409-
// For other types (possible values float64/null/int), use Float64
411+
match (&arg_types[0], &arg_types[1]) {
412+
(DataType::Float32, DataType::Float32) => Ok(DataType::Float32),
410413
_ => Ok(DataType::Float64),
411414
}
412415
}

datafusion/functions/src/math/monotonicity.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,11 @@ Can be a constant, column, or function, and any combination of arithmetic operat
262262
)
263263
.with_sql_example(r#"```sql
264264
> SELECT atan2(1, 1);
265-
+------------+
266-
| atan2(1,1) |
267-
+------------+
268-
| 0.7853982 |
269-
+------------+
265+
+--------------------+
266+
| atan2(1,1) |
267+
+--------------------+
268+
| 0.7853981633974483 |
269+
+--------------------+
270270
```"#)
271271
.build()
272272
});

datafusion/sqllogictest/test_files/scalar.slt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,26 @@ select round(atanh(a), 5), round(atanh(b), 5), round(atanh(c), 5) from small_flo
234234
query RRR rowsort
235235
select atan2(0, 1), atan2(1, 2), atan2(2, 2);
236236
----
237-
0 0.4636476 0.7853982
237+
0 0.463647609001 0.785398163397
238+
239+
# atan2 returns Float32 only when both arguments are Float32; every other
240+
# numeric combination (integers, Float64, mixed, NULL) is computed in Float64
241+
query TTTTTT
242+
select
243+
arrow_typeof(atan2(arrow_cast(1.0, 'Float32'), arrow_cast(1.0, 'Float32'))),
244+
arrow_typeof(atan2(1, 1)),
245+
arrow_typeof(atan2(arrow_cast(1.0, 'Float32'), arrow_cast(1.0, 'Float64'))),
246+
arrow_typeof(atan2(arrow_cast(1.0, 'Float64'), arrow_cast(1.0, 'Float32'))),
247+
arrow_typeof(atan2(null, null)),
248+
arrow_typeof(atan2(null, 64));
249+
----
250+
Float32 Float64 Float64 Float64 Float64 Float64
251+
252+
# atan2 with integer inputs is computed in double precision
253+
query B
254+
select atan2(1, 1000000) = atan2(1.0, 1000000.0);
255+
----
256+
true
238257

239258
# atan2 scalar nulls
240259
query R rowsort

docs/source/user-guide/sql/scalar_functions.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,11 @@ atan2(expression_y, expression_x)
227227

228228
```sql
229229
> SELECT atan2(1, 1);
230-
+------------+
231-
| atan2(1,1) |
232-
+------------+
233-
| 0.7853982 |
234-
+------------+
230+
+--------------------+
231+
| atan2(1,1) |
232+
+--------------------+
233+
| 0.7853981633974483 |
234+
+--------------------+
235235
```
236236

237237
### `atanh`

0 commit comments

Comments
 (0)