Skip to content

Commit a0aa47a

Browse files
committed
feat(functions-aggregate): support sum(interval)
1 parent a00f749 commit a0aa47a

2 files changed

Lines changed: 121 additions & 1 deletion

File tree

datafusion/functions-aggregate/src/sum.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ use arrow::datatypes::{
2424
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type,
2525
Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType,
2626
DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef,
27-
Float64Type, Int64Type, TimeUnit, UInt64Type,
27+
Float64Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
28+
IntervalYearMonthType, TimeUnit, UInt64Type,
2829
};
2930
use datafusion_common::hash_utils::RandomState;
3031
use datafusion_common::internal_err;
@@ -117,6 +118,21 @@ macro_rules! downcast_sum {
117118
$args.return_field.data_type().clone()
118119
)
119120
}
121+
DataType::Interval(IntervalUnit::YearMonth) => {
122+
$helper!(
123+
IntervalYearMonthType,
124+
$args.return_field.data_type().clone()
125+
)
126+
}
127+
DataType::Interval(IntervalUnit::DayTime) => {
128+
$helper!(IntervalDayTimeType, $args.return_field.data_type().clone())
129+
}
130+
DataType::Interval(IntervalUnit::MonthDayNano) => {
131+
$helper!(
132+
IntervalMonthDayNanoType,
133+
$args.return_field.data_type().clone()
134+
)
135+
}
120136
_ => {
121137
not_impl_err!(
122138
"Sum not supported for {}: {}",
@@ -186,6 +202,9 @@ impl Sum {
186202
TypeSignature::Coercible(vec![Coercion::new_exact(
187203
TypeSignatureClass::Duration,
188204
)]),
205+
TypeSignature::Coercible(vec![Coercion::new_exact(
206+
TypeSignatureClass::Interval,
207+
)]),
189208
],
190209
Volatility::Immutable,
191210
),
@@ -232,6 +251,7 @@ impl AggregateUDFImpl for Sum {
232251
Ok(DataType::Decimal256(new_precision, *scale))
233252
}
234253
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
254+
DataType::Interval(interval_unit) => Ok(DataType::Interval(*interval_unit)),
235255
other => {
236256
exec_err!("[return_type] SUM not supported for {}", other)
237257
}
@@ -378,6 +398,11 @@ impl AggregateUDFImpl for Sum {
378398
if lit_type == DataType::Null {
379399
return Ok(None);
380400
}
401+
// Skip the rewrite for interval: it requires `lit * COUNT(arg)`, but
402+
// there is no generic Interval×Int64 multiplication kernel.
403+
if matches!(lit_type, DataType::Interval(_)) {
404+
return Ok(None);
405+
}
381406

382407
// Build up SUM(arg)
383408
let mut sum_agg = agg_function.clone();

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6551,6 +6551,101 @@ c NULL 2
65516551
statement ok
65526552
drop table dn;
65536553

6554+
# sum_interval
6555+
# Component-wise sum across all three Interval variants (matches PostgreSQL).
6556+
6557+
# Basic Interval(MonthDayNano): the issue's repro.
6558+
# (0 mons, 0 days, 1s) + (12 mons, 0, 0) + (1 mon, 0, 0) = (13 mons, 0, 1s)
6559+
query T?
6560+
SELECT arrow_typeof(sum(v)), sum(v) FROM (VALUES
6561+
(interval '1 second'),
6562+
(interval '1 year'),
6563+
(interval '1 month')) t(v);
6564+
----
6565+
Interval(MonthDayNano) 13 mons 1.000000000 secs
6566+
6567+
# NULLs are skipped.
6568+
query ?
6569+
SELECT sum(v) FROM (VALUES
6570+
(interval '1 day'),
6571+
(NULL),
6572+
(interval '2 days')) t(v);
6573+
----
6574+
3 days
6575+
6576+
# Empty input → NULL.
6577+
query ?
6578+
SELECT sum(v) FROM (VALUES (interval '1 day')) t(v) WHERE 1 = 0;
6579+
----
6580+
NULL
6581+
6582+
# GROUP BY exercises the PrimitiveGroupsAccumulator path.
6583+
query I? rowsort
6584+
SELECT k, sum(v) FROM (VALUES
6585+
(1, interval '1 day'),
6586+
(1, interval '2 days'),
6587+
(2, interval '1 month')) t(k, v)
6588+
GROUP BY k;
6589+
----
6590+
1 3 days
6591+
2 1 mons
6592+
6593+
# Interval(YearMonth) via cast.
6594+
query T?
6595+
SELECT arrow_typeof(sum(v)), sum(v) FROM (VALUES
6596+
(arrow_cast('1 year', 'Interval(YearMonth)')),
6597+
(arrow_cast('6 months', 'Interval(YearMonth)'))) t(v);
6598+
----
6599+
Interval(YearMonth) 1 years 6 mons
6600+
6601+
# Interval(DayTime) via cast.
6602+
query T?
6603+
SELECT arrow_typeof(sum(v)), sum(v) FROM (VALUES
6604+
(arrow_cast('1 day', 'Interval(DayTime)')),
6605+
(arrow_cast('1 day', 'Interval(DayTime)'))) t(v);
6606+
----
6607+
Interval(DayTime) 2 days
6608+
6609+
# Sliding window sum on intervals.
6610+
query ??
6611+
SELECT v, sum(v) OVER (ORDER BY v ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)
6612+
FROM (VALUES
6613+
(interval '1 day'),
6614+
(interval '2 days'),
6615+
(interval '3 days')) t(v);
6616+
----
6617+
1 days 1 days
6618+
2 days 3 days
6619+
3 days 5 days
6620+
6621+
# DISTINCT sum drops duplicates (DistinctSumAccumulator path).
6622+
query ?
6623+
SELECT sum(DISTINCT v) FROM (VALUES
6624+
(interval '1 day'),
6625+
(interval '1 day'),
6626+
(interval '2 days')) t(v);
6627+
----
6628+
3 days
6629+
6630+
# Regression: SUM(col + interval_lit) must NOT be rewritten to
6631+
# SUM(col) + lit * COUNT(col) by simplify_expr_op_literal — there is no
6632+
# Interval×Int64 multiplication kernel. Without the guard, this query fails.
6633+
query ?
6634+
SELECT sum(v + interval '1 day') FROM (VALUES
6635+
(interval '1 day'),
6636+
(interval '2 days'),
6637+
(interval '3 days')) t(v);
6638+
----
6639+
9 days
6640+
6641+
# Negative intervals: component-wise wrapping_add over signed i32/i64.
6642+
query ?
6643+
SELECT sum(v) FROM (VALUES
6644+
(interval '1 day'),
6645+
(interval '-3 days')) t(v);
6646+
----
6647+
-2 days
6648+
65546649
# Prepare the table with dictionary values for testing
65556650
statement ok
65566651
CREATE TABLE value(x bigint) AS VALUES (1), (2), (3), (1), (3), (4), (5), (2);

0 commit comments

Comments
 (0)