From 3de8efd6a0ea3796d329870f7357e2d83f07d5fe Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Sat, 23 May 2026 14:30:52 -0400 Subject: [PATCH 1/2] fix: widen `power(decimal, float)` to Float64 and simplify --- .../core/tests/expr_api/simplification.rs | 10 +- datafusion/functions/Cargo.toml | 5 + datafusion/functions/benches/power.rs | 140 +++++++++++ datafusion/functions/src/math/power.rs | 217 ++++++------------ .../sqllogictest/test_files/decimal.slt | 33 ++- datafusion/sqllogictest/test_files/math.slt | 15 +- 6 files changed, 260 insertions(+), 160 deletions(-) create mode 100644 datafusion/functions/benches/power.rs diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 245aba66849ce..6e1271ef19aa9 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -648,13 +648,19 @@ fn test_simplify_power() { let expected = col("c3_non_null"); test_simplify(expr, expected) } - // Power(c3, Log(c3, c4)) ===> c4 + // Power(c3, Log(c3, c4)) ===> cast(c4 AS Int64) + // The simplifier rewrites `power(b, log(b, x))` to `x`, but the + // rewritten expression must keep the same type as the original + // `power` call. `power`'s declared return type follows its base + // argument (c3 = Int64), so the UInt32 c4 has to be cast to Int64 + // to preserve the output schema the optimizer already committed to. { let expr = power( col("c3_non_null"), log(col("c3_non_null"), col("c4_non_null")), ); - let expected = col("c4_non_null"); + let expected = + Expr::Cast(Cast::new(Box::new(col("c4_non_null")), DataType::Int64)); test_simplify(expr, expected) } // Power(c3, c4) ===> Power(c3, c4) diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index d6a6693d862cc..4eca16961fa8c 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -212,6 +212,11 @@ harness = false name = "atan2" required-features = ["math_expressions"] +[[bench]] +harness = false +name = "power" +required-features = ["math_expressions"] + [[bench]] harness = false name = "substr_index" diff --git a/datafusion/functions/benches/power.rs b/datafusion/functions/benches/power.rs new file mode 100644 index 0000000000000..5336e42ebe59b --- /dev/null +++ b/datafusion/functions/benches/power.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Microbenchmark for `power(decimal_array, int_*)`. +//! +//! Covers both array- and scalar-shaped integer exponents on a Decimal +//! base. Both shapes are dispatched to the native per-row decimal kernel; +//! the bench guards against any future change that routes either shape +//! through a Float64 round-trip, which is measurably slower than the +//! decimal kernel for the cases the kernel can handle. + +extern crate criterion; + +use arrow::array::{Decimal128Array, Int64Array}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; +use datafusion_functions::math::power; +use std::hint::black_box; +use std::sync::Arc; + +fn make_decimal_array(size: usize, precision: u8, scale: i8) -> Decimal128Array { + // Use a fixed unscaled value (250) so the bench is independent of `scale`. + // The four-arm dispatch in `power` only cares about the Decimal variant + // and the exponent's shape, not the numeric value. + let arr = Decimal128Array::from(vec![250i128; size]); + arr.with_precision_and_scale(precision, scale).unwrap() +} + +fn make_int_array(size: usize, value: i64) -> Int64Array { + Int64Array::from(vec![value; size]) +} + +fn run_power( + power_fn: &ScalarUDF, + args: &[ColumnarValue], + arg_fields: &[FieldRef], + return_field: &FieldRef, + config_options: &Arc, + num_rows: usize, +) { + black_box( + power_fn + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.to_vec(), + number_rows: num_rows, + return_field: Arc::clone(return_field), + config_options: Arc::clone(config_options), + }) + .unwrap(), + ); +} + +fn criterion_benchmark(c: &mut Criterion) { + let power_fn = power(); + let config_options = Arc::new(ConfigOptions::default()); + let precision: u8 = 20; + let scale: i8 = 2; + let decimal_ty = DataType::Decimal128(precision, scale); + + // Exponents are bounded by what the native decimal kernel can handle + // without overflowing the i128 intermediate; see + // + let exponents = [2i64, 4, 8]; + + for size in [1024usize, 8192] { + let base_arr = Arc::new(make_decimal_array(size, precision, scale)); + let base_field: FieldRef = Field::new("base", decimal_ty.clone(), true).into(); + let exp_field: FieldRef = Field::new("exp", DataType::Int64, true).into(); + let return_field: FieldRef = Field::new("r", decimal_ty.clone(), true).into(); + let arg_fields = vec![base_field, exp_field]; + + for &exp in &exponents { + let exp_arr = Arc::new(make_int_array(size, exp)); + let array_args = vec![ + ColumnarValue::Array(base_arr.clone()), + ColumnarValue::Array(exp_arr), + ]; + c.bench_function( + &format!( + "power decimal({precision},{scale}) array x int array, exp={exp}, n={size}" + ), + |b| { + b.iter(|| { + run_power( + &power_fn, + &array_args, + &arg_fields, + &return_field, + &config_options, + size, + ) + }) + }, + ); + + let scalar_args = vec![ + ColumnarValue::Array(base_arr.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(exp))), + ]; + c.bench_function( + &format!( + "power decimal({precision},{scale}) array x int scalar, exp={exp}, n={size}" + ), + |b| { + b.iter(|| { + run_power( + &power_fn, + &scalar_args, + &arg_fields, + &return_field, + &config_options, + size, + ) + }) + }, + ); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 3fe30a1ffa86a..270bfd29c0524 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -32,7 +32,7 @@ use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + Cast, Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, lit, }; use datafusion_macros::user_doc; @@ -249,8 +249,7 @@ where }) } -/// Fallback implementation using f64 for negative or non-integer exponents. -/// This handles cases that cannot be computed using integer arithmetic. +/// Fallback for `pow_decimal_int` when the exponent is negative or non-integer. fn pow_decimal_float_fallback(base: T, scale: i8, exp: f64) -> Result where T: ToPrimitive + NumCast + Copy, @@ -271,7 +270,7 @@ where decimal_from_i128(result_i128) } -/// Decimal256 specialized float exponent version. +/// Like `pow_decimal_float`, but specialized for Decimal256. fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result { if exp.is_finite() && exp.trunc() == exp && exp >= 0f64 && exp < u32::MAX as f64 { return pow_decimal256_int(base, scale, exp as i64); @@ -286,7 +285,7 @@ fn pow_decimal256_float(base: i256, scale: i8, exp: f64) -> Result Result { if exp < 0 { return pow_decimal256_float(base, scale, exp as f64); @@ -346,7 +345,7 @@ fn pow_decimal256_int(base: i256, scale: i8, exp: i64) -> Result Result { - use arrow::compute::cast; - - let original_type = base.data_type().clone(); - let base_f64 = cast(base.as_ref(), &DataType::Float64)?; - - let exp_f64 = match exponent { - ColumnarValue::Array(arr) => cast(arr.as_ref(), &DataType::Float64)?, - ColumnarValue::Scalar(scalar) => { - let scalar_f64 = scalar.cast_to(&DataType::Float64)?; - scalar_f64.to_array_of_size(num_rows)? - } - }; - - let result_f64 = calculate_binary_math::( + let base_f64 = base + .cast_to(&DataType::Float64, None)? + .into_array(num_rows)?; + let result = calculate_binary_math::( &base_f64, - &ColumnarValue::Array(exp_f64), + exponent, float64_power_checked, )?; - - let result = cast(result_f64.as_ref(), &original_type)?; Ok(ColumnarValue::Array(result)) } @@ -410,11 +398,23 @@ impl ScalarUDFImpl for PowerFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[0].is_null() { - Ok(DataType::Float64) - } else { - Ok(arg_types[0].clone()) + // Return type as a function of (base, exponent). After signature + // coercion the operands are one of three shapes, plus a NULL on + // either side when an operand is a literal NULL: + // + // - NULL on either side -> Float64 (matches the other math + // UDFs, which all return Float64 for NULL input) + // - (Decimal, Float64) -> Float64 + // - (Decimal, Int64) -> the base's Decimal type + // - (Float64, Float64) -> Float64 + let [base, exponent] = take_function_args(self.name(), arg_types)?; + if base.is_null() || exponent.is_null() { + return Ok(DataType::Float64); + } + if base.is_decimal() && exponent.is_floating() { + return Ok(DataType::Float64); } + Ok(base.clone()) } fn aliases(&self) -> &[String] { @@ -424,22 +424,24 @@ impl ScalarUDFImpl for PowerFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [base, exponent] = take_function_args(self.name(), &args.args)?; - // For decimal types, only use native decimal - // operations when we have a scalar exponent. When the exponent is an array, - // fall back to float computation for better performance. - let use_float_fallback = matches!( - base.data_type(), - DataType::Decimal32(_, _) - | DataType::Decimal64(_, _) - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - ) && matches!(exponent, ColumnarValue::Array(_)); + // No native kernel exists for `(Decimal, Float64)`; bridge via + // Float64. The match below handles the remaining coerced shapes. + if base.data_type().is_decimal() && exponent.data_type().is_floating() { + return pow_decimal_via_float64(base, exponent, args.number_rows); + } let base = base.to_array(args.number_rows)?; - // If decimal with array exponent, cast to float and compute - if use_float_fallback { - return pow_decimal_with_float_fallback(&base, exponent, args.number_rows); + macro_rules! decimal_pow_arm { + ($decimal_ty:ident, $pow_fn:ident, $precision:expr, $scale:expr) => { + calculate_binary_decimal_math::<$decimal_ty, Int64Type, $decimal_ty, _>( + &base, + exponent, + |b, e| $pow_fn(b, *$scale, e), + *$precision, + *$scale, + )? + }; } let arr: ArrayRef = match (base.data_type(), exponent.data_type()) { @@ -451,106 +453,16 @@ impl ScalarUDFImpl for PowerFunc { )? } (DataType::Decimal32(precision, scale), DataType::Int64) => { - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, - )? - } - (DataType::Decimal32(precision, scale), DataType::Float64) => { - calculate_binary_decimal_math::< - Decimal32Type, - Float64Type, - Decimal32Type, - _, - >( - &base, - exponent, - |b, e| pow_decimal_float(b, *scale, e), - *precision, - *scale, - )? + decimal_pow_arm!(Decimal32Type, pow_decimal_int, precision, scale) } (DataType::Decimal64(precision, scale), DataType::Int64) => { - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, - )? - } - (DataType::Decimal64(precision, scale), DataType::Float64) => { - calculate_binary_decimal_math::< - Decimal64Type, - Float64Type, - Decimal64Type, - _, - >( - &base, - exponent, - |b, e| pow_decimal_float(b, *scale, e), - *precision, - *scale, - )? + decimal_pow_arm!(Decimal64Type, pow_decimal_int, precision, scale) } (DataType::Decimal128(precision, scale), DataType::Int64) => { - calculate_binary_decimal_math::< - Decimal128Type, - Int64Type, - Decimal128Type, - _, - >( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, - )? - } - (DataType::Decimal128(precision, scale), DataType::Float64) => { - calculate_binary_decimal_math::< - Decimal128Type, - Float64Type, - Decimal128Type, - _, - >( - &base, - exponent, - |b, e| pow_decimal_float(b, *scale, e), - *precision, - *scale, - )? + decimal_pow_arm!(Decimal128Type, pow_decimal_int, precision, scale) } (DataType::Decimal256(precision, scale), DataType::Int64) => { - calculate_binary_decimal_math::< - Decimal256Type, - Int64Type, - Decimal256Type, - _, - >( - &base, - exponent, - |b, e| pow_decimal256_int(b, *scale, e), - *precision, - *scale, - )? - } - (DataType::Decimal256(precision, scale), DataType::Float64) => { - calculate_binary_decimal_math::< - Decimal256Type, - Float64Type, - Decimal256Type, - _, - >( - &base, - exponent, - |b, e| pow_decimal256_float(b, *scale, e), - *precision, - *scale, - )? + decimal_pow_arm!(Decimal256Type, pow_decimal256_int, precision, scale) } (base_type, exp_type) => { return internal_err!( @@ -582,22 +494,43 @@ impl ScalarUDFImpl for PowerFunc { ))); } + // The simplified expression must keep `return_type`'s declared + // type. For example, `power(decimal, float)` returns `Float64`, + // so rewriting `power(decimal, 0)` to `1::decimal` or + // `power(decimal, 1)` to the decimal base would violate that + // contract and trip the optimizer's schema-compatibility check. + // Wrap simplified forms in a cast to `return_type` when needed. + let return_type = + self.return_type(&[base_type.clone(), exponent_type.clone()])?; + let needs_cast = |t: &DataType| *t != return_type; match exponent { Expr::Literal(value, _) if value == ScalarValue::new_zero(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( &base_type, - )?))) + )? + .cast_to(&return_type)?))) } Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => { - Ok(ExprSimplifyResult::Simplified(base)) + let result = if needs_cast(&base_type) { + Expr::Cast(Cast::new(Box::new(base), return_type)) + } else { + base + }; + Ok(ExprSimplifyResult::Simplified(result)) } Expr::ScalarFunction(ScalarFunction { func, mut args }) if is_log(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above - Ok(ExprSimplifyResult::Simplified(b)) + let b_type = info.get_data_type(&b)?; + let result = if needs_cast(&b_type) { + Expr::Cast(Cast::new(Box::new(b), return_type)) + } else { + b + }; + Ok(ExprSimplifyResult::Simplified(result)) } _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])), } diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 5faf801c84652..f650ce1a12627 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -1118,26 +1118,39 @@ SELECT power(2.5::decimal(38, 3), 4), arrow_typeof(power(2.5::decimal(38, 3), 4) ---- 39.062 Decimal128(38, 3) +# `power(decimal, float)` returns Float64: the result can grow beyond the +# base's Decimal precision (e.g. `2.5 ^ 4.0 = 39.0625` does not fit in +# `Decimal128(2, 1)`), so the function widens to Float64 instead of +# silently truncating. Inf / NaN are representable in Float64, so the +# edge cases below succeed rather than erroring on Decimal cast. query RT SELECT power(2.5, 4.0), arrow_typeof(power(2.5, 4.0)); ---- -39 Decimal128(2, 1) +39.0625 Float64 -# Non-integer exponent now works (fallback to f64) query RT SELECT power(2.5, 4.2), arrow_typeof(power(2.5, 4.2)); ---- -46.9 Decimal128(2, 1) +46.9189232024 Float64 -query error Compute error: Cannot use non-finite exp: NaN -SELECT power(2::decimal(38, 0), arrow_cast('NaN','Float64')) +query RT +SELECT power(2::decimal(38, 0), arrow_cast('NaN','Float64')), + arrow_typeof(power(2::decimal(38, 0), arrow_cast('NaN','Float64'))); +---- +NaN Float64 -query error Compute error: Cannot use non-finite exp: inf -SELECT power(2::decimal(38, 0), arrow_cast('INF','Float64')) +query RT +SELECT power(2::decimal(38, 0), arrow_cast('INF','Float64')), + arrow_typeof(power(2::decimal(38, 0), arrow_cast('INF','Float64'))); +---- +Infinity Float64 -# Floating above u32::max now works (fallback to f64, returns infinity which is an error) -query error Arrow error: Arithmetic overflow: Result of 2\^5000000000.1 is not finite -SELECT power(2::decimal(38, 0), 5000000000.1) +# Result overflows finite Float64 range +query RT +SELECT power(2::decimal(38, 0), 5000000000.1), + arrow_typeof(power(2::decimal(38, 0), 5000000000.1)); +---- +Infinity Float64 # Integer Above u32::max - still goes through integer path which fails query error Arrow error: Arithmetic overflow: Unsupported exp value diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 475434883d315..a186699303108 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -818,6 +818,9 @@ from values 81 NULL +# `power` accepts (Decimal, Int), (Decimal, Float), or (Float, Float) -- there +# is no form that takes a Decimal exponent. So a Decimal exponent coerces to +# Float64, and `power(decimal, float)` returns Float64. query RT rowsort select power(base::decimal(38, 0), exponent::decimal(38, 0)), @@ -830,12 +833,12 @@ from values (2, 3), (3, 4) as t(base, exponent); ---- -0 Decimal128(38, 0) -1 Decimal128(38, 0) -4 Decimal128(38, 0) -625 Decimal128(38, 0) -8 Decimal128(38, 0) -81 Decimal128(38, 0) +0 Float64 +1 Float64 +4 Float64 +625 Float64 +8 Float64 +81 Float64 query RT select From 6b3e7cc77c7fa65a1e1e00184f060d8a36dc6f09 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Sun, 24 May 2026 10:29:50 -0400 Subject: [PATCH 2/2] Comment cleanup --- datafusion/functions/src/math/power.rs | 12 ++++-------- datafusion/sqllogictest/test_files/decimal.slt | 5 ----- datafusion/sqllogictest/test_files/math.slt | 6 +++--- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 270bfd29c0524..e9da7f09b0df1 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -370,8 +370,6 @@ fn pow_decimal256_float_fallback( /// Compute `power(decimal_base, float_exponent)` by casting the base to /// `Float64` and running `pow` in float space; returns `Float64`. -/// `calculate_binary_math` casts the exponent internally and preserves -/// scalar shape, so we only need to materialize the base here. fn pow_decimal_via_float64( base: &ColumnarValue, exponent: &ColumnarValue, @@ -399,11 +397,9 @@ impl ScalarUDFImpl for PowerFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { // Return type as a function of (base, exponent). After signature - // coercion the operands are one of three shapes, plus a NULL on - // either side when an operand is a literal NULL: + // coercion, we have to handle the following cases: // - // - NULL on either side -> Float64 (matches the other math - // UDFs, which all return Float64 for NULL input) + // - NULL on either side -> Float64 (typed NULL) // - (Decimal, Float64) -> Float64 // - (Decimal, Int64) -> the base's Decimal type // - (Float64, Float64) -> Float64 @@ -424,8 +420,8 @@ impl ScalarUDFImpl for PowerFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [base, exponent] = take_function_args(self.name(), &args.args)?; - // No native kernel exists for `(Decimal, Float64)`; bridge via - // Float64. The match below handles the remaining coerced shapes. + // No native kernel exists for `(Decimal, Float64)`; bridge by casting + // the base to Float64. if base.data_type().is_decimal() && exponent.data_type().is_floating() { return pow_decimal_via_float64(base, exponent, args.number_rows); } diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index f650ce1a12627..d9eac8492814c 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -1118,11 +1118,6 @@ SELECT power(2.5::decimal(38, 3), 4), arrow_typeof(power(2.5::decimal(38, 3), 4) ---- 39.062 Decimal128(38, 3) -# `power(decimal, float)` returns Float64: the result can grow beyond the -# base's Decimal precision (e.g. `2.5 ^ 4.0 = 39.0625` does not fit in -# `Decimal128(2, 1)`), so the function widens to Float64 instead of -# silently truncating. Inf / NaN are representable in Float64, so the -# edge cases below succeed rather than erroring on Decimal cast. query RT SELECT power(2.5, 4.0), arrow_typeof(power(2.5, 4.0)); ---- diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index a186699303108..51ceb4e716036 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -818,9 +818,9 @@ from values 81 NULL -# `power` accepts (Decimal, Int), (Decimal, Float), or (Float, Float) -- there -# is no form that takes a Decimal exponent. So a Decimal exponent coerces to -# Float64, and `power(decimal, float)` returns Float64. +# There is no variant of `power` that accepts (Decimal, Decimal); type coercion +# will cast the exponent to `Float64`, and `power(decimal, float)` returns +# `Float64`. query RT rowsort select power(base::decimal(38, 0), exponent::decimal(38, 0)),