Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/query/expression/src/aggregate/aggregate_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ impl AggregateHashTable {
Self {
direct_append: false,
current_radix_bits: config.initial_radix_bits,
payload: PartitionedPayload::new(
payload: PartitionedPayload::new_with_start_bit(
group_types,
aggrs,
1 << config.initial_radix_bits,
config.partition_start_bit,
vec![arena],
),
hash_index: HashIndex::new(&config, capacity),
Expand Down Expand Up @@ -105,10 +106,11 @@ impl AggregateHashTable {
Self {
direct_append: !need_init_entry,
current_radix_bits: config.initial_radix_bits,
payload: PartitionedPayload::new(
payload: PartitionedPayload::new_with_start_bit(
group_types,
aggrs,
1 << config.initial_radix_bits,
config.partition_start_bit,
vec![arena],
),
hash_index,
Expand Down
7 changes: 7 additions & 0 deletions src/query/expression/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ pub struct HashTableConfig {
// Max radix bits across all threads, this is a hint to repartition
pub current_max_radix_bits: Arc<AtomicU64>,
pub initial_radix_bits: u64,
pub partition_start_bit: u64,
pub max_radix_bits: u64,
pub repartition_radix_bits_incr: u64,
pub block_fill_factor: f64,
Expand All @@ -167,6 +168,7 @@ impl Default for HashTableConfig {
Self {
current_max_radix_bits: Arc::new(AtomicU64::new(3)),
initial_radix_bits: 3,
partition_start_bit: 0,
max_radix_bits: MAX_RADIX_BITS,
repartition_radix_bits_incr: 2,
block_fill_factor: 1.8,
Expand Down Expand Up @@ -211,6 +213,11 @@ impl HashTableConfig {
self
}

pub fn with_partition_start_bit(mut self, partition_start_bit: u64) -> Self {
self.partition_start_bit = partition_start_bit;
self
}

pub fn with_experiment_hash_index(mut self, enable: bool) -> Self {
self.enable_experiment_hash_index = enable;
self
Expand Down
51 changes: 46 additions & 5 deletions src/query/expression/src/aggregate/partitioned_payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,15 @@ struct PartitionMask {

impl PartitionMask {
fn new(partition_count: u64) -> Self {
Self::with_start_bit(partition_count, 0)
}

fn with_start_bit(partition_count: u64, start_bit: u64) -> Self {
let radix_bits = partition_count.trailing_zeros() as u64;
debug_assert_eq!(1 << radix_bits, partition_count);
debug_assert!(start_bit + radix_bits <= 48);

let shift = 48 - radix_bits;
let shift = 48 - start_bit - radix_bits;
let mask = ((1 << radix_bits) - 1) << shift;

Self { mask, shift }
Expand All @@ -59,6 +64,7 @@ pub struct PartitionedPayload {

pub arenas: Vec<Arc<Bump>>,

partition_start_bit: u64,
partition_mask: PartitionMask,
}

Expand All @@ -71,6 +77,16 @@ impl PartitionedPayload {
aggrs: Vec<AggregateFunctionRef>,
partition_count: u64,
arenas: Vec<Arc<Bump>>,
) -> Self {
Self::new_with_start_bit(group_types, aggrs, partition_count, 0, arenas)
}

pub fn new_with_start_bit(
group_types: Vec<DataType>,
aggrs: Vec<AggregateFunctionRef>,
partition_count: u64,
partition_start_bit: u64,
arenas: Vec<Arc<Bump>>,
) -> Self {
let states_layout = if !aggrs.is_empty() {
Some(get_states_layout(&aggrs).unwrap())
Expand Down Expand Up @@ -101,7 +117,8 @@ impl PartitionedPayload {
row_layout,

arenas,
partition_mask: PartitionMask::new(partition_count),
partition_start_bit,
partition_mask: PartitionMask::with_start_bit(partition_count, partition_start_bit),
}
}

Expand Down Expand Up @@ -169,11 +186,17 @@ impl PartitionedPayload {
group_types,
aggrs,
arenas,
partition_start_bit,
..
} = self;

let mut new_partition_payload =
PartitionedPayload::new(group_types, aggrs, new_partition_count as u64, arenas);
let mut new_partition_payload = PartitionedPayload::new_with_start_bit(
group_types,
aggrs,
new_partition_count as u64,
partition_start_bit,
arenas,
);

state.clear();
for payload in payloads.into_iter() {
Expand All @@ -184,7 +207,9 @@ impl PartitionedPayload {
}

pub fn combine(&mut self, other: PartitionedPayload, state: &mut PayloadFlushState) {
if other.partition_count() == self.partition_count() {
if other.partition_count() == self.partition_count()
&& other.partition_start_bit == self.partition_start_bit
{
for (l, r) in self.payloads.iter_mut().zip(other.payloads.into_iter()) {
l.combine(r);
}
Expand Down Expand Up @@ -293,3 +318,19 @@ impl PartitionedPayload {
self.payloads.iter().map(|x| x.memory_size()).sum()
}
}

#[cfg(test)]
mod tests {
use super::PartitionMask;

#[test]
fn test_partition_mask_with_start_bit() {
let top_bit_mask = PartitionMask::new(2);
assert_eq!(top_bit_mask.index(1_u64 << 47), 1);
assert_eq!(top_bit_mask.index(1_u64 << 44), 0);

let shifted_mask = PartitionMask::with_start_bit(2, 3);
assert_eq!(shifted_mask.index(1_u64 << 47), 0);
assert_eq!(shifted_mask.index(1_u64 << 44), 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::sync::Arc;
use std::sync::atomic::AtomicU64;

use databend_common_catalog::table_context::TableContext;
use databend_common_exception::Result;
Expand Down Expand Up @@ -49,6 +50,7 @@ fn build_partition_bucket_experimental(
shuffle_mode: AggregateShuffleMode,
) -> Result<()> {
let mut final_parallelism = ctx.get_settings().get_max_threads()? as usize;
let base_consumed_bits = shuffle_mode.determine_radix_bits();
match shuffle_mode {
AggregateShuffleMode::Row => {
let schema = params.spill_schema();
Expand Down Expand Up @@ -107,6 +109,7 @@ fn build_partition_bucket_experimental(

let mut builder = TransformPipeBuilder::create();
let (tx, rx) = async_channel::unbounded();
let next_task_id = Arc::new(AtomicU64::new(1));
for id in 0..final_parallelism {
let input_port = InputPort::create();
let output_port = OutputPort::create();
Expand All @@ -115,9 +118,11 @@ fn build_partition_bucket_experimental(
output_port.clone(),
params.clone(),
id,
base_consumed_bits,
ctx.clone(),
tx.clone(),
rx.clone(),
next_task_id.clone(),
)?;
builder.add_transform(input_port, output_port, ProcessorPtr::create(processor));
}
Expand Down
Loading
Loading