|
1 | 1 | // SPDX-License-Identifier: Apache-2.0 |
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
4 | | -use arrayref::array_mut_ref; |
5 | 4 | use fastlanes::RLE as FastLanesRLE; |
6 | 5 | use vortex_array::IntoArray; |
7 | 6 | use vortex_array::ToCanonical; |
@@ -51,46 +50,58 @@ where |
51 | 50 | let mut values_idx_offsets = BufferMut::<u64>::with_capacity(len.div_ceil(FL_CHUNK_SIZE)); |
52 | 51 |
|
53 | 52 | let values_uninit = values_buf.spare_capacity_mut(); |
54 | | - let indices_uninit = indices_buf.spare_capacity_mut(); |
| 53 | + let (indices_uninit, _) = indices_buf |
| 54 | + .spare_capacity_mut() |
| 55 | + .as_chunks_mut::<FL_CHUNK_SIZE>(); |
55 | 56 | let mut value_count_acc = 0; // Chunk value count prefix sum. |
56 | 57 |
|
57 | 58 | let (chunks, remainder) = values.as_chunks::<FL_CHUNK_SIZE>(); |
58 | 59 |
|
59 | | - let mut process_chunk = |chunk_start_idx: usize, input: &[T; FL_CHUNK_SIZE]| { |
| 60 | + let mut process_chunk = |chunk_start_idx: usize, |
| 61 | + input: &[T; FL_CHUNK_SIZE], |
| 62 | + rle_idxs: &mut [u16; FL_CHUNK_SIZE]| { |
60 | 63 | // SAFETY: NativeValue is repr(transparent) |
61 | 64 | let input: &[NativeValue<T>; FL_CHUNK_SIZE] = unsafe { std::mem::transmute(input) }; |
62 | 65 |
|
63 | 66 | // SAFETY: `MaybeUninit<NativeValue<T>>` and `NativeValue<T>` have the same layout. |
64 | 67 | let rle_vals: &mut [NativeValue<T>] = |
65 | 68 | unsafe { std::mem::transmute(&mut values_uninit[value_count_acc..][..FL_CHUNK_SIZE]) }; |
66 | 69 |
|
67 | | - // SAFETY: `MaybeUninit<u16>` and `u16` have the same layout. |
68 | | - let rle_idxs: &mut [u16] = |
69 | | - unsafe { std::mem::transmute(&mut indices_uninit[chunk_start_idx..][..FL_CHUNK_SIZE]) }; |
70 | | - |
71 | 70 | // Capture chunk start indices. This is necessary as indices |
72 | 71 | // returned from `T::encode` are relative to the chunk. |
73 | 72 | values_idx_offsets.push(value_count_acc as u64); |
74 | 73 |
|
75 | 74 | let value_count = NativeValue::<T>::encode( |
76 | 75 | input, |
77 | | - array_mut_ref![rle_vals, 0, FL_CHUNK_SIZE], |
78 | | - array_mut_ref![rle_idxs, 0, FL_CHUNK_SIZE], |
| 76 | + unsafe { &mut *(rle_vals.as_mut_ptr() as *mut [_; FL_CHUNK_SIZE]) }, |
| 77 | + rle_idxs, |
79 | 78 | ); |
80 | 79 |
|
81 | 80 | value_count_acc += value_count; |
82 | 81 | }; |
83 | 82 |
|
84 | | - for (chunk_idx, chunk_slice) in chunks.iter().enumerate() { |
85 | | - process_chunk(chunk_idx * FL_CHUNK_SIZE, chunk_slice); |
| 83 | + for (chunk_idx, (chunk_slice, rle_idxs)) in |
| 84 | + chunks.iter().zip(indices_uninit.iter_mut()).enumerate() |
| 85 | + { |
| 86 | + // SAFETY: `MaybeUninit<u16>` and `u16` have the same layout. |
| 87 | + process_chunk(chunk_idx * FL_CHUNK_SIZE, chunk_slice, unsafe { |
| 88 | + std::mem::transmute(rle_idxs) |
| 89 | + }); |
86 | 90 | } |
87 | 91 |
|
88 | 92 | if !remainder.is_empty() { |
89 | 93 | // Repeat the last value for padding to prevent |
90 | 94 | // accounting for an additional value change. |
91 | 95 | let mut padded_chunk = [values[len - 1]; FL_CHUNK_SIZE]; |
92 | 96 | padded_chunk[..remainder.len()].copy_from_slice(remainder); |
93 | | - process_chunk((len / FL_CHUNK_SIZE) * FL_CHUNK_SIZE, &padded_chunk); |
| 97 | + let last_idx_chunk = indices_uninit |
| 98 | + .last_mut() |
| 99 | + .vortex_expect("Must have the trailing chunk"); |
| 100 | + process_chunk( |
| 101 | + (len / FL_CHUNK_SIZE) * FL_CHUNK_SIZE, |
| 102 | + &padded_chunk, |
| 103 | + unsafe { std::mem::transmute(last_idx_chunk) }, |
| 104 | + ); |
94 | 105 | } |
95 | 106 |
|
96 | 107 | unsafe { |
@@ -143,11 +154,14 @@ mod tests { |
143 | 154 | use rstest::rstest; |
144 | 155 | use vortex_array::IntoArray; |
145 | 156 | use vortex_array::ToCanonical; |
| 157 | + use vortex_array::arrays::ConstantArray; |
| 158 | + use vortex_array::arrays::MaskedArray; |
| 159 | + use vortex_array::arrays::PrimitiveArray; |
146 | 160 | use vortex_array::assert_arrays_eq; |
147 | 161 | use vortex_array::dtype::half::f16; |
148 | 162 | use vortex_buffer::Buffer; |
149 | 163 | use vortex_buffer::buffer; |
150 | | - use vortex_error::VortexExpect; |
| 164 | + use vortex_error::VortexResult; |
151 | 165 |
|
152 | 166 | use super::*; |
153 | 167 | use crate::rle::array::RLEArrayExt; |
@@ -271,6 +285,84 @@ mod tests { |
271 | 285 | assert_arrays_eq!(decoded, expected); |
272 | 286 | } |
273 | 287 |
|
| 288 | + /// Replaces the indices of an RLE array with MaskedArray(ConstantArray(1u16), validity). |
| 289 | + /// |
| 290 | + /// Simulates a compressor that represents indices as a masked constant. |
| 291 | + /// Valid when every chunk has at least two RLE dictionary entries (the |
| 292 | + /// fill-forward default at index 0 and the actual value at index 1), which |
| 293 | + /// holds whenever the first position of each chunk is null. |
| 294 | + fn with_masked_constant_indices(rle: &RLEArray) -> VortexResult<RLEArray> { |
| 295 | + let indices_prim = rle.indices().to_primitive(); |
| 296 | + let masked_indices = MaskedArray::try_new( |
| 297 | + ConstantArray::new(1u16, indices_prim.len()).into_array(), |
| 298 | + indices_prim.validity()?, |
| 299 | + )? |
| 300 | + .into_array(); |
| 301 | + RLE::try_new( |
| 302 | + rle.values().clone(), |
| 303 | + masked_indices, |
| 304 | + rle.values_idx_offsets().clone(), |
| 305 | + rle.offset(), |
| 306 | + rle.len(), |
| 307 | + ) |
| 308 | + } |
| 309 | + |
| 310 | + #[test] |
| 311 | + fn test_encode_all_null_chunk() -> VortexResult<()> { |
| 312 | + let values: Vec<Option<u32>> = vec![None; FL_CHUNK_SIZE]; |
| 313 | + let original = PrimitiveArray::from_option_iter(values); |
| 314 | + let rle = RLEData::encode(&original)?; |
| 315 | + let decoded = with_masked_constant_indices(&rle)?; |
| 316 | + assert_arrays_eq!(decoded, original); |
| 317 | + Ok(()) |
| 318 | + } |
| 319 | + |
| 320 | + #[test] |
| 321 | + fn test_encode_all_null_chunk_then_value_chunk() -> VortexResult<()> { |
| 322 | + // First chunk is entirely null, second chunk has a value preceded by nulls. |
| 323 | + let mut values: Vec<Option<u32>> = vec![None; 2 * FL_CHUNK_SIZE]; |
| 324 | + values[FL_CHUNK_SIZE + 100] = Some(42); |
| 325 | + let original = PrimitiveArray::from_option_iter(values); |
| 326 | + let rle = RLEData::encode(&original)?; |
| 327 | + let decoded = with_masked_constant_indices(&rle)?; |
| 328 | + assert_arrays_eq!(decoded, original); |
| 329 | + Ok(()) |
| 330 | + } |
| 331 | + |
| 332 | + #[test] |
| 333 | + fn test_encode_one_value_near_end() -> VortexResult<()> { |
| 334 | + // Single distinct value near the end of the chunk. |
| 335 | + let mut values: Vec<Option<u32>> = vec![None; FL_CHUNK_SIZE]; |
| 336 | + values[1000] = Some(42); |
| 337 | + let original = PrimitiveArray::from_option_iter(values); |
| 338 | + let rle = RLEData::encode(&original)?; |
| 339 | + let decoded = with_masked_constant_indices(&rle)?; |
| 340 | + assert_arrays_eq!(decoded, original); |
| 341 | + Ok(()) |
| 342 | + } |
| 343 | + |
| 344 | + #[test] |
| 345 | + fn test_encode_value_chunk_then_all_null_remainder() -> VortexResult<()> { |
| 346 | + // 1085 elements (2 chunks: 1024 + 61 padded to 1024). |
| 347 | + // Chunk 0 has -1i16 at scattered positions (273..=366), rest null. |
| 348 | + // Chunk 1 (the remainder) is entirely null. |
| 349 | + const NEG1_POSITIONS: &[usize] = &[ |
| 350 | + 273, 276, 277, 278, 279, 281, 282, 284, 285, 286, 287, 288, 289, 291, 292, 293, 296, |
| 351 | + 298, 299, 302, 304, 308, 310, 311, 313, 314, 315, 317, 318, 322, 324, 325, 334, 335, |
| 352 | + 336, 337, 338, 339, 340, 341, 342, 343, 344, 346, 347, 348, 350, 352, 353, 355, 358, |
| 353 | + 359, 362, 363, 364, 366, |
| 354 | + ]; |
| 355 | + let mut values: Vec<Option<i16>> = vec![None; 1085]; |
| 356 | + for &pos in NEG1_POSITIONS { |
| 357 | + values[pos] = Some(-1); |
| 358 | + } |
| 359 | + let original = PrimitiveArray::from_option_iter(values); |
| 360 | + let rle = RLEData::encode(&original)?; |
| 361 | + let decoded = with_masked_constant_indices(&rle)?; |
| 362 | + assert_arrays_eq!(decoded, original); |
| 363 | + Ok(()) |
| 364 | + } |
| 365 | + |
274 | 366 | // Regression test: RLE compression properly supports decoding pos/neg zeros |
275 | 367 | // See <https://github.com/vortex-data/vortex/issues/6491> |
276 | 368 | #[rstest] |
|
0 commit comments