@@ -244,6 +244,7 @@ impl Validity {
244244 }
245245 }
246246
247+ #[ inline]
247248 pub fn execute_mask ( & self , length : usize , ctx : & mut ExecutionCtx ) -> VortexResult < Mask > {
248249 match self {
249250 Self :: NonNullable | Self :: AllValid => Ok ( Mask :: AllTrue ( length) ) ,
@@ -263,18 +264,22 @@ impl Validity {
263264 }
264265 }
265266
266- /// Compare two Validity values of the same length by executing them into masks if necessary.
267- pub fn mask_eq ( & self , other : & Validity , ctx : & mut ExecutionCtx ) -> VortexResult < bool > {
267+ /// Compare the logical masks of two Validity values of the given length, executing them
268+ /// into [`Mask`]s if necessary.
269+ pub fn mask_eq (
270+ & self ,
271+ other : & Validity ,
272+ length : usize ,
273+ ctx : & mut ExecutionCtx ,
274+ ) -> VortexResult < bool > {
268275 match ( self , other) {
269- ( Validity :: NonNullable , Validity :: NonNullable ) => Ok ( true ) ,
270- ( Validity :: AllValid , Validity :: AllValid ) => Ok ( true ) ,
271- ( Validity :: AllInvalid , Validity :: AllInvalid ) => Ok ( true ) ,
272- ( Validity :: Array ( a) , Validity :: Array ( b) ) => {
273- let a = a. clone ( ) . execute :: < Mask > ( ctx) ?;
274- let b = b. clone ( ) . execute :: < Mask > ( ctx) ?;
275- Ok ( a == b)
276- }
277- _ => Ok ( false ) ,
276+ // Fast paths that avoid executing: constant variants with known-equal masks.
277+ (
278+ Validity :: NonNullable | Validity :: AllValid ,
279+ Validity :: NonNullable | Validity :: AllValid ,
280+ )
281+ | ( Validity :: AllInvalid , Validity :: AllInvalid ) => Ok ( true ) ,
282+ _ => Ok ( self . execute_mask ( length, ctx) ? == other. execute_mask ( length, ctx) ?) ,
278283 }
279284 }
280285
@@ -703,7 +708,7 @@ mod tests {
703708 validity
704709 . patch( len, 0 , & indices, & patches, & mut ctx, )
705710 . unwrap( )
706- . mask_eq( & expected, & mut ctx)
711+ . mask_eq( & expected, len , & mut ctx)
707712 . unwrap( )
708713 ) ;
709714 }
@@ -768,8 +773,50 @@ mod tests {
768773 validity
769774 . take( & indices)
770775 . unwrap( )
771- . mask_eq( & expected, & mut ctx)
776+ . mask_eq( & expected, indices . len ( ) , & mut ctx)
772777 . unwrap( )
773778 ) ;
774779 }
780+
781+ #[ rstest]
782+ // Mixed constant variants with equal masks.
783+ #[ case( Validity :: NonNullable , Validity :: AllValid , true ) ]
784+ #[ case( Validity :: AllValid , Validity :: NonNullable , true ) ]
785+ #[ case( Validity :: AllValid , Validity :: AllInvalid , false ) ]
786+ #[ case( Validity :: NonNullable , Validity :: AllInvalid , false ) ]
787+ // An array that resolves to a constant mask must equal the constant variant.
788+ #[ case(
789+ Validity :: Array ( BoolArray :: from_iter( [ true , true , true ] ) . into_array( ) ) ,
790+ Validity :: AllValid ,
791+ true
792+ ) ]
793+ #[ case(
794+ Validity :: NonNullable ,
795+ Validity :: Array ( BoolArray :: from_iter( [ true , true , true ] ) . into_array( ) ) ,
796+ true
797+ ) ]
798+ #[ case(
799+ Validity :: Array ( BoolArray :: from_iter( [ false , false , false ] ) . into_array( ) ) ,
800+ Validity :: AllInvalid ,
801+ true
802+ ) ]
803+ #[ case(
804+ Validity :: Array ( BoolArray :: from_iter( [ true , false , true ] ) . into_array( ) ) ,
805+ Validity :: AllValid ,
806+ false
807+ ) ]
808+ #[ case(
809+ Validity :: Array ( BoolArray :: from_iter( [ true , false , true ] ) . into_array( ) ) ,
810+ Validity :: AllInvalid ,
811+ false
812+ ) ]
813+ fn mask_eq_mixed_variants (
814+ #[ case] lhs : Validity ,
815+ #[ case] rhs : Validity ,
816+ #[ case] expected : bool ,
817+ ) -> vortex_error:: VortexResult < ( ) > {
818+ let mut ctx = LEGACY_SESSION . create_execution_ctx ( ) ;
819+ assert_eq ! ( lhs. mask_eq( & rhs, 3 , & mut ctx) ?, expected) ;
820+ Ok ( ( ) )
821+ }
775822}
0 commit comments