Skip to content

Commit f308dfe

Browse files
add unchecked array slot take and put (#7514)
Adds crate-private unchecked slot take/put helpers on `ArrayRef`. This allows for in-place swapping of array children is the array is exclusively owned. --------- Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 9b11e57 commit f308dfe

File tree

3 files changed

+118
-12
lines changed

3 files changed

+118
-12
lines changed

vortex-array/src/array/erased.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,56 @@ impl ArrayRef {
431431
self.with_slots(slots)
432432
}
433433

434+
/// Take a slot for executor-owned physical rewrites. This has the result that the array may
435+
/// either be taken or cloned from the parent.
436+
///
437+
/// The array can be put back with [`put_slot_unchecked`].
438+
///
439+
/// # Safety
440+
/// The caller must put back a slot with the same logical dtype and length before exposing the
441+
/// parent array, and must only use this for physical rewrites.
442+
pub(crate) unsafe fn take_slot_unchecked(
443+
mut self,
444+
slot_idx: usize,
445+
) -> VortexResult<(ArrayRef, ArrayRef)> {
446+
let child = if let Some(inner) = Arc::get_mut(&mut self.0) {
447+
// # Safety: ensured by the caller.
448+
unsafe { inner.slots_mut()[slot_idx].take() }
449+
.vortex_expect("take_slot_unchecked cannot take an absent slot")
450+
} else {
451+
self.slots()[slot_idx]
452+
.as_ref()
453+
.vortex_expect("take_slot_unchecked cannot take an absent slot")
454+
.clone()
455+
};
456+
457+
Ok((self, child))
458+
}
459+
460+
/// Puts an array into `slot_idx` by either, cloning the inner array if the Arc is not exclusive
461+
/// or replacing the slot in this `ArrayRef`.
462+
/// This is the mirror of [`take_slot_unchecked`].
463+
///
464+
/// # Safety
465+
/// The replacement must have the same logical dtype and length as the taken slot, and this
466+
/// must only be used for physical rewrites.
467+
pub(crate) unsafe fn put_slot_unchecked(
468+
mut self,
469+
slot_idx: usize,
470+
replacement: ArrayRef,
471+
) -> VortexResult<ArrayRef> {
472+
if let Some(inner) = Arc::get_mut(&mut self.0) {
473+
// # Safety: ensured by the caller.
474+
unsafe { inner.slots_mut()[slot_idx] = Some(replacement) };
475+
return Ok(self);
476+
}
477+
478+
let mut slots = self.slots().to_vec();
479+
slots[slot_idx] = Some(replacement);
480+
let inner = Arc::clone(&self.0);
481+
inner.with_slots(self, slots)
482+
}
483+
434484
/// Returns a new array with the provided slots.
435485
///
436486
/// This is only valid for physical rewrites: slot count, presence, logical `DType`, and
@@ -611,6 +661,7 @@ impl<V: VTable> Matcher for V {
611661

612662
fn try_match<'a>(array: &'a ArrayRef) -> Option<ArrayView<'a, V>> {
613663
let inner = array.0.as_any().downcast_ref::<ArrayInner<V>>()?;
664+
// # Safety checked by `downcast_ref`.
614665
Some(unsafe { ArrayView::new_unchecked(array, &inner.data) })
615666
}
616667
}

vortex-array/src/array/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,13 @@ pub(crate) trait DynArray: 'static + private::Sealed + Send + Sync + Debug {
6868
/// Returns the slots of the array.
6969
fn slots(&self) -> &[Option<ArrayRef>];
7070

71+
/// Returns mutable slots of the array.
72+
///
73+
/// # Safety: any slot (Some(child)) that replaces an existing slot must have a compatible
74+
/// DType and length. Currently compatible means equal, but there is no reason why that must
75+
/// be the case.
76+
unsafe fn slots_mut(&mut self) -> &mut [Option<ArrayRef>];
77+
7178
/// Returns the encoding ID of the array.
7279
fn encoding_id(&self) -> ArrayId;
7380

@@ -212,6 +219,10 @@ impl<V: VTable> DynArray for ArrayInner<V> {
212219
&self.slots
213220
}
214221

222+
unsafe fn slots_mut(&mut self) -> &mut [Option<ArrayRef>] {
223+
&mut self.slots
224+
}
225+
215226
fn encoding_id(&self) -> ArrayId {
216227
self.vtable.id()
217228
}

vortex-array/src/executor.rs

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::AnyCanonical;
3333
use crate::ArrayRef;
3434
use crate::Canonical;
3535
use crate::IntoArray;
36+
use crate::dtype::DType;
3637
use crate::matcher::Matcher;
3738
use crate::memory::HostAllocatorRef;
3839
use crate::memory::MemorySessionExt;
@@ -107,22 +108,21 @@ impl ArrayRef {
107108
/// maximum (default 128, override with `VORTEX_MAX_ITERATIONS`).
108109
pub fn execute_until<M: Matcher>(self, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
109110
let mut current = self.optimize()?;
110-
// Stack frames: (parent, slot_idx, done_predicate_for_slot)
111-
let mut stack: Vec<(ArrayRef, usize, DonePredicate)> = Vec::new();
111+
let mut stack: Vec<StackFrame> = Vec::new();
112112

113113
for _ in 0..max_iterations() {
114114
// Step 1: done / canonical — splice back into stacked parent or return.
115115
let is_done = stack
116116
.last()
117-
.map_or(M::matches as DonePredicate, |frame| frame.2);
117+
.map_or(M::matches as DonePredicate, |frame| frame.done);
118118
if is_done(&current) || AnyCanonical::matches(&current) {
119119
match stack.pop() {
120120
None => {
121121
ctx.log(format_args!("-> {}", current));
122122
return Ok(current);
123123
}
124-
Some((parent, slot_idx, _)) => {
125-
current = parent.with_slot(slot_idx, current)?.optimize()?;
124+
Some(frame) => {
125+
current = frame.put_back(current)?.optimize()?;
126126
continue;
127127
}
128128
}
@@ -139,8 +139,8 @@ impl ArrayRef {
139139
current, rewritten
140140
));
141141
current = rewritten.optimize()?;
142-
if let Some((parent, slot_idx, _)) = stack.pop() {
143-
current = parent.with_slot(slot_idx, current)?.optimize()?;
142+
if let Some(frame) = stack.pop() {
143+
current = frame.put_back(current)?.optimize()?;
144144
}
145145
continue;
146146
}
@@ -150,14 +150,15 @@ impl ArrayRef {
150150
let (array, step) = result.into_parts();
151151
match step {
152152
ExecutionStep::ExecuteSlot(i, done) => {
153-
let child = array.slots()[i]
154-
.clone()
155-
.vortex_expect("ExecuteSlot index in bounds");
153+
// SAFETY: we record the child's dtype and len, and assert they are preserved
154+
// when the slot is put back via `put_slot_unchecked`.
155+
let (parent, child) = unsafe { array.take_slot_unchecked(i) }?;
156156
ctx.log(format_args!(
157157
"ExecuteSlot({i}): pushing {}, focusing on {}",
158-
array, child
158+
parent, child
159159
));
160-
stack.push((array, i, done));
160+
let frame = StackFrame::new(parent, i, done, &child);
161+
stack.push(frame);
161162
current = child.optimize()?;
162163
}
163164
ExecutionStep::Done => {
@@ -174,6 +175,49 @@ impl ArrayRef {
174175
}
175176
}
176177

178+
/// A stack frame for the iterative executor, tracking the parent array whose slot is being
179+
/// executed and the original child's dtype/len for validation on put-back.
180+
struct StackFrame {
181+
parent: ArrayRef,
182+
slot_idx: usize,
183+
done: DonePredicate,
184+
original_dtype: DType,
185+
original_len: usize,
186+
}
187+
188+
impl StackFrame {
189+
fn new(parent: ArrayRef, slot_idx: usize, done: DonePredicate, child: &ArrayRef) -> Self {
190+
Self {
191+
parent,
192+
slot_idx,
193+
done,
194+
original_dtype: child.dtype().clone(),
195+
original_len: child.len(),
196+
}
197+
}
198+
199+
fn put_back(self, replacement: ArrayRef) -> VortexResult<ArrayRef> {
200+
debug_assert_eq!(
201+
replacement.dtype(),
202+
&self.original_dtype,
203+
"slot {} dtype changed from {} to {} during execution",
204+
self.slot_idx,
205+
self.original_dtype,
206+
replacement.dtype()
207+
);
208+
debug_assert_eq!(
209+
replacement.len(),
210+
self.original_len,
211+
"slot {} len changed from {} to {} during execution",
212+
self.slot_idx,
213+
self.original_len,
214+
replacement.len()
215+
);
216+
// SAFETY: we assert above that dtype and len are preserved.
217+
unsafe { self.parent.put_slot_unchecked(self.slot_idx, replacement) }
218+
}
219+
}
220+
177221
/// Execution context for batch CPU compute.
178222
#[derive(Debug, Clone)]
179223
pub struct ExecutionCtx {

0 commit comments

Comments
 (0)