Skip to content

Commit a08b70e

Browse files
authored
feat(query): add decimal sum widening setting (#19836)
1 parent 3afb936 commit a08b70e

4 files changed

Lines changed: 77 additions & 0 deletions

File tree

src/query/settings/src/settings_default.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,13 @@ impl DefaultSettings {
12551255
scope: SettingScope::Both,
12561256
range: Some(SettingRange::Numeric(0..=1)),
12571257
}),
1258+
("enable_decimal_sum_widening", DefaultSettingValue {
1259+
value: UserSettingValue::UInt64(0),
1260+
desc: "Automatically widen SUM arguments from Decimal(19..38, scale) to Decimal(76, scale).",
1261+
mode: SettingMode::Both,
1262+
scope: SettingScope::Both,
1263+
range: Some(SettingRange::Numeric(0..=1)),
1264+
}),
12581265
("statement_queued_timeout_in_seconds", DefaultSettingValue {
12591266
value: UserSettingValue::UInt64(0),
12601267
desc: "The maximum waiting seconds in the queue. The default value is 0(no limit).",

src/query/settings/src/settings_getter_setter.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,10 @@ impl Settings {
907907
Ok(self.try_get_u64("use_legacy_query_executor")? == 1)
908908
}
909909

910+
pub fn get_enable_decimal_sum_widening(&self) -> Result<bool> {
911+
Ok(self.try_get_u64("enable_decimal_sum_widening")? != 0)
912+
}
913+
910914
pub fn get_statement_queued_timeout(&self) -> Result<u64> {
911915
self.try_get_u64("statement_queued_timeout_in_seconds")
912916
}

src/query/sql/src/planner/semantic/type_check/aggregate.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ use databend_common_expression::FunctionContext;
2121
use databend_common_expression::Scalar;
2222
use databend_common_expression::type_check::check_number;
2323
use databend_common_expression::types::DataType;
24+
use databend_common_expression::types::Decimal;
2425
use databend_common_expression::types::NumberScalar;
26+
use databend_common_expression::types::decimal::DecimalSize;
27+
use databend_common_expression::types::i256;
2528
use databend_common_functions::BUILTIN_FUNCTIONS;
2629
use databend_common_functions::aggregates::AggregateFunctionFactory;
2730

@@ -30,7 +33,9 @@ use crate::binder::ExprContext;
3033
use crate::planner::metadata::optimize_remove_count_args;
3134
use crate::plans::AggregateFunction;
3235
use crate::plans::AggregateFunctionScalarSortDesc;
36+
use crate::plans::CastExpr;
3337
use crate::plans::ConstantExpr;
38+
use crate::plans::ScalarExpr;
3439

3540
impl<'a> TypeChecker<'a> {
3641
/// Resolve aggregation function call.
@@ -98,6 +103,8 @@ impl<'a> TypeChecker<'a> {
98103
self.in_aggregate_function = false;
99104
let (mut arguments, mut arg_types) = arguments_result?;
100105

106+
self.try_widen_sum_decimal_argument(func_name, &mut arguments, &mut arg_types)?;
107+
101108
let sort_descs = order_by
102109
.iter()
103110
.map(
@@ -202,4 +209,45 @@ impl<'a> TypeChecker<'a> {
202209

203210
Ok((new_agg_func, data_type))
204211
}
212+
213+
fn try_widen_sum_decimal_argument(
214+
&self,
215+
func_name: &str,
216+
arguments: &mut [ScalarExpr],
217+
arg_types: &mut [DataType],
218+
) -> Result<()> {
219+
if !func_name.eq_ignore_ascii_case("sum")
220+
|| arguments.len() != 1
221+
|| !self.ctx.get_settings().get_enable_decimal_sum_widening()?
222+
{
223+
return Ok(());
224+
}
225+
226+
let input_is_nullable = arg_types[0].is_nullable();
227+
let DataType::Decimal(size) = arg_types[0].remove_nullable() else {
228+
return Ok(());
229+
};
230+
231+
if !size.can_carried_by_128() || size.precision() <= i64::MAX_PRECISION {
232+
return Ok(());
233+
}
234+
235+
let mut target_type = DataType::Decimal(DecimalSize::new_unchecked(
236+
i256::MAX_PRECISION,
237+
size.scale(),
238+
));
239+
if input_is_nullable {
240+
target_type = target_type.wrap_nullable();
241+
}
242+
243+
arguments[0] = ScalarExpr::CastExpr(CastExpr {
244+
span: arguments[0].span(),
245+
is_try: false,
246+
argument: Box::new(arguments[0].clone()),
247+
target_type: Box::new(target_type.clone()),
248+
});
249+
arg_types[0] = target_type;
250+
251+
Ok(())
252+
}
205253
}

tests/sqllogictests/suites/query/aggregate.test

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,24 @@ SELECT SUM(agg0) FROM (
6969
----
7070
4.0
7171

72+
query TT
73+
select typeof(sum(number::Decimal(18, 1))), typeof(sum(number::Decimal(19, 1))) from numbers(1000);
74+
----
75+
DECIMAL(18, 1) NULL DECIMAL(38, 1) NULL
76+
77+
query TT
78+
settings(enable_decimal_sum_widening=1) select typeof(sum(number::Decimal(18, 1))), typeof(sum(number::Decimal(19, 1))) from numbers(1000);
79+
----
80+
DECIMAL(18, 1) NULL DECIMAL(76, 1) NULL
81+
82+
statement error Decimal overflow
83+
select sum(a) from (select '99999999999999999999999999999999999999'::Decimal(38, 0) as a union all select 1::Decimal(38, 0) as a)
84+
85+
query RT
86+
settings(enable_decimal_sum_widening=1) select sum(a), typeof(sum(a)) from (select '99999999999999999999999999999999999999'::Decimal(38, 0) as a union all select 1::Decimal(38, 0) as a);
87+
----
88+
100000000000000000000000000000000000000 DECIMAL(76, 0) NULL
89+
7290
query RRT
7391
select avg(number * number), avg( (number * number)::Decimal(39, 7) ), typeof(avg( (number * number)::Decimal(39, 7) )) from numbers(100);
7492
----

0 commit comments

Comments
 (0)