Skip to content

Commit 5a815e3

Browse files
committed
add flattened union builder
1 parent 7d66061 commit 5a815e3

6 files changed

Lines changed: 234 additions & 6 deletions

File tree

serde_arrow/src/internal/arrow/data_type.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@ pub struct Field {
1212
pub metadata: HashMap<String, String>,
1313
}
1414

15+
impl Field {
16+
pub fn from_flattened_enum(&self) -> bool {
17+
self.name.contains("::")
18+
}
19+
20+
pub fn enum_variant_name(&self) -> Option<&str> {
21+
if self.from_flattened_enum() {
22+
self.name.split("::").next()
23+
} else {
24+
None
25+
}
26+
}
27+
}
28+
1529
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
1630
#[non_exhaustive]
1731
pub enum DataType {

serde_arrow/src/internal/schema/tracer.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,17 @@ impl Tracer {
116116
let tracing_mode = dispatch_tracer!(self, tracer => tracer.options.tracing_mode);
117117

118118
let fields = match root.data_type {
119-
DataType::Struct(children) => children,
119+
DataType::Struct(children) => {
120+
if let Some(strategy) = root
121+
.metadata.get(STRATEGY_KEY) {
122+
if *strategy == Strategy::EnumsWithNamedFieldsAsStructs.to_string() {
123+
// TODO: combine with fail messaging below
124+
fail!("Schema tracing is not directly supported for the root data Union. Consider using the `Item` / `Items` wrappers.");
125+
}
126+
}
127+
128+
children
129+
}
120130
DataType::Null => fail!("No records found to determine schema"),
121131
dt => fail!(
122132
concat!(

serde_arrow/src/internal/serialization/array_builder.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ use super::{
1313
date64_builder::Date64Builder, decimal_builder::DecimalBuilder,
1414
dictionary_utf8_builder::DictionaryUtf8Builder, duration_builder::DurationBuilder,
1515
fixed_size_binary_builder::FixedSizeBinaryBuilder,
16-
fixed_size_list_builder::FixedSizeListBuilder, float_builder::FloatBuilder,
17-
int_builder::IntBuilder, list_builder::ListBuilder, map_builder::MapBuilder,
18-
null_builder::NullBuilder, simple_serializer::SimpleSerializer, struct_builder::StructBuilder,
19-
time_builder::TimeBuilder, union_builder::UnionBuilder,
16+
fixed_size_list_builder::FixedSizeListBuilder, flattened_union_builder::FlattenedUnionBuilder,
17+
float_builder::FloatBuilder, int_builder::IntBuilder, list_builder::ListBuilder,
18+
map_builder::MapBuilder, null_builder::NullBuilder, simple_serializer::SimpleSerializer,
19+
struct_builder::StructBuilder, time_builder::TimeBuilder, union_builder::UnionBuilder,
2020
unknown_variant_builder::UnknownVariantBuilder, utf8_builder::Utf8Builder,
2121
};
2222

@@ -53,6 +53,7 @@ pub enum ArrayBuilder {
5353
LargeUtf8(Utf8Builder<i64>),
5454
DictionaryUtf8(DictionaryUtf8Builder),
5555
Union(UnionBuilder),
56+
FlattenedUnion(FlattenedUnionBuilder),
5657
UnknownVariant(UnknownVariantBuilder),
5758
}
5859

@@ -90,6 +91,7 @@ macro_rules! dispatch {
9091
$wrapper::Struct($name) => $expr,
9192
$wrapper::DictionaryUtf8($name) => $expr,
9293
$wrapper::Union($name) => $expr,
94+
$wrapper::FlattenedUnion($name) => $expr,
9395
$wrapper::UnknownVariant($name) => $expr,
9496
}
9597
};
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
use std::collections::BTreeMap;
2+
3+
use crate::internal::{
4+
arrow::{Array, StructArray},
5+
error::{fail, set_default, try_, Context, ContextSupport, Result},
6+
utils::array_ext::{ArrayExt, CountArray, SeqArrayExt},
7+
};
8+
9+
use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer};
10+
11+
#[derive(Debug, Clone)]
12+
pub struct FlattenedUnionBuilder {
13+
pub path: String,
14+
pub fields: Vec<ArrayBuilder>,
15+
pub seq: CountArray,
16+
}
17+
18+
impl FlattenedUnionBuilder {
19+
pub fn new(path: String, fields: Vec<ArrayBuilder>) -> Self {
20+
Self {
21+
path,
22+
fields,
23+
seq: CountArray::new(true),
24+
}
25+
}
26+
27+
pub fn take_self(&mut self) -> Self {
28+
Self {
29+
path: self.path.clone(),
30+
fields: self.fields.clone(),
31+
seq: self.seq.take(),
32+
}
33+
}
34+
35+
pub fn take(&mut self) -> ArrayBuilder {
36+
ArrayBuilder::FlattenedUnion(self.take_self())
37+
}
38+
39+
pub fn is_nullable(&self) -> bool {
40+
self.seq.validity.is_some()
41+
}
42+
43+
pub fn into_array(self) -> Result<Array> {
44+
let mut fields = Vec::new();
45+
46+
for builder in self.fields.into_iter() {
47+
let ArrayBuilder::Struct(builder) = builder else {
48+
fail!("enum variant not built as a struct")
49+
};
50+
51+
for (sub_builder, sub_meta) in builder.fields.into_iter() {
52+
fields.push((sub_builder.into_array()?, sub_meta));
53+
}
54+
}
55+
56+
Ok(Array::Struct(StructArray {
57+
len: fields.len(),
58+
validity: self.seq.validity,
59+
fields,
60+
}))
61+
}
62+
}
63+
64+
impl FlattenedUnionBuilder {
65+
pub fn serialize_variant(&mut self, variant_index: u32) -> Result<&mut ArrayBuilder> {
66+
// self.len += 1;
67+
68+
let variant_index = variant_index as usize;
69+
70+
// call push_none for any variant that was not selected
71+
for (idx, builder) in self.fields.iter_mut().enumerate() {
72+
if idx != variant_index {
73+
builder.serialize_none()?;
74+
self.seq.push_seq_none()?;
75+
}
76+
}
77+
78+
let Some(variant_builder) = self.fields.get_mut(variant_index) else {
79+
fail!("Could not find variant {variant_index} in Union");
80+
};
81+
82+
Ok(variant_builder)
83+
}
84+
}
85+
86+
impl Context for FlattenedUnionBuilder {
87+
fn annotate(&self, annotations: &mut BTreeMap<String, String>) {
88+
set_default(annotations, "field", &self.path);
89+
set_default(annotations, "data_type", "Union(..)");
90+
}
91+
}
92+
93+
impl SimpleSerializer for FlattenedUnionBuilder {
94+
// fn serialize_unit_variant(
95+
// &mut self,
96+
// _: &'static str,
97+
// variant_index: u32,
98+
// _: &'static str,
99+
// ) -> Result<()> {
100+
// let mut ctx = BTreeMap::new();
101+
// self.annotate(&mut ctx);
102+
103+
// try_(|| self.serialize_variant(variant_index)?.serialize_unit()).ctx(&ctx)
104+
// }
105+
106+
// fn serialize_newtype_variant<V: serde::Serialize + ?Sized>(
107+
// &mut self,
108+
// _: &'static str,
109+
// variant_index: u32,
110+
// _: &'static str,
111+
// value: &V,
112+
// ) -> Result<()> {
113+
// let mut ctx = BTreeMap::new();
114+
// self.annotate(&mut ctx);
115+
116+
// try_(|| {
117+
// let variant_builder = self.serialize_variant(variant_index)?;
118+
// value.serialize(Mut(variant_builder))
119+
// })
120+
// .ctx(&ctx)
121+
// }
122+
123+
fn serialize_struct_variant_start<'this>(
124+
&'this mut self,
125+
_: &'static str,
126+
variant_index: u32,
127+
variant: &'static str,
128+
len: usize,
129+
) -> Result<&'this mut ArrayBuilder> {
130+
let mut ctx = BTreeMap::new();
131+
self.annotate(&mut ctx);
132+
self.seq.start_seq()?;
133+
self.seq.push_seq_elements(1)?;
134+
135+
try_(|| {
136+
let variant_builder = self.serialize_variant(variant_index)?;
137+
variant_builder.serialize_struct_start(variant, len)?;
138+
Ok(variant_builder)
139+
})
140+
.ctx(&ctx)
141+
}
142+
143+
// fn serialize_tuple_variant_start<'this>(
144+
// &'this mut self,
145+
// _: &'static str,
146+
// variant_index: u32,
147+
// variant: &'static str,
148+
// len: usize,
149+
// ) -> Result<&'this mut ArrayBuilder> {
150+
// let mut ctx = BTreeMap::new();
151+
// self.annotate(&mut ctx);
152+
153+
// try_(|| {
154+
// let variant_builder = self.serialize_variant(variant_index)?;
155+
// variant_builder.serialize_tuple_struct_start(variant, len)?;
156+
// Ok(variant_builder)
157+
// })
158+
// .ctx(&ctx)
159+
// }
160+
}
161+
162+
// #[cfg(test)]
163+
// mod tests {
164+
// fn test_serialize_union() {
165+
// #[derive(Serialize, Deserialize)]
166+
// enum Number {
167+
// Real { value: f32 },
168+
// Complex { i: f32, j: f32 },
169+
// }
170+
171+
// let numbers = vec![];
172+
// }
173+
// }

serde_arrow/src/internal/serialization/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ pub mod dictionary_utf8_builder;
1010
pub mod duration_builder;
1111
pub mod fixed_size_binary_builder;
1212
pub mod fixed_size_list_builder;
13+
pub mod flattened_union_builder;
1314
pub mod float_builder;
1415
pub mod int_builder;
1516
pub mod list_builder;

serde_arrow/src/internal/serialization/outer_sequence_builder.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::internal::{
1010
binary_builder::BinaryBuilder, duration_builder::DurationBuilder,
1111
fixed_size_binary_builder::FixedSizeBinaryBuilder,
1212
fixed_size_list_builder::FixedSizeListBuilder,
13+
flattened_union_builder::FlattenedUnionBuilder,
1314
},
1415
utils::{btree_map, meta_from_field, ChildName, Mut},
1516
};
@@ -226,7 +227,34 @@ fn build_builder(path: String, field: &Field) -> Result<ArrayBuilder> {
226227
.ctx(&ctx)?,
227228
)
228229
}
229-
T::Struct(children) => A::Struct(build_struct(path, children, field.nullable)?),
230+
T::Struct(children) => {
231+
if let Some(Strategy::EnumsWithNamedFieldsAsStructs) =
232+
get_strategy_from_metadata(&field.metadata)?
233+
{
234+
let mut related_fields: HashMap<&str, Vec<Field>> = HashMap::new();
235+
let mut builders: Vec<ArrayBuilder> = Vec::new();
236+
237+
for field in children {
238+
let Some(variant_name) = field.enum_variant_name() else {
239+
// TODO: warning? fail! ?
240+
continue;
241+
};
242+
related_fields
243+
.entry(variant_name)
244+
.or_default()
245+
.push(field.clone());
246+
}
247+
248+
for (variant_name, fields) in related_fields {
249+
let sub_struct_name = format!("{}.{}", path, variant_name);
250+
builders.push(build_struct(sub_struct_name, fields.as_slice(), true)?.take());
251+
}
252+
253+
A::FlattenedUnion(FlattenedUnionBuilder::new(path, builders))
254+
} else {
255+
A::Struct(build_struct(path, children, field.nullable)?)
256+
}
257+
}
230258
T::Dictionary(key, value, _) => {
231259
let key_path = format!("{path}.key");
232260
let key_field = Field {

0 commit comments

Comments
 (0)