Skip to content

Commit 76f19cb

Browse files
authored
Clean up DType serde seeds (#6244)
And include a test demonstrating using the seeded deserialization Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 6ff3d41 commit 76f19cb

2 files changed

Lines changed: 28 additions & 146 deletions

File tree

vortex-dtype/src/serde/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ pub use serde::*;
1717
#[cfg(test)]
1818
#[cfg(feature = "serde")]
1919
mod test {
20+
use serde::de::DeserializeSeed;
2021
use serde_test::Token;
2122
use serde_test::assert_tokens;
2223

2324
use crate::DType;
2425
use crate::Nullability;
2526
use crate::PType;
27+
use crate::serde::DTypeSerde;
28+
use crate::test::SESSION;
2629

2730
#[test]
2831
fn test_serde_ptype_json() {
@@ -117,7 +120,10 @@ mod test {
117120
"#);
118121

119122
// Deserialize back and verify round-trip
120-
let deserialized: DType = serde_json::from_str(&json).unwrap();
123+
let mut deserializer = serde_json::Deserializer::from_str(&json);
124+
let deserialized: DType = DTypeSerde::<DType>::new(&SESSION)
125+
.deserialize(&mut deserializer)
126+
.unwrap();
121127
assert_eq!(struct_dtype, deserialized);
122128
}
123129
}

vortex-dtype/src/serde/serde.rs

Lines changed: 21 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use serde::de::VariantAccess;
1919
use serde::de::Visitor;
2020
use serde::ser::SerializeStruct;
2121
use serde::ser::SerializeTupleVariant;
22+
use vortex_session::VortexSession;
2223

2324
use crate::DType;
2425
use crate::ExtID;
@@ -28,7 +29,7 @@ use crate::PType;
2829
use crate::StructFields;
2930
use crate::decimal::DecimalDType;
3031
use crate::extension::ExtDTypeRef;
31-
use crate::session::DTypeSession;
32+
use crate::session::DTypeSessionExt;
3233

3334
/// Serialize Nullability as a boolean
3435
impl Serialize for Nullability {
@@ -52,13 +53,13 @@ impl<'de> Deserialize<'de> for Nullability {
5253

5354
/// Seed for deserializing DType references that require session context.
5455
pub struct DTypeSerde<'a, T> {
55-
session: &'a DTypeSession,
56+
session: &'a VortexSession,
5657
_marker: PhantomData<T>,
5758
}
5859

5960
impl<'a, T> DTypeSerde<'a, T> {
6061
/// Create a new DTypeSerde seed.
61-
pub fn new(session: &'a DTypeSession) -> Self {
62+
pub fn new(session: &'a VortexSession) -> Self {
6263
Self {
6364
session,
6465
_marker: PhantomData,
@@ -132,121 +133,6 @@ impl Serialize for StructFields {
132133
}
133134
}
134135

135-
// ============================================================================
136-
// DType Deserialization (without session - errors on Extension types)
137-
// ============================================================================
138-
139-
impl<'de> Deserialize<'de> for DType {
140-
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
141-
where
142-
D: Deserializer<'de>,
143-
{
144-
const VARIANTS: &[&str] = &[
145-
"Null",
146-
"Bool",
147-
"Primitive",
148-
"Decimal",
149-
"Utf8",
150-
"Binary",
151-
"List",
152-
"FixedSizeList",
153-
"Struct",
154-
"Extension",
155-
];
156-
157-
struct DTypeVisitor;
158-
159-
impl<'de> Visitor<'de> for DTypeVisitor {
160-
type Value = DType;
161-
162-
fn expecting(&self, f: &mut Formatter) -> fmt::Result {
163-
f.write_str("enum DType")
164-
}
165-
166-
fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
167-
where
168-
A: EnumAccess<'de>,
169-
{
170-
let (variant, access) = data.variant::<&str>()?;
171-
match variant {
172-
"Null" => {
173-
access.unit_variant()?;
174-
Ok(DType::Null)
175-
}
176-
"Bool" => {
177-
let n = access.newtype_variant()?;
178-
Ok(DType::Bool(n))
179-
}
180-
"Primitive" => {
181-
#[derive(Deserialize)]
182-
struct Fields(PType, Nullability);
183-
let Fields(ptype, n) = access.newtype_variant()?;
184-
Ok(DType::Primitive(ptype, n))
185-
}
186-
"Decimal" => {
187-
#[derive(Deserialize)]
188-
struct Fields(DecimalDType, Nullability);
189-
let Fields(decimal, n) = access.newtype_variant()?;
190-
Ok(DType::Decimal(decimal, n))
191-
}
192-
"Utf8" => {
193-
let n = access.newtype_variant()?;
194-
Ok(DType::Utf8(n))
195-
}
196-
"Binary" => {
197-
let n = access.newtype_variant()?;
198-
Ok(DType::Binary(n))
199-
}
200-
"List" => {
201-
#[derive(Deserialize)]
202-
struct Fields(Box<DType>, Nullability);
203-
let Fields(element, n) = access.newtype_variant()?;
204-
Ok(DType::List(Arc::from(*element), n))
205-
}
206-
"FixedSizeList" => {
207-
#[derive(Deserialize)]
208-
struct Fields(Box<DType>, u32, Nullability);
209-
let Fields(element, size, n) = access.newtype_variant()?;
210-
Ok(DType::FixedSizeList(Arc::from(*element), size, n))
211-
}
212-
"Struct" => {
213-
#[derive(Deserialize)]
214-
struct Fields(StructFieldsDeserialize, Nullability);
215-
let Fields(fields, n) = access.newtype_variant()?;
216-
Ok(DType::Struct(fields.0, n))
217-
}
218-
"Extension" => Err(de::Error::custom(
219-
"Extension types require a session context for deserialization. \
220-
Use DTypeSerde::new(session) with DeserializeSeed instead.",
221-
)),
222-
_ => Err(de::Error::unknown_variant(variant, VARIANTS)),
223-
}
224-
}
225-
}
226-
227-
deserializer.deserialize_enum("DType", VARIANTS, DTypeVisitor)
228-
}
229-
}
230-
231-
// Helper for deserializing StructFields without session
232-
#[derive(Deserialize)]
233-
struct StructFieldsDeserialize(
234-
#[serde(deserialize_with = "deserialize_struct_fields")] StructFields,
235-
);
236-
237-
fn deserialize_struct_fields<'de, D>(deserializer: D) -> Result<StructFields, D::Error>
238-
where
239-
D: Deserializer<'de>,
240-
{
241-
#[derive(Deserialize)]
242-
struct Inner {
243-
names: FieldNames,
244-
dtypes: Vec<DType>,
245-
}
246-
let inner = Inner::deserialize(deserializer)?;
247-
Ok(StructFields::new(inner.names, inner.dtypes))
248-
}
249-
250136
// ============================================================================
251137
// DType Deserialization with session context (DeserializeSeed)
252138
// ============================================================================
@@ -272,7 +158,7 @@ impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, DType> {
272158
];
273159

274160
struct DTypeVisitor<'a> {
275-
session: &'a DTypeSession,
161+
session: &'a VortexSession,
276162
}
277163

278164
impl<'de> Visitor<'de> for DTypeVisitor<'_> {
@@ -350,7 +236,7 @@ impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, DType> {
350236
// ============================================================================
351237

352238
struct ListFieldsSeed<'a> {
353-
session: &'a DTypeSession,
239+
session: &'a VortexSession,
354240
}
355241

356242
impl<'de> DeserializeSeed<'de> for ListFieldsSeed<'_> {
@@ -361,7 +247,7 @@ impl<'de> DeserializeSeed<'de> for ListFieldsSeed<'_> {
361247
D: Deserializer<'de>,
362248
{
363249
struct ListVisitor<'a> {
364-
session: &'a DTypeSession,
250+
session: &'a VortexSession,
365251
}
366252

367253
impl<'de> Visitor<'de> for ListVisitor<'_> {
@@ -395,7 +281,7 @@ impl<'de> DeserializeSeed<'de> for ListFieldsSeed<'_> {
395281
}
396282

397283
struct FixedSizeListFieldsSeed<'a> {
398-
session: &'a DTypeSession,
284+
session: &'a VortexSession,
399285
}
400286

401287
impl<'de> DeserializeSeed<'de> for FixedSizeListFieldsSeed<'_> {
@@ -406,7 +292,7 @@ impl<'de> DeserializeSeed<'de> for FixedSizeListFieldsSeed<'_> {
406292
D: Deserializer<'de>,
407293
{
408294
struct FixedSizeListVisitor<'a> {
409-
session: &'a DTypeSession,
295+
session: &'a VortexSession,
410296
}
411297

412298
impl<'de> Visitor<'de> for FixedSizeListVisitor<'_> {
@@ -447,7 +333,7 @@ impl<'de> DeserializeSeed<'de> for FixedSizeListFieldsSeed<'_> {
447333
}
448334

449335
struct StructFieldsSeed<'a> {
450-
session: &'a DTypeSession,
336+
session: &'a VortexSession,
451337
}
452338

453339
impl<'de> DeserializeSeed<'de> for StructFieldsSeed<'_> {
@@ -458,7 +344,7 @@ impl<'de> DeserializeSeed<'de> for StructFieldsSeed<'_> {
458344
D: Deserializer<'de>,
459345
{
460346
struct StructVisitor<'a> {
461-
session: &'a DTypeSession,
347+
session: &'a VortexSession,
462348
}
463349

464350
impl<'de> Visitor<'de> for StructVisitor<'_> {
@@ -473,9 +359,7 @@ impl<'de> DeserializeSeed<'de> for StructFieldsSeed<'_> {
473359
A: SeqAccess<'de>,
474360
{
475361
let fields = seq
476-
.next_element_seed(StructFieldsDeserializeSeed {
477-
session: self.session,
478-
})?
362+
.next_element_seed(DTypeSerde::<StructFields>::new(self.session))?
479363
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
480364
let nullability = seq
481365
.next_element()?
@@ -493,11 +377,7 @@ impl<'de> DeserializeSeed<'de> for StructFieldsSeed<'_> {
493377
}
494378
}
495379

496-
struct StructFieldsDeserializeSeed<'a> {
497-
session: &'a DTypeSession,
498-
}
499-
500-
impl<'de> DeserializeSeed<'de> for StructFieldsDeserializeSeed<'_> {
380+
impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, StructFields> {
501381
type Value = StructFields;
502382

503383
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
@@ -507,7 +387,7 @@ impl<'de> DeserializeSeed<'de> for StructFieldsDeserializeSeed<'_> {
507387
const FIELDS: &[&str] = &["names", "dtypes"];
508388

509389
struct StructFieldsInnerVisitor<'a> {
510-
session: &'a DTypeSession,
390+
session: &'a VortexSession,
511391
}
512392

513393
impl<'de> Visitor<'de> for StructFieldsInnerVisitor<'_> {
@@ -536,9 +416,9 @@ impl<'de> DeserializeSeed<'de> for StructFieldsDeserializeSeed<'_> {
536416
if dtypes.is_some() {
537417
return Err(de::Error::duplicate_field("dtypes"));
538418
}
539-
dtypes = Some(map.next_value_seed(DTypeVecSeed {
540-
session: self.session,
541-
})?);
419+
dtypes = Some(
420+
map.next_value_seed(DTypeSerde::<Vec<DType>>::new(self.session))?,
421+
);
542422
}
543423
_ => {
544424
let _ = map.next_value::<de::IgnoredAny>()?;
@@ -563,19 +443,15 @@ impl<'de> DeserializeSeed<'de> for StructFieldsDeserializeSeed<'_> {
563443
}
564444
}
565445

566-
struct DTypeVecSeed<'a> {
567-
session: &'a DTypeSession,
568-
}
569-
570-
impl<'de> DeserializeSeed<'de> for DTypeVecSeed<'_> {
446+
impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, Vec<DType>> {
571447
type Value = Vec<DType>;
572448

573449
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
574450
where
575451
D: Deserializer<'de>,
576452
{
577453
struct DTypeVecVisitor<'a> {
578-
session: &'a DTypeSession,
454+
session: &'a VortexSession,
579455
}
580456

581457
impl<'de> Visitor<'de> for DTypeVecVisitor<'_> {
@@ -642,7 +518,7 @@ impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, ExtDTypeRef> {
642518
const FIELDS: &[&str] = &["id", "storage_dtype", "metadata"];
643519

644520
struct ExtDTypeVisitor<'a> {
645-
session: &'a DTypeSession,
521+
session: &'a VortexSession,
646522
}
647523

648524
impl<'de> Visitor<'de> for ExtDTypeVisitor<'_> {
@@ -689,7 +565,7 @@ impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, ExtDTypeRef> {
689565

690566
let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
691567
let id = ExtID::new_arc(id);
692-
let vtable = self.session.registry().find(&id).ok_or_else(|| {
568+
let vtable = self.session.dtypes().registry().find(&id).ok_or_else(|| {
693569
de::Error::custom(format!("unknown extension dtype id: {}", id))
694570
})?;
695571

0 commit comments

Comments
 (0)