Skip to content

Commit 03e833b

Browse files
authored
fix(spark-expr): preserve scalar tag in WideDecimalBinaryExpr when both inputs are scalars (#4187)
1 parent 6f17d48 commit 03e833b

1 file changed

Lines changed: 83 additions & 1 deletion

File tree

native/spark-expr/src/math_funcs/wide_decimal_binary_expr.rs

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,17 @@ impl PhysicalExpr for WideDecimalBinaryExpr {
184184
let left_val = self.left.evaluate(batch)?;
185185
let right_val = self.right.evaluate(batch)?;
186186

187+
// Track scalar-ness so we can return a Scalar when both inputs are scalars.
188+
// Without this, a (Scalar op Scalar) result would be returned as a length-1
189+
// Array, and downstream comparisons against full batches would incorrectly
190+
// see two Array operands with mismatched lengths instead of (Array, Scalar),
191+
// crashing arrow-ord's compare_op with "Cannot compare arrays of different
192+
// lengths". This pattern appears, for example, in TPC-DS q23's BHJ filter
193+
// `0.95 * scalar_subquery > ssales`.
194+
let both_scalar = matches!(
195+
(&left_val, &right_val),
196+
(ColumnarValue::Scalar(_), ColumnarValue::Scalar(_))
197+
);
187198
let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (&left_val, &right_val) {
188199
(ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)),
189200
(ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
@@ -280,7 +291,16 @@ impl PhysicalExpr for WideDecimalBinaryExpr {
280291
result
281292
};
282293
let result = result.with_data_type(DataType::Decimal128(p_out, s_out));
283-
Ok(ColumnarValue::Array(Arc::new(result)))
294+
if both_scalar {
295+
// Convert the length-1 result back into a Scalar so downstream
296+
// expressions (binary ops, comparisons) can take the scalar fast-path
297+
// and propagate the scalar tag (Datum::is_scalar) through arrow-rs
298+
// kernels.
299+
let scalar = datafusion::common::ScalarValue::try_from_array(&result, 0)?;
300+
Ok(ColumnarValue::Scalar(scalar))
301+
} else {
302+
Ok(ColumnarValue::Array(Arc::new(result)))
303+
}
284304
}
285305

286306
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
@@ -557,4 +577,66 @@ mod tests {
557577
let arr = result.as_primitive::<Decimal128Type>();
558578
assert_eq!(arr.value(0), 20000); // 2.0000
559579
}
580+
581+
/// Regression test for the Scalar x Scalar wide-decimal evaluation path.
582+
///
583+
/// When both inputs are `ColumnarValue::Scalar`, `evaluate` must return a
584+
/// `ColumnarValue::Scalar` -- not a length-1 `ColumnarValue::Array`. Otherwise
585+
/// downstream comparisons against full batches see two `Array` operands with
586+
/// mismatched lengths and arrow-ord's `compare_op` rejects them with
587+
/// "Cannot compare arrays of different lengths, got N vs 1". This pattern
588+
/// appears, for example, in TPC-DS q23's BHJ filter
589+
/// `0.95 * scalar_subquery > ssales`.
590+
#[test]
591+
fn test_scalar_scalar_returns_scalar() {
592+
use datafusion::common::ScalarValue;
593+
use datafusion::physical_expr::expressions::Literal;
594+
595+
// 0.95 * 100.00 -- the same Scalar x Scalar decimal multiply pattern that
596+
// appears in TPC-DS q23's filter `0.95 * scalar_subquery > ssales`.
597+
let left: Arc<dyn PhysicalExpr> =
598+
Arc::new(Literal::new(ScalarValue::Decimal128(Some(95), 38, 2)));
599+
let right: Arc<dyn PhysicalExpr> =
600+
Arc::new(Literal::new(ScalarValue::Decimal128(Some(10000), 38, 2)));
601+
602+
let expr = WideDecimalBinaryExpr::new(
603+
left,
604+
right,
605+
WideDecimalOp::Multiply,
606+
38,
607+
2,
608+
EvalMode::Legacy,
609+
);
610+
611+
// Empty schema -- both inputs are Literals so no columns are needed.
612+
let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
613+
match expr.evaluate(&batch).unwrap() {
614+
ColumnarValue::Scalar(ScalarValue::Decimal128(Some(v), 38, 2)) => {
615+
// 0.95 * 100.00 = 95.00 -> at scale 2, integer 9500
616+
assert_eq!(v, 9500);
617+
}
618+
ColumnarValue::Scalar(other) => {
619+
panic!("expected Decimal128(Some(_), 38, 2), got {other:?}");
620+
}
621+
ColumnarValue::Array(_) => {
622+
panic!("Scalar x Scalar must return ColumnarValue::Scalar, not Array");
623+
}
624+
}
625+
}
626+
627+
/// Companion test: when at least one input is an Array, the result must remain an Array.
628+
/// Guards against over-eager scalar-unwrapping in the fix.
629+
#[test]
630+
fn test_array_input_returns_array() {
631+
let batch = make_batch(
632+
vec![Some(150), Some(250)],
633+
38,
634+
2,
635+
vec![Some(100), Some(200)],
636+
38,
637+
2,
638+
);
639+
let result = eval_expr(&batch, WideDecimalOp::Add, 38, 2, EvalMode::Legacy).unwrap();
640+
assert_eq!(result.len(), 2);
641+
}
560642
}

0 commit comments

Comments
 (0)