Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

mod bytes;
mod dict;
mod groups;
mod native;

pub use bytes::BytesDistinctCountAccumulator;
pub use bytes::BytesViewDistinctCountAccumulator;
pub use dict::DictionaryCountAccumulator;
pub use groups::PrimitiveDistinctCountGroupsAccumulator;
pub use native::Bitmap65536DistinctCountAccumulator;
pub use native::Bitmap65536DistinctCountAccumulatorI16;
pub use native::BoolArray256DistinctCountAccumulator;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{
ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, PrimitiveArray,
};
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
use arrow::datatypes::{ArrowPrimitiveType, Field};
use datafusion_common::HashSet;
use datafusion_common::hash_utils::RandomState;
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
use std::hash::Hash;
use std::mem::size_of;
use std::sync::Arc;

use crate::aggregate::groups_accumulator::accumulate::accumulate;

pub struct PrimitiveDistinctCountGroupsAccumulator<T: ArrowPrimitiveType>
where
T::Native: Eq + Hash,
{
seen: HashSet<(usize, T::Native), RandomState>,
counts: Vec<i64>,
}

impl<T: ArrowPrimitiveType> PrimitiveDistinctCountGroupsAccumulator<T>
where
T::Native: Eq + Hash,
{
pub fn new() -> Self {
Self {
seen: HashSet::default(),
counts: Vec::new(),
}
}
}

impl<T: ArrowPrimitiveType> Default for PrimitiveDistinctCountGroupsAccumulator<T>
where
T::Native: Eq + Hash,
{
fn default() -> Self {
Self::new()
}
}

impl<T: ArrowPrimitiveType + Send + std::fmt::Debug> GroupsAccumulator
for PrimitiveDistinctCountGroupsAccumulator<T>
where
T::Native: Eq + Hash,
{
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> datafusion_common::Result<()> {
debug_assert_eq!(values.len(), 1);
self.counts.resize(total_num_groups, 0);
let arr = values[0].as_primitive::<T>();
accumulate(group_indices, arr, opt_filter, |group_idx, value| {
if self.seen.insert((group_idx, value)) {
self.counts[group_idx] += 1;
}
});
Ok(())
}

fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
let counts = emit_to.take_needed(&mut self.counts);

match emit_to {
EmitTo::All => {
self.seen.clear();
}
EmitTo::First(n) => {
let mut remaining = HashSet::default();
for (group_idx, value) in self.seen.drain() {
if group_idx >= n {
remaining.insert((group_idx - n, value));
}
}
self.seen = remaining;
}
}

Ok(Arc::new(Int64Array::from(counts)))
}

fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
let num_emitted = match emit_to {
EmitTo::All => self.counts.len(),
EmitTo::First(n) => n,
};

// Prefix-sum counts[..num_emitted] into offsets
let mut offsets = Vec::with_capacity(num_emitted + 1);
offsets.push(0i32);
let mut total = 0i32;
for &c in &self.counts[..num_emitted] {
total += c as i32;
offsets.push(total);
}

let mut all_values = vec![T::Native::default(); total as usize];
let mut cursors: Vec<i32> = offsets[..num_emitted].to_vec();

if matches!(emit_to, EmitTo::All) {
for (group_idx, value) in self.seen.drain() {
let pos = cursors[group_idx] as usize;
all_values[pos] = value;
cursors[group_idx] += 1;
}
self.counts.clear();
} else {
let mut remaining = HashSet::default();
for (group_idx, value) in self.seen.drain() {
if group_idx < num_emitted {
let pos = cursors[group_idx] as usize;
all_values[pos] = value;
cursors[group_idx] += 1;
} else {
remaining.insert((group_idx - num_emitted, value));
}
}
self.seen = remaining;
let _ = emit_to.take_needed(&mut self.counts);
}

let values_array = Arc::new(PrimitiveArray::<T>::new(
ScalarBuffer::from(all_values),
None,
));
let list_array = ListArray::new(
Arc::new(Field::new_list_field(T::DATA_TYPE, true)),
OffsetBuffer::new(offsets.into()),
values_array,
None,
);

Ok(vec![Arc::new(list_array)])
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
_opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> datafusion_common::Result<()> {
debug_assert_eq!(values.len(), 1);
self.counts.resize(total_num_groups, 0);
let list_array = values[0].as_list::<i32>();
let inner = list_array.values().as_primitive::<T>();
let inner_values = inner.values();
let offsets = list_array.offsets();

for (row_idx, &group_idx) in group_indices.iter().enumerate() {
let start = offsets[row_idx] as usize;
let end = offsets[row_idx + 1] as usize;
for &value in &inner_values[start..end] {
if self.seen.insert((group_idx, value)) {
self.counts[group_idx] += 1;
}
}
}

Ok(())
}

fn size(&self) -> usize {
size_of::<Self>()
+ self.seen.capacity() * (size_of::<(usize, T::Native)>() + size_of::<u64>())
+ self.counts.capacity() * size_of::<i64>()
}
}
71 changes: 61 additions & 10 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use arrow::{
compute,
datatypes::{
DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field,
FieldRef, Float16Type, Float32Type, Float64Type, Int32Type, Int64Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type,
Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
UInt32Type, UInt64Type,
UInt8Type, UInt16Type, UInt32Type, UInt64Type,
},
};
use datafusion_common::hash_utils::RandomState;
Expand All @@ -41,6 +41,7 @@ use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
utils::format_state_name,
};
use datafusion_functions_aggregate_common::aggregate::count_distinct::PrimitiveDistinctCountGroupsAccumulator;
use datafusion_functions_aggregate_common::aggregate::{
count_distinct::Bitmap65536DistinctCountAccumulator,
count_distinct::Bitmap65536DistinctCountAccumulatorI16,
Expand Down Expand Up @@ -344,20 +345,33 @@ impl AggregateUDFImpl for Count {
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
// groups accumulator only supports `COUNT(c1)`, not
// `COUNT(c1, c2)`, etc
if args.is_distinct {
if args.exprs.len() != 1 {
return false;
}
args.exprs.len() == 1
if !args.is_distinct {
return true;
}
matches!(
args.expr_fields[0].data_type(),
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
)
}

fn create_groups_accumulator(
&self,
_args: AccumulatorArgs,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
// instantiate specialized accumulator
Ok(Box::new(CountGroupsAccumulator::new()))
if !args.is_distinct {
return Ok(Box::new(CountGroupsAccumulator::new()));
}
create_distinct_count_groups_accumulator(&args)
}

fn reverse_expr(&self) -> ReversedUDAF {
Expand Down Expand Up @@ -430,6 +444,43 @@ impl AggregateUDFImpl for Count {
}
}

#[cold]
fn create_distinct_count_groups_accumulator(
args: &AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
let data_type = args.expr_fields[0].data_type();
match data_type {
DataType::Int8 => Ok(Box::new(
PrimitiveDistinctCountGroupsAccumulator::<Int8Type>::new(),
)),
DataType::Int16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
Int16Type,
>::new())),
DataType::Int32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
Int32Type,
>::new())),
DataType::Int64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
Int64Type,
>::new())),
DataType::UInt8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
UInt8Type,
>::new())),
DataType::UInt16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
UInt16Type,
>::new())),
DataType::UInt32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
UInt32Type,
>::new())),
DataType::UInt64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::<
UInt64Type,
>::new())),
_ => not_impl_err!(
"GroupsAccumulator not supported for COUNT(DISTINCT) with {}",
data_type
),
}
}

// DistinctCountAccumulator does not support retract_batch and sliding window
// this is a specialized accumulator for distinct count that supports retract_batch
// and sliding window.
Expand Down
Loading