Skip to content

Commit fdf19ab

Browse files
committed
fix(query): handle tdigest edge cases
1 parent 39b244c commit fdf19ab

4 files changed

Lines changed: 688 additions & 114 deletions

File tree

src/query/functions/src/aggregates/aggregate_quantile_tdigest.rs

Lines changed: 138 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ use databend_common_expression::types::number::*;
3737
use databend_common_expression::types::*;
3838
use databend_common_expression::with_decimal_mapped_type;
3939
use databend_common_expression::with_number_mapped_type;
40-
use itertools::Itertools;
4140

4241
use super::AggrState;
4342
use super::AggrStateLoc;
@@ -55,60 +54,63 @@ pub(crate) const MEDIAN: u8 = 0;
5554
pub(crate) const QUANTILE: u8 = 1;
5655

5756
#[derive(BorshSerialize, BorshDeserialize)]
58-
pub(crate) struct QuantileTDigestState {
59-
epsilon: u32,
60-
max_centroids: usize,
57+
pub(crate) enum QuantileTDigestState {
58+
Normal(TDigestData),
59+
Nan,
60+
}
6161

62+
#[derive(BorshSerialize, BorshDeserialize)]
63+
pub(crate) struct TDigestData {
6264
total_weight: f64,
63-
weights: Vec<f64>,
64-
means: Vec<f64>,
65+
centroids: Vec<Centroid>,
6566
unmerged_total_weight: f64,
66-
unmerged_weights: Vec<f64>,
67-
unmerged_means: Vec<f64>,
67+
unmerged: Vec<Centroid>,
6868

6969
min: f64,
7070
max: f64,
7171
}
7272

73+
#[derive(Clone, Copy, BorshSerialize, BorshDeserialize)]
74+
struct Centroid {
75+
mean: f64,
76+
weight: f64,
77+
}
78+
7379
impl QuantileTDigestState {
7480
pub(crate) fn new() -> Self {
75-
Self {
76-
epsilon: 100u32,
77-
max_centroids: 2048,
78-
total_weight: 0f64,
79-
weights: vec![],
80-
means: vec![],
81-
unmerged_total_weight: 0f64,
82-
unmerged_weights: vec![],
83-
unmerged_means: vec![],
84-
min: 0f64,
85-
max: 0f64,
86-
}
81+
Self::Normal(TDigestData {
82+
total_weight: 0.0,
83+
centroids: vec![],
84+
unmerged_total_weight: 0.0,
85+
unmerged: vec![],
86+
min: f64::INFINITY,
87+
max: f64::NEG_INFINITY,
88+
})
8789
}
8890

8991
pub(crate) fn add(&mut self, other: f64, weight: Option<u64>) {
90-
if self.unmerged_weights.len() + self.weights.len() >= self.max_centroids - 1 {
91-
self.compress();
92+
let weight = weight.unwrap_or(1) as f64;
93+
if weight == 0.0 {
94+
return;
95+
}
96+
if other.is_nan() {
97+
*self = Self::Nan;
98+
return;
9299
}
93100

94-
let weight = weight.unwrap_or(1) as f64;
95-
self.unmerged_weights.push(weight);
96-
self.unmerged_means.push(other);
97-
self.unmerged_total_weight += weight;
101+
let Self::Normal(state) = self else {
102+
return;
103+
};
104+
state.add_finite(other, weight);
98105
}
99106

100107
pub(crate) fn merge(&mut self, rhs: &mut Self) -> Result<()> {
101-
if rhs.len() == 0 {
102-
return Ok(());
108+
match (&mut *self, rhs) {
109+
(Self::Nan, _) | (_, Self::Nan) => {
110+
*self = Self::Nan;
111+
}
112+
(Self::Normal(state), Self::Normal(rhs)) => state.merge(rhs)?,
103113
}
104-
105-
rhs.compress();
106-
107-
self.unmerged_weights.extend_from_slice(&rhs.weights);
108-
self.unmerged_means.extend_from_slice(&rhs.means);
109-
self.unmerged_total_weight = rhs.weights.iter().sum();
110-
self.compress();
111-
112114
Ok(())
113115
}
114116

@@ -136,74 +138,112 @@ impl QuantileTDigestState {
136138
}
137139

138140
pub(crate) fn quantile(&mut self, level: f64) -> f64 {
139-
self.compress();
140-
if self.weights.is_empty() {
141-
return 0f64;
142-
} else if self.weights.len() == 1 {
143-
return self.means[0];
141+
match self {
142+
Self::Normal(state) => state.quantile(level),
143+
Self::Nan => f64::NAN,
144144
}
145+
}
146+
}
145147

146-
let mean_last = self.means.len() - 1;
147-
let weight_last = self.weights.len() - 1;
148+
impl TDigestData {
149+
const EPSILON: f64 = 100.0;
150+
const MAX_CENTROIDS: usize = 2048;
151+
152+
fn add_finite(&mut self, other: f64, weight: f64) {
153+
if self.unmerged.len() + self.centroids.len() >= Self::MAX_CENTROIDS - 1 {
154+
self.compress();
155+
}
156+
157+
self.unmerged.push(Centroid {
158+
mean: other,
159+
weight,
160+
});
161+
self.unmerged_total_weight += weight;
162+
}
163+
164+
fn merge(&mut self, rhs: &mut Self) -> Result<()> {
165+
if rhs.len() == 0 {
166+
return Ok(());
167+
}
168+
169+
rhs.compress();
170+
171+
self.unmerged.extend_from_slice(&rhs.centroids);
172+
self.unmerged_total_weight += rhs
173+
.centroids
174+
.iter()
175+
.map(|centroid| centroid.weight)
176+
.sum::<f64>();
177+
self.compress();
178+
179+
Ok(())
180+
}
181+
182+
fn quantile(&mut self, level: f64) -> f64 {
183+
self.compress();
184+
let (first, last) = match self.centroids.as_slice() {
185+
[] => return 0.0,
186+
[Centroid { mean, .. }] => return *mean,
187+
[first, .., last] => (*first, *last),
188+
};
148189

149190
let index = level * self.total_weight;
150-
if index < 1f64 {
191+
if index < 1.0 {
151192
return self.min;
152193
}
153-
if self.weights[0] > 1f64 && index < self.weights[0] / 2f64 {
154-
return self.min
155-
+ (index - 1f64) / (self.weights[0] / 2f64 - 1f64) * (self.means[0] - self.min);
194+
if first.weight > 1.0 && index < first.weight / 2.0 {
195+
return self.min + (index - 1.0) / (first.weight / 2.0 - 1.0) * (first.mean - self.min);
156196
}
157-
if index > self.total_weight - 1f64 {
197+
if index > self.total_weight - 1.0 {
158198
return self.max;
159199
}
160-
if self.weights[weight_last] > 1f64
161-
&& self.total_weight - index <= self.weights[weight_last] / 2f64
162-
{
200+
if last.weight > 1.0 && self.total_weight - index <= last.weight / 2.0 {
201+
if last.weight / 2.0 <= 1.0 {
202+
return self.max;
203+
}
163204
return self.max
164-
- (self.total_weight - index - 1f64) / (self.weights[weight_last] / 2f64 - 1f64)
165-
* (self.max - self.means[mean_last]);
205+
- (self.total_weight - index - 1.0) / (last.weight / 2.0 - 1.0)
206+
* (self.max - last.mean);
166207
}
167208

168-
let mut weight_so_far = self.weights[0] / 2f64;
169-
for i in 0..(self.weights.len() - 1) {
170-
let dw = (self.weights[i] + self.weights[i + 1]) / 2f64;
209+
let mut weight_so_far = first.weight / 2.0;
210+
for (left, right) in self
211+
.centroids
212+
.windows(2)
213+
.map(|centroids| (centroids[0], centroids[1]))
214+
{
215+
let dw = (left.weight + right.weight) / 2.0;
171216
if weight_so_far + dw > index {
172-
let mut left_unit = 0f64;
173-
if self.weights[i] == 1f64 {
217+
let mut left_unit = 0.0;
218+
if left.weight == 1.0 {
174219
if index - weight_so_far < 0.5 {
175-
return self.means[i];
220+
return left.mean;
176221
}
177222
left_unit = 0.5;
178223
}
179224

180-
let mut right_unit = 0f64;
181-
if self.weights[i + 1] == 1f64 {
225+
let mut right_unit = 0.0;
226+
if right.weight == 1.0 {
182227
if weight_so_far + dw - index <= 0.5 {
183-
return self.means[i + 1];
228+
return right.mean;
184229
}
185230
right_unit = 0.5;
186231
}
187232

188233
let z1 = index - weight_so_far - left_unit;
189234
let z2 = weight_so_far + dw - index - right_unit;
190-
return QuantileTDigestState::weighted_average(
191-
self.means[i],
192-
z2,
193-
self.means[i + 1],
194-
z1,
195-
);
235+
return Self::weighted_average(left.mean, z2, right.mean, z1);
196236
}
197237
weight_so_far += dw;
198238
}
199239

200240
debug_assert!(index <= self.total_weight);
201-
debug_assert!(index >= self.total_weight - self.weights[weight_last] / 2f64);
241+
debug_assert!(index >= self.total_weight - last.weight / 2.0);
202242

203-
let z1 = index - self.total_weight - self.weights[weight_last] / 2f64;
204-
let z2 = self.weights[weight_last] / 2f64 - z1;
243+
let z1 = index - self.total_weight - last.weight / 2.0;
244+
let z2 = last.weight / 2.0 - z1;
205245

206-
QuantileTDigestState::weighted_average(self.means[mean_last], z1, self.max, z2)
246+
Self::weighted_average(last.mean, z1, self.max, z2)
207247
}
208248

209249
fn len(&self) -> usize {
@@ -219,65 +259,49 @@ impl QuantileTDigestState {
219259
}
220260

221261
fn compress(&mut self) {
222-
if self.unmerged_total_weight > 0f64 {
223-
self.merge_centroid(self.unmerged_weights.clone(), self.unmerged_means.clone());
224-
self.unmerged_weights.clear();
225-
self.unmerged_means.clear();
226-
self.unmerged_total_weight = 0f64;
262+
if self.unmerged_total_weight > 0.0 {
263+
self.merge_centroid(self.unmerged.clone());
264+
self.unmerged.clear();
265+
self.unmerged_total_weight = 0.0;
227266
}
228267
}
229268

230-
fn merge_centroid(&mut self, incoming_weights: Vec<f64>, incoming_means: Vec<f64>) {
231-
let mut incoming_weights = incoming_weights;
232-
incoming_weights.extend_from_slice(&self.weights);
233-
let mut incoming_means = incoming_means;
234-
incoming_means.extend_from_slice(&self.means);
235-
236-
// sort (0..incoming_means.len()) by values in incoming_means.
237-
// e.g. incoming_means[5.0, 2.0, 9.1, 1.3] => [3, 1, 0, 2]
238-
let incoming_order = (0..incoming_means.len())
239-
.sorted_by(|&i, &j| incoming_means[i].partial_cmp(&incoming_means[j]).unwrap())
240-
.collect::<Vec<_>>();
269+
fn merge_centroid(&mut self, mut incoming: Vec<Centroid>) {
270+
incoming.extend_from_slice(&self.centroids);
271+
incoming.sort_by(|a, b| a.mean.total_cmp(&b.mean));
241272

242273
self.total_weight += self.unmerged_total_weight;
243274

244-
let normalizer = self.epsilon as f64 / (PI * self.total_weight);
245-
246-
let mut weights = vec![];
247-
let mut means = vec![];
248-
249-
weights.push(incoming_weights[incoming_order[0]]);
250-
means.push(incoming_means[incoming_order[0]]);
275+
let normalizer = Self::EPSILON / (PI * self.total_weight);
251276

252-
let mut weight_so_far = 0f64;
277+
let mut incoming = incoming.into_iter();
278+
let mut current = incoming.next().unwrap();
279+
let first_mean = current.mean;
280+
let mut centroids = vec![];
281+
let mut weight_so_far = 0.0;
253282

254-
for idx in &incoming_order[1..] {
255-
let idx = *idx;
256-
let proposed_weight = weights[weights.len() - 1] + incoming_weights[idx];
283+
for centroid in incoming {
284+
let proposed_weight = current.weight + centroid.weight;
257285
let z = normalizer * proposed_weight;
258286
let q0 = weight_so_far / self.total_weight;
259287
let q2 = (weight_so_far + proposed_weight) / self.total_weight;
260-
let weight_last = weights.len() - 1;
261-
let mean_last = means.len() - 1;
262-
if z * z <= q0 * (1f64 - q0) && z * z <= q2 * (1f64 - q2) {
263-
weights[weight_last] += incoming_weights[idx];
264-
means[mean_last] = means[mean_last]
265-
+ (incoming_means[idx] - means[mean_last]) * incoming_weights[idx]
266-
/ weights[weight_last];
288+
if z * z <= q0 * (1.0 - q0) && z * z <= q2 * (1.0 - q2) {
289+
current.weight = proposed_weight;
290+
current.mean += (centroid.mean - current.mean) * centroid.weight / current.weight;
267291
} else {
268-
weight_so_far += weights[weight_last];
269-
weights.push(incoming_weights[idx]);
270-
means.push(incoming_means[idx]);
292+
weight_so_far += current.weight;
293+
centroids.push(current);
294+
current = centroid;
271295
}
272296
}
273297

274-
if self.total_weight > 0f64 {
275-
self.min = f64::min(self.min, means[0]);
276-
self.max = f64::max(self.max, means[means.len() - 1]);
298+
if self.total_weight > 0.0 {
299+
self.min = f64::min(self.min, first_mean);
300+
self.max = f64::max(self.max, current.mean);
277301
}
278302

279-
self.weights = weights;
280-
self.means = means;
303+
centroids.push(current);
304+
self.centroids = centroids;
281305
}
282306
}
283307

src/query/functions/tests/it/aggregates/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
mod agg;
1818
mod agg_hashtable;
19+
mod quantile_tdigest;
1920

2021
use std::io::Write;
2122

0 commit comments

Comments
 (0)