Skip to content

Commit 49d3cfd

Browse files
committed
wip: flattened union builder first working version
1 parent 5a815e3 commit 49d3cfd

4 files changed

Lines changed: 92 additions & 96 deletions

File tree

serde_arrow/src/internal/arrow/data_type.rs

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,49 @@ pub struct Field {
1212
pub metadata: HashMap<String, String>,
1313
}
1414

15+
impl PartialOrd for Field {
16+
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
17+
self.name.partial_cmp(&other.name)
18+
}
19+
}
20+
1521
impl Field {
16-
pub fn from_flattened_enum(&self) -> bool {
22+
pub fn to_flattened_union_field(mut self, variant_name: &str) -> Self {
23+
self.name = format!("{}::{}", variant_name, self.name);
24+
self.nullable = true;
25+
self
26+
}
27+
28+
fn from_flattened_union(&self) -> bool {
1729
self.name.contains("::")
1830
}
1931

20-
pub fn enum_variant_name(&self) -> Option<&str> {
21-
if self.from_flattened_enum() {
32+
pub fn union_variant_name(&self) -> Option<&str> {
33+
if self.from_flattened_union() {
2234
self.name.split("::").next()
2335
} else {
2436
None
2537
}
2638
}
39+
40+
pub fn union_field_name(&self) -> Option<String> {
41+
if self.from_flattened_union() {
42+
Some(
43+
self.name
44+
.split("::")
45+
.skip(1)
46+
.fold(String::new(), |acc: String, e| {
47+
if acc.is_empty() {
48+
String::from(e)
49+
} else {
50+
format!("{acc}::{e}")
51+
}
52+
}),
53+
)
54+
} else {
55+
None
56+
}
57+
}
2758
}
2859

2960
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]

serde_arrow/src/internal/schema/tracer.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,10 +1086,8 @@ impl UnionTracer {
10861086
for variant in &self.variants {
10871087
if let Some(variant) = variant {
10881088
let schema = variant.tracer.to_schema()?;
1089-
for mut field in schema.fields {
1090-
field.name = format!("{}::{}", variant.name, field.name);
1091-
field.nullable = true;
1092-
fields.push(field)
1089+
for field in schema.fields {
1090+
fields.push(field.to_flattened_union_field(variant.name.as_str()))
10931091
}
10941092
} else {
10951093
fields.push(unknown_variant_field())

serde_arrow/src/internal/serialization/flattened_union_builder.rs

Lines changed: 30 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,77 @@
11
use std::collections::BTreeMap;
22

33
use crate::internal::{
4-
arrow::{Array, StructArray},
4+
arrow::{Array, FieldMeta, StructArray},
55
error::{fail, set_default, try_, Context, ContextSupport, Result},
6-
utils::array_ext::{ArrayExt, CountArray, SeqArrayExt},
76
};
87

98
use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer};
109

1110
#[derive(Debug, Clone)]
1211
pub struct FlattenedUnionBuilder {
1312
pub path: String,
14-
pub fields: Vec<ArrayBuilder>,
15-
pub seq: CountArray,
13+
pub fields: Vec<(ArrayBuilder, FieldMeta)>,
1614
}
1715

1816
impl FlattenedUnionBuilder {
19-
pub fn new(path: String, fields: Vec<ArrayBuilder>) -> Self {
20-
Self {
21-
path,
22-
fields,
23-
seq: CountArray::new(true),
24-
}
25-
}
26-
27-
pub fn take_self(&mut self) -> Self {
28-
Self {
29-
path: self.path.clone(),
30-
fields: self.fields.clone(),
31-
seq: self.seq.take(),
32-
}
17+
pub fn new(path: String, fields: Vec<(ArrayBuilder, FieldMeta)>) -> Self {
18+
Self { path, fields }
3319
}
3420

3521
pub fn take(&mut self) -> ArrayBuilder {
36-
ArrayBuilder::FlattenedUnion(self.take_self())
22+
ArrayBuilder::FlattenedUnion(Self {
23+
path: self.path.clone(),
24+
fields: self
25+
.fields
26+
.iter_mut()
27+
.map(|(field, meta)| (field.take(), meta.clone()))
28+
.collect(),
29+
})
3730
}
3831

3932
pub fn is_nullable(&self) -> bool {
40-
self.seq.validity.is_some()
33+
false
4134
}
4235

4336
pub fn into_array(self) -> Result<Array> {
4437
let mut fields = Vec::new();
38+
let mut num_elements = 0;
4539

46-
for builder in self.fields.into_iter() {
40+
for (builder, meta) in self.fields.into_iter() {
4741
let ArrayBuilder::Struct(builder) = builder else {
48-
fail!("enum variant not built as a struct")
42+
fail!("enum variant not built as a struct") // TODO: better failure message
4943
};
5044

51-
for (sub_builder, sub_meta) in builder.fields.into_iter() {
45+
for (sub_builder, mut sub_meta) in builder.fields.into_iter() {
46+
num_elements += 1;
47+
// TODO: this mirrors the field name structure in the tracer but represents
48+
// implementation details crossing boundaries. Is there another way?
49+
// Currently necessary to allow struct field lookup to work correctly.
50+
sub_meta.name = format!("{}::{}", meta.name, sub_meta.name);
5251
fields.push((sub_builder.into_array()?, sub_meta));
5352
}
5453
}
5554

5655
Ok(Array::Struct(StructArray {
57-
len: fields.len(),
58-
validity: self.seq.validity,
56+
len: num_elements,
57+
validity: None, // TODO: is this ok?
5958
fields,
6059
}))
6160
}
6261
}
6362

6463
impl FlattenedUnionBuilder {
6564
pub fn serialize_variant(&mut self, variant_index: u32) -> Result<&mut ArrayBuilder> {
66-
// self.len += 1;
67-
6865
let variant_index = variant_index as usize;
6966

70-
// call push_none for any variant that was not selected
71-
for (idx, builder) in self.fields.iter_mut().enumerate() {
67+
// don't serialize any variant not selected
68+
for (idx, (builder, _meta)) in self.fields.iter_mut().enumerate() {
7269
if idx != variant_index {
7370
builder.serialize_none()?;
74-
self.seq.push_seq_none()?;
7571
}
7672
}
7773

78-
let Some(variant_builder) = self.fields.get_mut(variant_index) else {
74+
let Some((variant_builder, _variant_meta)) = self.fields.get_mut(variant_index) else {
7975
fail!("Could not find variant {variant_index} in Union");
8076
};
8177

@@ -86,40 +82,11 @@ impl FlattenedUnionBuilder {
8682
impl Context for FlattenedUnionBuilder {
8783
fn annotate(&self, annotations: &mut BTreeMap<String, String>) {
8884
set_default(annotations, "field", &self.path);
89-
set_default(annotations, "data_type", "Union(..)");
85+
set_default(annotations, "data_type", "Struct(..)");
9086
}
9187
}
9288

9389
impl SimpleSerializer for FlattenedUnionBuilder {
94-
// fn serialize_unit_variant(
95-
// &mut self,
96-
// _: &'static str,
97-
// variant_index: u32,
98-
// _: &'static str,
99-
// ) -> Result<()> {
100-
// let mut ctx = BTreeMap::new();
101-
// self.annotate(&mut ctx);
102-
103-
// try_(|| self.serialize_variant(variant_index)?.serialize_unit()).ctx(&ctx)
104-
// }
105-
106-
// fn serialize_newtype_variant<V: serde::Serialize + ?Sized>(
107-
// &mut self,
108-
// _: &'static str,
109-
// variant_index: u32,
110-
// _: &'static str,
111-
// value: &V,
112-
// ) -> Result<()> {
113-
// let mut ctx = BTreeMap::new();
114-
// self.annotate(&mut ctx);
115-
116-
// try_(|| {
117-
// let variant_builder = self.serialize_variant(variant_index)?;
118-
// value.serialize(Mut(variant_builder))
119-
// })
120-
// .ctx(&ctx)
121-
// }
122-
12390
fn serialize_struct_variant_start<'this>(
12491
&'this mut self,
12592
_: &'static str,
@@ -129,8 +96,6 @@ impl SimpleSerializer for FlattenedUnionBuilder {
12996
) -> Result<&'this mut ArrayBuilder> {
13097
let mut ctx = BTreeMap::new();
13198
self.annotate(&mut ctx);
132-
self.seq.start_seq()?;
133-
self.seq.push_seq_elements(1)?;
13499

135100
try_(|| {
136101
let variant_builder = self.serialize_variant(variant_index)?;
@@ -139,26 +104,10 @@ impl SimpleSerializer for FlattenedUnionBuilder {
139104
})
140105
.ctx(&ctx)
141106
}
142-
143-
// fn serialize_tuple_variant_start<'this>(
144-
// &'this mut self,
145-
// _: &'static str,
146-
// variant_index: u32,
147-
// variant: &'static str,
148-
// len: usize,
149-
// ) -> Result<&'this mut ArrayBuilder> {
150-
// let mut ctx = BTreeMap::new();
151-
// self.annotate(&mut ctx);
152-
153-
// try_(|| {
154-
// let variant_builder = self.serialize_variant(variant_index)?;
155-
// variant_builder.serialize_tuple_struct_start(variant, len)?;
156-
// Ok(variant_builder)
157-
// })
158-
// .ctx(&ctx)
159-
// }
160107
}
161108

109+
// TODO: add tests
110+
162111
// #[cfg(test)]
163112
// mod tests {
164113
// fn test_serialize_union() {

serde_arrow/src/internal/serialization/outer_sequence_builder.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::collections::{BTreeMap, HashMap};
33
use serde::Serialize;
44

55
use crate::internal::{
6-
arrow::{DataType, Field, TimeUnit},
6+
arrow::{DataType, Field, FieldMeta, TimeUnit},
77
error::{fail, Context, ContextSupport, Result},
88
schema::{get_strategy_from_metadata, SerdeArrowSchema, Strategy},
99
serialization::{
@@ -231,23 +231,41 @@ fn build_builder(path: String, field: &Field) -> Result<ArrayBuilder> {
231231
if let Some(Strategy::EnumsWithNamedFieldsAsStructs) =
232232
get_strategy_from_metadata(&field.metadata)?
233233
{
234-
let mut related_fields: HashMap<&str, Vec<Field>> = HashMap::new();
235-
let mut builders: Vec<ArrayBuilder> = Vec::new();
234+
let mut related_fields: BTreeMap<&str, Vec<Field>> = BTreeMap::new();
235+
let mut builders: Vec<(ArrayBuilder, FieldMeta)> = Vec::new();
236236

237237
for field in children {
238-
let Some(variant_name) = field.enum_variant_name() else {
239-
// TODO: warning? fail! ?
238+
let Some(variant_name) = field.union_variant_name() else {
239+
// TODO: failure message
240240
continue;
241241
};
242+
243+
let Some(field_name) = field.union_field_name() else {
244+
// TODO: failure message
245+
continue;
246+
};
247+
248+
let mut new_field = field.clone();
249+
new_field.name = field_name;
250+
242251
related_fields
243252
.entry(variant_name)
244253
.or_default()
245-
.push(field.clone());
254+
.push(new_field);
246255
}
247256

248257
for (variant_name, fields) in related_fields {
249-
let sub_struct_name = format!("{}.{}", path, variant_name);
250-
builders.push(build_struct(sub_struct_name, fields.as_slice(), true)?.take());
258+
let builder = build_struct(
259+
format!("{}.{}", path.to_owned(), variant_name),
260+
fields.as_slice(),
261+
true,
262+
)?
263+
.take();
264+
265+
let mut meta = meta_from_field(field.clone());
266+
meta.name = variant_name.to_owned();
267+
268+
builders.push((builder, meta));
251269
}
252270

253271
A::FlattenedUnion(FlattenedUnionBuilder::new(path, builders))

0 commit comments

Comments
 (0)