Skip to content

Commit e8845d0

Browse files
committed
Avoid zero-filling IPC reads with typed buffer handling
1 parent 1ffd202 commit e8845d0

1 file changed

Lines changed: 224 additions & 42 deletions

File tree

arrow-ipc/src/reader.rs

Lines changed: 224 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,57 @@ fn read_buffer(
7272
}
7373
}
7474
}
75+
76+
/// Source for IPC body buffers.
77+
///
78+
/// Most decode paths use an already materialized [`Buffer`] and slice into it
79+
/// using the offsets from the IPC metadata. Keeping this behind a small helper
80+
/// lets typed buffer reads share the same bounds handling as regular buffer
81+
/// reads.
82+
enum IpcBufferSource<'a> {
83+
Buffer(&'a Buffer),
84+
}
85+
86+
impl<'a> IpcBufferSource<'a> {
87+
fn read_buffer(
88+
&self,
89+
buf: &crate::Buffer,
90+
compression: Option<CompressionCodec>,
91+
decompression_context: &mut DecompressionContext,
92+
) -> Result<Buffer, ArrowError> {
93+
match self {
94+
Self::Buffer(data) => read_buffer(buf, data, compression, decompression_context),
95+
}
96+
}
97+
/// Reads a physical IPC buffer that is expected to contain `len` values of
98+
/// type `T`.
99+
///
100+
/// The returned value is still a [`Buffer`], not a [`ScalarBuffer<T>`]. This
101+
/// preserves the existing alignment behavior: properly aligned buffers remain
102+
/// zero-copy, while unaligned buffers are still handled later by
103+
/// `ArrayDataBuilder::align_buffers` according to `require_alignment`.
104+
fn read_typed_buffer<T: ArrowNativeType>(
105+
&self,
106+
buf: &crate::Buffer,
107+
len: usize,
108+
compression: Option<CompressionCodec>,
109+
decompression_context: &mut DecompressionContext,
110+
) -> Result<Buffer, ArrowError> {
111+
let byte_len = len
112+
.checked_mul(std::mem::size_of::<T>())
113+
.ok_or_else(|| ArrowError::IpcError("Buffer length overflow".to_string()))?;
114+
115+
let buffer = self.read_buffer(buf, compression, decompression_context)?;
116+
// Some invalid or legacy IPC inputs may contain shorter buffers than
117+
// implied by the schema. Preserve the existing behavior and let array
118+
// construction/validation report the error.
119+
if buffer.len() < byte_len {
120+
return Ok(buffer);
121+
}
122+
Ok(buffer.slice_with_length(0, byte_len))
123+
}
124+
}
125+
75126
impl RecordBatchDecoder<'_> {
76127
/// Coordinates reading arrays based on data types.
77128
///
@@ -93,25 +144,46 @@ impl RecordBatchDecoder<'_> {
93144
let data_type = field.data_type();
94145
match data_type {
95146
Utf8 | Binary | LargeBinary | LargeUtf8 => {
147+
// Binary and string arrays use fixed-width offset buffers
148+
// followed by the variable-width value bytes buffer.
149+
// Read the offsets through the next_typed_buffer::<T>() helper.
96150
let field_node = self.next_node(field)?;
151+
let null_buffer = self.next_buffer()?;
152+
let len = field_node.length() as usize + 1;
153+
let offsets = match data_type {
154+
Utf8 | Binary => self.next_typed_buffer::<i32>(len)?,
155+
LargeBinary | LargeUtf8 => self.next_typed_buffer::<i64>(len)?,
156+
_ => unreachable!(),
157+
};
158+
97159
let buffers = [
98-
self.next_buffer()?,
99-
self.next_buffer()?,
100-
self.next_buffer()?,
160+
null_buffer,
161+
offsets,
162+
self.next_buffer()?, // value bytes
101163
];
164+
102165
self.create_primitive_array(field_node, data_type, &buffers)
103166
}
167+
// The first buffer after the null bitmap is the fixed-width view
168+
// buffer. Any remaining buffers are variadic data buffers.
104169
BinaryView | Utf8View => {
105170
let count = variadic_counts
106171
.pop_front()
107172
.ok_or(ArrowError::IpcError(format!(
108173
"Missing variadic count for {data_type} column"
109174
)))?;
110-
let count = count + 2; // view and null buffer.
111-
let buffers = (0..count)
112-
.map(|_| self.next_buffer())
113-
.collect::<Result<Vec<_>, _>>()?;
175+
114176
let field_node = self.next_node(field)?;
177+
let len = field_node.length() as usize;
178+
179+
let mut buffers = Vec::with_capacity(count as usize + 2);
180+
buffers.push(self.next_buffer()?); // null buffer
181+
buffers.push(self.next_typed_buffer::<u128>(len)?); // views
182+
183+
for _ in 0..count {
184+
buffers.push(self.next_buffer()?); // variadic data buffers
185+
}
186+
115187
self.create_primitive_array(field_node, data_type, &buffers)
116188
}
117189
FixedSizeBinary(_) => {
@@ -121,17 +193,37 @@ impl RecordBatchDecoder<'_> {
121193
}
122194
List(list_field) | LargeList(list_field) | Map(list_field, _) => {
123195
let list_node = self.next_node(field)?;
124-
let list_buffers = [self.next_buffer()?, self.next_buffer()?];
196+
let null_buffer = self.next_buffer()?;
197+
198+
let offset_len = list_node.length() as usize + 1;
199+
let offsets = match data_type {
200+
List(_) | Map(_, _) => self.next_typed_buffer::<i32>(offset_len)?,
201+
LargeList(_) => self.next_typed_buffer::<i64>(offset_len)?,
202+
_ => unreachable!(),
203+
};
204+
205+
let list_buffers = [null_buffer, offsets];
125206
let values = self.create_array(list_field, variadic_counts)?;
126207
self.create_list_array(list_node, data_type, &list_buffers, values)
127208
}
128209
ListView(list_field) | LargeListView(list_field) => {
129210
let list_node = self.next_node(field)?;
130-
let list_buffers = [
131-
self.next_buffer()?, // null buffer
132-
self.next_buffer()?, // offsets
133-
self.next_buffer()?, // sizes
134-
];
211+
let null_buffer = self.next_buffer()?;
212+
213+
let len = list_node.length() as usize;
214+
let (offsets, sizes) = match data_type {
215+
ListView(_) => (
216+
self.next_typed_buffer::<i32>(len)?,
217+
self.next_typed_buffer::<i32>(len)?,
218+
),
219+
LargeListView(_) => (
220+
self.next_typed_buffer::<i64>(len)?,
221+
self.next_typed_buffer::<i64>(len)?,
222+
),
223+
_ => unreachable!(),
224+
};
225+
226+
let list_buffers = [null_buffer, offsets, sizes];
135227
let values = self.create_array(list_field, variadic_counts)?;
136228
self.create_list_view_array(list_node, data_type, &list_buffers, values)
137229
}
@@ -170,10 +262,29 @@ impl RecordBatchDecoder<'_> {
170262

171263
self.create_array_from_builder(builder)
172264
}
173-
// Create dictionary array from RecordBatch
174265
Dictionary(_, _) => {
175266
let index_node = self.next_node(field)?;
176-
let index_buffers = [self.next_buffer()?, self.next_buffer()?];
267+
let null_buffer = self.next_buffer()?;
268+
269+
// Dictionary indices are fixed-width values. Read the index
270+
// buffer through the next_typed_buffer::<T>() helper so length handling is
271+
// based on the physical key type.
272+
let len = index_node.length() as usize;
273+
let indices = match data_type {
274+
Dictionary(key_type, _) => match key_type.as_ref() {
275+
Int8 => self.next_typed_buffer::<i8>(len)?,
276+
Int16 => self.next_typed_buffer::<i16>(len)?,
277+
Int32 => self.next_typed_buffer::<i32>(len)?,
278+
Int64 => self.next_typed_buffer::<i64>(len)?,
279+
UInt8 => self.next_typed_buffer::<u8>(len)?,
280+
UInt16 => self.next_typed_buffer::<u16>(len)?,
281+
UInt32 => self.next_typed_buffer::<u32>(len)?,
282+
UInt64 => self.next_typed_buffer::<u64>(len)?,
283+
t => unreachable!("Unsupported dictionary key type {t:?}"),
284+
},
285+
_ => unreachable!(),
286+
};
287+
let index_buffers = [null_buffer, indices];
177288

178289
#[allow(deprecated)]
179290
let dict_id = field.dict_id().ok_or_else(|| {
@@ -206,13 +317,13 @@ impl RecordBatchDecoder<'_> {
206317
self.next_buffer()?;
207318
}
208319

209-
let type_ids: ScalarBuffer<i8> =
210-
self.next_buffer()?.slice_with_length(0, len).into();
211-
320+
// Union type ids and dense union offsets are fixed-width
321+
// buffers. Read them through next_typed_buffer::<T>() before
322+
// constructing the ScalarBuffers used by UnionArray.
323+
let type_ids: ScalarBuffer<i8> = self.next_typed_buffer::<i8>(len)?.into();
212324
let value_offsets = match mode {
213325
UnionMode::Dense => {
214-
let offsets: ScalarBuffer<i32> =
215-
self.next_buffer()?.slice_with_length(0, len * 4).into();
326+
let offsets: ScalarBuffer<i32> = self.next_typed_buffer::<i32>(len)?.into();
216327
Some(offsets)
217328
}
218329
UnionMode::Sparse => None,
@@ -253,7 +364,41 @@ impl RecordBatchDecoder<'_> {
253364
}
254365
_ => {
255366
let field_node = self.next_node(field)?;
256-
let buffers = [self.next_buffer()?, self.next_buffer()?];
367+
let null_buffer = self.next_buffer()?;
368+
369+
// Primitive and primitive-like arrays use fixed-width physical
370+
// buffers with widths that depend on the logical data type.
371+
let len = field_node.length() as usize;
372+
let values = match data_type {
373+
Int8 => self.next_typed_buffer::<i8>(len)?,
374+
Int16 => self.next_typed_buffer::<i16>(len)?,
375+
Int32 => self.next_typed_buffer::<i32>(len)?,
376+
Int64 => self.next_typed_buffer::<i64>(len)?,
377+
UInt8 => self.next_typed_buffer::<u8>(len)?,
378+
UInt16 => self.next_typed_buffer::<u16>(len)?,
379+
UInt32 => self.next_typed_buffer::<u32>(len)?,
380+
UInt64 => self.next_typed_buffer::<u64>(len)?,
381+
Float16 => self.next_typed_buffer::<u16>(len)?,
382+
Float32 => self.next_typed_buffer::<f32>(len)?,
383+
Float64 => self.next_typed_buffer::<f64>(len)?,
384+
Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) | Decimal32(_, _) => {
385+
self.next_typed_buffer::<i32>(len)?
386+
}
387+
Date64 | Time64(_) | Timestamp(_, _) | Duration(_) | Decimal64(_, _) => {
388+
self.next_typed_buffer::<i64>(len)?
389+
}
390+
Decimal128(_, _) => self.next_typed_buffer::<i128>(len)?,
391+
Decimal256(_, _) => self.next_typed_buffer::<arrow_buffer::i256>(len)?,
392+
Interval(IntervalUnit::DayTime) => {
393+
self.next_typed_buffer::<arrow_buffer::IntervalDayTime>(len)?
394+
}
395+
Interval(IntervalUnit::MonthDayNano) => {
396+
self.next_typed_buffer::<arrow_buffer::IntervalMonthDayNano>(len)?
397+
}
398+
Boolean | FixedSizeBinary(_) => self.next_buffer()?,
399+
t => unreachable!("Unsupported primitive data type {t:?}"),
400+
};
401+
let buffers = [null_buffer, values];
257402
self.create_primitive_array(field_node, data_type, &buffers)
258403
}
259404
}
@@ -454,8 +599,8 @@ pub struct RecordBatchDecoder<'a> {
454599
decompression_context: DecompressionContext,
455600
/// The format version
456601
version: MetadataVersion,
457-
/// The raw data buffer
458-
data: &'a Buffer,
602+
/// Source of IPC body buffers
603+
data: IpcBufferSource<'a>,
459604
/// The fields comprising this array
460605
nodes: VectorIter<'a, FieldNode>,
461606
/// The buffers comprising this array
@@ -500,7 +645,7 @@ impl<'a> RecordBatchDecoder<'a> {
500645
compression,
501646
decompression_context: DecompressionContext::new(),
502647
version: *metadata,
503-
data: buf,
648+
data: IpcBufferSource::Buffer(buf),
504649
nodes: field_nodes.iter(),
505650
buffers: buffers.iter(),
506651
projection: None,
@@ -615,9 +760,22 @@ impl<'a> RecordBatchDecoder<'a> {
615760
let buffer = self.buffers.next().ok_or_else(|| {
616761
ArrowError::IpcError("Buffer count mismatched with metadata".to_string())
617762
})?;
618-
read_buffer(
763+
self.data
764+
.read_buffer(buffer, self.compression, &mut self.decompression_context)
765+
}
766+
/// Advances to the next IPC buffer and trims it to the expected physical
767+
/// length for `len` values of `T`.
768+
///
769+
/// This keeps typed buffer length handling in one place while leaving final
770+
/// array construction on the existing `ArrayDataBuilder` path.
771+
fn next_typed_buffer<T: ArrowNativeType>(&mut self, len: usize) -> Result<Buffer, ArrowError> {
772+
let buffer = self.buffers.next().ok_or_else(|| {
773+
ArrowError::IpcError("Buffer count mismatched with metadata".to_string())
774+
})?;
775+
776+
self.data.read_typed_buffer::<T>(
619777
buffer,
620-
self.data,
778+
len,
621779
self.compression,
622780
&mut self.decompression_context,
623781
)
@@ -902,16 +1060,29 @@ fn get_dictionary_values(
9021060
Ok(dictionary_values)
9031061
}
9041062

905-
/// Read the data for a given block
1063+
/// Read the data for a given IPC file block.
1064+
///
1065+
/// The returned buffer is fully initialized by the reader. This avoids first
1066+
/// zero-filling the allocation and then immediately overwriting it with block
1067+
/// data.
9061068
fn read_block<R: Read + Seek>(mut reader: R, block: &Block) -> Result<Buffer, ArrowError> {
9071069
reader.seek(SeekFrom::Start(block.offset() as u64))?;
9081070
let body_len = block.bodyLength().to_usize().unwrap();
9091071
let metadata_len = block.metaDataLength().to_usize().unwrap();
9101072
let total_len = body_len.checked_add(metadata_len).unwrap();
9111073

912-
let mut buf = MutableBuffer::from_len_zeroed(total_len);
913-
reader.read_exact(&mut buf)?;
914-
Ok(buf.into())
1074+
let mut buf = Vec::with_capacity(total_len);
1075+
reader
1076+
.by_ref()
1077+
.take(total_len as u64)
1078+
.read_to_end(&mut buf)?;
1079+
if buf.len() != total_len {
1080+
return Err(ArrowError::IpcError(format!(
1081+
"Expected IPC block of length {total_len}, got {}",
1082+
buf.len()
1083+
)));
1084+
}
1085+
Ok(Buffer::from_vec(buf))
9151086
}
9161087

9171088
/// Parse an encapsulated message
@@ -1809,16 +1980,14 @@ impl<R: Read> MessageReader<R> {
18091980
}
18101981
}
18111982

1812-
/// Reads the entire next message from the underlying reader which includes
1813-
/// the metadata length, the metadata, and the body.
1983+
/// Reads the next IPC message, including the metadata and body.
18141984
///
18151985
/// # Returns
1816-
/// - `Ok(None)` if the the reader signals the end of stream with EOF on
1817-
/// the first read
1818-
/// - `Err(_)` if the reader returns an error other than on the first
1819-
/// read, or if the metadata length is invalid
1820-
/// - `Ok(Some(_))` with the Message and buffer containiner the
1821-
/// body bytes otherwise.
1986+
/// - `Ok(None)` if the reader signals end-of-stream before reading a
1987+
/// metadata length
1988+
/// - `Err(_)` if the reader returns an error or the IPC message is invalid
1989+
/// - `Ok(Some(_))` with the decoded message metadata and body bytes
1990+
/// otherwise
18221991
fn maybe_next(&mut self) -> Result<Option<(Message::Message<'_>, MutableBuffer)>, ArrowError> {
18231992
let meta_len = self.read_meta_len()?;
18241993
let Some(meta_len) = meta_len else {
@@ -1832,10 +2001,23 @@ impl<R: Read> MessageReader<R> {
18322001
ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
18332002
})?;
18342003

1835-
let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
1836-
self.reader.read_exact(&mut buf)?;
1837-
1838-
Ok(Some((message, buf)))
2004+
// The message body is fully provided by the reader. Read it into a
2005+
// Vec to avoid zero-filling a MutableBuffer before immediately
2006+
// overwriting it with the body bytes.
2007+
let body_len = message.bodyLength() as usize;
2008+
let mut buf = Vec::with_capacity(body_len);
2009+
self.reader
2010+
.by_ref()
2011+
.take(body_len as u64)
2012+
.read_to_end(&mut buf)?;
2013+
2014+
if buf.len() != body_len {
2015+
return Err(ArrowError::IpcError(format!(
2016+
"Expected IPC message body of length {body_len}, got {}",
2017+
buf.len()
2018+
)));
2019+
}
2020+
Ok(Some((message, MutableBuffer::from(buf))))
18392021
}
18402022

18412023
/// Get a mutable reference to the underlying reader.

0 commit comments

Comments
 (0)