Skip to content

Commit 921b04b

Browse files
committed
u
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent b0d1f54 commit 921b04b

3 files changed

Lines changed: 137 additions & 99 deletions

File tree

encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,16 @@ where
106106
let len = array.len();
107107
let mut uninit_range = builder.uninit_range(len);
108108

109-
// SAFETY: We initialize all `len` values below via `decode_cast_into` and the patch loop.
109+
// SAFETY: We initialize all `len` values below via `decode_map_into` and the patch loop.
110110
unsafe {
111111
uninit_range.append_mask(array.validity()?.execute_mask(len, ctx)?);
112112
}
113113

114-
// SAFETY: `decode_cast_into` writes a value to every slot in this range.
114+
// SAFETY: `decode_map_into` writes a value to every slot in this range.
115115
let uninit_slice = unsafe { uninit_range.slice_uninit_mut(0, len) };
116116

117117
let mut chunks = array.unpacked_chunks::<F>()?;
118-
chunks.decode_cast_into(uninit_slice, |v: F| v.as_());
118+
chunks.decode_map_into(uninit_slice, |v: F| v.as_());
119119

120120
if let Some(patches) = array.patches() {
121121
apply_cast_patches_to_uninit_range::<F, T>(&mut uninit_range, &patches, ctx)?;

encodings/fastlanes/src/bitpacking/array/unpack_iter.rs

Lines changed: 70 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -104,50 +104,6 @@ impl<T: BitPacked> BitUnpackedChunks<T> {
104104
)
105105
}
106106

107-
/// Decode all chunks (initial, full, and trailer), mapping each value through `f` and writing
108-
/// the result into a differently-typed `output`.
109-
///
110-
/// Unlike [`decode_into`](Self::decode_into), full chunks cannot be unpacked directly into the
111-
/// output because the output element type `U` differs from the packed type `T`. Each chunk is
112-
/// unpacked into the small internal scratch buffer (which stays resident in cache) and then
113-
/// mapped value-by-value into `output`. This avoids materializing a full-length `T`-typed
114-
/// intermediate buffer, which is the win for cast-on-decompression (e.g. bit-packed `u16` cast
115-
/// to `u32`).
116-
pub fn decode_cast_into<U: Copy>(&mut self, output: &mut [MaybeUninit<U>], f: impl Fn(T) -> U) {
117-
let mut local_idx = 0;
118-
119-
if let Some(initial) = self.initial() {
120-
for (dst, &src) in output[..initial.len()].iter_mut().zip(initial.iter()) {
121-
dst.write(f(src));
122-
}
123-
local_idx += initial.len();
124-
}
125-
126-
// `initial` already handled the only chunk when `num_chunks == 1`; mirror the guard in
127-
// `decode_full_chunks_into_at` so we don't decode chunk 0 twice.
128-
if self.num_chunks != 1 {
129-
let mut chunks = self.full_chunks();
130-
while let Some(chunk) = chunks.next() {
131-
for (dst, &src) in output[local_idx..][..CHUNK_SIZE]
132-
.iter_mut()
133-
.zip(chunk.iter())
134-
{
135-
dst.write(f(src));
136-
}
137-
local_idx += CHUNK_SIZE;
138-
}
139-
}
140-
141-
if let Some(trailer) = self.trailer() {
142-
for (dst, &src) in output[local_idx..][..trailer.len()]
143-
.iter_mut()
144-
.zip(trailer.iter())
145-
{
146-
dst.write(f(src));
147-
}
148-
}
149-
}
150-
151107
pub fn full_chunks(&mut self) -> BitUnpackIterator<'_, T> {
152108
let elems_per_chunk = self.elems_per_chunk();
153109
let last_chunk_is_sliced = self.last_chunk_is_sliced() as usize;
@@ -161,6 +117,15 @@ impl<T: BitPacked> BitUnpackedChunks<T> {
161117
first_chunk_is_sliced,
162118
)
163119
}
120+
121+
/// Decode all chunks (initial, full, and trailer), mapping each value through `f` and writing
122+
/// the result into a differently-typed `output`.
123+
///
124+
/// Kept as a cast-oriented alias for callers that want the old name. Internal code can call
125+
/// `decode_map_into` directly.
126+
pub fn decode_cast_into<U: Copy>(&mut self, output: &mut [MaybeUninit<U>], f: impl Fn(T) -> U) {
127+
self.decode_map_into(output, f);
128+
}
164129
}
165130

166131
impl<T: PhysicalPType, S: UnpackStrategy<T>> UnpackedChunks<T, S> {
@@ -225,70 +190,79 @@ impl<T: PhysicalPType, S: UnpackStrategy<T>> UnpackedChunks<T, S> {
225190
})
226191
}
227192

228-
/// Decode all chunks (initial, full, and trailer) into the output range.
229-
/// This consolidates the logic for handling all three chunk types in one place.
230-
pub fn decode_into(&mut self, output: &mut [MaybeUninit<T>]) {
193+
/// Decode all chunks (initial, full, and trailer), calling `write_chunk` for each concrete
194+
/// unpacked chunk and its corresponding output range.
195+
pub(crate) fn decode_chunks_into<U>(
196+
&mut self,
197+
output: &mut [MaybeUninit<U>],
198+
mut write_chunk: impl FnMut(&[T], &mut [MaybeUninit<U>]),
199+
) {
200+
debug_assert_eq!(output.len(), self.len);
231201
let mut local_idx = 0;
232202

233-
// Handle initial partial chunk if present
234203
if let Some(initial) = self.initial() {
235-
local_idx = initial.len();
236-
237-
// TODO(connor): use `maybe_uninit_write_slice` feature when it gets stabilized.
238-
// https://github.com/rust-lang/rust/issues/79995
239-
// SAFETY: &[T] and &[MaybeUninit<T>] have the same layout.
240-
let init_initial: &[MaybeUninit<T>] = unsafe { mem::transmute(initial) };
241-
output[..local_idx].copy_from_slice(init_initial);
204+
let chunk_len = initial.len();
205+
write_chunk(initial, &mut output[..chunk_len]);
206+
local_idx += chunk_len;
242207
}
243208

244-
// Handle full chunks
245-
local_idx = self.decode_full_chunks_into_at(output, local_idx);
209+
if self.num_chunks != 1 {
210+
let first_chunk_is_sliced = self.first_chunk_is_sliced();
211+
let last_chunk_is_sliced = self.last_chunk_is_sliced();
212+
let full_chunks_range =
213+
(first_chunk_is_sliced as usize)..(self.num_chunks - last_chunk_is_sliced as usize);
214+
215+
let packed_slice: &[T::Physical] = buffer_as_slice(&self.packed);
216+
let elems_per_chunk = self.elems_per_chunk();
217+
for i in full_chunks_range {
218+
let chunk = &packed_slice[i * elems_per_chunk..][..elems_per_chunk];
219+
220+
unsafe {
221+
let dst: &mut [MaybeUninit<T>] = &mut self.buffer;
222+
// SAFETY: &[T] and &[MaybeUninit<T>] have the same layout.
223+
let dst: &mut [T::Physical] = mem::transmute(dst);
224+
self.strategy.unpack_chunk(self.bit_width, chunk, dst);
225+
}
246226

247-
// Handle trailing partial chunk if present
248-
if let Some(trailer) = self.trailer() {
249-
// TODO(connor): use `maybe_uninit_write_slice` feature when it gets stabilized.
250-
// https://github.com/rust-lang/rust/issues/79995
251-
// SAFETY: &[T] and &[MaybeUninit<T>] have the same layout.
252-
let init_trailer: &[MaybeUninit<T>] = unsafe { mem::transmute(trailer) };
253-
output[local_idx..][..init_trailer.len()].copy_from_slice(init_trailer);
227+
// SAFETY: `unpack_chunk` initialized the whole scratch chunk above.
228+
let unpacked: &[T] = unsafe { mem::transmute(&self.buffer[..]) };
229+
write_chunk(unpacked, &mut output[local_idx..local_idx + CHUNK_SIZE]);
230+
local_idx += CHUNK_SIZE;
231+
}
254232
}
255-
}
256233

257-
/// Unpack full chunks into output range starting at the given index.
258-
/// Returns the next local index to write to.
259-
fn decode_full_chunks_into_at(
260-
&mut self,
261-
output: &mut [MaybeUninit<T>],
262-
start_idx: usize,
263-
) -> usize {
264-
// If there's only one chunk it has been handled already by `initial` method
265-
if self.num_chunks == 1 {
266-
// Return the start_idx since initial already wrote everything.
267-
return start_idx;
234+
if let Some(trailer) = self.trailer() {
235+
let chunk_len = trailer.len();
236+
write_chunk(trailer, &mut output[local_idx..local_idx + chunk_len]);
237+
local_idx += chunk_len;
268238
}
269239

270-
let first_chunk_is_sliced = self.first_chunk_is_sliced();
271-
272-
let last_chunk_is_sliced = self.last_chunk_is_sliced();
273-
let full_chunks_range =
274-
(first_chunk_is_sliced as usize)..(self.num_chunks - last_chunk_is_sliced as usize);
275-
276-
let mut local_idx = start_idx;
277-
278-
let packed_slice: &[T::Physical] = buffer_as_slice(&self.packed);
279-
let elems_per_chunk = self.elems_per_chunk();
280-
for i in full_chunks_range {
281-
let chunk = &packed_slice[i * elems_per_chunk..][..elems_per_chunk];
240+
debug_assert_eq!(local_idx, self.len);
241+
}
282242

283-
unsafe {
284-
let uninit_dst = &mut output[local_idx..local_idx + CHUNK_SIZE];
285-
// SAFETY: &[T] and &[MaybeUninit<T>] have the same layout
286-
let dst: &mut [T::Physical] = mem::transmute(uninit_dst);
287-
self.strategy.unpack_chunk(self.bit_width, chunk, dst);
243+
/// Decode all chunks (initial, full, and trailer), mapping each unpacked value through `f`.
244+
pub(crate) fn decode_map_into<U>(
245+
&mut self,
246+
output: &mut [MaybeUninit<U>],
247+
mut f: impl FnMut(T) -> U,
248+
) {
249+
self.decode_chunks_into(output, |chunk, dst| {
250+
for (dst, &src) in dst.iter_mut().zip(chunk.iter()) {
251+
dst.write(f(src));
288252
}
289-
local_idx += CHUNK_SIZE;
290-
}
291-
local_idx
253+
});
254+
}
255+
256+
/// Decode all chunks (initial, full, and trailer) into the output range.
257+
/// This is the identity mapping of [`decode_map_into`](Self::decode_map_into).
258+
pub fn decode_into(&mut self, output: &mut [MaybeUninit<T>]) {
259+
self.decode_chunks_into(output, |chunk, dst| {
260+
// TODO(connor): use `maybe_uninit_write_slice` feature when it gets stabilized.
261+
// https://github.com/rust-lang/rust/issues/79995
262+
// SAFETY: &[T] and &[MaybeUninit<T>] have the same layout.
263+
let initialized: &[MaybeUninit<T>] = unsafe { mem::transmute(chunk) };
264+
dst.copy_from_slice(initialized);
265+
});
292266
}
293267

294268
/// Access last chunk of the array if the last chunk has fewer than 1024 due to slicing

encodings/fastlanes/src/bitpacking/compute/cast.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,15 @@ mod tests {
121121
use vortex_array::builtins::ArrayBuiltins;
122122
use vortex_array::compute::conformance::cast::test_cast_conformance;
123123
use vortex_array::dtype::DType;
124+
use vortex_array::dtype::NativePType;
124125
use vortex_array::dtype::Nullability;
125126
use vortex_array::dtype::PType;
127+
use vortex_array::match_each_integer_ptype;
128+
use vortex_array::scalar_fn::fns::cast::CastKernel;
126129
use vortex_buffer::buffer;
130+
use vortex_error::VortexResult;
127131

132+
use crate::BitPacked;
128133
use crate::BitPackedArray;
129134
use crate::BitPackedData;
130135

@@ -166,6 +171,65 @@ mod tests {
166171
);
167172
}
168173

174+
#[test]
175+
fn test_cast_bitpacked_widening_integer_matrix() -> VortexResult<()> {
176+
fn values<T: NativePType>(len: usize) -> PrimitiveArray {
177+
PrimitiveArray::from_iter((0..len).map(|i| {
178+
let value = if i % 17 == 0 { 31 } else { i % 8 };
179+
<T as num_traits::FromPrimitive>::from_usize(value)
180+
.expect("test values fit every integer ptype")
181+
}))
182+
}
183+
184+
fn supported(src: PType, tgt: PType) -> bool {
185+
src.is_int()
186+
&& tgt.is_int()
187+
&& tgt.byte_width() > src.byte_width()
188+
&& (src.is_unsigned_int() || tgt.is_signed_int())
189+
}
190+
191+
let ptypes = [
192+
PType::I8,
193+
PType::I16,
194+
PType::I32,
195+
PType::I64,
196+
PType::U8,
197+
PType::U16,
198+
PType::U32,
199+
PType::U64,
200+
];
201+
let lengths = [0, 1, 7, 1023, 1024, 1025, 2051];
202+
203+
for src in ptypes {
204+
for tgt in ptypes {
205+
if !supported(src, tgt) {
206+
continue;
207+
}
208+
209+
for len in lengths {
210+
let source = match_each_integer_ptype!(src, |S| { values::<S>(len) });
211+
let source_ref = source.clone().into_array();
212+
let packed = bp(&source_ref, 3);
213+
let target = DType::Primitive(tgt, Nullability::NonNullable);
214+
215+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
216+
let casted =
217+
<BitPacked as CastKernel>::cast(packed.as_view(), &target, &mut ctx)?
218+
.expect(
219+
"supported widening integer cast should hit BitPacked CastKernel",
220+
);
221+
let reference = source_ref
222+
.cast(target)?
223+
.execute::<PrimitiveArray>(&mut ctx)?;
224+
225+
assert_arrays_eq!(casted, reference);
226+
}
227+
}
228+
}
229+
230+
Ok(())
231+
}
232+
169233
#[rstest]
170234
#[case(bp(&buffer![0u8, 10, 20, 30, 40, 50, 60, 63].into_array(), 6))]
171235
#[case(bp(&buffer![0u16, 100, 200, 300, 400, 500].into_array(), 9))]

0 commit comments

Comments
 (0)