Skip to content

Commit 2d0b1ca

Browse files
committed
Test fix
1 parent 025ddde commit 2d0b1ca

4 files changed

Lines changed: 153 additions & 6 deletions

File tree

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20+
use std::any::{type_name, Any};
2021
use std::cmp::Ordering;
2122
use std::collections::{HashSet, VecDeque};
22-
use std::mem::{size_of, size_of_val};
23+
use std::hash::{DefaultHasher, Hash, Hasher};
24+
use std::mem::{size_of, size_of_val, take};
2325
use std::sync::Arc;
2426

2527
use arrow::array::{
@@ -31,7 +33,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Fields};
3133

3234
use datafusion_common::cast::as_list_array;
3335
use datafusion_common::scalar::copy_array_data;
34-
use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
36+
use datafusion_common::utils::{compare_rows, get_row_at_idx, SingleRowListArrayBuilder};
3537
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
3638
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3739
use datafusion_expr::utils::format_state_name;
@@ -74,22 +76,24 @@ This aggregation function can only mix DISTINCT and ORDER BY if the ordering exp
7476
"#,
7577
standard_argument(name = "expression",)
7678
)]
77-
#[derive(Debug)]
79+
#[derive(Debug, PartialEq, Eq, Hash)]
7880
/// ARRAY_AGG aggregate expression
7981
pub struct ArrayAgg {
8082
signature: Signature,
83+
is_input_pre_ordered: bool,
8184
}
8285

8386
impl Default for ArrayAgg {
8487
fn default() -> Self {
8588
Self {
8689
signature: Signature::any(1, Volatility::Immutable),
90+
is_input_pre_ordered: false,
8791
}
8892
}
8993
}
9094

9195
impl AggregateUDFImpl for ArrayAgg {
92-
fn as_any(&self) -> &dyn std::any::Any {
96+
fn as_any(&self) -> &dyn Any {
9397
self
9498
}
9599

@@ -144,6 +148,16 @@ impl AggregateUDFImpl for ArrayAgg {
144148
Ok(fields)
145149
}
146150

151+
fn with_beneficial_ordering(
152+
self: Arc<Self>,
153+
beneficial_ordering: bool,
154+
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
155+
Ok(Some(Arc::new(Self {
156+
signature: self.signature.clone(),
157+
is_input_pre_ordered: beneficial_ordering,
158+
})))
159+
}
160+
147161
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
148162
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
149163
let ignore_nulls =
@@ -196,6 +210,7 @@ impl AggregateUDFImpl for ArrayAgg {
196210
&data_type,
197211
&ordering_dtypes,
198212
ordering,
213+
self.is_input_pre_ordered,
199214
acc_args.is_reversed,
200215
ignore_nulls,
201216
)
@@ -209,6 +224,23 @@ impl AggregateUDFImpl for ArrayAgg {
209224
fn documentation(&self) -> Option<&Documentation> {
210225
self.doc()
211226
}
227+
228+
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
229+
let Some(other) = <dyn Any + 'static>::downcast_ref::<Self>(other.as_any())
230+
else {
231+
return false;
232+
};
233+
fn assert_self_impls_eq<T: Eq>() {}
234+
assert_self_impls_eq::<Self>();
235+
PartialEq::eq(self, other)
236+
}
237+
238+
fn hash_value(&self) -> u64 {
239+
let hasher = &mut DefaultHasher::new();
240+
type_name::<Self>().hash(hasher);
241+
Hash::hash(self, hasher);
242+
Hasher::finish(hasher)
243+
}
212244
}
213245

214246
#[derive(Debug)]
@@ -518,6 +550,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
518550
datatypes: Vec<DataType>,
519551
/// Stores the ordering requirement of the `Accumulator`.
520552
ordering_req: LexOrdering,
553+
/// Whether the input is known to be pre-ordered
554+
is_input_pre_ordered: bool,
521555
/// Whether the aggregation is running in reverse.
522556
reverse: bool,
523557
/// Whether the aggregation should ignore null values.
@@ -531,6 +565,7 @@ impl OrderSensitiveArrayAggAccumulator {
531565
datatype: &DataType,
532566
ordering_dtypes: &[DataType],
533567
ordering_req: LexOrdering,
568+
is_input_pre_ordered: bool,
534569
reverse: bool,
535570
ignore_nulls: bool,
536571
) -> Result<Self> {
@@ -541,11 +576,34 @@ impl OrderSensitiveArrayAggAccumulator {
541576
ordering_values: vec![],
542577
datatypes,
543578
ordering_req,
579+
is_input_pre_ordered,
544580
reverse,
545581
ignore_nulls,
546582
})
547583
}
548584

585+
fn sort(&mut self) {
586+
let sort_options = self
587+
.ordering_req
588+
.iter()
589+
.map(|sort_expr| sort_expr.options)
590+
.collect::<Vec<_>>();
591+
let mut values = take(&mut self.values)
592+
.into_iter()
593+
.zip(take(&mut self.ordering_values))
594+
.collect::<Vec<_>>();
595+
let mut delayed_cmp_err = Ok(());
596+
values.sort_by(|(_, left_ordering), (_, right_ordering)| {
597+
compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else(
598+
|err| {
599+
delayed_cmp_err = Err(err);
600+
Ordering::Equal
601+
},
602+
)
603+
});
604+
(self.values, self.ordering_values) = values.into_iter().unzip();
605+
}
606+
549607
fn evaluate_orderings(&self) -> Result<ScalarValue> {
550608
let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
551609

@@ -629,6 +687,9 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
629687
let mut partition_ordering_values = vec![];
630688

631689
// Existing values should be merged also.
690+
if !self.is_input_pre_ordered {
691+
self.sort();
692+
}
632693
partition_values.push(self.values.clone().into());
633694
partition_ordering_values.push(self.ordering_values.clone().into());
634695

@@ -679,13 +740,21 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
679740
}
680741

681742
fn state(&mut self) -> Result<Vec<ScalarValue>> {
743+
if !self.is_input_pre_ordered {
744+
self.sort();
745+
}
746+
682747
let mut result = vec![self.evaluate()?];
683748
result.push(self.evaluate_orderings()?);
684749

685750
Ok(result)
686751
}
687752

688753
fn evaluate(&mut self) -> Result<ScalarValue> {
754+
if !self.is_input_pre_ordered {
755+
self.sort();
756+
}
757+
689758
if self.values.is_empty() {
690759
return Ok(ScalarValue::new_null_list(
691760
self.datatypes[0].clone(),

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6028,6 +6028,84 @@ GROUP BY dummy
60286028
----
60296029
text1
60306030

6031+
6032+
# Test string_agg with ORDER BY clasuses (issue #17011)
6033+
statement ok
6034+
create table t (k varchar, v int);
6035+
6036+
statement ok
6037+
insert into t values ('a', 2), ('b', 3), ('c', 1), ('d', null);
6038+
6039+
query T
6040+
select string_agg(k, ',' order by k) from t;
6041+
----
6042+
a,b,c,d
6043+
6044+
query T
6045+
select string_agg(k, ',' order by k desc) from t;
6046+
----
6047+
d,c,b,a
6048+
6049+
query T
6050+
select string_agg(k, ',' order by v) from t;
6051+
----
6052+
c,a,b,d
6053+
6054+
query T
6055+
select string_agg(k, ',' order by v nulls first) from t;
6056+
----
6057+
d,c,a,b
6058+
6059+
query T
6060+
select string_agg(k, ',' order by v desc) from t;
6061+
----
6062+
d,b,a,c
6063+
6064+
query T
6065+
select string_agg(k, ',' order by v desc nulls last) from t;
6066+
----
6067+
b,a,c,d
6068+
6069+
query T
6070+
-- odd indexes should appear first, ties solved by v
6071+
select string_agg(k, ',' order by v % 2 == 0, v) from t;
6072+
----
6073+
c,b,a,d
6074+
6075+
query T
6076+
-- odd indexes should appear first, ties solved by v desc
6077+
select string_agg(k, ',' order by v % 2 == 0, v desc) from t;
6078+
----
6079+
b,c,a,d
6080+
6081+
query T
6082+
select string_agg(k, ',' order by
6083+
case
6084+
when k = 'a' then 3
6085+
when k = 'b' then 0
6086+
when k = 'c' then 2
6087+
when k = 'd' then 1
6088+
end)
6089+
from t;
6090+
----
6091+
b,d,c,a
6092+
6093+
query T
6094+
select string_agg(k, ',' order by
6095+
case
6096+
when k = 'a' then 3
6097+
when k = 'b' then 0
6098+
when k = 'c' then 2
6099+
when k = 'd' then 1
6100+
end desc)
6101+
from t;
6102+
----
6103+
a,c,d,b
6104+
6105+
statement ok
6106+
drop table t;
6107+
6108+
60316109
# Tests for aggregating with NaN values
60326110
statement ok
60336111
CREATE TABLE float_table (

0 commit comments

Comments
 (0)