Skip to content

Commit c46cc26

Browse files
b41shzhang2014
authored andcommitted
fix(query): Fix window aggregate function result materialization (#19823)
* fix(query): Fix window aggregate function result materialization * fix
1 parent b9f7afd commit c46cc26

6 files changed

Lines changed: 94 additions & 34 deletions

File tree

src/query/functions/src/aggregates/aggregate_array_agg.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ use super::assert_unary_arguments;
6969
use super::batch_merge1;
7070
use super::batch_serialize1;
7171

72-
#[derive(Debug)]
72+
#[derive(Clone, Debug)]
7373
struct ArrayAggStateAny<T>
7474
where T: ValueType
7575
{
@@ -170,7 +170,7 @@ where
170170
}
171171
}
172172

173-
#[derive(Debug)]
173+
#[derive(Clone, Debug)]
174174
struct ArrayAggStateSimple<T>
175175
where T: SimpleType
176176
{
@@ -288,7 +288,7 @@ where T: SimpleType
288288
}
289289
}
290290

291-
#[derive(Debug)]
291+
#[derive(Clone, Debug)]
292292
struct ArrayAggStateZST<const IS_NULL: bool> {
293293
validity: MutableBitmap,
294294
}
@@ -387,7 +387,7 @@ impl<const IS_NULL: bool> StateSerde for ArrayAggStateZST<IS_NULL> {
387387
}
388388
}
389389

390-
#[derive(Debug)]
390+
#[derive(Clone, Debug)]
391391
struct ArrayAggStateBinary<T>
392392
where T: ArgType
393393
{
@@ -565,7 +565,7 @@ struct AggregateArrayAggFunction<T, State> {
565565
impl<T, State> AggregateFunction for AggregateArrayAggFunction<T, State>
566566
where
567567
T: AccessType,
568-
State: ScalarStateFunc<T>,
568+
State: Clone + ScalarStateFunc<T>,
569569
{
570570
fn name(&self) -> &str {
571571
"AggregateArrayAggFunction"
@@ -686,11 +686,16 @@ where
686686
fn merge_result(
687687
&self,
688688
place: AggrState,
689-
_read_only: bool,
689+
read_only: bool,
690690
builder: &mut ColumnBuilder,
691691
) -> Result<()> {
692692
let state = place.get::<State>();
693-
state.merge_result(builder)
693+
if read_only {
694+
let mut state = state.clone();
695+
state.merge_result(builder)
696+
} else {
697+
state.merge_result(builder)
698+
}
694699
}
695700

696701
fn need_manual_drop_state(&self) -> bool {
@@ -712,7 +717,7 @@ impl<T, State> fmt::Display for AggregateArrayAggFunction<T, State> {
712717
impl<T, State> AggregateArrayAggFunction<T, State>
713718
where
714719
T: ValueType,
715-
State: ScalarStateFunc<T>,
720+
State: Clone + ScalarStateFunc<T>,
716721
{
717722
fn create(display_name: &str, return_type: DataType) -> Result<Arc<dyn AggregateFunction>> {
718723
let func = AggregateArrayAggFunction::<T, State> {

src/query/functions/src/aggregates/aggregate_json_array_agg.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ use super::assert_params;
5555
use super::assert_unary_arguments;
5656
use super::batch_merge1;
5757

58-
#[derive(BorshSerialize, BorshDeserialize, Debug)]
58+
#[derive(Clone, BorshSerialize, BorshDeserialize, Debug)]
5959
pub struct JsonArrayAggState<T>
6060
where
6161
T: ValueType,
@@ -184,7 +184,7 @@ struct AggregateJsonArrayAggFunction<T> {
184184
impl<T> AggregateFunction for AggregateJsonArrayAggFunction<T>
185185
where
186186
T: ValueType,
187-
T::Scalar: borsh::BorshSerialize + borsh::BorshDeserialize,
187+
T::Scalar: borsh::BorshSerialize + borsh::BorshDeserialize + Clone,
188188
{
189189
fn name(&self) -> &str {
190190
"AggregateJsonArrayAggFunction"
@@ -307,11 +307,16 @@ where
307307
fn merge_result(
308308
&self,
309309
place: AggrState,
310-
_read_only: bool,
310+
read_only: bool,
311311
builder: &mut ColumnBuilder,
312312
) -> Result<()> {
313313
let state = place.get::<JsonArrayAggState<T>>();
314-
state.merge_result(builder)
314+
if read_only {
315+
let mut state = state.clone();
316+
state.merge_result(builder)
317+
} else {
318+
state.merge_result(builder)
319+
}
315320
}
316321

317322
fn need_manual_drop_state(&self) -> bool {

src/query/functions/src/aggregates/aggregate_json_object_agg.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ use super::borsh_partial_deserialize;
5454
use crate::aggregates::AggregateFunctionFeatures;
5555

5656
pub(super) trait BinaryScalarStateFunc<V: ValueType>:
57-
BorshSerialize + BorshDeserialize + Send + 'static
57+
Clone + BorshSerialize + BorshDeserialize + Send + 'static
5858
{
5959
fn new() -> Self;
6060
fn add(&mut self, other: Option<(&str, V::ScalarRef<'_>)>) -> Result<()>;
@@ -68,7 +68,7 @@ pub(super) trait BinaryScalarStateFunc<V: ValueType>:
6868
fn merge_result(&mut self, builder: &mut ColumnBuilder) -> Result<()>;
6969
}
7070

71-
#[derive(BorshSerialize, BorshDeserialize, Debug)]
71+
#[derive(Clone, BorshSerialize, BorshDeserialize, Debug)]
7272
pub struct JsonObjectAggState<V>
7373
where
7474
V: ValueType,
@@ -347,11 +347,16 @@ where
347347
fn merge_result(
348348
&self,
349349
place: AggrState,
350-
_read_only: bool,
350+
read_only: bool,
351351
builder: &mut ColumnBuilder,
352352
) -> Result<()> {
353353
let state = place.get::<State>();
354-
state.merge_result(builder)
354+
if read_only {
355+
let mut state = state.clone();
356+
state.merge_result(builder)
357+
} else {
358+
state.merge_result(builder)
359+
}
355360
}
356361

357362
fn need_manual_drop_state(&self) -> bool {

src/query/functions/src/aggregates/aggregate_markov_tarin.rs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,32 @@ impl AggregateFunction for MarkovTarin {
174174
fn merge_result(
175175
&self,
176176
place: AggrState,
177-
_read_only: bool,
177+
read_only: bool,
178178
builder: &mut ColumnBuilder,
179179
) -> Result<()> {
180-
let model = place.get::<MarkovModel>();
181-
model.finalize(&self.params);
180+
if read_only {
181+
let mut model = place.get::<MarkovModel>().clone();
182+
model.finalize(&self.params);
183+
self.append_model_result(&model, builder)
184+
} else {
185+
let model = place.get::<MarkovModel>();
186+
model.finalize(&self.params);
187+
self.append_model_result(model, builder)
188+
}
189+
}
190+
191+
fn need_manual_drop_state(&self) -> bool {
192+
true
193+
}
182194

195+
unsafe fn drop_state(&self, place: AggrState) {
196+
let state = place.get::<MarkovModel>();
197+
unsafe { std::ptr::drop_in_place(state) };
198+
}
199+
}
200+
201+
impl MarkovTarin {
202+
fn append_model_result(&self, model: &MarkovModel, builder: &mut ColumnBuilder) -> Result<()> {
183203
let ColumnBuilder::Array(box array_builder) = builder else {
184204
unreachable!()
185205
};
@@ -207,15 +227,6 @@ impl AggregateFunction for MarkovTarin {
207227
array_builder.commit_row();
208228
Ok(())
209229
}
210-
211-
fn need_manual_drop_state(&self) -> bool {
212-
true
213-
}
214-
215-
unsafe fn drop_state(&self, place: AggrState) {
216-
let state = place.get::<MarkovModel>();
217-
unsafe { std::ptr::drop_in_place(state) };
218-
}
219230
}
220231

221232
impl fmt::Display for MarkovTarin {
@@ -303,7 +314,7 @@ impl Histogram {
303314
}
304315
}
305316

306-
#[derive(Default, BorshSerialize, BorshDeserialize)]
317+
#[derive(Clone, Default, BorshSerialize, BorshDeserialize)]
307318
struct MarkovModel {
308319
table: BTreeMap<NGramHash, Histogram>,
309320
}

src/query/functions/src/aggregates/aggregate_range_bound.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,12 @@ where
170170
) -> Result<()> {
171171
let step = self.total_rows as f64 / range_bound_data.partitions as f64;
172172

173-
let values = std::mem::take(&mut self.values);
174173
let mut data = Vec::with_capacity(self.total_samples);
175174
let mut weights = Vec::with_capacity(self.total_samples);
176-
for (num, values) in values.into_iter() {
177-
let weight = num as f64 / values.len() as f64;
178-
values.into_iter().for_each(|v| {
179-
data.push(v);
175+
for (num, values) in self.values.iter() {
176+
let weight = *num as f64 / values.len() as f64;
177+
values.iter().for_each(|v| {
178+
data.push(v.clone());
180179
weights.push(weight);
181180
});
182181
}

tests/sqllogictests/suites/query/window_function/window_basic.test

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,41 @@ SELECT sum(salary) OVER (PARTITION BY depname ORDER BY salary) ss FROM empsalary
6060
9600
6161
14600
6262

63+
query T
64+
SELECT array_agg(v) OVER (PARTITION BY k) FROM (SELECT 1 AS k, ['a'] AS v UNION ALL SELECT 1, ['b'] UNION ALL SELECT 1, ['c'])
65+
----
66+
[["a"],["b"],["c"]]
67+
[["a"],["b"],["c"]]
68+
[["a"],["b"],["c"]]
69+
70+
query T rowsort
71+
SELECT json_array_agg(v) OVER (PARTITION BY k) FROM (SELECT 1 AS k, 'a' AS v UNION ALL SELECT 1, 'b' UNION ALL SELECT 1, 'c')
72+
----
73+
["a","b","c"]
74+
["a","b","c"]
75+
["a","b","c"]
76+
77+
query T
78+
SELECT json_object_agg(v, n) OVER (PARTITION BY k) FROM (SELECT 1 AS k, 'a' AS v, 10 AS n UNION ALL SELECT 1, 'b', 20 UNION ALL SELECT 1, 'c', 30)
79+
----
80+
{"a":10,"b":20,"c":30}
81+
{"a":10,"b":20,"c":30}
82+
{"a":10,"b":20,"c":30}
83+
84+
query T
85+
SELECT range_bound(3)(v) OVER (PARTITION BY k) FROM (SELECT 1 AS k, 10 AS v UNION ALL SELECT 1, 20 UNION ALL SELECT 1, 30)
86+
----
87+
[10,20]
88+
[10,20]
89+
[10,20]
90+
91+
query T
92+
SELECT m FROM (SELECT markov_train(1, 0, 0, 1, 0)(v) OVER (PARTITION BY k) AS m FROM (SELECT 1 AS k, 'ab' AS v UNION ALL SELECT 1, 'ac' UNION ALL SELECT 1, 'ad'));
93+
----
94+
[(0,6,4,{97:4,98:2,99:2,100:2})]
95+
[(0,6,4,{97:4,98:2,99:2,100:2})]
96+
[(0,6,4,{97:4,98:2,99:2,100:2})]
97+
6398
# row_number
6499
query I
65100
SELECT row_number() OVER (PARTITION BY depname ORDER BY salary) rn FROM empsalary ORDER BY depname, rn

0 commit comments

Comments
 (0)