Skip to content

Commit 854387e

Browse files
authored
Fix: simplify dict array compute functions and report correct dictarray dtype now that codes and values don't have same nullability (#4579)
Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent 7310b2d commit 854387e

4 files changed

Lines changed: 45 additions & 34 deletions

File tree

encodings/dict/src/array.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ pub struct DictArray {
4444
codes: ArrayRef,
4545
values: ArrayRef,
4646
stats_set: ArrayStats,
47+
dtype: DType,
4748
}
4849

4950
#[derive(Clone, Debug)]
@@ -57,10 +58,14 @@ impl DictArray {
5758
/// by the safe [`DictArray::try_new`] constructor are valid, for example when
5859
/// you are filtering or slicing an existing valid `DictArray`.
5960
pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
61+
let dtype = values
62+
.dtype()
63+
.union_nullability(codes.dtype().nullability());
6064
Self {
6165
codes,
6266
values,
6367
stats_set: Default::default(),
68+
dtype,
6469
}
6570
}
6671

@@ -88,11 +93,7 @@ impl DictArray {
8893
vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
8994
}
9095

91-
Ok(Self {
92-
codes,
93-
values,
94-
stats_set: Default::default(),
95-
})
96+
Ok(unsafe { Self::new_unchecked(codes, values) })
9697
}
9798

9899
#[inline]
@@ -112,7 +113,7 @@ impl ArrayVTable<DictVTable> for DictVTable {
112113
}
113114

114115
fn dtype(array: &DictArray) -> &DType {
115-
array.values.dtype()
116+
&array.dtype
116117
}
117118

118119
fn stats(array: &DictArray) -> StatsSetRef<'_> {

encodings/dict/src/compute/cast.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
use vortex_array::compute::{CastKernel, CastKernelAdapter, cast};
5-
use vortex_array::{ArrayRef, IntoArray, register_kernel};
5+
use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
66
use vortex_dtype::DType;
77
use vortex_error::VortexResult;
88

@@ -13,7 +13,8 @@ impl CastKernel for DictVTable {
1313
// Cast the dictionary values to the target type
1414
let casted_values = cast(array.values(), dtype)?;
1515

16-
let casted_codes = if dtype.nullability() != array.codes().dtype().nullability() {
16+
// If the codes are nullable but we are casting to non nullable dtype we have to remove nullability from codes as well
17+
let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
1718
cast(
1819
array.codes(),
1920
&array.codes().dtype().with_nullability(dtype.nullability()),
@@ -23,11 +24,9 @@ impl CastKernel for DictVTable {
2324
};
2425

2526
// SAFETY: casting does not alter invariants of the codes
26-
unsafe {
27-
Ok(Some(
28-
DictArray::new_unchecked(casted_codes, casted_values).into_array(),
29-
))
30-
}
27+
Ok(Some(
28+
unsafe { DictArray::new_unchecked(casted_codes, casted_values) }.into_array(),
29+
))
3130
}
3231
}
3332

@@ -132,7 +131,7 @@ mod tests {
132131
let nullable_dict = nullable.as_::<DictVTable>();
133132
assert_eq!(
134133
nullable_dict.codes().dtype().nullability(),
135-
Nullability::Nullable
134+
Nullability::NonNullable
136135
);
137136
assert_eq!(
138137
nullable_dict.values().dtype().nullability(),

encodings/dict/src/compute/mod.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ mod like;
1010
mod min_max;
1111

1212
use vortex_array::compute::{
13-
FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, cast, filter, take,
13+
FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, filter, take,
1414
};
1515
use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
1616
use vortex_error::VortexResult;
@@ -20,16 +20,9 @@ use crate::{DictArray, DictVTable};
2020

2121
impl TakeKernel for DictVTable {
2222
fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
23-
// TODO(joe): can we remove the cast and allow dict arrays to have nullable codes and values
2423
let codes = take(array.codes(), indices)?;
25-
let values_dtype = array
26-
.values()
27-
.dtype()
28-
.union_nullability(codes.dtype().nullability());
2924
// SAFETY: selecting codes doesn't change the invariants of DictArray
30-
unsafe {
31-
Ok(DictArray::new_unchecked(codes, cast(array.values(), &values_dtype)?).into_array())
32-
}
25+
Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()) }.into_array())
3326
}
3427
}
3528

encodings/dict/src/ops.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,33 @@ impl OperationsVTable<DictVTable> for DictVTable {
1919
return if let Some(code) = code {
2020
ConstantArray::new(array.values().scalar_at(*code), sliced_code.len()).into_array()
2121
} else {
22-
let dtype = array.values().dtype().with_nullability(
23-
array.values().dtype().nullability() | array.codes().dtype().nullability(),
24-
);
25-
ConstantArray::new(Scalar::null(dtype), sliced_code.len()).to_array()
22+
ConstantArray::new(Scalar::null(array.dtype().clone()), sliced_code.len())
23+
.to_array()
2624
};
2725
}
2826
// SAFETY: slicing the codes preserves invariants
2927
unsafe { DictArray::new_unchecked(sliced_code, array.values().clone()).into_array() }
3028
}
3129

3230
fn scalar_at(array: &DictArray, index: usize) -> Scalar {
33-
let dict_index: usize = array
34-
.codes()
35-
.scalar_at(index)
36-
.as_ref()
37-
.try_into()
38-
.vortex_expect("code overflowed usize");
39-
array.values().scalar_at(dict_index)
31+
let Some(dict_index) = array.codes().scalar_at(index).as_primitive().as_::<usize>() else {
32+
return Scalar::null(array.dtype().clone());
33+
};
34+
35+
array
36+
.values()
37+
.scalar_at(dict_index)
38+
.cast(array.dtype())
39+
.vortex_expect("Array dtype will only differ by nullability")
4040
}
4141
}
4242

4343
#[cfg(test)]
4444
mod tests {
45+
use vortex_array::IntoArray;
4546
use vortex_array::arrays::PrimitiveArray;
47+
use vortex_buffer::buffer;
48+
use vortex_dtype::Nullability;
4649
use vortex_scalar::Scalar;
4750

4851
use crate::DictArray;
@@ -65,4 +68,19 @@ mod tests {
6568
dict.slice(1..2).as_constant()
6669
);
6770
}
71+
72+
#[test]
73+
fn test_scalar_at_null_code() {
74+
let dict = DictArray::try_new(
75+
PrimitiveArray::from_option_iter(vec![None, Some(0u32), None]).to_array(),
76+
buffer![1i32].into_array(),
77+
)
78+
.unwrap();
79+
80+
assert_eq!(dict.scalar_at(0), Scalar::null(dict.dtype().clone()));
81+
assert_eq!(
82+
dict.scalar_at(1),
83+
Scalar::primitive(1, Nullability::Nullable)
84+
);
85+
}
6886
}

0 commit comments

Comments
 (0)