44use std:: fmt:: Display ;
55use std:: fmt:: Formatter ;
66use std:: num:: NonZeroUsize ;
7+ use std:: sync:: LazyLock ;
78
89use vortex_buffer:: BufferString ;
910use vortex_buffer:: ByteBuffer ;
1011use vortex_error:: VortexExpect ;
1112use vortex_error:: VortexResult ;
13+ use vortex_error:: vortex_bail;
1214use vortex_error:: vortex_ensure;
1315use vortex_session:: VortexSession ;
1416
@@ -24,12 +26,24 @@ use crate::aggregate_fn::EmptyOptions;
2426use crate :: aggregate_fn:: fns:: max:: Max ;
2527use crate :: aggregate_fn:: fns:: min_max:: MinMax ;
2628use crate :: aggregate_fn:: fns:: min_max:: min_max;
29+ use crate :: builtins:: ArrayBuiltins ;
2730use crate :: dtype:: DType ;
31+ use crate :: dtype:: FieldNames ;
32+ use crate :: dtype:: Nullability ;
33+ use crate :: dtype:: StructFields ;
2834use crate :: partial_ord:: partial_max;
2935use crate :: scalar:: Scalar ;
3036use crate :: scalar:: ScalarTruncation ;
3137use crate :: scalar:: upper_bound;
3238
39+ /// Field name for the bounded maximum upper-bound value in the partial state.
40+ pub const BOUNDED_MAX_BOUND : & str = "bound" ;
41+ /// Field name for whether the partial state represents an unknown upper bound.
42+ pub const BOUNDED_MAX_UNKNOWN : & str = "unknown" ;
43+
44+ static NAMES : LazyLock < FieldNames > =
45+ LazyLock :: new ( || FieldNames :: from ( [ BOUNDED_MAX_BOUND , BOUNDED_MAX_UNKNOWN ] ) ) ;
46+
3347/// Options for [`BoundedMax`].
3448#[ derive( Clone , Debug , PartialEq , Eq , Hash ) ]
3549pub struct BoundedMaxOptions {
@@ -61,12 +75,8 @@ pub struct BoundedMaxPartial {
6175}
6276
6377impl BoundedMaxPartial {
64- fn merge ( & mut self , max : Scalar ) {
78+ fn merge_bound ( & mut self , max : Scalar ) {
6579 if max. is_null ( ) {
66- // Serialized partials encode both empty input and unknown upper bounds as null.
67- // Treat null as unknown when merging; this may lose a bound from an empty shard, but
68- // it preserves pruning soundness.
69- self . state = BoundedMaxState :: Unknown ;
7080 return ;
7181 }
7282
@@ -82,6 +92,32 @@ impl BoundedMaxPartial {
8292 fn unknown ( & mut self ) {
8393 self . state = BoundedMaxState :: Unknown ;
8494 }
95+
96+ fn final_scalar ( & self ) -> VortexResult < Scalar > {
97+ let dtype = self . element_dtype . as_nullable ( ) ;
98+ match & self . state {
99+ BoundedMaxState :: Value ( max) => max. cast ( & dtype) ,
100+ BoundedMaxState :: Empty | BoundedMaxState :: Unknown => Ok ( Scalar :: null ( dtype) ) ,
101+ }
102+ }
103+ }
104+
105+ /// Return the serialized partial-state dtype for [`BoundedMax`].
106+ ///
107+ /// A null struct means the partial is empty. A non-null struct with a null `bound` and
108+ /// `unknown = true` means the input has a non-null maximum but no finite upper bound could be
109+ /// represented within the configured byte limit.
110+ pub fn make_bounded_max_partial_dtype ( element_dtype : & DType ) -> DType {
111+ DType :: Struct (
112+ StructFields :: new (
113+ NAMES . clone ( ) ,
114+ vec ! [
115+ element_dtype. as_nullable( ) ,
116+ DType :: Bool ( Nullability :: NonNullable ) ,
117+ ] ,
118+ ) ,
119+ Nullability :: Nullable ,
120+ )
85121}
86122
87123impl AggregateFnVTable for BoundedMax {
@@ -144,7 +180,7 @@ impl AggregateFnVTable for BoundedMax {
144180 }
145181
146182 fn partial_dtype ( & self , options : & Self :: Options , input_dtype : & DType ) -> Option < DType > {
147- self . return_dtype ( options, input_dtype)
183+ supported_dtype ( options, input_dtype) . map ( make_bounded_max_partial_dtype )
148184 }
149185
150186 fn empty_partial (
@@ -160,15 +196,50 @@ impl AggregateFnVTable for BoundedMax {
160196 }
161197
162198 fn combine_partials ( & self , partial : & mut Self :: Partial , other : Scalar ) -> VortexResult < ( ) > {
163- partial. merge ( other) ;
199+ if other. is_null ( ) {
200+ return Ok ( ( ) ) ;
201+ }
202+
203+ let Some ( other) = other. as_struct_opt ( ) else {
204+ vortex_bail ! ( "BoundedMax partial must be a struct, got {}" , other. dtype( ) ) ;
205+ } ;
206+ let Some ( bound) = other. field_by_idx ( 0 ) else {
207+ vortex_bail ! ( "BoundedMax partial is missing its bound field" ) ;
208+ } ;
209+ let Some ( unknown) = other
210+ . field_by_idx ( 1 )
211+ . and_then ( |unknown| unknown. as_bool ( ) . value ( ) )
212+ else {
213+ vortex_bail ! ( "BoundedMax partial is missing its non-null unknown field" ) ;
214+ } ;
215+
216+ if unknown {
217+ partial. unknown ( ) ;
218+ } else {
219+ partial. merge_bound ( bound) ;
220+ }
164221 Ok ( ( ) )
165222 }
166223
167224 fn to_scalar ( & self , partial : & Self :: Partial ) -> VortexResult < Scalar > {
168- let dtype = partial. element_dtype . as_nullable ( ) ;
225+ let dtype = make_bounded_max_partial_dtype ( & partial. element_dtype ) ;
226+ let bound_dtype = partial. element_dtype . as_nullable ( ) ;
169227 match & partial. state {
170- BoundedMaxState :: Value ( max) => max. cast ( & dtype) ,
171- BoundedMaxState :: Empty | BoundedMaxState :: Unknown => Ok ( Scalar :: null ( dtype) ) ,
228+ BoundedMaxState :: Empty => Ok ( Scalar :: null ( dtype) ) ,
229+ BoundedMaxState :: Value ( max) => Ok ( Scalar :: struct_ (
230+ dtype,
231+ vec ! [
232+ max. cast( & bound_dtype) ?,
233+ Scalar :: bool ( false , Nullability :: NonNullable ) ,
234+ ] ,
235+ ) ) ,
236+ BoundedMaxState :: Unknown => Ok ( Scalar :: struct_ (
237+ dtype,
238+ vec ! [
239+ Scalar :: null( bound_dtype) ,
240+ Scalar :: bool ( true , Nullability :: NonNullable ) ,
241+ ] ,
242+ ) ) ,
172243 }
173244 }
174245
@@ -196,18 +267,18 @@ impl AggregateFnVTable for BoundedMax {
196267 return Ok ( ( ) ) ;
197268 } ;
198269 match truncate_max ( result. max , partial. max_bytes . get ( ) ) ? {
199- Some ( bound) => partial. merge ( bound) ,
270+ Some ( bound) => partial. merge_bound ( bound) ,
200271 None => partial. unknown ( ) ,
201272 }
202273 Ok ( ( ) )
203274 }
204275
205276 fn finalize ( & self , partials : ArrayRef ) -> VortexResult < ArrayRef > {
206- Ok ( partials)
277+ partials. get_item ( BOUNDED_MAX_BOUND )
207278 }
208279
209280 fn finalize_scalar ( & self , partial : & Self :: Partial ) -> VortexResult < Scalar > {
210- self . to_scalar ( partial )
281+ partial . final_scalar ( )
211282 }
212283}
213284
@@ -256,6 +327,7 @@ mod tests {
256327 use crate :: aggregate_fn:: EmptyOptions ;
257328 use crate :: aggregate_fn:: fns:: bounded_max:: BoundedMax ;
258329 use crate :: aggregate_fn:: fns:: bounded_max:: BoundedMaxOptions ;
330+ use crate :: aggregate_fn:: fns:: bounded_max:: make_bounded_max_partial_dtype;
259331 use crate :: aggregate_fn:: fns:: max:: Max ;
260332 use crate :: aggregate_fn:: fns:: min:: Min ;
261333 use crate :: arrays:: PrimitiveArray ;
@@ -352,7 +424,29 @@ mod tests {
352424 }
353425
354426 #[ test]
355- fn bounded_max_null_partial_poisons_existing_bound ( ) -> VortexResult < ( ) > {
427+ fn bounded_max_empty_partial_does_not_poison_existing_bound ( ) -> VortexResult < ( ) > {
428+ let mut ctx = fresh_session ( ) . create_execution_ctx ( ) ;
429+ let values = VarBinViewArray :: from_iter_bin ( [ & [ 1u8 ] [ ..] ] ) . into_array ( ) ;
430+ let mut acc = Accumulator :: try_new (
431+ BoundedMax ,
432+ BoundedMaxOptions {
433+ max_bytes : max_bytes ( 2 ) ,
434+ } ,
435+ values. dtype ( ) . clone ( ) ,
436+ ) ?;
437+
438+ acc. accumulate ( & values, & mut ctx) ?;
439+ acc. combine_partials ( Scalar :: null ( make_bounded_max_partial_dtype ( values. dtype ( ) ) ) ) ?;
440+
441+ assert_eq ! (
442+ acc. finish( ) ?,
443+ Scalar :: binary( buffer![ 1u8 ] , Nullability :: Nullable )
444+ ) ;
445+ Ok ( ( ) )
446+ }
447+
448+ #[ test]
449+ fn bounded_max_unknown_partial_poisons_existing_bound ( ) -> VortexResult < ( ) > {
356450 let mut ctx = fresh_session ( ) . create_execution_ctx ( ) ;
357451 let values = VarBinViewArray :: from_iter_bin ( [ & [ 1u8 ] [ ..] ] ) . into_array ( ) ;
358452 let mut acc = Accumulator :: try_new (
@@ -363,8 +457,17 @@ mod tests {
363457 values. dtype ( ) . clone ( ) ,
364458 ) ?;
365459
460+ let partial_dtype = make_bounded_max_partial_dtype ( values. dtype ( ) ) ;
461+ let unknown = Scalar :: struct_ (
462+ partial_dtype,
463+ vec ! [
464+ Scalar :: null( values. dtype( ) . as_nullable( ) ) ,
465+ Scalar :: bool ( true , Nullability :: NonNullable ) ,
466+ ] ,
467+ ) ;
468+
366469 acc. accumulate ( & values, & mut ctx) ?;
367- acc. combine_partials ( Scalar :: null ( values . dtype ( ) . as_nullable ( ) ) ) ?;
470+ acc. combine_partials ( unknown ) ?;
368471
369472 assert_eq ! ( acc. finish( ) ?, Scalar :: null( values. dtype( ) . as_nullable( ) ) ) ;
370473 Ok ( ( ) )
0 commit comments