@@ -19,6 +19,7 @@ use serde::de::VariantAccess;
1919use serde:: de:: Visitor ;
2020use serde:: ser:: SerializeStruct ;
2121use serde:: ser:: SerializeTupleVariant ;
22+ use vortex_session:: VortexSession ;
2223
2324use crate :: DType ;
2425use crate :: ExtID ;
@@ -28,7 +29,7 @@ use crate::PType;
2829use crate :: StructFields ;
2930use crate :: decimal:: DecimalDType ;
3031use crate :: extension:: ExtDTypeRef ;
31- use crate :: session:: DTypeSession ;
32+ use crate :: session:: DTypeSessionExt ;
3233
3334/// Serialize Nullability as a boolean
3435impl Serialize for Nullability {
@@ -52,13 +53,13 @@ impl<'de> Deserialize<'de> for Nullability {
5253
5354/// Seed for deserializing DType references that require session context.
5455pub struct DTypeSerde < ' a , T > {
55- session : & ' a DTypeSession ,
56+ session : & ' a VortexSession ,
5657 _marker : PhantomData < T > ,
5758}
5859
5960impl < ' a , T > DTypeSerde < ' a , T > {
6061 /// Create a new DTypeSerde seed.
61- pub fn new ( session : & ' a DTypeSession ) -> Self {
62+ pub fn new ( session : & ' a VortexSession ) -> Self {
6263 Self {
6364 session,
6465 _marker : PhantomData ,
@@ -132,121 +133,6 @@ impl Serialize for StructFields {
132133 }
133134}
134135
135- // ============================================================================
136- // DType Deserialization (without session - errors on Extension types)
137- // ============================================================================
138-
139- impl < ' de > Deserialize < ' de > for DType {
140- fn deserialize < D > ( deserializer : D ) -> Result < Self , D :: Error >
141- where
142- D : Deserializer < ' de > ,
143- {
144- const VARIANTS : & [ & str ] = & [
145- "Null" ,
146- "Bool" ,
147- "Primitive" ,
148- "Decimal" ,
149- "Utf8" ,
150- "Binary" ,
151- "List" ,
152- "FixedSizeList" ,
153- "Struct" ,
154- "Extension" ,
155- ] ;
156-
157- struct DTypeVisitor ;
158-
159- impl < ' de > Visitor < ' de > for DTypeVisitor {
160- type Value = DType ;
161-
162- fn expecting ( & self , f : & mut Formatter ) -> fmt:: Result {
163- f. write_str ( "enum DType" )
164- }
165-
166- fn visit_enum < A > ( self , data : A ) -> Result < Self :: Value , A :: Error >
167- where
168- A : EnumAccess < ' de > ,
169- {
170- let ( variant, access) = data. variant :: < & str > ( ) ?;
171- match variant {
172- "Null" => {
173- access. unit_variant ( ) ?;
174- Ok ( DType :: Null )
175- }
176- "Bool" => {
177- let n = access. newtype_variant ( ) ?;
178- Ok ( DType :: Bool ( n) )
179- }
180- "Primitive" => {
181- #[ derive( Deserialize ) ]
182- struct Fields ( PType , Nullability ) ;
183- let Fields ( ptype, n) = access. newtype_variant ( ) ?;
184- Ok ( DType :: Primitive ( ptype, n) )
185- }
186- "Decimal" => {
187- #[ derive( Deserialize ) ]
188- struct Fields ( DecimalDType , Nullability ) ;
189- let Fields ( decimal, n) = access. newtype_variant ( ) ?;
190- Ok ( DType :: Decimal ( decimal, n) )
191- }
192- "Utf8" => {
193- let n = access. newtype_variant ( ) ?;
194- Ok ( DType :: Utf8 ( n) )
195- }
196- "Binary" => {
197- let n = access. newtype_variant ( ) ?;
198- Ok ( DType :: Binary ( n) )
199- }
200- "List" => {
201- #[ derive( Deserialize ) ]
202- struct Fields ( Box < DType > , Nullability ) ;
203- let Fields ( element, n) = access. newtype_variant ( ) ?;
204- Ok ( DType :: List ( Arc :: from ( * element) , n) )
205- }
206- "FixedSizeList" => {
207- #[ derive( Deserialize ) ]
208- struct Fields ( Box < DType > , u32 , Nullability ) ;
209- let Fields ( element, size, n) = access. newtype_variant ( ) ?;
210- Ok ( DType :: FixedSizeList ( Arc :: from ( * element) , size, n) )
211- }
212- "Struct" => {
213- #[ derive( Deserialize ) ]
214- struct Fields ( StructFieldsDeserialize , Nullability ) ;
215- let Fields ( fields, n) = access. newtype_variant ( ) ?;
216- Ok ( DType :: Struct ( fields. 0 , n) )
217- }
218- "Extension" => Err ( de:: Error :: custom (
219- "Extension types require a session context for deserialization. \
220- Use DTypeSerde::new(session) with DeserializeSeed instead.",
221- ) ) ,
222- _ => Err ( de:: Error :: unknown_variant ( variant, VARIANTS ) ) ,
223- }
224- }
225- }
226-
227- deserializer. deserialize_enum ( "DType" , VARIANTS , DTypeVisitor )
228- }
229- }
230-
231- // Helper for deserializing StructFields without session
232- #[ derive( Deserialize ) ]
233- struct StructFieldsDeserialize (
234- #[ serde( deserialize_with = "deserialize_struct_fields" ) ] StructFields ,
235- ) ;
236-
237- fn deserialize_struct_fields < ' de , D > ( deserializer : D ) -> Result < StructFields , D :: Error >
238- where
239- D : Deserializer < ' de > ,
240- {
241- #[ derive( Deserialize ) ]
242- struct Inner {
243- names : FieldNames ,
244- dtypes : Vec < DType > ,
245- }
246- let inner = Inner :: deserialize ( deserializer) ?;
247- Ok ( StructFields :: new ( inner. names , inner. dtypes ) )
248- }
249-
250136// ============================================================================
251137// DType Deserialization with session context (DeserializeSeed)
252138// ============================================================================
@@ -272,7 +158,7 @@ impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, DType> {
272158 ] ;
273159
274160 struct DTypeVisitor < ' a > {
275- session : & ' a DTypeSession ,
161+ session : & ' a VortexSession ,
276162 }
277163
278164 impl < ' de > Visitor < ' de > for DTypeVisitor < ' _ > {
@@ -350,7 +236,7 @@ impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, DType> {
350236// ============================================================================
351237
352238struct ListFieldsSeed < ' a > {
353- session : & ' a DTypeSession ,
239+ session : & ' a VortexSession ,
354240}
355241
356242impl < ' de > DeserializeSeed < ' de > for ListFieldsSeed < ' _ > {
@@ -361,7 +247,7 @@ impl<'de> DeserializeSeed<'de> for ListFieldsSeed<'_> {
361247 D : Deserializer < ' de > ,
362248 {
363249 struct ListVisitor < ' a > {
364- session : & ' a DTypeSession ,
250+ session : & ' a VortexSession ,
365251 }
366252
367253 impl < ' de > Visitor < ' de > for ListVisitor < ' _ > {
@@ -395,7 +281,7 @@ impl<'de> DeserializeSeed<'de> for ListFieldsSeed<'_> {
395281}
396282
397283struct FixedSizeListFieldsSeed < ' a > {
398- session : & ' a DTypeSession ,
284+ session : & ' a VortexSession ,
399285}
400286
401287impl < ' de > DeserializeSeed < ' de > for FixedSizeListFieldsSeed < ' _ > {
@@ -406,7 +292,7 @@ impl<'de> DeserializeSeed<'de> for FixedSizeListFieldsSeed<'_> {
406292 D : Deserializer < ' de > ,
407293 {
408294 struct FixedSizeListVisitor < ' a > {
409- session : & ' a DTypeSession ,
295+ session : & ' a VortexSession ,
410296 }
411297
412298 impl < ' de > Visitor < ' de > for FixedSizeListVisitor < ' _ > {
@@ -447,7 +333,7 @@ impl<'de> DeserializeSeed<'de> for FixedSizeListFieldsSeed<'_> {
447333}
448334
449335struct StructFieldsSeed < ' a > {
450- session : & ' a DTypeSession ,
336+ session : & ' a VortexSession ,
451337}
452338
453339impl < ' de > DeserializeSeed < ' de > for StructFieldsSeed < ' _ > {
@@ -458,7 +344,7 @@ impl<'de> DeserializeSeed<'de> for StructFieldsSeed<'_> {
458344 D : Deserializer < ' de > ,
459345 {
460346 struct StructVisitor < ' a > {
461- session : & ' a DTypeSession ,
347+ session : & ' a VortexSession ,
462348 }
463349
464350 impl < ' de > Visitor < ' de > for StructVisitor < ' _ > {
@@ -473,9 +359,7 @@ impl<'de> DeserializeSeed<'de> for StructFieldsSeed<'_> {
473359 A : SeqAccess < ' de > ,
474360 {
475361 let fields = seq
476- . next_element_seed ( StructFieldsDeserializeSeed {
477- session : self . session ,
478- } ) ?
362+ . next_element_seed ( DTypeSerde :: < StructFields > :: new ( self . session ) ) ?
479363 . ok_or_else ( || de:: Error :: invalid_length ( 0 , & self ) ) ?;
480364 let nullability = seq
481365 . next_element ( ) ?
@@ -493,11 +377,7 @@ impl<'de> DeserializeSeed<'de> for StructFieldsSeed<'_> {
493377 }
494378}
495379
496- struct StructFieldsDeserializeSeed < ' a > {
497- session : & ' a DTypeSession ,
498- }
499-
500- impl < ' de > DeserializeSeed < ' de > for StructFieldsDeserializeSeed < ' _ > {
380+ impl < ' de > DeserializeSeed < ' de > for DTypeSerde < ' _ , StructFields > {
501381 type Value = StructFields ;
502382
503383 fn deserialize < D > ( self , deserializer : D ) -> Result < Self :: Value , D :: Error >
@@ -507,7 +387,7 @@ impl<'de> DeserializeSeed<'de> for StructFieldsDeserializeSeed<'_> {
507387 const FIELDS : & [ & str ] = & [ "names" , "dtypes" ] ;
508388
509389 struct StructFieldsInnerVisitor < ' a > {
510- session : & ' a DTypeSession ,
390+ session : & ' a VortexSession ,
511391 }
512392
513393 impl < ' de > Visitor < ' de > for StructFieldsInnerVisitor < ' _ > {
@@ -536,9 +416,9 @@ impl<'de> DeserializeSeed<'de> for StructFieldsDeserializeSeed<'_> {
536416 if dtypes. is_some ( ) {
537417 return Err ( de:: Error :: duplicate_field ( "dtypes" ) ) ;
538418 }
539- dtypes = Some ( map . next_value_seed ( DTypeVecSeed {
540- session : self . session ,
541- } ) ? ) ;
419+ dtypes = Some (
420+ map . next_value_seed ( DTypeSerde :: < Vec < DType > > :: new ( self . session ) ) ? ,
421+ ) ;
542422 }
543423 _ => {
544424 let _ = map. next_value :: < de:: IgnoredAny > ( ) ?;
@@ -563,19 +443,15 @@ impl<'de> DeserializeSeed<'de> for StructFieldsDeserializeSeed<'_> {
563443 }
564444}
565445
566- struct DTypeVecSeed < ' a > {
567- session : & ' a DTypeSession ,
568- }
569-
570- impl < ' de > DeserializeSeed < ' de > for DTypeVecSeed < ' _ > {
446+ impl < ' de > DeserializeSeed < ' de > for DTypeSerde < ' _ , Vec < DType > > {
571447 type Value = Vec < DType > ;
572448
573449 fn deserialize < D > ( self , deserializer : D ) -> Result < Self :: Value , D :: Error >
574450 where
575451 D : Deserializer < ' de > ,
576452 {
577453 struct DTypeVecVisitor < ' a > {
578- session : & ' a DTypeSession ,
454+ session : & ' a VortexSession ,
579455 }
580456
581457 impl < ' de > Visitor < ' de > for DTypeVecVisitor < ' _ > {
@@ -642,7 +518,7 @@ impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, ExtDTypeRef> {
642518 const FIELDS : & [ & str ] = & [ "id" , "storage_dtype" , "metadata" ] ;
643519
644520 struct ExtDTypeVisitor < ' a > {
645- session : & ' a DTypeSession ,
521+ session : & ' a VortexSession ,
646522 }
647523
648524 impl < ' de > Visitor < ' de > for ExtDTypeVisitor < ' _ > {
@@ -689,7 +565,7 @@ impl<'de> DeserializeSeed<'de> for DTypeSerde<'_, ExtDTypeRef> {
689565
690566 let id = id. ok_or_else ( || de:: Error :: missing_field ( "id" ) ) ?;
691567 let id = ExtID :: new_arc ( id) ;
692- let vtable = self . session . registry ( ) . find ( & id) . ok_or_else ( || {
568+ let vtable = self . session . dtypes ( ) . registry ( ) . find ( & id) . ok_or_else ( || {
693569 de:: Error :: custom ( format ! ( "unknown extension dtype id: {}" , id) )
694570 } ) ?;
695571
0 commit comments