Skip to content

Commit 4b3fd48

Browse files
committed
feat: fused WideDecimalBinaryExpr for Decimal128 add/sub/mul
Replace the 4-node expression tree (Cast→BinaryExpr→Cast→Cast) used for Decimal128 arithmetic that may overflow with a single fused expression that performs i256 register arithmetic directly. This reduces per-batch allocation from 4 intermediate arrays (112 bytes/elem) to 1 output array (16 bytes/elem). The new WideDecimalBinaryExpr evaluates children, performs add/sub/mul using i256 intermediates via try_binary, applies scale adjustment with HALF_UP rounding, checks precision bounds, and outputs a single Decimal128 array. Follows the same pattern as decimal_div.
1 parent 6cff599 commit 4b3fd48

4 files changed

Lines changed: 528 additions & 22 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ use datafusion_comet_spark_expr::{
128128
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
129129
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RandExpr,
130130
RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance,
131+
WideDecimalBinaryExpr, WideDecimalOp,
131132
};
132133
use itertools::Itertools;
133134
use jni::objects::GlobalRef;
@@ -674,31 +675,31 @@ impl PhysicalPlanner {
674675
) {
675676
(
676677
DataFusionOperator::Plus | DataFusionOperator::Minus | DataFusionOperator::Multiply,
677-
Ok(DataType::Decimal128(p1, s1)),
678-
Ok(DataType::Decimal128(p2, s2)),
678+
Ok(DataType::Decimal128(_p1, _s1)),
679+
Ok(DataType::Decimal128(_p2, _s2)),
679680
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
680-
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
681+
&& max(_s1, _s2) as u8 + max(_p1 - _s1 as u8, _p2 - _s2 as u8)
681682
>= DECIMAL128_MAX_PRECISION)
682-
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) =>
683+
|| (op == DataFusionOperator::Multiply
684+
&& _p1 + _p2 >= DECIMAL128_MAX_PRECISION) =>
683685
{
684686
let data_type = return_type.map(to_arrow_datatype).unwrap();
685-
// For some Decimal128 operations, we need wider internal digits.
686-
// Cast left and right to Decimal256 and cast the result back to Decimal128
687-
let left = Arc::new(Cast::new(
688-
left,
689-
DataType::Decimal256(p1, s1),
690-
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
691-
));
692-
let right = Arc::new(Cast::new(
693-
right,
694-
DataType::Decimal256(p2, s2),
695-
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
696-
));
697-
let child = Arc::new(BinaryExpr::new(left, op, right));
698-
Ok(Arc::new(Cast::new(
699-
child,
700-
data_type,
701-
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
687+
let (p_out, s_out) = match &data_type {
688+
DataType::Decimal128(p, s) => (*p, *s),
689+
dt => {
690+
return Err(ExecutionError::GeneralError(format!(
691+
"Expected Decimal128 return type, got {dt:?}"
692+
)))
693+
}
694+
};
695+
let wide_op = match op {
696+
DataFusionOperator::Plus => WideDecimalOp::Add,
697+
DataFusionOperator::Minus => WideDecimalOp::Subtract,
698+
DataFusionOperator::Multiply => WideDecimalOp::Multiply,
699+
_ => unreachable!(),
700+
};
701+
Ok(Arc::new(WideDecimalBinaryExpr::new(
702+
left, right, wide_op, p_out, s_out, eval_mode,
702703
)))
703704
}
704705
(

native/spark-expr/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ pub use json_funcs::{FromJson, ToJson};
8181
pub use math_funcs::{
8282
create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
8383
spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex,
84-
spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero,
84+
spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero, WideDecimalBinaryExpr,
85+
WideDecimalOp,
8586
};
8687
pub use string_funcs::*;
8788

native/spark-expr/src/math_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod negative;
2626
mod round;
2727
pub(crate) mod unhex;
2828
mod utils;
29+
mod wide_decimal_binary_expr;
2930

3031
pub use ceil::spark_ceil;
3132
pub use div::spark_decimal_div;
@@ -36,3 +37,4 @@ pub use modulo_expr::create_modulo_expr;
3637
pub use negative::{create_negate_expr, NegativeExpr};
3738
pub use round::spark_round;
3839
pub use unhex::spark_unhex;
40+
pub use wide_decimal_binary_expr::{WideDecimalBinaryExpr, WideDecimalOp};

0 commit comments

Comments
 (0)