Skip to content

Commit 31227eb

Browse files
committed
feat: add custom_metadata support to RecordBatch with IPC read/write
1 parent 9d0e8be commit 31227eb

9 files changed

Lines changed: 418 additions & 9 deletions

File tree

arrow-array/src/record_batch.rs

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
use crate::cast::AsArray;
2222
use crate::{Array, ArrayRef, StructArray, new_empty_array};
2323
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24+
use std::collections::HashMap;
2425
use std::ops::Index;
2526
use std::sync::Arc;
2627

@@ -207,6 +208,12 @@ pub struct RecordBatch {
207208
///
208209
/// This is stored separately from the columns to handle the case of no columns
209210
row_count: usize,
211+
212+
/// Per-batch custom metadata
213+
///
214+
/// This corresponds to the `custom_metadata` field on the IPC `Message`
215+
/// flatbuffer, allowing per-batch metadata separate from schema-level metadata.
216+
custom_metadata: HashMap<String, String>,
210217
}
211218

212219
impl RecordBatch {
@@ -267,6 +274,7 @@ impl RecordBatch {
267274
schema,
268275
columns,
269276
row_count,
277+
custom_metadata: HashMap::new(),
270278
}
271279
}
272280

@@ -294,6 +302,7 @@ impl RecordBatch {
294302
schema,
295303
columns,
296304
row_count: 0,
305+
custom_metadata: HashMap::new(),
297306
}
298307
}
299308

@@ -368,14 +377,30 @@ impl RecordBatch {
368377
schema,
369378
columns,
370379
row_count,
380+
custom_metadata: HashMap::new(),
371381
})
372382
}
373383

374384
/// Return the schema, columns and row count of this [`RecordBatch`]
385+
///
386+
/// Note: this discards any [`Self::custom_metadata`]. Use
387+
/// [`Self::into_parts_with_custom_metadata`] to also retrieve it.
375388
pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
376389
(self.schema, self.columns, self.row_count)
377390
}
378391

392+
/// Return the schema, columns, row count and custom metadata of this [`RecordBatch`]
393+
pub fn into_parts_with_custom_metadata(
394+
self,
395+
) -> (SchemaRef, Vec<ArrayRef>, usize, HashMap<String, String>) {
396+
(
397+
self.schema,
398+
self.columns,
399+
self.row_count,
400+
self.custom_metadata,
401+
)
402+
}
403+
379404
/// Override the schema of this [`RecordBatch`]
380405
///
381406
/// Returns an error if `schema` is not a superset of the current schema
@@ -394,6 +419,7 @@ impl RecordBatch {
394419
schema,
395420
columns: self.columns,
396421
row_count: self.row_count,
422+
custom_metadata: self.custom_metadata,
397423
})
398424
}
399425

@@ -429,6 +455,25 @@ impl RecordBatch {
429455
&mut schema.metadata
430456
}
431457

458+
/// Returns a reference to the per-batch custom metadata.
459+
///
460+
/// This metadata corresponds to the `custom_metadata` field on the IPC
461+
/// `Message` flatbuffer, separate from schema-level metadata.
462+
pub fn custom_metadata(&self) -> &HashMap<String, String> {
463+
&self.custom_metadata
464+
}
465+
466+
/// Returns a mutable reference to the per-batch custom metadata.
467+
pub fn custom_metadata_mut(&mut self) -> &mut HashMap<String, String> {
468+
&mut self.custom_metadata
469+
}
470+
471+
/// Sets the per-batch custom metadata, returning `self`.
472+
pub fn with_custom_metadata(mut self, metadata: HashMap<String, String>) -> Self {
473+
self.custom_metadata = metadata;
474+
self
475+
}
476+
432477
/// Projects the schema onto the specified columns
433478
pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
434479
let projected_schema = self.schema.project(indices)?;
@@ -453,7 +498,8 @@ impl RecordBatch {
453498
SchemaRef::new(projected_schema),
454499
batch_fields,
455500
self.row_count,
456-
))
501+
)
502+
.with_custom_metadata(self.custom_metadata.clone()))
457503
}
458504
}
459505

@@ -556,7 +602,9 @@ impl RecordBatch {
556602
}
557603
}
558604
}
605+
let custom_metadata = self.custom_metadata.clone();
559606
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
607+
.map(|b| b.with_custom_metadata(custom_metadata))
560608
}
561609

562610
/// Returns the number of columns in the record batch.
@@ -677,6 +725,7 @@ impl RecordBatch {
677725
schema: self.schema.clone(),
678726
columns,
679727
row_count: length,
728+
custom_metadata: self.custom_metadata.clone(),
680729
}
681730
}
682731

@@ -836,6 +885,7 @@ impl From<StructArray> for RecordBatch {
836885
schema: Arc::new(Schema::new(fields)),
837886
row_count,
838887
columns,
888+
custom_metadata: HashMap::new(),
839889
}
840890
}
841891
}
@@ -1706,4 +1756,81 @@ mod tests {
17061756
"bar"
17071757
);
17081758
}
1759+
1760+
#[test]
1761+
fn test_with_custom_metadata() {
1762+
let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1763+
assert!(batch.custom_metadata().is_empty());
1764+
1765+
let mut metadata = HashMap::new();
1766+
metadata.insert("key".to_string(), "value".to_string());
1767+
let batch = batch.with_custom_metadata(metadata.clone());
1768+
assert_eq!(batch.custom_metadata(), &metadata);
1769+
}
1770+
1771+
#[test]
1772+
fn test_custom_metadata_mut() {
1773+
let mut batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1774+
batch
1775+
.custom_metadata_mut()
1776+
.insert("key".to_string(), "value".to_string());
1777+
assert_eq!(
1778+
batch.custom_metadata().get("key"),
1779+
Some(&"value".to_string())
1780+
);
1781+
}
1782+
1783+
#[test]
1784+
fn test_slice_preserves_custom_metadata() {
1785+
let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1786+
let mut metadata = HashMap::new();
1787+
metadata.insert("k".to_string(), "v".to_string());
1788+
let batch = batch.with_custom_metadata(metadata.clone());
1789+
1790+
let sliced = batch.slice(0, 2);
1791+
assert_eq!(sliced.custom_metadata(), &metadata);
1792+
}
1793+
1794+
#[test]
1795+
fn test_project_preserves_custom_metadata() {
1796+
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
1797+
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1798+
let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1799+
1800+
let mut metadata = HashMap::new();
1801+
metadata.insert("k".to_string(), "v".to_string());
1802+
let batch = batch.with_custom_metadata(metadata.clone());
1803+
1804+
let projected = batch.project(&[0]).unwrap();
1805+
assert_eq!(projected.custom_metadata(), &metadata);
1806+
}
1807+
1808+
#[test]
1809+
fn test_into_parts_with_custom_metadata() {
1810+
let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1811+
let mut metadata = HashMap::new();
1812+
metadata.insert("k".to_string(), "v".to_string());
1813+
let batch = batch.with_custom_metadata(metadata.clone());
1814+
1815+
let (schema, columns, row_count, custom_metadata) = batch.into_parts_with_custom_metadata();
1816+
assert_eq!(schema.fields().len(), 1);
1817+
assert_eq!(columns.len(), 1);
1818+
assert_eq!(row_count, 3);
1819+
assert_eq!(custom_metadata, metadata);
1820+
}
1821+
1822+
#[test]
1823+
fn test_custom_metadata_equality() {
1824+
let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1825+
let batch2 = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1826+
1827+
// Both empty metadata -> equal
1828+
assert_eq!(batch1, batch2);
1829+
1830+
// Different metadata -> not equal
1831+
let mut metadata = HashMap::new();
1832+
metadata.insert("k".to_string(), "v".to_string());
1833+
let batch1 = batch1.with_custom_metadata(metadata);
1834+
assert_ne!(batch1, batch2);
1835+
}
17091836
}

arrow-flight/src/utils.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ pub fn flight_data_to_arrow_batch(
6161
let message = arrow_ipc::root_as_message(&data.data_header[..])
6262
.map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
6363

64+
let custom_metadata = arrow_ipc::reader::message_custom_metadata(&message);
65+
6466
message
6567
.header_as_record_batch()
6668
.ok_or_else(|| {
@@ -77,6 +79,13 @@ pub fn flight_data_to_arrow_batch(
7779
None,
7880
&message.version(),
7981
)
82+
.map(|rb| {
83+
if custom_metadata.is_empty() {
84+
rb
85+
} else {
86+
rb.with_custom_metadata(custom_metadata)
87+
}
88+
})
8089
})?
8190
}
8291

arrow-ipc/src/reader.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ use crate::r#gen::Message::{self};
4747
use crate::{Block, CONTINUATION_MARKER, FieldNode, MetadataVersion};
4848
use DataType::*;
4949

50+
/// Extract `custom_metadata` key-value pairs from an IPC [`Message`].
51+
///
52+
/// Returns an empty [`HashMap`] if the message has no custom metadata.
53+
pub fn message_custom_metadata(message: &crate::Message) -> HashMap<String, String> {
54+
let mut metadata = HashMap::new();
55+
if let Some(list) = message.custom_metadata() {
56+
for kv in list {
57+
if let (Some(k), Some(v)) = (kv.key(), kv.value()) {
58+
metadata.insert(k.to_string(), v.to_string());
59+
}
60+
}
61+
}
62+
metadata
63+
}
64+
5065
/// Read a buffer based on offset and length
5166
/// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58>
5267
/// Each constituent buffer is first compressed with the indicated
@@ -467,6 +482,8 @@ pub struct RecordBatchDecoder<'a> {
467482
///
468483
/// See [`FileDecoder::with_skip_validation`] for details.
469484
skip_validation: UnsafeFlag,
485+
/// Per-batch custom metadata to attach to the decoded RecordBatch
486+
custom_metadata: HashMap<String, String>,
470487
}
471488

472489
impl<'a> RecordBatchDecoder<'a> {
@@ -503,6 +520,7 @@ impl<'a> RecordBatchDecoder<'a> {
503520
projection: None,
504521
require_alignment: false,
505522
skip_validation: UnsafeFlag::new(),
523+
custom_metadata: HashMap::new(),
506524
})
507525
}
508526

@@ -541,6 +559,12 @@ impl<'a> RecordBatchDecoder<'a> {
541559
self
542560
}
543561

562+
/// Set per-batch custom metadata to attach to the decoded [`RecordBatch`]
563+
pub(crate) fn with_custom_metadata(mut self, custom_metadata: HashMap<String, String>) -> Self {
564+
self.custom_metadata = custom_metadata;
565+
self
566+
}
567+
544568
/// Read the record batch, consuming the reader
545569
fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
546570
let mut variadic_counts: VecDeque<i64> = self
@@ -551,9 +575,10 @@ impl<'a> RecordBatchDecoder<'a> {
551575
.collect();
552576

553577
let options = RecordBatchOptions::new().with_row_count(Some(self.batch.length() as usize));
578+
let custom_metadata = std::mem::take(&mut self.custom_metadata);
554579

555580
let schema = Arc::clone(&self.schema);
556-
if let Some(projection) = self.projection {
581+
let batch = if let Some(projection) = self.projection {
557582
let mut arrays = vec![];
558583
// project fields
559584
for (idx, field) in schema.fields().iter().enumerate() {
@@ -605,7 +630,15 @@ impl<'a> RecordBatchDecoder<'a> {
605630
assert!(variadic_counts.is_empty());
606631
RecordBatch::try_new_with_options(schema, children, &options)
607632
}
608-
}
633+
};
634+
635+
batch.map(|b| {
636+
if custom_metadata.is_empty() {
637+
b
638+
} else {
639+
b.with_custom_metadata(custom_metadata)
640+
}
641+
})
609642
}
610643

611644
fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
@@ -718,6 +751,11 @@ impl<'a> RecordBatchDecoder<'a> {
718751
/// and copy over the data if any array data in the input `buf` is not properly aligned.
719752
/// (Properly aligned array data will remain zero-copy.)
720753
/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`].
754+
///
755+
/// Note: this function operates on the inner `RecordBatch` flatbuffer, not the
756+
/// outer `Message` envelope. Message-level `custom_metadata` is not extracted.
757+
/// Callers who need it should use [`message_custom_metadata`] on the `Message`
758+
/// and apply it via [`RecordBatch::with_custom_metadata`].
721759
pub fn read_record_batch(
722760
buf: &Buffer,
723761
batch: crate::RecordBatch,
@@ -1080,6 +1118,7 @@ impl FileDecoder {
10801118
let batch = message.header_as_record_batch().ok_or_else(|| {
10811119
ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
10821120
})?;
1121+
let custom_metadata = message_custom_metadata(&message);
10831122
// read the block that makes up the record batch into a buffer
10841123
RecordBatchDecoder::try_new(
10851124
&buf.slice(block.metaDataLength() as _),
@@ -1091,6 +1130,7 @@ impl FileDecoder {
10911130
.with_projection(self.projection.as_deref())
10921131
.with_require_alignment(self.require_alignment)
10931132
.with_skip_validation(self.skip_validation.clone())
1133+
.with_custom_metadata(custom_metadata)
10941134
.read_record_batch()
10951135
.map(Some)
10961136
}
@@ -1645,6 +1685,7 @@ impl<R: Read> StreamReader<R> {
16451685
ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
16461686
})?;
16471687

1688+
let custom_metadata = message_custom_metadata(&message);
16481689
let version = message.version();
16491690
let schema = self.schema.clone();
16501691
let record_batch = RecordBatchDecoder::try_new(
@@ -1657,6 +1698,7 @@ impl<R: Read> StreamReader<R> {
16571698
.with_projection(self.projection.as_ref().map(|x| x.0.as_ref()))
16581699
.with_require_alignment(false)
16591700
.with_skip_validation(self.skip_validation.clone())
1701+
.with_custom_metadata(custom_metadata)
16601702
.read_record_batch()?;
16611703
IpcMessage::RecordBatch(record_batch)
16621704
}

arrow-ipc/src/reader/stream.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use arrow_data::UnsafeFlag;
2525
use arrow_schema::{ArrowError, SchemaRef};
2626

2727
use crate::convert::MessageBuffer;
28-
use crate::reader::{RecordBatchDecoder, read_dictionary_impl};
28+
use crate::reader::{RecordBatchDecoder, message_custom_metadata, read_dictionary_impl};
2929
use crate::{CONTINUATION_MARKER, MessageHeader};
3030

3131
/// A low-level interface for reading [`RecordBatch`] data from a stream of bytes
@@ -223,6 +223,7 @@ impl StreamDecoder {
223223
}
224224
MessageHeader::RecordBatch => {
225225
let batch = message.header_as_record_batch().unwrap();
226+
let custom_metadata = message_custom_metadata(&message);
226227
let schema = self.schema.clone().ok_or_else(|| {
227228
ArrowError::IpcError("Missing schema".to_string())
228229
})?;
@@ -234,6 +235,7 @@ impl StreamDecoder {
234235
&version,
235236
)?
236237
.with_require_alignment(self.require_alignment)
238+
.with_custom_metadata(custom_metadata)
237239
.read_record_batch()?;
238240
self.state = DecoderState::default();
239241
return Ok(Some(batch));

0 commit comments

Comments
 (0)