Skip to content

Commit 9390ec4

Browse files
committed
feat: add custom_metadata support to RecordBatch with IPC read/write
1 parent 4fa8d2f commit 9390ec4

9 files changed

Lines changed: 418 additions & 9 deletions

File tree

arrow-array/src/record_batch.rs

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
use crate::cast::AsArray;
2222
use crate::{Array, ArrayRef, StructArray, new_empty_array};
2323
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
24+
use std::collections::HashMap;
2425
use std::ops::Index;
2526
use std::sync::Arc;
2627

@@ -229,6 +230,12 @@ pub struct RecordBatch {
229230
///
230231
/// This is stored separately from the columns to handle the case of no columns
231232
row_count: usize,
233+
234+
/// Per-batch custom metadata
235+
///
236+
/// This corresponds to the `custom_metadata` field on the IPC `Message`
237+
/// flatbuffer, allowing per-batch metadata separate from schema-level metadata.
238+
custom_metadata: HashMap<String, String>,
232239
}
233240

234241
impl RecordBatch {
@@ -289,6 +296,7 @@ impl RecordBatch {
289296
schema,
290297
columns,
291298
row_count,
299+
custom_metadata: HashMap::new(),
292300
}
293301
}
294302

@@ -316,6 +324,7 @@ impl RecordBatch {
316324
schema,
317325
columns,
318326
row_count: 0,
327+
custom_metadata: HashMap::new(),
319328
}
320329
}
321330

@@ -390,14 +399,30 @@ impl RecordBatch {
390399
schema,
391400
columns,
392401
row_count,
402+
custom_metadata: HashMap::new(),
393403
})
394404
}
395405

396406
/// Return the schema, columns and row count of this [`RecordBatch`]
407+
///
408+
/// Note: this discards any [`Self::custom_metadata`]. Use
409+
/// [`Self::into_parts_with_custom_metadata`] to also retrieve it.
397410
pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
398411
(self.schema, self.columns, self.row_count)
399412
}
400413

414+
/// Return the schema, columns, row count and custom metadata of this [`RecordBatch`]
415+
pub fn into_parts_with_custom_metadata(
416+
self,
417+
) -> (SchemaRef, Vec<ArrayRef>, usize, HashMap<String, String>) {
418+
(
419+
self.schema,
420+
self.columns,
421+
self.row_count,
422+
self.custom_metadata,
423+
)
424+
}
425+
401426
/// Override the schema of this [`RecordBatch`]
402427
///
403428
/// Returns an error if `schema` is not a superset of the current schema
@@ -416,6 +441,7 @@ impl RecordBatch {
416441
schema,
417442
columns: self.columns,
418443
row_count: self.row_count,
444+
custom_metadata: self.custom_metadata,
419445
})
420446
}
421447

@@ -451,6 +477,25 @@ impl RecordBatch {
451477
&mut schema.metadata
452478
}
453479

480+
/// Returns a reference to the per-batch custom metadata.
481+
///
482+
/// This metadata corresponds to the `custom_metadata` field on the IPC
483+
/// `Message` flatbuffer, separate from schema-level metadata.
484+
pub fn custom_metadata(&self) -> &HashMap<String, String> {
485+
&self.custom_metadata
486+
}
487+
488+
/// Returns a mutable reference to the per-batch custom metadata.
489+
pub fn custom_metadata_mut(&mut self) -> &mut HashMap<String, String> {
490+
&mut self.custom_metadata
491+
}
492+
493+
/// Sets the per-batch custom metadata, returning `self`.
494+
pub fn with_custom_metadata(mut self, metadata: HashMap<String, String>) -> Self {
495+
self.custom_metadata = metadata;
496+
self
497+
}
498+
454499
/// Projects the schema onto the specified columns
455500
pub fn project(&self, indices: &[usize]) -> Result<RecordBatch, ArrowError> {
456501
let projected_schema = self.schema.project(indices)?;
@@ -475,7 +520,8 @@ impl RecordBatch {
475520
SchemaRef::new(projected_schema),
476521
batch_fields,
477522
self.row_count,
478-
))
523+
)
524+
.with_custom_metadata(self.custom_metadata.clone()))
479525
}
480526
}
481527

@@ -570,7 +616,9 @@ impl RecordBatch {
570616
}
571617
}
572618
}
619+
let custom_metadata = self.custom_metadata.clone();
573620
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
621+
.map(|b| b.with_custom_metadata(custom_metadata))
574622
}
575623

576624
/// Returns the number of columns in the record batch.
@@ -691,6 +739,7 @@ impl RecordBatch {
691739
schema: self.schema.clone(),
692740
columns,
693741
row_count: length,
742+
custom_metadata: self.custom_metadata.clone(),
694743
}
695744
}
696745

@@ -864,6 +913,7 @@ impl From<StructArray> for RecordBatch {
864913
schema: Arc::new(Schema::new(fields)),
865914
row_count,
866915
columns,
916+
custom_metadata: HashMap::new(),
867917
}
868918
}
869919
}
@@ -1792,4 +1842,81 @@ mod tests {
17921842
assert!(col.is_null(1));
17931843
assert!(col.is_valid(2));
17941844
}
1845+
1846+
#[test]
1847+
fn test_with_custom_metadata() {
1848+
let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1849+
assert!(batch.custom_metadata().is_empty());
1850+
1851+
let mut metadata = HashMap::new();
1852+
metadata.insert("key".to_string(), "value".to_string());
1853+
let batch = batch.with_custom_metadata(metadata.clone());
1854+
assert_eq!(batch.custom_metadata(), &metadata);
1855+
}
1856+
1857+
#[test]
1858+
fn test_custom_metadata_mut() {
1859+
let mut batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1860+
batch
1861+
.custom_metadata_mut()
1862+
.insert("key".to_string(), "value".to_string());
1863+
assert_eq!(
1864+
batch.custom_metadata().get("key"),
1865+
Some(&"value".to_string())
1866+
);
1867+
}
1868+
1869+
#[test]
1870+
fn test_slice_preserves_custom_metadata() {
1871+
let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1872+
let mut metadata = HashMap::new();
1873+
metadata.insert("k".to_string(), "v".to_string());
1874+
let batch = batch.with_custom_metadata(metadata.clone());
1875+
1876+
let sliced = batch.slice(0, 2);
1877+
assert_eq!(sliced.custom_metadata(), &metadata);
1878+
}
1879+
1880+
#[test]
1881+
fn test_project_preserves_custom_metadata() {
1882+
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
1883+
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"]));
1884+
let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1885+
1886+
let mut metadata = HashMap::new();
1887+
metadata.insert("k".to_string(), "v".to_string());
1888+
let batch = batch.with_custom_metadata(metadata.clone());
1889+
1890+
let projected = batch.project(&[0]).unwrap();
1891+
assert_eq!(projected.custom_metadata(), &metadata);
1892+
}
1893+
1894+
#[test]
1895+
fn test_into_parts_with_custom_metadata() {
1896+
let batch = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1897+
let mut metadata = HashMap::new();
1898+
metadata.insert("k".to_string(), "v".to_string());
1899+
let batch = batch.with_custom_metadata(metadata.clone());
1900+
1901+
let (schema, columns, row_count, custom_metadata) = batch.into_parts_with_custom_metadata();
1902+
assert_eq!(schema.fields().len(), 1);
1903+
assert_eq!(columns.len(), 1);
1904+
assert_eq!(row_count, 3);
1905+
assert_eq!(custom_metadata, metadata);
1906+
}
1907+
1908+
#[test]
1909+
fn test_custom_metadata_equality() {
1910+
let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1911+
let batch2 = record_batch!(("a", Int32, [1, 2, 3])).unwrap();
1912+
1913+
// Both empty metadata -> equal
1914+
assert_eq!(batch1, batch2);
1915+
1916+
// Different metadata -> not equal
1917+
let mut metadata = HashMap::new();
1918+
metadata.insert("k".to_string(), "v".to_string());
1919+
let batch1 = batch1.with_custom_metadata(metadata);
1920+
assert_ne!(batch1, batch2);
1921+
}
17951922
}

arrow-flight/src/utils.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ pub fn flight_data_to_arrow_batch(
6161
let message = arrow_ipc::root_as_message(&data.data_header[..])
6262
.map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
6363

64+
let custom_metadata = arrow_ipc::reader::message_custom_metadata(&message);
65+
6466
message
6567
.header_as_record_batch()
6668
.ok_or_else(|| {
@@ -77,6 +79,13 @@ pub fn flight_data_to_arrow_batch(
7779
None,
7880
&message.version(),
7981
)
82+
.map(|rb| {
83+
if custom_metadata.is_empty() {
84+
rb
85+
} else {
86+
rb.with_custom_metadata(custom_metadata)
87+
}
88+
})
8089
})?
8190
}
8291

arrow-ipc/src/reader.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@ use crate::r#gen::Message::{self};
4747
use crate::{Block, CONTINUATION_MARKER, FieldNode, MetadataVersion};
4848
use DataType::*;
4949

50+
/// Extract `custom_metadata` key-value pairs from an IPC [`Message`].
51+
///
52+
/// Returns an empty [`HashMap`] if the message has no custom metadata.
53+
pub fn message_custom_metadata(message: &crate::Message) -> HashMap<String, String> {
54+
let mut metadata = HashMap::new();
55+
if let Some(list) = message.custom_metadata() {
56+
for kv in list {
57+
if let (Some(k), Some(v)) = (kv.key(), kv.value()) {
58+
metadata.insert(k.to_string(), v.to_string());
59+
}
60+
}
61+
}
62+
metadata
63+
}
64+
5065
/// Read a buffer based on offset and length
5166
/// From <https://github.com/apache/arrow/blob/6a936c4ff5007045e86f65f1a6b6c3c955ad5103/format/Message.fbs#L58>
5267
/// Each constituent buffer is first compressed with the indicated
@@ -470,6 +485,8 @@ pub struct RecordBatchDecoder<'a> {
470485
///
471486
/// See [`FileDecoder::with_skip_validation`] for details.
472487
skip_validation: UnsafeFlag,
488+
/// Per-batch custom metadata to attach to the decoded RecordBatch
489+
custom_metadata: HashMap<String, String>,
473490
}
474491

475492
impl<'a> RecordBatchDecoder<'a> {
@@ -506,6 +523,7 @@ impl<'a> RecordBatchDecoder<'a> {
506523
projection: None,
507524
require_alignment: false,
508525
skip_validation: UnsafeFlag::new(),
526+
custom_metadata: HashMap::new(),
509527
})
510528
}
511529

@@ -544,6 +562,12 @@ impl<'a> RecordBatchDecoder<'a> {
544562
self
545563
}
546564

565+
/// Set per-batch custom metadata to attach to the decoded [`RecordBatch`]
566+
pub(crate) fn with_custom_metadata(mut self, custom_metadata: HashMap<String, String>) -> Self {
567+
self.custom_metadata = custom_metadata;
568+
self
569+
}
570+
547571
/// Read the record batch, consuming the reader
548572
fn read_record_batch(mut self) -> Result<RecordBatch, ArrowError> {
549573
let mut variadic_counts: VecDeque<i64> = self
@@ -554,9 +578,10 @@ impl<'a> RecordBatchDecoder<'a> {
554578
.collect();
555579

556580
let options = RecordBatchOptions::new().with_row_count(Some(self.batch.length() as usize));
581+
let custom_metadata = std::mem::take(&mut self.custom_metadata);
557582

558583
let schema = Arc::clone(&self.schema);
559-
if let Some(projection) = self.projection {
584+
let batch = if let Some(projection) = self.projection {
560585
let mut arrays = vec![];
561586
// project fields
562587
for (idx, field) in schema.fields().iter().enumerate() {
@@ -608,7 +633,15 @@ impl<'a> RecordBatchDecoder<'a> {
608633
assert!(variadic_counts.is_empty());
609634
RecordBatch::try_new_with_options(schema, children, &options)
610635
}
611-
}
636+
};
637+
638+
batch.map(|b| {
639+
if custom_metadata.is_empty() {
640+
b
641+
} else {
642+
b.with_custom_metadata(custom_metadata)
643+
}
644+
})
612645
}
613646

614647
fn next_buffer(&mut self) -> Result<Buffer, ArrowError> {
@@ -752,6 +785,11 @@ impl<'a> RecordBatchDecoder<'a> {
752785
/// and copy over the data if any array data in the input `buf` is not properly aligned.
753786
/// (Properly aligned array data will remain zero-copy.)
754787
/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`].
788+
///
789+
/// Note: this function operates on the inner `RecordBatch` flatbuffer, not the
790+
/// outer `Message` envelope. Message-level `custom_metadata` is not extracted.
791+
/// Callers who need it should use [`message_custom_metadata`] on the `Message`
792+
/// and apply it via [`RecordBatch::with_custom_metadata`].
755793
pub fn read_record_batch(
756794
buf: &Buffer,
757795
batch: crate::RecordBatch,
@@ -1114,6 +1152,7 @@ impl FileDecoder {
11141152
let batch = message.header_as_record_batch().ok_or_else(|| {
11151153
ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
11161154
})?;
1155+
let custom_metadata = message_custom_metadata(&message);
11171156
// read the block that makes up the record batch into a buffer
11181157
RecordBatchDecoder::try_new(
11191158
&buf.slice(block.metaDataLength() as _),
@@ -1125,6 +1164,7 @@ impl FileDecoder {
11251164
.with_projection(self.projection.as_deref())
11261165
.with_require_alignment(self.require_alignment)
11271166
.with_skip_validation(self.skip_validation.clone())
1167+
.with_custom_metadata(custom_metadata)
11281168
.read_record_batch()
11291169
.map(Some)
11301170
}
@@ -1679,6 +1719,7 @@ impl<R: Read> StreamReader<R> {
16791719
ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
16801720
})?;
16811721

1722+
let custom_metadata = message_custom_metadata(&message);
16821723
let version = message.version();
16831724
let schema = self.schema.clone();
16841725
let record_batch = RecordBatchDecoder::try_new(
@@ -1691,6 +1732,7 @@ impl<R: Read> StreamReader<R> {
16911732
.with_projection(self.projection.as_ref().map(|x| x.0.as_ref()))
16921733
.with_require_alignment(false)
16931734
.with_skip_validation(self.skip_validation.clone())
1735+
.with_custom_metadata(custom_metadata)
16941736
.read_record_batch()?;
16951737
IpcMessage::RecordBatch(record_batch)
16961738
}

arrow-ipc/src/reader/stream.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use arrow_data::UnsafeFlag;
2525
use arrow_schema::{ArrowError, SchemaRef};
2626

2727
use crate::convert::MessageBuffer;
28-
use crate::reader::{RecordBatchDecoder, read_dictionary_impl};
28+
use crate::reader::{RecordBatchDecoder, message_custom_metadata, read_dictionary_impl};
2929
use crate::{CONTINUATION_MARKER, MessageHeader};
3030

3131
/// A low-level interface for reading [`RecordBatch`] data from a stream of bytes
@@ -236,6 +236,7 @@ impl StreamDecoder {
236236
}
237237
MessageHeader::RecordBatch => {
238238
let batch = message.header_as_record_batch().unwrap();
239+
let custom_metadata = message_custom_metadata(&message);
239240
let schema = self.schema.clone().ok_or_else(|| {
240241
ArrowError::IpcError("Missing schema".to_string())
241242
})?;
@@ -247,6 +248,7 @@ impl StreamDecoder {
247248
&version,
248249
)?
249250
.with_require_alignment(self.require_alignment)
251+
.with_custom_metadata(custom_metadata)
250252
.read_record_batch()?;
251253
self.state = DecoderState::default();
252254
return Ok(Some(batch));

0 commit comments

Comments
 (0)