Skip to content

Commit ad61b12

Browse files
committed
add union type
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent e160125 commit ad61b12

46 files changed

Lines changed: 1784 additions & 109 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

encodings/sparse/src/canonical.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,6 @@ pub(super) fn execute_sparse(parts: SparseParts, ctx: &mut ExecutionCtx) -> Vort
114114
execute_sparse_primitives::<P>(&patches, &fill_value, ctx)?
115115
})
116116
}
117-
DType::Struct(struct_fields, ..) => execute_sparse_struct(
118-
struct_fields,
119-
fill_value.as_struct(),
120-
dtype.nullability(),
121-
&patches,
122-
len,
123-
ctx,
124-
)?,
125117
DType::Decimal(decimal_dtype, nullability) => {
126118
let canonical_decimal_value_type =
127119
DecimalType::smallest_decimal_value_type(decimal_dtype);
@@ -157,6 +149,15 @@ pub(super) fn execute_sparse(parts: SparseParts, ctx: &mut ExecutionCtx) -> Vort
157149
DType::FixedSizeList(.., nullability) => {
158150
execute_sparse_fixed_size_list(&patches, &fill_value, len, *nullability, ctx)?
159151
}
152+
DType::Struct(struct_fields, ..) => execute_sparse_struct(
153+
struct_fields,
154+
fill_value.as_struct(),
155+
dtype.nullability(),
156+
&patches,
157+
len,
158+
ctx,
159+
)?,
160+
DType::Union(..) => todo!("TODO(connor)[Union]: unimplemented"),
160161
DType::Extension(_ext_dtype) => todo!(),
161162
DType::Variant(_) => vortex_bail!("Sparse canonicalization does not support Variant"),
162163
})

fuzz/src/array/compare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ pub fn compare_canonical_array(
173173
}))
174174
.into_array()
175175
}
176-
d @ (DType::Null | DType::Extension(_) | DType::Variant(_)) => {
176+
d @ (DType::Null | DType::Union(..) | DType::Extension(_) | DType::Variant(_)) => {
177177
unreachable!("DType {d} not supported for fuzzing")
178178
}
179179
}

fuzz/src/array/filter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ pub fn filter_canonical_array(
122122
}
123123
take_canonical_array_non_nullable_indices(array, indices.as_slice(), ctx)
124124
}
125-
d @ (DType::Null | DType::Extension(_) | DType::Variant(_)) => {
125+
d @ (DType::Null | DType::Union(..) | DType::Extension(_) | DType::Variant(_)) => {
126126
unreachable!("DType {d} not supported for fuzzing")
127127
}
128128
}

fuzz/src/array/mod.rs

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -464,24 +464,8 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
464464
use ActionType::*;
465465

466466
match dtype {
467-
DType::Struct(sdt, _) => {
468-
// Struct supports: Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt
469-
// Does NOT support: SearchSorted (requires scalar comparison), Compare, Cast, Sum, FillNull
470-
let struct_actions = [Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt];
471-
sdt.fields()
472-
.map(|child| actions_for_dtype(&child))
473-
.fold(struct_actions.into(), |acc, actions| {
474-
acc.intersection(&actions).copied().collect()
475-
})
476-
}
477-
DType::List(..) | DType::FixedSizeList(..) => {
478-
// List supports: Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt
479-
// Does NOT support: SearchSorted, Compare, Cast, Sum, FillNull
480-
[Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt].into()
481-
}
482-
DType::Utf8(_) | DType::Binary(_) => {
483-
// Utf8/Binary supports everything except Sum and FillNull
484-
// Actions: Compress, Slice, Take, SearchSorted, Filter, Compare, Cast, MinMax, Mask, ScalarAt
467+
DType::Null => {
468+
// Null arrays support most operations but not Sum or MinMax (return None for dtype)
485469
[
486470
Compress,
487471
Slice,
@@ -490,7 +474,7 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
490474
Filter,
491475
Compare,
492476
Cast,
493-
MinMax,
477+
FillNull,
494478
Mask,
495479
ScalarAt,
496480
]
@@ -500,8 +484,9 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
500484
// These support all actions
501485
ActionType::iter().collect()
502486
}
503-
DType::Null => {
504-
// Null arrays support most operations but not Sum or MinMax (return None for dtype)
487+
DType::Utf8(_) | DType::Binary(_) => {
488+
// Utf8/Binary supports everything except Sum and FillNull
489+
// Actions: Compress, Slice, Take, SearchSorted, Filter, Compare, Cast, MinMax, Mask, ScalarAt
505490
[
506491
Compress,
507492
Slice,
@@ -510,12 +495,28 @@ fn actions_for_dtype(dtype: &DType) -> HashSet<ActionType> {
510495
Filter,
511496
Compare,
512497
Cast,
513-
FillNull,
498+
MinMax,
514499
Mask,
515500
ScalarAt,
516501
]
517502
.into()
518503
}
504+
DType::List(..) | DType::FixedSizeList(..) => {
505+
// List supports: Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt
506+
// Does NOT support: SearchSorted, Compare, Cast, Sum, FillNull
507+
[Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt].into()
508+
}
509+
DType::Struct(sdt, _) => {
510+
// Struct supports: Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt
511+
// Does NOT support: SearchSorted (requires scalar comparison), Compare, Cast, Sum, FillNull
512+
let struct_actions = [Compress, Slice, Take, Filter, MinMax, Mask, ScalarAt];
513+
sdt.fields()
514+
.map(|child| actions_for_dtype(&child))
515+
.fold(struct_actions.into(), |acc, actions| {
516+
acc.intersection(&actions).copied().collect()
517+
})
518+
}
519+
DType::Union(..) => todo!("TODO(connor)[Union]: unimplemented"),
519520
DType::Extension(_) => {
520521
// Extension types delegate to storage dtype, support most operations
521522
ActionType::iter().collect()

fuzz/src/array/search_sorted.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ pub fn search_sorted_canonical_array(
148148
.collect::<VortexResult<Vec<_>>>()?;
149149
scalar_vals.search_sorted(&scalar.cast(array.dtype())?, side)
150150
}
151-
d @ (DType::Null | DType::Extension(_) | DType::Variant(_)) => {
151+
d @ (DType::Null | DType::Union(..) | DType::Extension(_) | DType::Variant(_)) => {
152152
unreachable!("DType {d} not supported for fuzzing")
153153
}
154154
}

fuzz/src/array/slice.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ pub fn slice_canonical_array(
123123
.into_array(),
124124
)
125125
}
126-
d @ (DType::Null | DType::Extension(_) | DType::Variant(_)) => {
126+
d @ (DType::Null | DType::Union(..) | DType::Extension(_) | DType::Variant(_)) => {
127127
unreachable!("DType {d} not supported for fuzzing")
128128
}
129129
}

fuzz/src/array/sort.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pub fn sort_canonical_array(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexR
101101
});
102102
take_canonical_array_non_nullable_indices(array, &sort_indices, ctx)
103103
}
104-
d @ (DType::Null | DType::Extension(_) | DType::Variant(_)) => {
104+
d @ (DType::Null | DType::Union(..) | DType::Extension(_) | DType::Variant(_)) => {
105105
unreachable!("DType {d} not supported for fuzzing")
106106
}
107107
}

fuzz/src/array/take.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ pub fn take_canonical_array(
147147
}
148148
Ok(builder.finish())
149149
}
150-
d @ (DType::Null | DType::Extension(_) | DType::Variant(_)) => {
150+
d @ (DType::Null | DType::Union(..) | DType::Extension(_) | DType::Variant(_)) => {
151151
unreachable!("DType {d} not supported for fuzzing")
152152
}
153153
}

vortex-array/public-api.lock

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9346,6 +9346,8 @@ pub vortex_array::dtype::DType::Primitive(vortex_array::dtype::PType, vortex_arr
93469346

93479347
pub vortex_array::dtype::DType::Struct(vortex_array::dtype::StructFields, vortex_array::dtype::Nullability)
93489348

9349+
pub vortex_array::dtype::DType::Union(vortex_array::dtype::UnionVariants, vortex_array::dtype::Nullability)
9350+
93499351
pub vortex_array::dtype::DType::Utf8(vortex_array::dtype::Nullability)
93509352

93519353
pub vortex_array::dtype::DType::Variant(vortex_array::dtype::Nullability)
@@ -9376,6 +9378,8 @@ pub fn vortex_array::dtype::DType::as_struct_fields(&self) -> &vortex_array::dty
93769378

93779379
pub fn vortex_array::dtype::DType::as_struct_fields_opt(&self) -> core::option::Option<&vortex_array::dtype::StructFields>
93789380

9381+
pub fn vortex_array::dtype::DType::as_union_variants_opt(&self) -> core::option::Option<&vortex_array::dtype::UnionVariants>
9382+
93799383
pub fn vortex_array::dtype::DType::element_size(&self) -> core::option::Option<usize>
93809384

93819385
pub fn vortex_array::dtype::DType::eq_ignore_nullability(&self, &Self) -> bool
@@ -9396,6 +9400,8 @@ pub fn vortex_array::dtype::DType::into_struct_fields(self) -> vortex_array::dty
93969400

93979401
pub fn vortex_array::dtype::DType::into_struct_fields_opt(self) -> core::option::Option<vortex_array::dtype::StructFields>
93989402

9403+
pub fn vortex_array::dtype::DType::into_union_variants_opt(self) -> core::option::Option<vortex_array::dtype::UnionVariants>
9404+
93999405
pub fn vortex_array::dtype::DType::is_binary(&self) -> bool
94009406

94019407
pub fn vortex_array::dtype::DType::is_boolean(&self) -> bool
@@ -9422,6 +9428,8 @@ pub fn vortex_array::dtype::DType::is_signed_int(&self) -> bool
94229428

94239429
pub fn vortex_array::dtype::DType::is_struct(&self) -> bool
94249430

9431+
pub fn vortex_array::dtype::DType::is_union(&self) -> bool
9432+
94259433
pub fn vortex_array::dtype::DType::is_unsigned_int(&self) -> bool
94269434

94279435
pub fn vortex_array::dtype::DType::is_utf8(&self) -> bool
@@ -10506,6 +10514,66 @@ impl<T, V> core::iter::traits::collect::FromIterator<(T, V)> for vortex_array::d
1050610514

1050710515
pub fn vortex_array::dtype::StructFields::from_iter<I: core::iter::traits::collect::IntoIterator<Item = (T, V)>>(I) -> Self
1050810516

10517+
pub struct vortex_array::dtype::UnionVariants(_)
10518+
10519+
impl vortex_array::dtype::UnionVariants
10520+
10521+
pub fn vortex_array::dtype::UnionVariants::child_index_to_tag(&self, usize) -> i8
10522+
10523+
pub fn vortex_array::dtype::UnionVariants::empty() -> Self
10524+
10525+
pub fn vortex_array::dtype::UnionVariants::find(&self, impl core::convert::AsRef<str>) -> core::option::Option<usize>
10526+
10527+
pub fn vortex_array::dtype::UnionVariants::is_consecutive(&self) -> bool
10528+
10529+
pub fn vortex_array::dtype::UnionVariants::is_empty(&self) -> bool
10530+
10531+
pub fn vortex_array::dtype::UnionVariants::len(&self) -> usize
10532+
10533+
pub fn vortex_array::dtype::UnionVariants::names(&self) -> &vortex_array::dtype::FieldNames
10534+
10535+
pub fn vortex_array::dtype::UnionVariants::new_consecutive(vortex_array::dtype::FieldNames, alloc::vec::Vec<vortex_array::dtype::DType>) -> vortex_error::VortexResult<Self>
10536+
10537+
pub fn vortex_array::dtype::UnionVariants::nullability_constraints_satisfied(&self, vortex_array::dtype::Nullability) -> bool
10538+
10539+
pub fn vortex_array::dtype::UnionVariants::tag_to_child_index(&self, i8) -> core::option::Option<usize>
10540+
10541+
pub fn vortex_array::dtype::UnionVariants::try_new(vortex_array::dtype::FieldNames, alloc::vec::Vec<vortex_array::dtype::DType>, alloc::vec::Vec<i8>) -> vortex_error::VortexResult<Self>
10542+
10543+
pub fn vortex_array::dtype::UnionVariants::type_ids(&self) -> &[i8]
10544+
10545+
pub fn vortex_array::dtype::UnionVariants::variant(&self, impl core::convert::AsRef<str>) -> core::option::Option<vortex_array::dtype::DType>
10546+
10547+
pub fn vortex_array::dtype::UnionVariants::variant_by_index(&self, usize) -> core::option::Option<vortex_array::dtype::DType>
10548+
10549+
pub fn vortex_array::dtype::UnionVariants::variants(&self) -> impl core::iter::traits::exact_size::ExactSizeIterator<Item = vortex_array::dtype::DType> + '_
10550+
10551+
impl core::clone::Clone for vortex_array::dtype::UnionVariants
10552+
10553+
pub fn vortex_array::dtype::UnionVariants::clone(&self) -> vortex_array::dtype::UnionVariants
10554+
10555+
impl core::cmp::Eq for vortex_array::dtype::UnionVariants
10556+
10557+
impl core::cmp::PartialEq for vortex_array::dtype::UnionVariants
10558+
10559+
pub fn vortex_array::dtype::UnionVariants::eq(&self, &Self) -> bool
10560+
10561+
impl core::default::Default for vortex_array::dtype::UnionVariants
10562+
10563+
pub fn vortex_array::dtype::UnionVariants::default() -> Self
10564+
10565+
impl core::fmt::Debug for vortex_array::dtype::UnionVariants
10566+
10567+
pub fn vortex_array::dtype::UnionVariants::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result
10568+
10569+
impl core::fmt::Display for vortex_array::dtype::UnionVariants
10570+
10571+
pub fn vortex_array::dtype::UnionVariants::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result
10572+
10573+
impl core::hash::Hash for vortex_array::dtype::UnionVariants
10574+
10575+
pub fn vortex_array::dtype::UnionVariants::hash<__H: core::hash::Hasher>(&self, &mut __H)
10576+
1050910577
#[repr(transparent)] pub struct vortex_array::dtype::i256(_)
1051010578

1051110579
impl vortex_array::dtype::i256

vortex-array/src/aggregate_fn/fns/uncompressed_size_in_bytes/mod.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,14 @@ pub(crate) fn constant_uncompressed_size_in_bytes(
240240
array.len(),
241241
array.scalar().as_binary().value().map(|value| value.len()),
242242
)?,
243-
DType::Variant(_) => {
244-
vortex_bail!("UncompressedSizeInBytes is not supported for Variant arrays")
245-
}
246243
DType::Struct(..) | DType::List(..) | DType::FixedSizeList(..) | DType::Extension(_) => {
247244
let canonical = array.array().clone().execute::<Canonical>(ctx)?;
248245
return canonical_uncompressed_size_in_bytes(&canonical, ctx);
249246
}
247+
DType::Union(..) => todo!("TODO(connor)[Union]: unimplemented"),
248+
DType::Variant(_) => {
249+
vortex_bail!("UncompressedSizeInBytes is not supported for Variant arrays")
250+
}
250251
};
251252

252253
value_size
@@ -287,22 +288,25 @@ fn checked_len_mul(len: usize, width: usize, name: &str) -> VortexResult<u64> {
287288

288289
fn supports_uncompressed_size_in_bytes(dtype: &DType) -> bool {
289290
match dtype {
291+
DType::Null
292+
| DType::Bool(_)
293+
| DType::Primitive(..)
294+
| DType::Decimal(..)
295+
| DType::Utf8(_)
296+
| DType::Binary(_) => true,
290297
DType::List(element_dtype, _) | DType::FixedSizeList(element_dtype, ..) => {
291298
supports_uncompressed_size_in_bytes(element_dtype)
292299
}
293300
DType::Struct(fields, _) => fields
294301
.fields()
295302
.all(|field| supports_uncompressed_size_in_bytes(&field)),
303+
DType::Union(variants, _) => variants
304+
.variants()
305+
.all(|variant| supports_uncompressed_size_in_bytes(&variant)),
296306
DType::Extension(ext_dtype) => {
297307
supports_uncompressed_size_in_bytes(ext_dtype.storage_dtype())
298308
}
299309
DType::Variant(_) => false,
300-
DType::Null
301-
| DType::Bool(_)
302-
| DType::Primitive(..)
303-
| DType::Decimal(..)
304-
| DType::Utf8(_)
305-
| DType::Binary(_) => true,
306310
}
307311
}
308312

0 commit comments

Comments
 (0)