Skip to content

Commit 2de6e4a

Browse files
nuno-fariaLiaCastaneda
authored andcommitted
Implement alternative fix
1 parent d4ebaa2 commit 2de6e4a

2 files changed

Lines changed: 8 additions & 72 deletions

File tree

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 4 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717

1818
//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20-
use std::any::{type_name, Any};
20+
use std::any::Any;
2121
use std::cmp::Ordering;
2222
use std::collections::{HashSet, VecDeque};
23-
use std::hash::{DefaultHasher, Hash, Hasher};
24-
use std::mem::{size_of, size_of_val, take};
23+
use std::mem::{size_of, size_of_val};
2524
use std::sync::Arc;
2625

2726
use arrow::array::{
@@ -33,7 +32,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Fields};
3332

3433
use datafusion_common::cast::as_list_array;
3534
use datafusion_common::scalar::copy_array_data;
36-
use datafusion_common::utils::{compare_rows, get_row_at_idx, SingleRowListArrayBuilder};
35+
use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
3736
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
3837
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3938
use datafusion_expr::utils::format_state_name;
@@ -76,18 +75,16 @@ This aggregation function can only mix DISTINCT and ORDER BY if the ordering exp
7675
"#,
7776
standard_argument(name = "expression",)
7877
)]
79-
#[derive(Debug, PartialEq, Eq, Hash)]
78+
#[derive(Debug)]
8079
/// ARRAY_AGG aggregate expression
8180
pub struct ArrayAgg {
8281
signature: Signature,
83-
is_input_pre_ordered: bool,
8482
}
8583

8684
impl Default for ArrayAgg {
8785
fn default() -> Self {
8886
Self {
8987
signature: Signature::any(1, Volatility::Immutable),
90-
is_input_pre_ordered: false,
9188
}
9289
}
9390
}
@@ -148,16 +145,6 @@ impl AggregateUDFImpl for ArrayAgg {
148145
Ok(fields)
149146
}
150147

151-
fn with_beneficial_ordering(
152-
self: Arc<Self>,
153-
beneficial_ordering: bool,
154-
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
155-
Ok(Some(Arc::new(Self {
156-
signature: self.signature.clone(),
157-
is_input_pre_ordered: beneficial_ordering,
158-
})))
159-
}
160-
161148
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
162149
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
163150
let ignore_nulls =
@@ -210,7 +197,6 @@ impl AggregateUDFImpl for ArrayAgg {
210197
&data_type,
211198
&ordering_dtypes,
212199
ordering,
213-
self.is_input_pre_ordered,
214200
acc_args.is_reversed,
215201
ignore_nulls,
216202
)
@@ -224,23 +210,6 @@ impl AggregateUDFImpl for ArrayAgg {
224210
fn documentation(&self) -> Option<&Documentation> {
225211
self.doc()
226212
}
227-
228-
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
229-
let Some(other) = <dyn Any + 'static>::downcast_ref::<Self>(other.as_any())
230-
else {
231-
return false;
232-
};
233-
fn assert_self_impls_eq<T: Eq>() {}
234-
assert_self_impls_eq::<Self>();
235-
PartialEq::eq(self, other)
236-
}
237-
238-
fn hash_value(&self) -> u64 {
239-
let hasher = &mut DefaultHasher::new();
240-
type_name::<Self>().hash(hasher);
241-
Hash::hash(self, hasher);
242-
Hasher::finish(hasher)
243-
}
244213
}
245214

246215
#[derive(Debug)]
@@ -550,8 +519,6 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
550519
datatypes: Vec<DataType>,
551520
/// Stores the ordering requirement of the `Accumulator`.
552521
ordering_req: LexOrdering,
553-
/// Whether the input is known to be pre-ordered
554-
is_input_pre_ordered: bool,
555522
/// Whether the aggregation is running in reverse.
556523
reverse: bool,
557524
/// Whether the aggregation should ignore null values.
@@ -565,7 +532,6 @@ impl OrderSensitiveArrayAggAccumulator {
565532
datatype: &DataType,
566533
ordering_dtypes: &[DataType],
567534
ordering_req: LexOrdering,
568-
is_input_pre_ordered: bool,
569535
reverse: bool,
570536
ignore_nulls: bool,
571537
) -> Result<Self> {
@@ -576,34 +542,11 @@ impl OrderSensitiveArrayAggAccumulator {
576542
ordering_values: vec![],
577543
datatypes,
578544
ordering_req,
579-
is_input_pre_ordered,
580545
reverse,
581546
ignore_nulls,
582547
})
583548
}
584549

585-
fn sort(&mut self) {
586-
let sort_options = self
587-
.ordering_req
588-
.iter()
589-
.map(|sort_expr| sort_expr.options)
590-
.collect::<Vec<_>>();
591-
let mut values = take(&mut self.values)
592-
.into_iter()
593-
.zip(take(&mut self.ordering_values))
594-
.collect::<Vec<_>>();
595-
let mut delayed_cmp_err = Ok(());
596-
values.sort_by(|(_, left_ordering), (_, right_ordering)| {
597-
compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else(
598-
|err| {
599-
delayed_cmp_err = Err(err);
600-
Ordering::Equal
601-
},
602-
)
603-
});
604-
(self.values, self.ordering_values) = values.into_iter().unzip();
605-
}
606-
607550
fn evaluate_orderings(&self) -> Result<ScalarValue> {
608551
let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
609552

@@ -687,9 +630,6 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
687630
let mut partition_ordering_values = vec![];
688631

689632
// Existing values should be merged also.
690-
if !self.is_input_pre_ordered {
691-
self.sort();
692-
}
693633
partition_values.push(self.values.clone().into());
694634
partition_ordering_values.push(self.ordering_values.clone().into());
695635

@@ -740,21 +680,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
740680
}
741681

742682
fn state(&mut self) -> Result<Vec<ScalarValue>> {
743-
if !self.is_input_pre_ordered {
744-
self.sort();
745-
}
746-
747683
let mut result = vec![self.evaluate()?];
748684
result.push(self.evaluate_orderings()?);
749685

750686
Ok(result)
751687
}
752688

753689
fn evaluate(&mut self) -> Result<ScalarValue> {
754-
if !self.is_input_pre_ordered {
755-
self.sort();
756-
}
757-
758690
if self.values.is_empty() {
759691
return Ok(ScalarValue::new_null_list(
760692
self.datatypes[0].clone(),

datafusion/functions-aggregate/src/string_agg.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ impl AggregateUDFImpl for StringAgg {
178178
)))
179179
}
180180

181+
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
182+
datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
183+
}
184+
181185
fn documentation(&self) -> Option<&Documentation> {
182186
self.doc()
183187
}

0 commit comments

Comments
 (0)