Skip to content

Commit 359d0e7

Browse files
authored
fix(functions): respect tdigest weighted state weight (#19952)
* fix(functions): respect tdigest weighted state weight * fix(query): handle tdigest edge cases * test
1 parent 97ba2be commit 359d0e7

6 files changed

Lines changed: 848 additions & 119 deletions

File tree

src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs

Lines changed: 138 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ use databend_common_expression::types::number::*;
3737
use databend_common_expression::types::*;
3838
use databend_common_expression::with_decimal_mapped_type;
3939
use databend_common_expression::with_number_mapped_type;
40-
use itertools::Itertools;
4140

4241
use super::AggrState;
4342
use super::AggrStateLoc;
@@ -55,59 +54,63 @@ pub(crate) const MEDIAN: u8 = 0;
5554
pub(crate) const QUANTILE: u8 = 1;
5655

5756
#[derive(BorshSerialize, BorshDeserialize)]
58-
pub(crate) struct QuantileTDigestState {
59-
epsilon: u32,
60-
max_centroids: usize,
57+
pub(crate) enum QuantileTDigestState {
58+
Normal(TDigestData),
59+
Nan,
60+
}
6161

62+
#[derive(BorshSerialize, BorshDeserialize)]
63+
pub(crate) struct TDigestData {
6264
total_weight: f64,
63-
weights: Vec<f64>,
64-
means: Vec<f64>,
65+
centroids: Vec<Centroid>,
6566
unmerged_total_weight: f64,
66-
unmerged_weights: Vec<f64>,
67-
unmerged_means: Vec<f64>,
67+
unmerged: Vec<Centroid>,
6868

6969
min: f64,
7070
max: f64,
7171
}
7272

73+
#[derive(Clone, Copy, BorshSerialize, BorshDeserialize)]
74+
struct Centroid {
75+
mean: f64,
76+
weight: f64,
77+
}
78+
7379
impl QuantileTDigestState {
7480
pub(crate) fn new() -> Self {
75-
Self {
76-
epsilon: 100u32,
77-
max_centroids: 2048,
78-
total_weight: 0f64,
79-
weights: vec![],
80-
means: vec![],
81-
unmerged_total_weight: 0f64,
82-
unmerged_weights: vec![],
83-
unmerged_means: vec![],
84-
min: 0f64,
85-
max: 0f64,
86-
}
81+
Self::Normal(TDigestData {
82+
total_weight: 0.0,
83+
centroids: vec![],
84+
unmerged_total_weight: 0.0,
85+
unmerged: vec![],
86+
min: f64::INFINITY,
87+
max: f64::NEG_INFINITY,
88+
})
8789
}
8890

8991
pub(crate) fn add(&mut self, other: f64, weight: Option<u64>) {
90-
if self.unmerged_weights.len() + self.weights.len() >= self.max_centroids - 1 {
91-
self.compress();
92+
let weight = weight.unwrap_or(1) as f64;
93+
if weight == 0.0 {
94+
return;
95+
}
96+
if other.is_nan() {
97+
*self = Self::Nan;
98+
return;
9299
}
93100

94-
self.unmerged_weights.push(weight.unwrap_or(1) as f64);
95-
self.unmerged_means.push(other);
96-
self.unmerged_total_weight += 1f64;
101+
let Self::Normal(state) = self else {
102+
return;
103+
};
104+
state.add_finite(other, weight);
97105
}
98106

99107
pub(crate) fn merge(&mut self, rhs: &mut Self) -> Result<()> {
100-
if rhs.len() == 0 {
101-
return Ok(());
108+
match (&mut *self, rhs) {
109+
(Self::Nan, _) | (_, Self::Nan) => {
110+
*self = Self::Nan;
111+
}
112+
(Self::Normal(state), Self::Normal(rhs)) => state.merge(rhs)?,
102113
}
103-
104-
rhs.compress();
105-
106-
self.unmerged_weights.extend_from_slice(&rhs.weights);
107-
self.unmerged_means.extend_from_slice(&rhs.means);
108-
self.unmerged_total_weight = rhs.weights.iter().sum();
109-
self.compress();
110-
111114
Ok(())
112115
}
113116

@@ -135,74 +138,112 @@ impl QuantileTDigestState {
135138
}
136139

137140
pub(crate) fn quantile(&mut self, level: f64) -> f64 {
138-
self.compress();
139-
if self.weights.is_empty() {
140-
return 0f64;
141-
} else if self.weights.len() == 1 {
142-
return self.means[0];
141+
match self {
142+
Self::Normal(state) => state.quantile(level),
143+
Self::Nan => f64::NAN,
143144
}
145+
}
146+
}
144147

145-
let mean_last = self.means.len() - 1;
146-
let weight_last = self.weights.len() - 1;
148+
impl TDigestData {
149+
const EPSILON: f64 = 100.0;
150+
const MAX_CENTROIDS: usize = 2048;
151+
152+
fn add_finite(&mut self, other: f64, weight: f64) {
153+
if self.unmerged.len() + self.centroids.len() >= Self::MAX_CENTROIDS - 1 {
154+
self.compress();
155+
}
156+
157+
self.unmerged.push(Centroid {
158+
mean: other,
159+
weight,
160+
});
161+
self.unmerged_total_weight += weight;
162+
}
163+
164+
fn merge(&mut self, rhs: &mut Self) -> Result<()> {
165+
if rhs.len() == 0 {
166+
return Ok(());
167+
}
168+
169+
rhs.compress();
170+
171+
self.unmerged.extend_from_slice(&rhs.centroids);
172+
self.unmerged_total_weight += rhs
173+
.centroids
174+
.iter()
175+
.map(|centroid| centroid.weight)
176+
.sum::<f64>();
177+
self.compress();
178+
179+
Ok(())
180+
}
181+
182+
fn quantile(&mut self, level: f64) -> f64 {
183+
self.compress();
184+
let (first, last) = match self.centroids.as_slice() {
185+
[] => return 0.0,
186+
[Centroid { mean, .. }] => return *mean,
187+
[first, .., last] => (*first, *last),
188+
};
147189

148190
let index = level * self.total_weight;
149-
if index < 1f64 {
191+
if index < 1.0 {
150192
return self.min;
151193
}
152-
if self.weights[0] > 1f64 && index < self.weights[0] / 2f64 {
153-
return self.min
154-
+ (index - 1f64) / (self.weights[0] / 2f64 - 1f64) * (self.means[0] - self.min);
194+
if first.weight > 1.0 && index < first.weight / 2.0 {
195+
return self.min + (index - 1.0) / (first.weight / 2.0 - 1.0) * (first.mean - self.min);
155196
}
156-
if index > self.total_weight - 1f64 {
197+
if index > self.total_weight - 1.0 {
157198
return self.max;
158199
}
159-
if self.weights[weight_last] > 1f64
160-
&& self.total_weight - index <= self.weights[weight_last] / 2f64
161-
{
200+
if last.weight > 1.0 && self.total_weight - index <= last.weight / 2.0 {
201+
if last.weight / 2.0 <= 1.0 {
202+
return self.max;
203+
}
162204
return self.max
163-
- (self.total_weight - index - 1f64) / (self.weights[weight_last] / 2f64 - 1f64)
164-
* (self.max - self.means[mean_last]);
205+
- (self.total_weight - index - 1.0) / (last.weight / 2.0 - 1.0)
206+
* (self.max - last.mean);
165207
}
166208

167-
let mut weight_so_far = self.weights[0] / 2f64;
168-
for i in 0..(self.weights.len() - 1) {
169-
let dw = (self.weights[i] + self.weights[i + 1]) / 2f64;
209+
let mut weight_so_far = first.weight / 2.0;
210+
for (left, right) in self
211+
.centroids
212+
.windows(2)
213+
.map(|centroids| (centroids[0], centroids[1]))
214+
{
215+
let dw = (left.weight + right.weight) / 2.0;
170216
if weight_so_far + dw > index {
171-
let mut left_unit = 0f64;
172-
if self.weights[i] == 1f64 {
217+
let mut left_unit = 0.0;
218+
if left.weight == 1.0 {
173219
if index - weight_so_far < 0.5 {
174-
return self.means[i];
220+
return left.mean;
175221
}
176222
left_unit = 0.5;
177223
}
178224

179-
let mut right_unit = 0f64;
180-
if self.weights[i + 1] == 1f64 {
225+
let mut right_unit = 0.0;
226+
if right.weight == 1.0 {
181227
if weight_so_far + dw - index <= 0.5 {
182-
return self.means[i + 1];
228+
return right.mean;
183229
}
184230
right_unit = 0.5;
185231
}
186232

187233
let z1 = index - weight_so_far - left_unit;
188234
let z2 = weight_so_far + dw - index - right_unit;
189-
return QuantileTDigestState::weighted_average(
190-
self.means[i],
191-
z2,
192-
self.means[i + 1],
193-
z1,
194-
);
235+
return Self::weighted_average(left.mean, z2, right.mean, z1);
195236
}
196237
weight_so_far += dw;
197238
}
198239

199240
debug_assert!(index <= self.total_weight);
200-
debug_assert!(index >= self.total_weight - self.weights[weight_last] / 2f64);
241+
debug_assert!(index >= self.total_weight - last.weight / 2.0);
201242

202-
let z1 = index - self.total_weight - self.weights[weight_last] / 2f64;
203-
let z2 = self.weights[weight_last] / 2f64 - z1;
243+
let z1 = index - self.total_weight - last.weight / 2.0;
244+
let z2 = last.weight / 2.0 - z1;
204245

205-
QuantileTDigestState::weighted_average(self.means[mean_last], z1, self.max, z2)
246+
Self::weighted_average(last.mean, z1, self.max, z2)
206247
}
207248

208249
fn len(&self) -> usize {
@@ -218,65 +259,49 @@ impl QuantileTDigestState {
218259
}
219260

220261
fn compress(&mut self) {
221-
if self.unmerged_total_weight > 0f64 {
222-
self.merge_centroid(self.unmerged_weights.clone(), self.unmerged_means.clone());
223-
self.unmerged_weights.clear();
224-
self.unmerged_means.clear();
225-
self.unmerged_total_weight = 0f64;
262+
if self.unmerged_total_weight > 0.0 {
263+
self.merge_centroid(self.unmerged.clone());
264+
self.unmerged.clear();
265+
self.unmerged_total_weight = 0.0;
226266
}
227267
}
228268

229-
fn merge_centroid(&mut self, incoming_weights: Vec<f64>, incoming_means: Vec<f64>) {
230-
let mut incoming_weights = incoming_weights;
231-
incoming_weights.extend_from_slice(&self.weights);
232-
let mut incoming_means = incoming_means;
233-
incoming_means.extend_from_slice(&self.means);
234-
235-
// sort (0..incoming_means.len()) by values in incoming_means.
236-
// e.g. incoming_means[5.0, 2.0, 9.1, 1.3] => [3, 1, 0, 2]
237-
let incoming_order = (0..incoming_means.len())
238-
.sorted_by(|&i, &j| incoming_means[i].partial_cmp(&incoming_means[j]).unwrap())
239-
.collect::<Vec<_>>();
269+
fn merge_centroid(&mut self, mut incoming: Vec<Centroid>) {
270+
incoming.extend_from_slice(&self.centroids);
271+
incoming.sort_by(|a, b| a.mean.total_cmp(&b.mean));
240272

241273
self.total_weight += self.unmerged_total_weight;
242274

243-
let normalizer = self.epsilon as f64 / (PI * self.total_weight);
244-
245-
let mut weights = vec![];
246-
let mut means = vec![];
247-
248-
weights.push(incoming_weights[incoming_order[0]]);
249-
means.push(incoming_means[incoming_order[0]]);
275+
let normalizer = Self::EPSILON / (PI * self.total_weight);
250276

251-
let mut weight_so_far = 0f64;
277+
let mut incoming = incoming.into_iter();
278+
let mut current = incoming.next().unwrap();
279+
let first_mean = current.mean;
280+
let mut centroids = vec![];
281+
let mut weight_so_far = 0.0;
252282

253-
for idx in &incoming_order[1..] {
254-
let idx = *idx;
255-
let proposed_weight = weights[weights.len() - 1] + incoming_weights[idx];
283+
for centroid in incoming {
284+
let proposed_weight = current.weight + centroid.weight;
256285
let z = normalizer * proposed_weight;
257286
let q0 = weight_so_far / self.total_weight;
258287
let q2 = (weight_so_far + proposed_weight) / self.total_weight;
259-
let weight_last = weights.len() - 1;
260-
let mean_last = means.len() - 1;
261-
if z * z <= q0 * (1f64 - q0) && z * z <= q2 * (1f64 - q2) {
262-
weights[weight_last] += incoming_weights[idx];
263-
means[mean_last] = means[mean_last]
264-
+ (incoming_means[idx] - means[mean_last]) * incoming_weights[idx]
265-
/ weights[weight_last];
288+
if z * z <= q0 * (1.0 - q0) && z * z <= q2 * (1.0 - q2) {
289+
current.weight = proposed_weight;
290+
current.mean += (centroid.mean - current.mean) * centroid.weight / current.weight;
266291
} else {
267-
weight_so_far += weights[weight_last];
268-
weights.push(incoming_weights[idx]);
269-
means.push(incoming_means[idx]);
292+
weight_so_far += current.weight;
293+
centroids.push(current);
294+
current = centroid;
270295
}
271296
}
272297

273-
if self.total_weight > 0f64 {
274-
self.min = f64::min(self.min, means[0]);
275-
self.max = f64::max(self.max, means[means.len() - 1]);
298+
if self.total_weight > 0.0 {
299+
self.min = f64::min(self.min, first_mean);
300+
self.max = f64::max(self.max, current.mean);
276301
}
277302

278-
self.weights = weights;
279-
self.means = means;
303+
centroids.push(current);
304+
self.centroids = centroids;
280305
}
281306
}
282307

src/query/functions/tests/it/aggregates/agg.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,7 @@ fn test_aggr_functions() {
108108
test_agg_quantile_disc(file, eval_aggr);
109109
test_agg_quantile_cont(file, eval_aggr);
110110
test_agg_quantile_tdigest(file, eval_aggr);
111-
// FIXME
112-
test_agg_quantile_tdigest_weighted(file, |name, params, columns, rows, _sort_descs| {
113-
let block_entries = columns.to_vec();
114-
eval_aggr_for_test(name, params, &block_entries, rows, false, false, vec![])
115-
});
111+
test_agg_quantile_tdigest_weighted(file, eval_aggr);
116112
test_agg_median(file, eval_aggr);
117113
test_agg_median_tdigest(file, eval_aggr);
118114
test_agg_array_agg(file, eval_aggr);

src/query/functions/tests/it/aggregates/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
mod agg;
1818
mod agg_hashtable;
19+
mod quantile_tdigest;
1920

2021
use std::io::Write;
2122

0 commit comments

Comments
 (0)