Skip to content

Commit 06065ff

Browse files
authored
Normalize Execution (#7278)
Allow execution during normalization Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 9dcb75e commit 06065ff

File tree

8 files changed

+265
-81
lines changed

8 files changed

+265
-81
lines changed

Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-array/public-api.lock

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13264,15 +13264,17 @@ pub fn V::try_match<'a>(array: &'a vortex_array::ArrayRef) -> core::option::Opti
1326413264

1326513265
pub mod vortex_array::normalize
1326613266

13267-
pub enum vortex_array::normalize::Operation
13267+
pub enum vortex_array::normalize::Operation<'a>
1326813268

1326913269
pub vortex_array::normalize::Operation::Error
1327013270

13271+
pub vortex_array::normalize::Operation::Execute(&'a mut vortex_array::ExecutionCtx)
13272+
1327113273
pub struct vortex_array::normalize::NormalizeOptions<'a>
1327213274

13273-
pub vortex_array::normalize::NormalizeOptions::allowed: &'a vortex_array::session::ArrayRegistry
13275+
pub vortex_array::normalize::NormalizeOptions::allowed: &'a vortex_utils::aliases::hash_set::HashSet<vortex_session::registry::Id>
1327413276

13275-
pub vortex_array::normalize::NormalizeOptions::operation: vortex_array::normalize::Operation
13277+
pub vortex_array::normalize::NormalizeOptions::operation: vortex_array::normalize::Operation<'a>
1327613278

1327713279
pub mod vortex_array::optimizer
1327813280

vortex-array/src/normalize.rs

Lines changed: 195 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use itertools::Itertools;
54
use vortex_error::VortexResult;
65
use vortex_error::vortex_bail;
76
use vortex_session::registry::Id;
7+
use vortex_utils::aliases::hash_set::HashSet;
88

99
use crate::ArrayRef;
10-
use crate::session::ArrayRegistry;
10+
use crate::ExecutionCtx;
1111

1212
/// Options for normalizing an array.
1313
pub struct NormalizeOptions<'a> {
1414
/// The set of allowed array encodings (in addition to the canonical ones) that are permitted
1515
/// in the normalized array.
16-
pub allowed: &'a ArrayRegistry,
16+
pub allowed: &'a HashSet<Id>,
1717
/// The operation to perform when a non-allowed encoding is encountered.
18-
pub operation: Operation,
18+
pub operation: Operation<'a>,
1919
}
2020

2121
/// The operation to perform when a non-allowed encoding is encountered.
22-
pub enum Operation {
22+
pub enum Operation<'a> {
2323
Error,
24-
// TODO(joe): add into canonical variant
24+
Execute(&'a mut ExecutionCtx),
2525
}
2626

2727
impl ArrayRef {
@@ -30,14 +30,18 @@ impl ArrayRef {
3030
/// This operation performs a recursive traversal of the array. Any non-allowed encoding is
3131
/// normalized per the configured operation.
3232
pub fn normalize(self, options: &mut NormalizeOptions) -> VortexResult<ArrayRef> {
33-
let array_ids = options.allowed.ids().collect_vec();
34-
self.normalize_with_error(&array_ids)?;
35-
// Note this takes ownership so we can at a later date remove non-allowed encodings.
36-
Ok(self)
33+
match &mut options.operation {
34+
Operation::Error => {
35+
self.normalize_with_error(options.allowed)?;
36+
// Note this takes ownership so we can at a later date remove non-allowed encodings.
37+
Ok(self)
38+
}
39+
Operation::Execute(ctx) => self.normalize_with_execution(options.allowed, ctx),
40+
}
3741
}
3842

39-
fn normalize_with_error(&self, allowed: &[Id]) -> VortexResult<()> {
40-
if !allowed.contains(&self.encoding_id()) {
43+
fn normalize_with_error(&self, allowed: &HashSet<Id>) -> VortexResult<()> {
44+
if !self.is_allowed_encoding(allowed) {
4145
vortex_bail!(AssertionFailed: "normalize forbids encoding ({})", self.encoding_id())
4246
}
4347

@@ -46,4 +50,183 @@ impl ArrayRef {
4650
}
4751
Ok(())
4852
}
53+
54+
fn normalize_with_execution(
55+
self,
56+
allowed: &HashSet<Id>,
57+
ctx: &mut ExecutionCtx,
58+
) -> VortexResult<ArrayRef> {
59+
let mut normalized = self;
60+
61+
// Top-first execute the array tree while we hit non-allowed encodings.
62+
while !normalized.is_allowed_encoding(allowed) {
63+
normalized = normalized.execute(ctx)?;
64+
}
65+
66+
// Now we've normalized the root, we need to ensure the children are normalized also.
67+
let slots = normalized.slots();
68+
let mut normalized_slots = Vec::with_capacity(slots.len());
69+
let mut any_slot_changed = false;
70+
71+
for slot in slots {
72+
match slot {
73+
Some(child) => {
74+
let normalized_child = child.clone().normalize(&mut NormalizeOptions {
75+
allowed,
76+
operation: Operation::Execute(ctx),
77+
})?;
78+
any_slot_changed |= !ArrayRef::ptr_eq(child, &normalized_child);
79+
normalized_slots.push(Some(normalized_child));
80+
}
81+
None => normalized_slots.push(None),
82+
}
83+
}
84+
85+
if any_slot_changed {
86+
normalized = normalized.with_slots(normalized_slots)?;
87+
}
88+
89+
Ok(normalized)
90+
}
91+
92+
fn is_allowed_encoding(&self, allowed: &HashSet<Id>) -> bool {
93+
allowed.contains(&self.encoding_id()) || self.is_canonical()
94+
}
95+
}
96+
97+
#[cfg(test)]
98+
mod tests {
99+
use vortex_error::VortexResult;
100+
use vortex_session::VortexSession;
101+
use vortex_utils::aliases::hash_set::HashSet;
102+
103+
use super::NormalizeOptions;
104+
use super::Operation;
105+
use crate::ArrayRef;
106+
use crate::ExecutionCtx;
107+
use crate::IntoArray;
108+
use crate::arrays::Dict;
109+
use crate::arrays::DictArray;
110+
use crate::arrays::Primitive;
111+
use crate::arrays::PrimitiveArray;
112+
use crate::arrays::Slice;
113+
use crate::arrays::SliceArray;
114+
use crate::arrays::StructArray;
115+
use crate::assert_arrays_eq;
116+
use crate::validity::Validity;
117+
118+
#[test]
119+
fn normalize_with_execution_keeps_parent_when_children_are_unchanged() -> VortexResult<()> {
120+
let field = PrimitiveArray::from_iter(0i32..4).into_array();
121+
let array = StructArray::try_new(
122+
["field"].into(),
123+
vec![field.clone()],
124+
field.len(),
125+
Validity::NonNullable,
126+
)?
127+
.into_array();
128+
let allowed = HashSet::from_iter([array.encoding_id(), field.encoding_id()]);
129+
let mut ctx = ExecutionCtx::new(VortexSession::empty());
130+
131+
let normalized = array.clone().normalize(&mut NormalizeOptions {
132+
allowed: &allowed,
133+
operation: Operation::Execute(&mut ctx),
134+
})?;
135+
136+
assert!(ArrayRef::ptr_eq(&array, &normalized));
137+
Ok(())
138+
}
139+
140+
#[test]
141+
fn normalize_with_error_allows_canonical_arrays() -> VortexResult<()> {
142+
let field = PrimitiveArray::from_iter(0i32..4).into_array();
143+
let array = StructArray::try_new(
144+
["field"].into(),
145+
vec![field.clone()],
146+
field.len(),
147+
Validity::NonNullable,
148+
)?
149+
.into_array();
150+
let allowed = HashSet::default();
151+
152+
let normalized = array.clone().normalize(&mut NormalizeOptions {
153+
allowed: &allowed,
154+
operation: Operation::Error,
155+
})?;
156+
157+
assert!(ArrayRef::ptr_eq(&array, &normalized));
158+
Ok(())
159+
}
160+
161+
#[test]
162+
fn normalize_with_execution_rebuilds_parent_when_a_child_changes() -> VortexResult<()> {
163+
let unchanged = PrimitiveArray::from_iter(0i32..4).into_array();
164+
let sliced =
165+
SliceArray::new(PrimitiveArray::from_iter(10i32..20).into_array(), 2..6).into_array();
166+
let array = StructArray::try_new(
167+
["lhs", "rhs"].into(),
168+
vec![unchanged.clone(), sliced],
169+
unchanged.len(),
170+
Validity::NonNullable,
171+
)?
172+
.into_array();
173+
let allowed = HashSet::from_iter([array.encoding_id(), unchanged.encoding_id()]);
174+
let mut ctx = ExecutionCtx::new(VortexSession::empty());
175+
176+
let normalized = array.clone().normalize(&mut NormalizeOptions {
177+
allowed: &allowed,
178+
operation: Operation::Execute(&mut ctx),
179+
})?;
180+
181+
assert!(!ArrayRef::ptr_eq(&array, &normalized));
182+
183+
let original_children = array.children();
184+
let normalized_children = normalized.children();
185+
assert!(ArrayRef::ptr_eq(
186+
&original_children[0],
187+
&normalized_children[0]
188+
));
189+
assert!(!ArrayRef::ptr_eq(
190+
&original_children[1],
191+
&normalized_children[1]
192+
));
193+
assert_arrays_eq!(normalized_children[1], PrimitiveArray::from_iter(12i32..16));
194+
195+
Ok(())
196+
}
197+
198+
#[test]
199+
fn normalize_slice_of_dict_returns_dict() -> VortexResult<()> {
200+
let codes = PrimitiveArray::from_iter(vec![0u32, 1, 0, 1, 2]).into_array();
201+
let values = PrimitiveArray::from_iter(vec![10i32, 20, 30]).into_array();
202+
let dict = DictArray::try_new(codes, values)?.into_array();
203+
204+
// Slice the dict array to get a SliceArray wrapping a DictArray.
205+
let sliced = SliceArray::new(dict, 1..4).into_array();
206+
assert_eq!(sliced.encoding_id(), Slice::ID);
207+
208+
let allowed = HashSet::from_iter([Dict::ID, Primitive::ID]);
209+
let mut ctx = ExecutionCtx::new(VortexSession::empty());
210+
211+
println!("sliced {}", sliced.display_tree());
212+
213+
let normalized = sliced.normalize(&mut NormalizeOptions {
214+
allowed: &allowed,
215+
operation: Operation::Execute(&mut ctx),
216+
})?;
217+
218+
println!("after {}", normalized.display_tree());
219+
220+
// The normalized result should be a DictArray, not a SliceArray.
221+
assert_eq!(normalized.encoding_id(), Dict::ID);
222+
assert_eq!(normalized.len(), 3);
223+
224+
// Verify the data: codes [1,0,1] -> values [20, 10, 20]
225+
assert_arrays_eq!(
226+
normalized.to_canonical()?,
227+
PrimitiveArray::from_iter(vec![20i32, 10, 20])
228+
);
229+
230+
Ok(())
231+
}
49232
}

vortex-cuda/src/layout.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use futures::FutureExt;
1414
use futures::StreamExt;
1515
use futures::future::BoxFuture;
1616
use vortex::array::ArrayContext;
17+
use vortex::array::ArrayId;
1718
use vortex::array::ArrayRef;
1819
use vortex::array::DeserializeMetadata;
1920
use vortex::array::MaskFuture;
@@ -28,7 +29,6 @@ use vortex::array::normalize::NormalizeOptions;
2829
use vortex::array::normalize::Operation;
2930
use vortex::array::serde::SerializeOptions;
3031
use vortex::array::serde::SerializedArray;
31-
use vortex::array::session::ArrayRegistry;
3232
use vortex::array::stats::StatsSetRef;
3333
use vortex::buffer::BufferString;
3434
use vortex::buffer::ByteBuffer;
@@ -63,6 +63,7 @@ use vortex::scalar::upper_bound;
6363
use vortex::session::VortexSession;
6464
use vortex::session::registry::ReadContext;
6565
use vortex::utils::aliases::hash_map::HashMap;
66+
use vortex::utils::aliases::hash_set::HashSet;
6667

6768
/// A buffer inlined into layout metadata for host-side access.
6869
#[derive(Clone, prost::Message)]
@@ -390,7 +391,7 @@ pub struct CudaFlatLayoutStrategy {
390391
/// Maximum length of variable length statistics.
391392
pub max_variable_length_statistics_size: usize,
392393
/// Optional set of allowed array encodings for normalization.
393-
pub allowed_encodings: Option<ArrayRegistry>,
394+
pub allowed_encodings: Option<HashSet<ArrayId>>,
394395
}
395396

396397
impl Default for CudaFlatLayoutStrategy {
@@ -414,7 +415,7 @@ impl CudaFlatLayoutStrategy {
414415
self
415416
}
416417

417-
pub fn with_allow_encodings(mut self, allow_encodings: ArrayRegistry) -> Self {
418+
pub fn with_allow_encodings(mut self, allow_encodings: HashSet<ArrayId>) -> Self {
418419
self.allowed_encodings = Some(allow_encodings);
419420
self
420421
}

vortex-file/public-api.lock

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ impl vortex_file::WriteStrategyBuilder
344344

345345
pub fn vortex_file::WriteStrategyBuilder::build(self) -> alloc::sync::Arc<dyn vortex_layout::strategy::LayoutStrategy>
346346

347-
pub fn vortex_file::WriteStrategyBuilder::with_allow_encodings(self, allow_encodings: vortex_array::session::ArrayRegistry) -> Self
347+
pub fn vortex_file::WriteStrategyBuilder::with_allow_encodings(self, allow_encodings: vortex_utils::aliases::hash_set::HashSet<vortex_array::array::ArrayId>) -> Self
348348

349349
pub fn vortex_file::WriteStrategyBuilder::with_btrblocks_builder(self, builder: vortex_btrblocks::builder::BtrBlocksCompressorBuilder) -> Self
350350

@@ -396,7 +396,7 @@ pub const vortex_file::VERSION: u16
396396

397397
pub const vortex_file::VORTEX_FILE_EXTENSION: &str
398398

399-
pub static vortex_file::ALLOWED_ENCODINGS: std::sync::lazy_lock::LazyLock<vortex_array::session::ArrayRegistry>
399+
pub static vortex_file::ALLOWED_ENCODINGS: std::sync::lazy_lock::LazyLock<vortex_utils::aliases::hash_set::HashSet<vortex_array::array::ArrayId>>
400400

401401
pub trait vortex_file::OpenOptionsSessionExt: vortex_array::session::ArraySessionExt + vortex_layout::session::LayoutSessionExt + vortex_io::session::RuntimeSessionExt
402402

0 commit comments

Comments
 (0)