1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15+ use std:: cell:: Cell ;
1516use std:: sync:: Arc ;
1617
1718use databend_common_base:: runtime:: drop_guard;
@@ -55,14 +56,21 @@ pub struct WindowFuncAggImpl {
5556 addr : StateAddr ,
5657 loc : Box < [ AggrStateLoc ] > ,
5758 args : Vec < usize > ,
59+ initialized : Cell < bool > ,
5860}
5961
6062unsafe impl Send for WindowFuncAggImpl { }
6163
6264impl WindowFuncAggImpl {
6365 #[ inline]
6466 pub fn reset ( & self ) {
67+ if self . initialized . replace ( false ) && self . agg . need_manual_drop_state ( ) {
68+ unsafe {
69+ self . agg . drop_state ( AggrState :: new ( self . addr , & self . loc ) ) ;
70+ }
71+ }
6572 self . agg . init_state ( AggrState :: new ( self . addr , & self . loc ) ) ;
73+ self . initialized . set ( true ) ;
6674 }
6775
6876 #[ inline]
@@ -86,7 +94,7 @@ impl WindowFuncAggImpl {
8694impl Drop for WindowFuncAggImpl {
8795 fn drop ( & mut self ) {
8896 drop_guard ( move || {
89- if self . agg . need_manual_drop_state ( ) {
97+ if self . initialized . get ( ) && self . agg . need_manual_drop_state ( ) {
9098 unsafe {
9199 self . agg . drop_state ( AggrState :: new ( self . addr , & self . loc ) ) ;
92100 }
@@ -271,6 +279,7 @@ impl WindowFunctionImpl {
271279 loc,
272280 args,
273281 _arena : arena,
282+ initialized : Cell :: new ( false ) ,
274283 } ;
275284 agg. reset ( ) ;
276285 Self :: Aggregate ( agg)
@@ -306,3 +315,157 @@ impl WindowFunctionImpl {
306315 }
307316 }
308317}
318+
319+ #[ cfg( test) ]
320+ mod tests {
321+ use std:: alloc:: Layout ;
322+ use std:: fmt;
323+ use std:: sync:: Arc ;
324+ use std:: sync:: atomic:: AtomicUsize ;
325+ use std:: sync:: atomic:: Ordering ;
326+
327+ use databend_common_exception:: Result ;
328+ use databend_common_expression:: AggrStateRegistry ;
329+ use databend_common_expression:: AggrStateType ;
330+ use databend_common_expression:: BlockEntry ;
331+ use databend_common_expression:: StateSerdeItem ;
332+ use databend_common_expression:: types:: Bitmap ;
333+ use databend_common_functions:: aggregates:: AggregateFunctionRef ;
334+
335+ use super :: * ;
336+
337+ struct DropCountingState {
338+ drops : Arc < AtomicUsize > ,
339+ }
340+
341+ impl Drop for DropCountingState {
342+ fn drop ( & mut self ) {
343+ self . drops . fetch_add ( 1 , Ordering :: SeqCst ) ;
344+ }
345+ }
346+
347+ struct DropCountingAggregate {
348+ drops : Arc < AtomicUsize > ,
349+ }
350+
351+ impl fmt:: Display for DropCountingAggregate {
352+ fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
353+ write ! ( f, "drop_counting_aggregate" )
354+ }
355+ }
356+
357+ impl AggregateFunction for DropCountingAggregate {
358+ fn name ( & self ) -> & str {
359+ "DropCountingAggregate"
360+ }
361+
362+ fn return_type ( & self ) -> Result < DataType > {
363+ Ok ( DataType :: Null )
364+ }
365+
366+ fn init_state ( & self , place : AggrState ) {
367+ let drops = self . drops . clone ( ) ;
368+ place. write ( || DropCountingState { drops } ) ;
369+ }
370+
371+ fn register_state ( & self , registry : & mut AggrStateRegistry ) {
372+ registry. register ( AggrStateType :: Custom ( Layout :: new :: < DropCountingState > ( ) ) ) ;
373+ }
374+
375+ fn accumulate (
376+ & self ,
377+ _place : AggrState ,
378+ _columns : ProjectedBlock ,
379+ _validity : Option < & Bitmap > ,
380+ _input_rows : usize ,
381+ ) -> Result < ( ) > {
382+ Ok ( ( ) )
383+ }
384+
385+ fn accumulate_row (
386+ & self ,
387+ _place : AggrState ,
388+ _columns : ProjectedBlock ,
389+ _row : usize ,
390+ ) -> Result < ( ) > {
391+ Ok ( ( ) )
392+ }
393+
394+ fn serialize_type ( & self ) -> Vec < StateSerdeItem > {
395+ vec ! [ ]
396+ }
397+
398+ fn batch_serialize (
399+ & self ,
400+ _places : & [ StateAddr ] ,
401+ _loc : & [ AggrStateLoc ] ,
402+ _builders : & mut [ ColumnBuilder ] ,
403+ ) -> Result < ( ) > {
404+ Ok ( ( ) )
405+ }
406+
407+ fn batch_merge (
408+ & self ,
409+ _places : & [ StateAddr ] ,
410+ _loc : & [ AggrStateLoc ] ,
411+ _state : & BlockEntry ,
412+ _filter : Option < & Bitmap > ,
413+ ) -> Result < ( ) > {
414+ Ok ( ( ) )
415+ }
416+
417+ fn merge_states ( & self , _place : AggrState , _rhs : AggrState ) -> Result < ( ) > {
418+ Ok ( ( ) )
419+ }
420+
421+ fn merge_result (
422+ & self ,
423+ _place : AggrState ,
424+ _read_only : bool ,
425+ _builder : & mut ColumnBuilder ,
426+ ) -> Result < ( ) > {
427+ Ok ( ( ) )
428+ }
429+
430+ fn need_manual_drop_state ( & self ) -> bool {
431+ true
432+ }
433+
434+ unsafe fn drop_state ( & self , place : AggrState ) {
435+ let state = place. get :: < DropCountingState > ( ) ;
436+ unsafe { std:: ptr:: drop_in_place ( state) } ;
437+ }
438+ }
439+
440+ #[ test]
441+ fn reset_drops_existing_manual_state_before_reinitializing ( ) -> Result < ( ) > {
442+ let drops = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
443+ let agg: AggregateFunctionRef = Arc :: new ( DropCountingAggregate {
444+ drops : drops. clone ( ) ,
445+ } ) ;
446+ let arena = Arena :: new ( ) ;
447+ let mut states_layout = get_states_layout ( std:: slice:: from_ref ( & agg) ) ?;
448+ let addr = arena. alloc_layout ( states_layout. layout ) . into ( ) ;
449+ let loc = states_layout. states_loc . pop ( ) . unwrap ( ) ;
450+
451+ let window_agg = WindowFuncAggImpl {
452+ _arena : arena,
453+ agg,
454+ addr,
455+ loc,
456+ args : vec ! [ ] ,
457+ initialized : Cell :: new ( false ) ,
458+ } ;
459+
460+ window_agg. reset ( ) ;
461+ assert_eq ! ( drops. load( Ordering :: SeqCst ) , 0 ) ;
462+
463+ window_agg. reset ( ) ;
464+ assert_eq ! ( drops. load( Ordering :: SeqCst ) , 1 ) ;
465+
466+ drop ( window_agg) ;
467+ assert_eq ! ( drops. load( Ordering :: SeqCst ) , 2 ) ;
468+
469+ Ok ( ( ) )
470+ }
471+ }
0 commit comments