diff --git a/parquet/src/arrow/array_reader/byte_view_array.rs b/parquet/src/arrow/array_reader/byte_view_array.rs index c134261609be..801d62edd06b 100644 --- a/parquet/src/arrow/array_reader/byte_view_array.rs +++ b/parquet/src/arrow/array_reader/byte_view_array.rs @@ -30,7 +30,6 @@ use crate::schema::types::ColumnDescPtr; use crate::util::utf8::check_valid_utf8; use arrow_array::{ArrayRef, builder::make_view}; use arrow_buffer::Buffer; -use arrow_data::ByteView; use arrow_schema::DataType as ArrowType; use bytes::Bytes; use std::any::Any; @@ -451,6 +450,26 @@ impl ByteViewArrayDecoderPlain { } } +/// Branchlessly add `base` to the buffer-index field of a long view +/// (inline views with length ≤ 12 carry data in the high bits and are +/// left untouched). +#[inline(always)] +fn adjust_buffer_index(view: u128, base: u32) -> u128 { + let is_long = ((view as u32) > 12) as u128; + view.wrapping_add((is_long * base as u128) << 64) +} + +#[cold] +#[inline(never)] +fn invalid_dict_key(chunk: &[i32], dict_len: usize) -> ParquetError { + let bad = chunk + .iter() + .copied() + .find(|&k| (k as usize) >= dict_len) + .unwrap_or(0); + general_err!("invalid key={} for dictionary of length {}", bad, dict_len) +} + pub struct ByteViewArrayDecoderDictionary { decoder: DictIndexDecoder, } @@ -504,49 +523,70 @@ impl ByteViewArrayDecoderDictionary { // Pre-reserve output capacity to avoid per-chunk reallocation in extend output.views.reserve(len); - let mut error = None; + let dict_views: &[u128] = dict.views.as_slice(); + let dict_len = dict_views.len(); + + if base_buffer_idx == 0 { + // Fused path: RLE decode + view gather in one pass via + // `RleDecoder::get_batch_with_dict`, writing directly into spare + // capacity (no zero-init) and skipping the intermediate index + // buffer for RLE runs. + let base = output.views.len(); + // SAFETY: `reserve(len)` above ensures the spare slice is at + // least `len` long. + let spare = unsafe { output.views.spare_capacity_mut().get_unchecked_mut(..len) }; + let read = self.decoder.read_with_dict(len, dict_views, spare)?; + // SAFETY: `read_with_dict` wrote exactly `read` views. + unsafe { output.views.set_len(base + read) }; + return Ok(read); + } + + let mut out_offset = 0usize; + let read = self.decoder.read(len, |keys| { - if base_buffer_idx == 0 { - // the dictionary buffers are the last buffers in output, we can directly use the views + // SAFETY: `reserve(len)` above + callbacks summing to `len` means + // spare capacity is always at least `keys.len()` from `out_offset`. + let out: &mut [std::mem::MaybeUninit] = unsafe { output .views - .extend(keys.iter().map(|k| match dict.views.get(*k as usize) { - Some(&view) => view, - None => { - if error.is_none() { - error = Some(general_err!("invalid key={} for dictionary", *k)); - } - 0 - } - })); - Ok(()) - } else { - output - .views - .extend(keys.iter().map(|k| match dict.views.get(*k as usize) { - Some(&view) => { - let len = view as u32; - if len <= 12 { - view - } else { - let mut view = ByteView::from(view); - view.buffer_index += base_buffer_idx; - view.into() - } - } - None => { - if error.is_none() { - error = Some(general_err!("invalid key={} for dictionary", *k)); - } - 0 - } - })); - Ok(()) + .spare_capacity_mut() + .get_unchecked_mut(out_offset..out_offset + keys.len()) + }; + out_offset += keys.len(); + + const CHUNK: usize = 16; + let mut out_chunks = out.chunks_exact_mut(CHUNK); + let mut key_chunks = keys.chunks_exact(CHUNK); + // Cast to u32 so negative i32 (corrupt data) compares as a large value. + let dict_len_u32 = dict_len as u32; + + for (out_chunk, key_chunk) in out_chunks.by_ref().zip(key_chunks.by_ref()) { + let max_key = key_chunk.iter().fold(0u32, |acc, &k| acc.max(k as u32)); + if max_key >= dict_len_u32 { + return Err(invalid_dict_key(key_chunk, dict_len)); + } + for (dst, &k) in out_chunk.iter_mut().zip(key_chunk.iter()) { + // SAFETY: bounds checked above. + let view = unsafe { *dict_views.get_unchecked(k as usize) }; + dst.write(adjust_buffer_index(view, base_buffer_idx)); + } + } + for (dst, &k) in out_chunks + .into_remainder() + .iter_mut() + .zip(key_chunks.remainder().iter()) + { + let view = *dict_views + .get(k as usize) + .ok_or_else(|| general_err!("invalid key={k} for dictionary"))?; + dst.write(adjust_buffer_index(view, base_buffer_idx)); } + + Ok(()) })?; - if let Some(e) = error { - return Err(e); - } + // SAFETY: decoder.read wrote exactly `read` views via dst.write. + debug_assert_eq!(out_offset, read); + unsafe { output.views.set_len(output.views.len() + read) }; Ok(read) } diff --git a/parquet/src/arrow/decoder/dictionary_index.rs b/parquet/src/arrow/decoder/dictionary_index.rs index 7a4b77f89d59..987793ae1ace 100644 --- a/parquet/src/arrow/decoder/dictionary_index.rs +++ b/parquet/src/arrow/decoder/dictionary_index.rs @@ -91,6 +91,78 @@ impl DictIndexDecoder { Ok(values_read) } + /// Read up to `len` values directly into `output`, gathering through `dict` + /// in a single pass. Avoids expanding RLE runs into the index buffer and + /// writes through a `MaybeUninit` slice so the caller does not need to + /// zero-initialise output capacity. + pub fn read_with_dict( + &mut self, + len: usize, + dict: &[T], + output: &mut [std::mem::MaybeUninit], + ) -> Result { + use crate::errors::ParquetError; + let total_to_read = len.min(self.max_remaining_values); + let mut values_read = 0; + + // Drain any leftover indices buffered from a prior `read` call before + // switching to the direct-gather path. Uses the same CHUNK=16 + + // max-reduction pattern as `RleDecoder::get_batch_with_dict`. + let leftover = self.index_buf_len - self.index_offset; + if leftover > 0 { + let n = leftover.min(total_to_read); + let keys = &self.index_buf[self.index_offset..self.index_offset + n]; + let out = &mut output[..n]; + let dict_len = dict.len(); + let dict_len_u32 = dict_len as u32; + + const CHUNK: usize = 16; + let mut out_chunks = out.chunks_exact_mut(CHUNK); + let mut key_chunks = keys.chunks_exact(CHUNK); + for (out_chunk, key_chunk) in out_chunks.by_ref().zip(key_chunks.by_ref()) { + let max_key = key_chunk.iter().fold(0u32, |acc, &k| acc.max(k as u32)); + if max_key >= dict_len_u32 { + return Err(ParquetError::General(format!( + "dictionary index out of bounds: the len is {dict_len} but the index is {max_key}" + ))); + } + for (dst, &k) in out_chunk.iter_mut().zip(key_chunk.iter()) { + // SAFETY: bounds checked above. + dst.write(unsafe { dict.get_unchecked(k as usize) }.clone()); + } + } + for (dst, &k) in out_chunks + .into_remainder() + .iter_mut() + .zip(key_chunks.remainder().iter()) + { + let idx = k as usize; + if idx >= dict_len { + return Err(ParquetError::General(format!( + "dictionary index out of bounds: the len is {dict_len} but the index is {idx}" + ))); + } + // SAFETY: bounds checked above. + dst.write(unsafe { dict.get_unchecked(idx) }.clone()); + } + + self.index_offset += n; + values_read += n; + } + + if values_read < total_to_read { + let got = self.decoder.get_batch_with_dict( + dict, + &mut output[values_read..total_to_read], + total_to_read - values_read, + )?; + values_read += got; + } + + self.max_remaining_values -= values_read; + Ok(values_read) + } + /// Skip up to `to_skip` values, returning the number of values skipped pub fn skip(&mut self, to_skip: usize) -> Result { let to_skip = to_skip.min(self.max_remaining_values); diff --git a/parquet/src/encodings/decoding.rs b/parquet/src/encodings/decoding.rs index f7f4d9be4726..0f7e8101aa85 100644 --- a/parquet/src/encodings/decoding.rs +++ b/parquet/src/encodings/decoding.rs @@ -405,7 +405,17 @@ impl Decoder for DictDecoder { let rle = self.rle_decoder.as_mut().unwrap(); let num_values = cmp::min(buffer.len(), self.num_values); - rle.get_batch_with_dict(&self.dictionary[..], buffer, num_values) + // SAFETY: reinterpreting `&mut [T]` as `&mut [MaybeUninit]` is sound + // because every initialised `T` is a valid `MaybeUninit`; `get_batch_with_dict` + // only writes through the slice, and we do not read through the original + // reference after this call. + let uninit: &mut [std::mem::MaybeUninit] = unsafe { + std::slice::from_raw_parts_mut( + buffer.as_mut_ptr().cast::>(), + buffer.len(), + ) + }; + rle.get_batch_with_dict(&self.dictionary[..], uninit, num_values) } /// Number of values left in this decoder stream diff --git a/parquet/src/encodings/rle.rs b/parquet/src/encodings/rle.rs index 806b41a353b4..42c53b644719 100644 --- a/parquet/src/encodings/rle.rs +++ b/parquet/src/encodings/rle.rs @@ -477,33 +477,30 @@ impl RleDecoder { pub fn get_batch_with_dict( &mut self, dict: &[T], - buffer: &mut [T], + buffer: &mut [std::mem::MaybeUninit], max_values: usize, ) -> Result where - T: Default + Clone, + T: Clone, { debug_assert!(buffer.len() >= max_values); let mut values_read = 0; while values_read < max_values { - let index_buf = self.index_buf.get_or_insert_with(|| Box::new([0; 1024])); - if self.rle_left > 0 { let num_values = cmp::min(max_values - values_read, self.rle_left as usize); let dict_idx = self.current_value.unwrap() as usize; - let dict_value = dict - .get(dict_idx) - .ok_or_else(|| { - general_err!( - "dictionary index out of bounds: the len is {} but the index is {}", - dict.len(), - dict_idx - ) - })? - .clone(); - - buffer[values_read..values_read + num_values].fill(dict_value); + let dict_value = dict.get(dict_idx).ok_or_else(|| { + general_err!( + "dictionary index out of bounds: the len is {} but the index is {}", + dict.len(), + dict_idx + ) + })?; + + for slot in &mut buffer[values_read..values_read + num_values] { + slot.write(dict_value.clone()); + } self.rle_left -= num_values as u32; values_read += num_values; @@ -512,6 +509,7 @@ impl RleDecoder { .bit_reader .as_mut() .ok_or_else(|| general_err!("bit_reader should be set"))?; + let index_buf = self.index_buf.get_or_insert_with(|| Box::new([0; 1024])); loop { let to_read = index_buf @@ -557,7 +555,7 @@ impl RleDecoder { } for (b, i) in out_chunk.iter_mut().zip(idx_chunk.iter()) { // SAFETY: all indices checked above to be in bounds - b.clone_from(unsafe { dict.get_unchecked(*i as usize) }); + b.write(unsafe { dict.get_unchecked(*i as usize) }.clone()); } } for (b, i) in out_chunks @@ -570,7 +568,7 @@ impl RleDecoder { return Err(oob(*i as u32, dict_len)); } // SAFETY: bounds checked above - b.clone_from(unsafe { dict.get_unchecked(dict_idx) }); + b.write(unsafe { dict.get_unchecked(dict_idx) }.clone()); } } self.bit_packed_left -= num_values as u32; @@ -624,6 +622,14 @@ mod tests { use crate::util::bit_util::ceil; use rand::{self, Rng, SeedableRng, distr::StandardUniform, rng}; + use std::mem::MaybeUninit; + + /// Reinterpret an initialised slice as a `MaybeUninit` slice for calls to + /// `get_batch_with_dict`. Sound because every `T` is a valid `MaybeUninit` + /// and the callee only writes. + fn as_uninit(s: &mut [T]) -> &mut [MaybeUninit] { + unsafe { std::slice::from_raw_parts_mut(s.as_mut_ptr().cast::>(), s.len()) } + } const MAX_WIDTH: usize = 32; @@ -772,7 +778,7 @@ mod tests { decoder.set_data(data.into()).unwrap(); let mut buffer = vec![0; 12]; let expected = vec![10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 30]; - let result = decoder.get_batch_with_dict::(&dict, &mut buffer, 12); + let result = decoder.get_batch_with_dict::(&dict, as_uninit(&mut buffer), 12); assert!(result.is_ok()); assert_eq!(buffer, expected); @@ -788,7 +794,7 @@ mod tests { "ddd", "eee", "fff", "ddd", "eee", "fff", "ddd", "eee", "fff", "eee", "fff", "fff", ]; let result = - decoder.get_batch_with_dict::<&str>(dict.as_slice(), buffer.as_mut_slice(), 12); + decoder.get_batch_with_dict::<&str>(dict.as_slice(), as_uninit(&mut buffer), 12); assert!(result.is_ok()); assert_eq!(buffer, expected); } @@ -806,7 +812,7 @@ mod tests { let skipped = decoder.skip(2).expect("skipping two values"); assert_eq!(skipped, 2); let remainder = decoder - .get_batch_with_dict::(&dict, &mut buffer, 10) + .get_batch_with_dict::(&dict, as_uninit(&mut buffer), 10) .expect("getting remainder"); assert_eq!(remainder, 10); assert_eq!(buffer, expected); @@ -825,7 +831,7 @@ mod tests { let remainder = decoder .get_batch_with_dict::<&str>( dict.as_slice(), - buffer.as_mut_slice(), + as_uninit(&mut buffer), BIT_PACK_GROUP_SIZE, ) .expect("getting remainder"); @@ -986,7 +992,7 @@ mod tests { let dict: Vec = (0..256).collect(); let mut output = vec![0_u16; 100]; let read = decoder - .get_batch_with_dict(&dict, &mut output, 100) + .get_batch_with_dict(&dict, as_uninit(&mut output), 100) .unwrap(); assert_eq!(read, 20); @@ -1056,7 +1062,7 @@ mod tests { decoder.set_data(buffer).unwrap(); let r = decoder - .get_batch_with_dict(&[0, 23], &mut decoded, num_values) + .get_batch_with_dict(&[0, 23], as_uninit(&mut decoded), num_values) .unwrap(); assert_eq!(r, num_values); assert_eq!(vec![23; num_values], decoded);