@@ -19,30 +19,33 @@ impl OperationsVTable<DictVTable> for DictVTable {
1919 return if let Some ( code) = code {
2020 ConstantArray :: new ( array. values ( ) . scalar_at ( * code) , sliced_code. len ( ) ) . into_array ( )
2121 } else {
22- let dtype = array. values ( ) . dtype ( ) . with_nullability (
23- array. values ( ) . dtype ( ) . nullability ( ) | array. codes ( ) . dtype ( ) . nullability ( ) ,
24- ) ;
25- ConstantArray :: new ( Scalar :: null ( dtype) , sliced_code. len ( ) ) . to_array ( )
22+ ConstantArray :: new ( Scalar :: null ( array. dtype ( ) . clone ( ) ) , sliced_code. len ( ) )
23+ . to_array ( )
2624 } ;
2725 }
2826 // SAFETY: slicing the codes preserves invariants
2927 unsafe { DictArray :: new_unchecked ( sliced_code, array. values ( ) . clone ( ) ) . into_array ( ) }
3028 }
3129
3230 fn scalar_at ( array : & DictArray , index : usize ) -> Scalar {
33- let dict_index: usize = array
34- . codes ( )
35- . scalar_at ( index)
36- . as_ref ( )
37- . try_into ( )
38- . vortex_expect ( "code overflowed usize" ) ;
39- array. values ( ) . scalar_at ( dict_index)
31+ let Some ( dict_index) = array. codes ( ) . scalar_at ( index) . as_primitive ( ) . as_ :: < usize > ( ) else {
32+ return Scalar :: null ( array. dtype ( ) . clone ( ) ) ;
33+ } ;
34+
35+ array
36+ . values ( )
37+ . scalar_at ( dict_index)
38+ . cast ( array. dtype ( ) )
39+ . vortex_expect ( "Array dtype will only differ by nullability" )
4040 }
4141}
4242
4343#[ cfg( test) ]
4444mod tests {
45+ use vortex_array:: IntoArray ;
4546 use vortex_array:: arrays:: PrimitiveArray ;
47+ use vortex_buffer:: buffer;
48+ use vortex_dtype:: Nullability ;
4649 use vortex_scalar:: Scalar ;
4750
4851 use crate :: DictArray ;
@@ -65,4 +68,19 @@ mod tests {
6568 dict. slice( 1 ..2 ) . as_constant( )
6669 ) ;
6770 }
71+
72+ #[ test]
73+ fn test_scalar_at_null_code ( ) {
74+ let dict = DictArray :: try_new (
75+ PrimitiveArray :: from_option_iter ( vec ! [ None , Some ( 0u32 ) , None ] ) . to_array ( ) ,
76+ buffer ! [ 1i32 ] . into_array ( ) ,
77+ )
78+ . unwrap ( ) ;
79+
80+ assert_eq ! ( dict. scalar_at( 0 ) , Scalar :: null( dict. dtype( ) . clone( ) ) ) ;
81+ assert_eq ! (
82+ dict. scalar_at( 1 ) ,
83+ Scalar :: primitive( 1 , Nullability :: Nullable )
84+ ) ;
85+ }
6886}
0 commit comments