Skip to content
118 changes: 79 additions & 39 deletions parquet/src/arrow/array_reader/byte_view_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use crate::schema::types::ColumnDescPtr;
use crate::util::utf8::check_valid_utf8;
use arrow_array::{ArrayRef, builder::make_view};
use arrow_buffer::Buffer;
use arrow_data::ByteView;
use arrow_schema::DataType as ArrowType;
use bytes::Bytes;
use std::any::Any;
Expand Down Expand Up @@ -451,6 +450,26 @@ impl ByteViewArrayDecoderPlain {
}
}

/// Branchlessly add `base` to the buffer-index field of a long view
/// (inline views with length ≤ 12 carry data in the high bits and are
/// left untouched).
#[inline(always)]
fn adjust_buffer_index(view: u128, base: u32) -> u128 {
let is_long = ((view as u32) > 12) as u128;
view.wrapping_add((is_long * base as u128) << 64)
}

#[cold]
#[inline(never)]
fn invalid_dict_key(chunk: &[i32], dict_len: usize) -> ParquetError {
let bad = chunk
.iter()
.copied()
.find(|&k| (k as usize) >= dict_len)
.unwrap_or(0);
general_err!("invalid key={} for dictionary of length {}", bad, dict_len)
}

pub struct ByteViewArrayDecoderDictionary {
decoder: DictIndexDecoder,
}
Expand Down Expand Up @@ -504,49 +523,70 @@ impl ByteViewArrayDecoderDictionary {
// Pre-reserve output capacity to avoid per-chunk reallocation in extend
output.views.reserve(len);

let mut error = None;
let dict_views: &[u128] = dict.views.as_slice();
let dict_len = dict_views.len();

if base_buffer_idx == 0 {
// Fused path: RLE decode + view gather in one pass via
// `RleDecoder::get_batch_with_dict`, writing directly into spare
// capacity (no zero-init) and skipping the intermediate index
// buffer for RLE runs.
let base = output.views.len();
// SAFETY: `reserve(len)` above ensures the spare slice is at
// least `len` long.
let spare = unsafe { output.views.spare_capacity_mut().get_unchecked_mut(..len) };
let read = self.decoder.read_with_dict(len, dict_views, spare)?;
// SAFETY: `read_with_dict` wrote exactly `read` views.
unsafe { output.views.set_len(base + read) };
return Ok(read);
}

let mut out_offset = 0usize;

let read = self.decoder.read(len, |keys| {
if base_buffer_idx == 0 {
// the dictionary buffers are the last buffers in output, we can directly use the views
// SAFETY: `reserve(len)` above + callbacks summing to `len` means
// spare capacity is always at least `keys.len()` from `out_offset`.
let out: &mut [std::mem::MaybeUninit<u128>] = unsafe {
output
.views
.extend(keys.iter().map(|k| match dict.views.get(*k as usize) {
Some(&view) => view,
None => {
if error.is_none() {
error = Some(general_err!("invalid key={} for dictionary", *k));
}
0
}
}));
Ok(())
} else {
output
.views
.extend(keys.iter().map(|k| match dict.views.get(*k as usize) {
Some(&view) => {
let len = view as u32;
if len <= 12 {
view
} else {
let mut view = ByteView::from(view);
view.buffer_index += base_buffer_idx;
view.into()
}
}
None => {
if error.is_none() {
error = Some(general_err!("invalid key={} for dictionary", *k));
}
0
}
}));
Ok(())
.spare_capacity_mut()
.get_unchecked_mut(out_offset..out_offset + keys.len())
};
out_offset += keys.len();

const CHUNK: usize = 16;
let mut out_chunks = out.chunks_exact_mut(CHUNK);
let mut key_chunks = keys.chunks_exact(CHUNK);
// Cast to u32 so negative i32 (corrupt data) compares as a large value.
let dict_len_u32 = dict_len as u32;

for (out_chunk, key_chunk) in out_chunks.by_ref().zip(key_chunks.by_ref()) {
let max_key = key_chunk.iter().fold(0u32, |acc, &k| acc.max(k as u32));
if max_key >= dict_len_u32 {
return Err(invalid_dict_key(key_chunk, dict_len));
}
for (dst, &k) in out_chunk.iter_mut().zip(key_chunk.iter()) {
// SAFETY: bounds checked above.
let view = unsafe { *dict_views.get_unchecked(k as usize) };
dst.write(adjust_buffer_index(view, base_buffer_idx));
}
}
for (dst, &k) in out_chunks
.into_remainder()
.iter_mut()
.zip(key_chunks.remainder().iter())
{
let view = *dict_views
.get(k as usize)
.ok_or_else(|| general_err!("invalid key={k} for dictionary"))?;
dst.write(adjust_buffer_index(view, base_buffer_idx));
}

Ok(())
})?;
if let Some(e) = error {
return Err(e);
}
// SAFETY: decoder.read wrote exactly `read` views via dst.write.
debug_assert_eq!(out_offset, read);
unsafe { output.views.set_len(output.views.len() + read) };
Ok(read)
}

Expand Down
72 changes: 72 additions & 0 deletions parquet/src/arrow/decoder/dictionary_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,78 @@ impl DictIndexDecoder {
Ok(values_read)
}

/// Read up to `len` values directly into `output`, gathering through `dict`
/// in a single pass. Avoids expanding RLE runs into the index buffer and
/// writes through a `MaybeUninit` slice so the caller does not need to
/// zero-initialise output capacity.
pub fn read_with_dict<T: Clone>(
&mut self,
len: usize,
dict: &[T],
output: &mut [std::mem::MaybeUninit<T>],
) -> Result<usize> {
use crate::errors::ParquetError;
let total_to_read = len.min(self.max_remaining_values);
let mut values_read = 0;

// Drain any leftover indices buffered from a prior `read` call before
// switching to the direct-gather path. Uses the same CHUNK=16 +
// max-reduction pattern as `RleDecoder::get_batch_with_dict`.
let leftover = self.index_buf_len - self.index_offset;
if leftover > 0 {
let n = leftover.min(total_to_read);
let keys = &self.index_buf[self.index_offset..self.index_offset + n];
let out = &mut output[..n];
let dict_len = dict.len();
let dict_len_u32 = dict_len as u32;

const CHUNK: usize = 16;
let mut out_chunks = out.chunks_exact_mut(CHUNK);
let mut key_chunks = keys.chunks_exact(CHUNK);
for (out_chunk, key_chunk) in out_chunks.by_ref().zip(key_chunks.by_ref()) {
let max_key = key_chunk.iter().fold(0u32, |acc, &k| acc.max(k as u32));
if max_key >= dict_len_u32 {
return Err(ParquetError::General(format!(
"dictionary index out of bounds: the len is {dict_len} but the index is {max_key}"
)));
}
for (dst, &k) in out_chunk.iter_mut().zip(key_chunk.iter()) {
// SAFETY: bounds checked above.
dst.write(unsafe { dict.get_unchecked(k as usize) }.clone());
}
}
for (dst, &k) in out_chunks
.into_remainder()
.iter_mut()
.zip(key_chunks.remainder().iter())
{
let idx = k as usize;
if idx >= dict_len {
return Err(ParquetError::General(format!(
"dictionary index out of bounds: the len is {dict_len} but the index is {idx}"
)));
}
// SAFETY: bounds checked above.
dst.write(unsafe { dict.get_unchecked(idx) }.clone());
}

self.index_offset += n;
values_read += n;
}

if values_read < total_to_read {
let got = self.decoder.get_batch_with_dict(
dict,
&mut output[values_read..total_to_read],
total_to_read - values_read,
)?;
values_read += got;
}

self.max_remaining_values -= values_read;
Ok(values_read)
}

/// Skip up to `to_skip` values, returning the number of values skipped
pub fn skip(&mut self, to_skip: usize) -> Result<usize> {
let to_skip = to_skip.min(self.max_remaining_values);
Expand Down
12 changes: 11 additions & 1 deletion parquet/src/encodings/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,17 @@ impl<T: DataType> Decoder<T> for DictDecoder<T> {

let rle = self.rle_decoder.as_mut().unwrap();
let num_values = cmp::min(buffer.len(), self.num_values);
rle.get_batch_with_dict(&self.dictionary[..], buffer, num_values)
// SAFETY: reinterpreting `&mut [T]` as `&mut [MaybeUninit<T>]` is sound
// because every initialised `T` is a valid `MaybeUninit<T>`; `get_batch_with_dict`
// only writes through the slice, and we do not read through the original
// reference after this call.
let uninit: &mut [std::mem::MaybeUninit<T::T>] = unsafe {
std::slice::from_raw_parts_mut(
buffer.as_mut_ptr().cast::<std::mem::MaybeUninit<T::T>>(),
buffer.len(),
)
};
rle.get_batch_with_dict(&self.dictionary[..], uninit, num_values)
}

/// Number of values left in this decoder stream
Expand Down
54 changes: 30 additions & 24 deletions parquet/src/encodings/rle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,33 +477,30 @@ impl RleDecoder {
pub fn get_batch_with_dict<T>(
&mut self,
dict: &[T],
buffer: &mut [T],
buffer: &mut [std::mem::MaybeUninit<T>],
max_values: usize,
) -> Result<usize>
where
T: Default + Clone,
T: Clone,
{
debug_assert!(buffer.len() >= max_values);

let mut values_read = 0;
while values_read < max_values {
let index_buf = self.index_buf.get_or_insert_with(|| Box::new([0; 1024]));

if self.rle_left > 0 {
let num_values = cmp::min(max_values - values_read, self.rle_left as usize);
let dict_idx = self.current_value.unwrap() as usize;
let dict_value = dict
.get(dict_idx)
.ok_or_else(|| {
general_err!(
"dictionary index out of bounds: the len is {} but the index is {}",
dict.len(),
dict_idx
)
})?
.clone();

buffer[values_read..values_read + num_values].fill(dict_value);
let dict_value = dict.get(dict_idx).ok_or_else(|| {
general_err!(
"dictionary index out of bounds: the len is {} but the index is {}",
dict.len(),
dict_idx
)
})?;

for slot in &mut buffer[values_read..values_read + num_values] {
slot.write(dict_value.clone());
}

self.rle_left -= num_values as u32;
values_read += num_values;
Expand All @@ -512,6 +509,7 @@ impl RleDecoder {
.bit_reader
.as_mut()
.ok_or_else(|| general_err!("bit_reader should be set"))?;
let index_buf = self.index_buf.get_or_insert_with(|| Box::new([0; 1024]));

loop {
let to_read = index_buf
Expand Down Expand Up @@ -557,7 +555,7 @@ impl RleDecoder {
}
for (b, i) in out_chunk.iter_mut().zip(idx_chunk.iter()) {
// SAFETY: all indices checked above to be in bounds
b.clone_from(unsafe { dict.get_unchecked(*i as usize) });
b.write(unsafe { dict.get_unchecked(*i as usize) }.clone());
}
}
for (b, i) in out_chunks
Expand All @@ -570,7 +568,7 @@ impl RleDecoder {
return Err(oob(*i as u32, dict_len));
}
// SAFETY: bounds checked above
b.clone_from(unsafe { dict.get_unchecked(dict_idx) });
b.write(unsafe { dict.get_unchecked(dict_idx) }.clone());
}
}
self.bit_packed_left -= num_values as u32;
Expand Down Expand Up @@ -624,6 +622,14 @@ mod tests {

use crate::util::bit_util::ceil;
use rand::{self, Rng, SeedableRng, distr::StandardUniform, rng};
use std::mem::MaybeUninit;

/// Reinterpret an initialised slice as a `MaybeUninit` slice for calls to
/// `get_batch_with_dict`. Sound because every `T` is a valid `MaybeUninit<T>`
/// and the callee only writes.
fn as_uninit<T>(s: &mut [T]) -> &mut [MaybeUninit<T>] {
unsafe { std::slice::from_raw_parts_mut(s.as_mut_ptr().cast::<MaybeUninit<T>>(), s.len()) }
}

const MAX_WIDTH: usize = 32;

Expand Down Expand Up @@ -772,7 +778,7 @@ mod tests {
decoder.set_data(data.into()).unwrap();
let mut buffer = vec![0; 12];
let expected = vec![10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 30];
let result = decoder.get_batch_with_dict::<i32>(&dict, &mut buffer, 12);
let result = decoder.get_batch_with_dict::<i32>(&dict, as_uninit(&mut buffer), 12);
assert!(result.is_ok());
assert_eq!(buffer, expected);

Expand All @@ -788,7 +794,7 @@ mod tests {
"ddd", "eee", "fff", "ddd", "eee", "fff", "ddd", "eee", "fff", "eee", "fff", "fff",
];
let result =
decoder.get_batch_with_dict::<&str>(dict.as_slice(), buffer.as_mut_slice(), 12);
decoder.get_batch_with_dict::<&str>(dict.as_slice(), as_uninit(&mut buffer), 12);
assert!(result.is_ok());
assert_eq!(buffer, expected);
}
Expand All @@ -806,7 +812,7 @@ mod tests {
let skipped = decoder.skip(2).expect("skipping two values");
assert_eq!(skipped, 2);
let remainder = decoder
.get_batch_with_dict::<i32>(&dict, &mut buffer, 10)
.get_batch_with_dict::<i32>(&dict, as_uninit(&mut buffer), 10)
.expect("getting remainder");
assert_eq!(remainder, 10);
assert_eq!(buffer, expected);
Expand All @@ -825,7 +831,7 @@ mod tests {
let remainder = decoder
.get_batch_with_dict::<&str>(
dict.as_slice(),
buffer.as_mut_slice(),
as_uninit(&mut buffer),
BIT_PACK_GROUP_SIZE,
)
.expect("getting remainder");
Expand Down Expand Up @@ -986,7 +992,7 @@ mod tests {
let dict: Vec<u16> = (0..256).collect();
let mut output = vec![0_u16; 100];
let read = decoder
.get_batch_with_dict(&dict, &mut output, 100)
.get_batch_with_dict(&dict, as_uninit(&mut output), 100)
.unwrap();

assert_eq!(read, 20);
Expand Down Expand Up @@ -1056,7 +1062,7 @@ mod tests {

decoder.set_data(buffer).unwrap();
let r = decoder
.get_batch_with_dict(&[0, 23], &mut decoded, num_values)
.get_batch_with_dict(&[0, 23], as_uninit(&mut decoded), num_values)
.unwrap();
assert_eq!(r, num_values);
assert_eq!(vec![23; num_values], decoded);
Expand Down
Loading