Skip to content

Commit 72bca8b

Browse files
committed
f
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 6fd7fc1 commit 72bca8b

3 files changed

Lines changed: 90 additions & 503 deletions

File tree

vortex-array/src/arrays/primitive/compute/cast.rs

Lines changed: 83 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use std::mem::align_of;
5-
use std::mem::size_of;
6-
74
use num_traits::AsPrimitive;
85
use num_traits::NumCast;
9-
use vortex_buffer::BitBuffer;
106
use vortex_buffer::Buffer;
117
use vortex_buffer::BufferMut;
128
use vortex_buffer::lane_ops_indexed::ReinterpretSink;
139
use vortex_buffer::lane_ops_indexed::map_no_validity;
10+
use vortex_buffer::lane_ops_indexed::map_no_validity_in_place;
1411
use vortex_buffer::lane_ops_indexed::try_map_no_validity;
12+
use vortex_buffer::lane_ops_indexed::try_map_no_validity_in_place;
1513
use vortex_buffer::lane_ops_indexed::try_map_with_mask;
1614
use vortex_buffer::lane_ops_indexed::try_map_with_mask_in_place;
1715
use vortex_error::VortexResult;
@@ -132,7 +130,6 @@ where
132130
F: NativePType + AsPrimitive<T>,
133131
T: NativePType,
134132
{
135-
let values = array.as_slice::<F>();
136133
let overflow = || {
137134
vortex_err!(
138135
Compute: "Cannot cast {} to {} — value exceeds target range",
@@ -156,138 +153,112 @@ where
156153
let infallible = casts_losslessly_to(F::PTYPE, T::PTYPE)
157154
|| cached_values_fit_in(array, &target_dtype) == Some(true);
158155

159-
// Same-bit-width in-place fast path: when F and T have the same byte width and the
160-
// buffer is uniquely owned, mutate in place and transmute the wrapper. Saves the
161-
// output allocation. Falls through to the out-of-place path when the buffer is shared
162-
// (the common case under the current borrow-based kernel API).
156+
let len = array.len();
157+
158+
// Same-bit-width in-place fast path: when F and T have the same byte width, try to take
159+
// unique ownership of the buffer. If successful, each kernel call site below mutates in
160+
// place via `ReinterpretSink` and transmutes the wrapper at the end, saving the output
161+
// allocation. Falls back to the out-of-place path (borrowed slice + fresh buffer) when
162+
// the buffer is shared — the common case under the current borrow-based kernel API.
163163
let same_bit_width = F::PTYPE.byte_width() == T::PTYPE.byte_width();
164-
if same_bit_width
165-
&& let Ok(buffer_mut) = array.into_owned().try_into_buffer_mut::<F>()
166-
{
167-
return cast_buffer_in_place::<F, T>(buffer_mut, array, new_validity, ctx, infallible);
168-
}
164+
let owned: Option<BufferMut<F>> = if same_bit_width {
165+
array.into_owned().try_into_buffer_mut::<F>().ok()
166+
} else {
167+
None
168+
};
169+
let values: &[F] = array.as_slice::<F>();
169170

170171
if infallible {
171-
let mut buffer = BufferMut::<T>::with_capacity(values.len());
172-
// Truncating `as`-cast — safe here because stats prove every valid value fits.
173-
// Null lanes' underlying garbage gets truncated/wrapped (harmless: the result
174-
// validity bitmap masks them downstream).
175-
map_no_validity(
176-
values,
177-
&mut buffer.spare_capacity_mut()[..values.len()],
178-
|v| v.as_(),
179-
);
180-
// SAFETY: map_no_validity initializes every lane.
181-
unsafe { buffer.set_len(values.len()) };
182-
return Ok(PrimitiveArray::new(buffer.freeze(), new_validity).into_array());
172+
// Truncating `as`-cast — safe here because static type analysis or cached stats prove
173+
// every valid value fits. Null lanes' underlying garbage gets truncated/wrapped
174+
// (harmless: the result validity bitmap masks them downstream).
175+
return match owned {
176+
Some(mut buf) => {
177+
map_no_validity_in_place(
178+
ReinterpretSink::<F, T>::new(buf.as_mut_slice()),
179+
|v: F| v.as_(),
180+
);
181+
// SAFETY: same size + alignment for NativePType same-byte-width pairs;
182+
// every F-slot was overwritten with a real `T` bit pattern.
183+
let result: BufferMut<T> = unsafe { buf.transmute::<T>() };
184+
Ok(PrimitiveArray::new(result.freeze(), new_validity).into_array())
185+
}
186+
None => {
187+
let mut buffer = BufferMut::<T>::with_capacity(len);
188+
map_no_validity(values, &mut buffer.spare_capacity_mut()[..len], |v| v.as_());
189+
// SAFETY: map_no_validity initializes every lane.
190+
unsafe { buffer.set_len(len) };
191+
Ok(PrimitiveArray::new(buffer.freeze(), new_validity).into_array())
192+
}
193+
};
183194
}
184195

185-
let mask = array.validity()?.execute_mask(array.len(), ctx)?;
196+
let mask = array.validity()?.execute_mask(len, ctx)?;
186197

187-
let buffer: Buffer<T> = match &mask {
188-
Mask::AllTrue(_) => {
189-
let mut buffer = BufferMut::<T>::with_capacity(values.len());
190-
try_map_no_validity(
191-
values,
192-
&mut buffer.spare_capacity_mut()[..values.len()],
193-
|v| <T as NumCast>::from(v),
198+
let buffer: Buffer<T> = match (&mask, owned) {
199+
(Mask::AllTrue(_), Some(mut buf)) => {
200+
try_map_no_validity_in_place(
201+
ReinterpretSink::<F, T>::new(buf.as_mut_slice()),
202+
|v: F| <T as NumCast>::from(v),
194203
)
195204
.map_err(|_| overflow())?;
205+
// SAFETY: same size + alignment for NativePType same-byte-width pairs;
206+
// every F-slot now holds a `T` bit pattern written by `ReinterpretSink`.
207+
let result: BufferMut<T> = unsafe { buf.transmute::<T>() };
208+
result.freeze()
209+
}
210+
(Mask::AllTrue(_), None) => {
211+
let mut buffer = BufferMut::<T>::with_capacity(len);
212+
try_map_no_validity(values, &mut buffer.spare_capacity_mut()[..len], |v| {
213+
<T as NumCast>::from(v)
214+
})
215+
.map_err(|_| overflow())?;
196216
// SAFETY: try_map_no_validity returned Ok, so it initialized every lane.
197-
unsafe { buffer.set_len(values.len()) };
217+
unsafe { buffer.set_len(len) };
198218
buffer.freeze()
199219
}
200-
Mask::AllFalse(_) => BufferMut::<T>::zeroed(values.len()).freeze(),
201-
Mask::Values(m) => {
202-
let mut buffer = BufferMut::<T>::with_capacity(values.len());
220+
(Mask::AllFalse(_), Some(buf)) => {
221+
// SAFETY: same size + alignment by NativePType same-byte-width invariant.
222+
let mut t_buf: BufferMut<T> = unsafe { buf.transmute::<T>() };
223+
t_buf.as_mut_slice().fill(T::zero());
224+
t_buf.freeze()
225+
}
226+
(Mask::AllFalse(_), None) => BufferMut::<T>::zeroed(len).freeze(),
227+
(Mask::Values(m), Some(mut buf)) => {
228+
try_map_with_mask_in_place(
229+
ReinterpretSink::<F, T>::new(buf.as_mut_slice()),
230+
m.bit_buffer(),
231+
|v: F, valid| <T as NumCast>::from(v).or_else(|| (!valid).then(T::zero)),
232+
)
233+
.map_err(|_| overflow())?;
234+
// SAFETY: same size + alignment for NativePType same-byte-width pairs;
235+
// every F-slot now holds a `T` bit pattern written by `ReinterpretSink`.
236+
let result: BufferMut<T> = unsafe { buf.transmute::<T>() };
237+
result.freeze()
238+
}
239+
(Mask::Values(m), None) => {
240+
let mut buffer = BufferMut::<T>::with_capacity(len);
203241
try_map_with_mask(
204242
values,
205243
m.bit_buffer(),
206-
&mut buffer.spare_capacity_mut()[..values.len()],
207-
// Lazy validity: only consult `valid` on the failure branch. For
208-
// widening / statically-infallible casts, `NumCast::from` is always
209-
// `Some` so the `or_else` is provably dead — LLVM DCEs the validity
210-
// path entirely, giving the same codegen as the maskless kernel.
211-
// For narrowing, `valid` is only read at lanes that actually
212-
// overflowed (a cold check on top of the cast).
244+
&mut buffer.spare_capacity_mut()[..len],
245+
// Lazy validity: only consult `valid` on the failure branch. For widening /
246+
// statically-infallible casts, `NumCast::from` is always `Some` so the
247+
// `or_else` is provably dead — LLVM DCEs the validity path entirely, giving
248+
// the same codegen as the maskless kernel. For narrowing, `valid` is only
249+
// read at lanes that actually overflowed (a cold check on top of the cast).
213250
|v, valid| <T as NumCast>::from(v).or_else(|| (!valid).then(T::zero)),
214251
)
215252
.map_err(|_| overflow())?;
216253
// SAFETY: try_map_with_mask returned Ok, so it initialized every lane.
217-
unsafe { buffer.set_len(values.len()) };
254+
unsafe { buffer.set_len(len) };
218255
buffer.freeze()
219256
}
220257
};
221258

222259
Ok(PrimitiveArray::new(buffer, new_validity).into_array())
223260
}
224261

225-
/// In-place cast of an owned `BufferMut<F>` to `BufferMut<T>` when `F` and `T` have the
226-
/// same byte width. Each slot is read as `F`, converted, and written back as `T`-bits
227-
/// using `BufferMut`'s transmute family. Avoids allocating a second output buffer.
228-
///
229-
/// The caller has already verified `F::PTYPE.byte_width() == T::PTYPE.byte_width()`.
230-
fn cast_buffer_in_place<F, T>(
231-
buffer: BufferMut<F>,
232-
array: ArrayView<'_, Primitive>,
233-
new_validity: Validity,
234-
ctx: &mut ExecutionCtx,
235-
infallible: bool,
236-
) -> VortexResult<ArrayRef>
237-
where
238-
F: NativePType + AsPrimitive<T>,
239-
T: NativePType,
240-
{
241-
debug_assert_eq!(size_of::<F>(), size_of::<T>());
242-
debug_assert_eq!(align_of::<F>(), align_of::<T>());
243-
244-
if infallible {
245-
// `map_each_in_place` does the BufferMut<F> → BufferMut<T> transmute internally
246-
// (same size + alignment for primitives of equal byte width) and walks each slot
247-
// with the closure.
248-
let result: BufferMut<T> = buffer.map_each_in_place(|v: F| v.as_());
249-
return Ok(PrimitiveArray::new(result.freeze(), new_validity).into_array());
250-
}
251-
252-
let mask = array.validity()?.execute_mask(array.len(), ctx)?;
253-
let overflow = || {
254-
vortex_err!(
255-
Compute: "Cannot cast {} to {} — value exceeds target range",
256-
F::PTYPE, T::PTYPE,
257-
)
258-
};
259-
260-
// All-null short-circuit: zero out the buffer and skip the conversion loop entirely.
261-
if matches!(mask, Mask::AllFalse(_)) {
262-
// SAFETY: same size + alignment by NativePType same-byte-width invariant.
263-
let mut t_buf: BufferMut<T> = unsafe { buffer.transmute::<T>() };
264-
t_buf.as_mut_slice().fill(T::zero());
265-
return Ok(PrimitiveArray::new(t_buf.freeze(), new_validity).into_array());
266-
}
267-
268-
let bit_buffer = match &mask {
269-
Mask::AllTrue(n) => BitBuffer::new_set(*n),
270-
Mask::AllFalse(_) => unreachable!("handled above"),
271-
Mask::Values(m) => m.bit_buffer().clone(),
272-
};
273-
274-
let mut buffer = buffer;
275-
try_map_with_mask_in_place(
276-
ReinterpretSink::<F, T>::new(buffer.as_mut_slice()),
277-
&bit_buffer,
278-
|f_val: F, valid| -> Option<T> {
279-
<T as NumCast>::from(f_val).or_else(|| (!valid).then(T::zero))
280-
},
281-
)
282-
.map_err(|_| overflow())?;
283-
284-
// SAFETY: same size + alignment for NativePType same-byte-width pairs. Every F-slot
285-
// now holds a valid T-bit pattern because `ReinterpretSink::set_unchecked` wrote a
286-
// real `T` at every visited lane.
287-
let result: BufferMut<T> = unsafe { buffer.transmute::<T>() };
288-
Ok(PrimitiveArray::new(result.freeze(), new_validity).into_array())
289-
}
290-
291262
fn reinterpret(
292263
array: ArrayView<'_, Primitive>,
293264
new_ptype: PType,

vortex-buffer/benches/add_checked.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,13 @@ fn premask_then_simd(bencher: Bencher, n: usize) {
513513
.bench_refs(|(lhs, rhs, lm, rm)| {
514514
let combined = lm as &BitBuffer & rm as &BitBuffer;
515515
let mut out = alloc_out(n);
516-
handrolled_premask(lhs.as_slice(), rhs.as_slice(), &combined, out.as_mut_slice())
517-
.unwrap();
516+
handrolled_premask(
517+
lhs.as_slice(),
518+
rhs.as_slice(),
519+
&combined,
520+
out.as_mut_slice(),
521+
)
522+
.unwrap();
518523
(combined, out)
519524
});
520525
}

0 commit comments

Comments
 (0)