Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 138 additions & 113 deletions src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ use databend_common_expression::types::number::*;
use databend_common_expression::types::*;
use databend_common_expression::with_decimal_mapped_type;
use databend_common_expression::with_number_mapped_type;
use itertools::Itertools;

use super::AggrState;
use super::AggrStateLoc;
Expand All @@ -55,59 +54,63 @@ pub(crate) const MEDIAN: u8 = 0;
pub(crate) const QUANTILE: u8 = 1;

#[derive(BorshSerialize, BorshDeserialize)]
pub(crate) struct QuantileTDigestState {
epsilon: u32,
max_centroids: usize,
pub(crate) enum QuantileTDigestState {
Normal(TDigestData),
Nan,
}
Comment thread
forsaken628 marked this conversation as resolved.

#[derive(BorshSerialize, BorshDeserialize)]
pub(crate) struct TDigestData {
total_weight: f64,
weights: Vec<f64>,
means: Vec<f64>,
centroids: Vec<Centroid>,
unmerged_total_weight: f64,
unmerged_weights: Vec<f64>,
unmerged_means: Vec<f64>,
unmerged: Vec<Centroid>,

min: f64,
max: f64,
}

#[derive(Clone, Copy, BorshSerialize, BorshDeserialize)]
struct Centroid {
mean: f64,
weight: f64,
}

impl QuantileTDigestState {
pub(crate) fn new() -> Self {
Self {
epsilon: 100u32,
max_centroids: 2048,
total_weight: 0f64,
weights: vec![],
means: vec![],
unmerged_total_weight: 0f64,
unmerged_weights: vec![],
unmerged_means: vec![],
min: 0f64,
max: 0f64,
}
Self::Normal(TDigestData {
total_weight: 0.0,
centroids: vec![],
unmerged_total_weight: 0.0,
unmerged: vec![],
min: f64::INFINITY,
max: f64::NEG_INFINITY,
})
}

pub(crate) fn add(&mut self, other: f64, weight: Option<u64>) {
if self.unmerged_weights.len() + self.weights.len() >= self.max_centroids - 1 {
self.compress();
let weight = weight.unwrap_or(1) as f64;
if weight == 0.0 {
return;
}
if other.is_nan() {
*self = Self::Nan;
return;
}

self.unmerged_weights.push(weight.unwrap_or(1) as f64);
self.unmerged_means.push(other);
self.unmerged_total_weight += 1f64;
let Self::Normal(state) = self else {
return;
};
state.add_finite(other, weight);
}

pub(crate) fn merge(&mut self, rhs: &mut Self) -> Result<()> {
if rhs.len() == 0 {
return Ok(());
match (&mut *self, rhs) {
(Self::Nan, _) | (_, Self::Nan) => {
*self = Self::Nan;
}
(Self::Normal(state), Self::Normal(rhs)) => state.merge(rhs)?,
}

rhs.compress();

self.unmerged_weights.extend_from_slice(&rhs.weights);
self.unmerged_means.extend_from_slice(&rhs.means);
self.unmerged_total_weight = rhs.weights.iter().sum();
self.compress();

Ok(())
}

Expand Down Expand Up @@ -135,74 +138,112 @@ impl QuantileTDigestState {
}

pub(crate) fn quantile(&mut self, level: f64) -> f64 {
self.compress();
if self.weights.is_empty() {
return 0f64;
} else if self.weights.len() == 1 {
return self.means[0];
match self {
Self::Normal(state) => state.quantile(level),
Self::Nan => f64::NAN,
}
}
}

let mean_last = self.means.len() - 1;
let weight_last = self.weights.len() - 1;
impl TDigestData {
const EPSILON: f64 = 100.0;
const MAX_CENTROIDS: usize = 2048;

fn add_finite(&mut self, other: f64, weight: f64) {
if self.unmerged.len() + self.centroids.len() >= Self::MAX_CENTROIDS - 1 {
self.compress();
}

self.unmerged.push(Centroid {
mean: other,
weight,
});
self.unmerged_total_weight += weight;
}

fn merge(&mut self, rhs: &mut Self) -> Result<()> {
if rhs.len() == 0 {
return Ok(());
}

rhs.compress();

self.unmerged.extend_from_slice(&rhs.centroids);
self.unmerged_total_weight += rhs
.centroids
.iter()
.map(|centroid| centroid.weight)
.sum::<f64>();
self.compress();

Ok(())
}

fn quantile(&mut self, level: f64) -> f64 {
self.compress();
let (first, last) = match self.centroids.as_slice() {
[] => return 0.0,
[Centroid { mean, .. }] => return *mean,
[first, .., last] => (*first, *last),
};

let index = level * self.total_weight;
if index < 1f64 {
if index < 1.0 {
return self.min;
}
if self.weights[0] > 1f64 && index < self.weights[0] / 2f64 {
return self.min
+ (index - 1f64) / (self.weights[0] / 2f64 - 1f64) * (self.means[0] - self.min);
if first.weight > 1.0 && index < first.weight / 2.0 {
return self.min + (index - 1.0) / (first.weight / 2.0 - 1.0) * (first.mean - self.min);
}
if index > self.total_weight - 1f64 {
if index > self.total_weight - 1.0 {
return self.max;
}
if self.weights[weight_last] > 1f64
&& self.total_weight - index <= self.weights[weight_last] / 2f64
{
if last.weight > 1.0 && self.total_weight - index <= last.weight / 2.0 {
if last.weight / 2.0 <= 1.0 {
return self.max;
}
return self.max
- (self.total_weight - index - 1f64) / (self.weights[weight_last] / 2f64 - 1f64)
* (self.max - self.means[mean_last]);
- (self.total_weight - index - 1.0) / (last.weight / 2.0 - 1.0)
* (self.max - last.mean);
}

let mut weight_so_far = self.weights[0] / 2f64;
for i in 0..(self.weights.len() - 1) {
let dw = (self.weights[i] + self.weights[i + 1]) / 2f64;
let mut weight_so_far = first.weight / 2.0;
for (left, right) in self
.centroids
.windows(2)
.map(|centroids| (centroids[0], centroids[1]))
{
let dw = (left.weight + right.weight) / 2.0;
if weight_so_far + dw > index {
let mut left_unit = 0f64;
if self.weights[i] == 1f64 {
let mut left_unit = 0.0;
if left.weight == 1.0 {
if index - weight_so_far < 0.5 {
return self.means[i];
return left.mean;
}
left_unit = 0.5;
}

let mut right_unit = 0f64;
if self.weights[i + 1] == 1f64 {
let mut right_unit = 0.0;
if right.weight == 1.0 {
if weight_so_far + dw - index <= 0.5 {
return self.means[i + 1];
return right.mean;
}
right_unit = 0.5;
}

let z1 = index - weight_so_far - left_unit;
let z2 = weight_so_far + dw - index - right_unit;
return QuantileTDigestState::weighted_average(
self.means[i],
z2,
self.means[i + 1],
z1,
);
return Self::weighted_average(left.mean, z2, right.mean, z1);
}
weight_so_far += dw;
}

debug_assert!(index <= self.total_weight);
debug_assert!(index >= self.total_weight - self.weights[weight_last] / 2f64);
debug_assert!(index >= self.total_weight - last.weight / 2.0);

let z1 = index - self.total_weight - self.weights[weight_last] / 2f64;
let z2 = self.weights[weight_last] / 2f64 - z1;
let z1 = index - self.total_weight - last.weight / 2.0;
let z2 = last.weight / 2.0 - z1;

QuantileTDigestState::weighted_average(self.means[mean_last], z1, self.max, z2)
Self::weighted_average(last.mean, z1, self.max, z2)
}

fn len(&self) -> usize {
Expand All @@ -218,65 +259,49 @@ impl QuantileTDigestState {
}

fn compress(&mut self) {
if self.unmerged_total_weight > 0f64 {
self.merge_centroid(self.unmerged_weights.clone(), self.unmerged_means.clone());
self.unmerged_weights.clear();
self.unmerged_means.clear();
self.unmerged_total_weight = 0f64;
if self.unmerged_total_weight > 0.0 {
self.merge_centroid(self.unmerged.clone());
self.unmerged.clear();
self.unmerged_total_weight = 0.0;
}
}

fn merge_centroid(&mut self, incoming_weights: Vec<f64>, incoming_means: Vec<f64>) {
let mut incoming_weights = incoming_weights;
incoming_weights.extend_from_slice(&self.weights);
let mut incoming_means = incoming_means;
incoming_means.extend_from_slice(&self.means);

// sort (0..incoming_means.len()) by values in incoming_means.
// e.g. incoming_means[5.0, 2.0, 9.1, 1.3] => [3, 1, 0, 2]
let incoming_order = (0..incoming_means.len())
.sorted_by(|&i, &j| incoming_means[i].partial_cmp(&incoming_means[j]).unwrap())
.collect::<Vec<_>>();
fn merge_centroid(&mut self, mut incoming: Vec<Centroid>) {
incoming.extend_from_slice(&self.centroids);
incoming.sort_by(|a, b| a.mean.total_cmp(&b.mean));

self.total_weight += self.unmerged_total_weight;

let normalizer = self.epsilon as f64 / (PI * self.total_weight);

let mut weights = vec![];
let mut means = vec![];

weights.push(incoming_weights[incoming_order[0]]);
means.push(incoming_means[incoming_order[0]]);
let normalizer = Self::EPSILON / (PI * self.total_weight);

let mut weight_so_far = 0f64;
let mut incoming = incoming.into_iter();
let mut current = incoming.next().unwrap();
let first_mean = current.mean;
let mut centroids = vec![];
let mut weight_so_far = 0.0;

for idx in &incoming_order[1..] {
let idx = *idx;
let proposed_weight = weights[weights.len() - 1] + incoming_weights[idx];
for centroid in incoming {
let proposed_weight = current.weight + centroid.weight;
let z = normalizer * proposed_weight;
let q0 = weight_so_far / self.total_weight;
let q2 = (weight_so_far + proposed_weight) / self.total_weight;
let weight_last = weights.len() - 1;
let mean_last = means.len() - 1;
if z * z <= q0 * (1f64 - q0) && z * z <= q2 * (1f64 - q2) {
weights[weight_last] += incoming_weights[idx];
means[mean_last] = means[mean_last]
+ (incoming_means[idx] - means[mean_last]) * incoming_weights[idx]
/ weights[weight_last];
if z * z <= q0 * (1.0 - q0) && z * z <= q2 * (1.0 - q2) {
current.weight = proposed_weight;
current.mean += (centroid.mean - current.mean) * centroid.weight / current.weight;
} else {
weight_so_far += weights[weight_last];
weights.push(incoming_weights[idx]);
means.push(incoming_means[idx]);
weight_so_far += current.weight;
centroids.push(current);
current = centroid;
}
}

if self.total_weight > 0f64 {
self.min = f64::min(self.min, means[0]);
self.max = f64::max(self.max, means[means.len() - 1]);
if self.total_weight > 0.0 {
self.min = f64::min(self.min, first_mean);
self.max = f64::max(self.max, current.mean);
}

self.weights = weights;
self.means = means;
centroids.push(current);
self.centroids = centroids;
}
}

Expand Down
6 changes: 1 addition & 5 deletions src/query/functions/tests/it/aggregates/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,7 @@ fn test_aggr_functions() {
test_agg_quantile_disc(file, eval_aggr);
test_agg_quantile_cont(file, eval_aggr);
test_agg_quantile_tdigest(file, eval_aggr);
// FIXME
test_agg_quantile_tdigest_weighted(file, |name, params, columns, rows, _sort_descs| {
let block_entries = columns.to_vec();
eval_aggr_for_test(name, params, &block_entries, rows, false, false, vec![])
});
test_agg_quantile_tdigest_weighted(file, eval_aggr);
test_agg_median(file, eval_aggr);
test_agg_median_tdigest(file, eval_aggr);
test_agg_array_agg(file, eval_aggr);
Expand Down
1 change: 1 addition & 0 deletions src/query/functions/tests/it/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

mod agg;
mod agg_hashtable;
mod quantile_tdigest;

use std::io::Write;

Expand Down
Loading
Loading