Skip to content

Commit 83fc799

Browse files
committed
feat: implement retract_batch for array_agg(DISTINCT) sliding window
1 parent 7843ab3 commit 83fc799

3 files changed

Lines changed: 447 additions & 12 deletions

File tree

datafusion/common/src/scalar/mod.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ mod struct_builder;
2323

2424
use std::borrow::Borrow;
2525
use std::cmp::Ordering;
26-
use std::collections::{HashSet, VecDeque};
26+
use std::collections::{HashMap, HashSet, VecDeque};
2727
use std::convert::Infallible;
2828
use std::fmt;
2929
use std::fmt::Write;
@@ -4753,6 +4753,18 @@ impl ScalarValue {
47534753
.sum::<usize>()
47544754
}
47554755

4756+
/// Estimates [size](Self::size) of [`HashMap`] keyed by [`ScalarValue`] in bytes.
4757+
///
4758+
/// Includes the size of the [`HashMap`] container itself. Heap payload of
4759+
/// `V` is not accounted for; callers storing heap-backed values should
4760+
/// supplement this estimate.
4761+
#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key
4762+
pub fn size_of_hashmap<V, S>(map: &HashMap<Self, V, S>) -> usize {
4763+
size_of_val(map)
4764+
+ ((size_of::<ScalarValue>() + size_of::<V>()) * map.capacity())
4765+
+ map.keys().map(|k| k.size() - size_of_val(k)).sum::<usize>()
4766+
}
4767+
47564768
/// Compacts the allocation referenced by `self` to the minimum, copying the data if
47574769
/// necessary.
47584770
///

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 185 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
2020
use std::cmp::Ordering;
21-
use std::collections::{HashSet, VecDeque};
21+
use std::collections::{HashMap, VecDeque};
2222
use std::mem::{size_of, size_of_val, take};
2323
use std::sync::Arc;
2424

@@ -34,7 +34,9 @@ use datafusion_common::cast::as_list_array;
3434
use datafusion_common::utils::{
3535
SingleRowListArrayBuilder, compare_rows, get_row_at_idx, take_function_args,
3636
};
37-
use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err, exec_err};
37+
use datafusion_common::{
38+
Result, ScalarValue, assert_eq_or_internal_err, exec_err, internal_err,
39+
};
3840
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3941
use datafusion_expr::utils::format_state_name;
4042
use datafusion_expr::{
@@ -814,7 +816,10 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator {
814816

815817
#[derive(Debug)]
816818
pub struct DistinctArrayAggAccumulator {
817-
values: HashSet<ScalarValue>,
819+
// Value → live refcount. Multiset state lets `retract_batch` correctly
820+
// drop a duplicate occurrence while keeping the key alive if other
821+
// copies remain in the current window frame.
822+
values: HashMap<ScalarValue, u64>,
818823
datatype: DataType,
819824
sort_options: Option<SortOptions>,
820825
ignore_nulls: bool,
@@ -827,7 +832,7 @@ impl DistinctArrayAggAccumulator {
827832
ignore_nulls: bool,
828833
) -> Result<Self> {
829834
Ok(Self {
830-
values: HashSet::new(),
835+
values: HashMap::new(),
831836
datatype: datatype.clone(),
832837
sort_options,
833838
ignore_nulls,
@@ -856,8 +861,8 @@ impl Accumulator for DistinctArrayAggAccumulator {
856861
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
857862
for i in 0..val.len() {
858863
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
859-
self.values
860-
.insert(ScalarValue::try_from_array(val, i)?.compacted());
864+
let key = ScalarValue::try_from_array(val, i)?.compacted();
865+
*self.values.entry(key).or_insert(0) += 1;
861866
}
862867
}
863868
}
@@ -872,6 +877,12 @@ impl Accumulator for DistinctArrayAggAccumulator {
872877

873878
assert_eq_or_internal_err!(states.len(), 1, "expects single state");
874879

880+
// The DISTINCT state schema is `List<value>` — partial accumulators
881+
// ship the set of values they saw, not multiplicities. Re-ingesting
882+
// each element here makes the merged counts represent "partitions
883+
// that emitted this value," which is fine because `evaluate` only
884+
// reads keys. Refcount semantics for retract are only valid within
885+
// a single accumulator instance (window execution).
875886
states[0]
876887
.as_list::<i32>()
877888
.iter()
@@ -880,7 +891,7 @@ impl Accumulator for DistinctArrayAggAccumulator {
880891
}
881892

882893
fn evaluate(&mut self) -> Result<ScalarValue> {
883-
let mut values: Vec<ScalarValue> = self.values.iter().cloned().collect();
894+
let mut values: Vec<ScalarValue> = self.values.keys().cloned().collect();
884895
if values.is_empty() {
885896
return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1));
886897
}
@@ -916,8 +927,50 @@ impl Accumulator for DistinctArrayAggAccumulator {
916927
Ok(ScalarValue::List(arr))
917928
}
918929

930+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
931+
if values.is_empty() {
932+
return Ok(());
933+
}
934+
935+
assert_eq_or_internal_err!(values.len(), 1, "expects single batch");
936+
937+
let val = &values[0];
938+
let nulls = if self.ignore_nulls {
939+
val.logical_nulls()
940+
} else {
941+
None
942+
};
943+
let nulls = nulls.as_ref();
944+
945+
for i in 0..val.len() {
946+
if nulls.is_some_and(|nulls| !nulls.is_valid(i)) {
947+
continue;
948+
}
949+
let key = ScalarValue::try_from_array(val, i)?.compacted();
950+
match self.values.get_mut(&key) {
951+
Some(count) => {
952+
*count -= 1;
953+
if *count == 0 {
954+
self.values.remove(&key);
955+
}
956+
}
957+
None => {
958+
return internal_err!(
959+
"DistinctArrayAggAccumulator::retract_batch: value not present in state"
960+
);
961+
}
962+
}
963+
}
964+
965+
Ok(())
966+
}
967+
968+
fn supports_retract_batch(&self) -> bool {
969+
true
970+
}
971+
919972
fn size(&self) -> usize {
920-
size_of_val(self) + ScalarValue::size_of_hashset(&self.values)
973+
size_of_val(self) + ScalarValue::size_of_hashmap(&self.values)
921974
- size_of_val(&self.values)
922975
+ self.datatype.size()
923976
- size_of_val(&self.datatype)
@@ -1494,8 +1547,8 @@ mod tests {
14941547
acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?;
14951548
acc1 = merge(acc1, acc2)?;
14961549

1497-
// without compaction, the size is 16660
1498-
assert_eq!(acc1.size(), 1660);
1550+
// without compaction, the size is 16684
1551+
assert_eq!(acc1.size(), 1684);
14991552

15001553
Ok(())
15011554
}
@@ -2415,4 +2468,126 @@ mod tests {
24152468

24162469
Ok(())
24172470
}
2471+
2472+
// ---- DistinctArrayAggAccumulator retract_batch tests ----
2473+
2474+
// Build a DISTINCT accumulator with ascending sort so evaluate output is
2475+
// deterministic regardless of HashMap iteration order.
2476+
fn distinct_acc(ignore_nulls: bool) -> Result<DistinctArrayAggAccumulator> {
2477+
DistinctArrayAggAccumulator::try_new(
2478+
&DataType::Utf8,
2479+
Some(SortOptions::default()),
2480+
ignore_nulls,
2481+
)
2482+
}
2483+
2484+
#[test]
2485+
fn distinct_retract_duplicate_remains() -> Result<()> {
2486+
// Canonical regression for the HashSet-can't-retract bug: a value
2487+
// that appears multiple times in-frame must survive retraction of
2488+
// a single occurrence.
2489+
let mut acc = distinct_acc(false)?;
2490+
2491+
// Feed [A, A, B] across two batches to exercise multi-batch state.
2492+
acc.update_batch(&[data(["A", "A"])])?;
2493+
acc.update_batch(&[data(["B"])])?;
2494+
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]);
2495+
2496+
// Retract a single A — the other A is still in the frame.
2497+
acc.retract_batch(&[data(["A"])])?;
2498+
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]);
2499+
2500+
// Retract the remaining A — only B left.
2501+
acc.retract_batch(&[data(["A"])])?;
2502+
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B"]);
2503+
2504+
Ok(())
2505+
}
2506+
2507+
#[test]
2508+
fn distinct_retract_full_removal() -> Result<()> {
2509+
let mut acc = distinct_acc(false)?;
2510+
2511+
acc.update_batch(&[data(["A", "B"])])?;
2512+
acc.retract_batch(&[data(["A", "B"])])?;
2513+
2514+
let result = acc.evaluate()?;
2515+
assert!(
2516+
matches!(&result, ScalarValue::List(arr) if arr.is_null(0)),
2517+
"expected null list after full retract, got {result:?}"
2518+
);
2519+
2520+
Ok(())
2521+
}
2522+
2523+
#[test]
2524+
fn distinct_retract_ignore_nulls_skips() -> Result<()> {
2525+
// ignore_nulls=true: NULL never enters state on update, so retract
2526+
// must also skip NULL — otherwise we'd error on the missing key.
2527+
let mut acc = distinct_acc(true)?;
2528+
2529+
acc.update_batch(&[data([Some("A"), None, Some("B")])])?;
2530+
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]);
2531+
2532+
// Retract [A, NULL] — the NULL is skipped, only A is removed.
2533+
acc.retract_batch(&[data([Some("A"), None])])?;
2534+
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B"]);
2535+
2536+
Ok(())
2537+
}
2538+
2539+
#[test]
2540+
fn distinct_retract_null_tracked() -> Result<()> {
2541+
// ignore_nulls=false: NULL enters state with a refcount and must
2542+
// retract symmetrically; the NULL key must be removed at zero
2543+
// (else evaluate still emits a NULL element).
2544+
let mut acc = distinct_acc(false)?;
2545+
2546+
acc.update_batch(&[data([Some("A"), None, None])])?;
2547+
// With nulls_first=true (SortOptions default), NULL sorts before A.
2548+
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["NULL", "A"]);
2549+
2550+
// Retract one NULL — count drops to 1, key still present.
2551+
acc.retract_batch(&[data::<Option<&str>, 1>([None])])?;
2552+
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["NULL", "A"]);
2553+
2554+
// Retract the remaining NULL — key is removed.
2555+
acc.retract_batch(&[data::<Option<&str>, 1>([None])])?;
2556+
assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A"]);
2557+
2558+
Ok(())
2559+
}
2560+
2561+
#[test]
2562+
fn distinct_supports_retract_batch() -> Result<()> {
2563+
let acc = distinct_acc(false)?;
2564+
assert!(acc.supports_retract_batch());
2565+
2566+
let acc_ignore = distinct_acc(true)?;
2567+
assert!(acc_ignore.supports_retract_batch());
2568+
2569+
Ok(())
2570+
}
2571+
2572+
#[test]
2573+
fn distinct_merge_then_evaluate_regression() -> Result<()> {
2574+
// Non-window path: state -> merge_batch -> evaluate must still
2575+
// produce the union of distinct values across partitions.
2576+
let mut acc1 = distinct_acc(false)?;
2577+
let mut acc2 = distinct_acc(false)?;
2578+
2579+
acc1.update_batch(&[data(["A", "A", "B"])])?;
2580+
acc2.update_batch(&[data(["A", "C"])])?;
2581+
2582+
let state = acc2.state()?;
2583+
let state_arrs: Vec<ArrayRef> = state
2584+
.into_iter()
2585+
.map(|sv| sv.to_array_of_size(1))
2586+
.collect::<Result<Vec<_>>>()?;
2587+
acc1.merge_batch(&state_arrs)?;
2588+
2589+
assert_eq!(print_nulls(str_arr(acc1.evaluate()?)?), vec!["A", "B", "C"]);
2590+
2591+
Ok(())
2592+
}
24182593
}

0 commit comments

Comments
 (0)