@@ -14,7 +14,12 @@ use vortex_session::VortexSession;
1414
1515use crate :: ArrayRef ;
1616use crate :: ExecutionCtx ;
17+ use crate :: arrays:: Filter ;
18+ use crate :: arrays:: Slice ;
1719use crate :: arrays:: Variant ;
20+ use crate :: arrays:: filter:: FilterArrayExt ;
21+ use crate :: arrays:: scalar_fn:: ScalarFnFactoryExt ;
22+ use crate :: arrays:: slice:: SliceArrayExt ;
1823use crate :: arrays:: variant:: VariantArrayExt ;
1924use crate :: dtype:: DType ;
2025use crate :: dtype:: Nullability ;
@@ -213,14 +218,17 @@ impl ScalarFnVTable for VariantGet {
213218
214219 fn execute (
215220 & self ,
216- _options : & VariantGetOptions ,
217- _args : & dyn ExecutionArgs ,
218- _ctx : & mut ExecutionCtx ,
221+ options : & VariantGetOptions ,
222+ args : & dyn ExecutionArgs ,
223+ ctx : & mut ExecutionCtx ,
219224 ) -> VortexResult < ArrayRef > {
220- vortex_bail ! (
221- "variant_get cannot be executed directly; \
222- it must be pushed down to a variant encoding via execute_parent"
223- )
225+ let input = args. get ( 0 ) ?;
226+
227+ if let Some ( rewritten) = rewrite_wrapped_variant_input ( input, options, ctx) ? {
228+ return Ok ( rewritten) ;
229+ }
230+
231+ vortex_bail ! ( "variant_get cannot be executed directly" )
224232 }
225233
226234 fn reduce (
@@ -254,6 +262,48 @@ impl ScalarFnVTable for VariantGet {
254262 }
255263}
256264
265+ fn rewrite_wrapped_variant_input (
266+ input : ArrayRef ,
267+ options : & VariantGetOptions ,
268+ ctx : & mut ExecutionCtx ,
269+ ) -> VortexResult < Option < ArrayRef > > {
270+ if let Some ( slice) = input. as_opt :: < Slice > ( ) {
271+ let Some ( inner) = unwrap_variant_input ( slice. child ( ) ) else {
272+ return Ok ( None ) ;
273+ } ;
274+ let sliced = inner
275+ . slice ( slice. slice_range ( ) . clone ( ) ) ?
276+ . execute :: < ArrayRef > ( ctx) ?;
277+ return VariantGet
278+ . try_new_array ( input. len ( ) , options. clone ( ) , [ sliced] ) ?
279+ . execute :: < ArrayRef > ( ctx)
280+ . map ( Some ) ;
281+ }
282+
283+ if let Some ( filter) = input. as_opt :: < Filter > ( ) {
284+ let Some ( inner) = unwrap_variant_input ( filter. child ( ) ) else {
285+ return Ok ( None ) ;
286+ } ;
287+ let filtered = inner
288+ . filter ( filter. filter_mask ( ) . clone ( ) ) ?
289+ . execute :: < ArrayRef > ( ctx) ?;
290+ return VariantGet
291+ . try_new_array ( input. len ( ) , options. clone ( ) , [ filtered] ) ?
292+ . execute :: < ArrayRef > ( ctx)
293+ . map ( Some ) ;
294+ }
295+
296+ Ok ( None )
297+ }
298+
299+ fn unwrap_variant_input ( input : & ArrayRef ) -> Option < ArrayRef > {
300+ if let Some ( variant) = input. as_opt :: < Variant > ( ) {
301+ return Some ( variant. child ( ) . clone ( ) ) ;
302+ }
303+
304+ matches ! ( input. dtype( ) , DType :: Variant ( _) ) . then ( || input. clone ( ) )
305+ }
306+
257307#[ cfg( test) ]
258308mod tests {
259309 use vortex_session:: VortexSession ;
0 commit comments