Skip to content

Commit a25c0d8

Browse files
committed
Fix ref vs. value type enum encoding
1 parent cfd6f21 commit a25c0d8

File tree

2 files changed

+103
-44
lines changed

2 files changed

+103
-44
lines changed

rust/bufferfish-derive/src/lib.rs

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
extern crate proc_macro;
22

33
use proc_macro_error::{abort, proc_macro_error};
4-
use proc_macro2::TokenStream;
4+
use proc_macro2::{Ident, Literal, Span, TokenStream};
55
use quote::quote;
66
use syn::{
7-
Data, DeriveInput, Expr, Fields, Index, Type, TypePath, parse_macro_input, spanned::Spanned,
7+
Data, DataEnum, DeriveInput, Expr, Fields, Index, Type, TypePath, parse_macro_input,
8+
spanned::Spanned,
89
};
910

1011
fn extract_message_id(ast: &DeriveInput) -> Option<Expr> {
@@ -194,59 +195,56 @@ fn generate_struct_field_encoders(data: &syn::DataStruct) -> Vec<TokenStream> {
194195
encoded_snippets
195196
}
196197

197-
fn generate_enum_variant_encoders(
198-
name: &proc_macro2::Ident,
199-
data_enum: &syn::DataEnum,
200-
) -> TokenStream {
201-
let mut variant_match_arms = Vec::new();
198+
fn generate_enum_variant_encoders(name: &Ident, data_enum: &DataEnum) -> TokenStream {
199+
let mut arms = Vec::new();
202200

203-
for (discriminant_value, variant) in data_enum.variants.iter().enumerate() {
204-
let variant_ident = &variant.ident;
205-
let discriminant_lit = Index::from(discriminant_value);
201+
for (idx, variant) in data_enum.variants.iter().enumerate() {
202+
let v_ident = &variant.ident;
203+
let discrim = Literal::u8_unsuffixed(idx as u8);
206204

207205
match &variant.fields {
208206
Fields::Unit => {
209-
variant_match_arms.push(quote! {
210-
#name::#variant_ident => {
211-
bf.write_u8(#discriminant_lit as u8)?;
207+
arms.push(quote! {
208+
#name::#v_ident => {
209+
bf.write_u8(#discrim)?;
212210
}
213211
});
214212
}
215213
Fields::Unnamed(fields) => {
216-
let field_idents: Vec<_> = (0..fields.unnamed.len())
217-
.map(|i| syn::Ident::new(&format!("f{i}"), fields.unnamed.span()))
214+
let idents: Vec<Ident> = (0..fields.unnamed.len())
215+
.map(|i| Ident::new(&format!("f{i}"), Span::call_site()))
218216
.collect();
219217

220-
let mut field_encoders = Vec::new();
218+
let mut encoders = Vec::new();
221219
for (i, field) in fields.unnamed.iter().enumerate() {
222-
let field_ident = &field_idents[i];
223-
encode_type(quote! { #field_ident }, &field.ty, &mut field_encoders);
220+
let fld = &idents[i];
221+
encode_type(quote! { #fld }, &field.ty, &mut encoders);
224222
}
225223

226-
variant_match_arms.push(quote! {
227-
#name::#variant_ident( #(#field_idents),* ) => {
228-
bf.write_u8(#discriminant_lit as u8)?;
229-
#(#field_encoders)*
224+
arms.push(quote! {
225+
#name::#v_ident( #(#idents),* ) => {
226+
bf.write_u8(#discrim)?;
227+
#(#encoders)*
230228
}
231229
});
232230
}
233231
Fields::Named(fields) => {
234-
let field_idents: Vec<_> = fields
232+
let idents: Vec<Ident> = fields
235233
.named
236234
.iter()
237-
.map(|f| f.ident.as_ref().unwrap().clone())
235+
.map(|f| f.ident.clone().unwrap())
238236
.collect();
239237

240-
let mut field_encoders = Vec::new();
238+
let mut encoders = Vec::new();
241239
for (i, field) in fields.named.iter().enumerate() {
242-
let field_ident = &field_idents[i];
243-
encode_type(quote! { #field_ident }, &field.ty, &mut field_encoders);
240+
let fld = &idents[i];
241+
encode_type(quote! { #fld }, &field.ty, &mut encoders);
244242
}
245243

246-
variant_match_arms.push(quote! {
247-
#name::#variant_ident { #(#field_idents),* } => {
248-
bf.write_u8(#discriminant_lit as u8)?;
249-
#(#field_encoders)*
244+
arms.push(quote! {
245+
#name::#v_ident { #(#idents),* } => {
246+
bf.write_u8(#discrim)?;
247+
#(#encoders)*
250248
}
251249
});
252250
}
@@ -255,7 +253,7 @@ fn generate_enum_variant_encoders(
255253

256254
quote! {
257255
match self {
258-
#(#variant_match_arms)*
256+
#(#arms),*
259257
}
260258
}
261259
}
@@ -537,19 +535,18 @@ fn generate_enum_variant_max_size_calc(variant: &syn::Variant) -> TokenStream {
537535
}
538536

539537
fn encode_type(accessor: TokenStream, field_type: &Type, dst: &mut Vec<TokenStream>) {
540-
let target_accessor = if let Type::Reference(_) = field_type {
541-
accessor
542-
} else {
543-
quote! { &(#accessor) }
544-
};
545-
546538
let effective_type = if let Type::Reference(type_ref) = field_type {
547539
&*type_ref.elem
548540
} else {
549541
field_type
550542
};
551543

552544
match effective_type {
545+
Type::Path(TypePath { path, .. }) if path.is_ident("String") => {
546+
dst.push(quote! {
547+
(#accessor).encode(bf)?;
548+
});
549+
}
553550
Type::Path(TypePath { path, .. })
554551
if path.is_ident("u8")
555552
|| path.is_ident("u16")
@@ -559,28 +556,29 @@ fn encode_type(accessor: TokenStream, field_type: &Type, dst: &mut Vec<TokenStre
559556
|| path.is_ident("i16")
560557
|| path.is_ident("i32")
561558
|| path.is_ident("i64")
562-
|| path.is_ident("bool")
563-
|| path.is_ident("String") =>
559+
|| path.is_ident("bool") =>
564560
{
565561
dst.push(quote! {
566-
bufferfish::Encodable::encode(#target_accessor, bf)?;
562+
(#accessor).encode(bf)?;
567563
});
568564
}
569565
Type::Path(TypePath { path, .. })
570566
if path.segments.len() == 1 && path.segments[0].ident == "Vec" =>
571567
{
572568
dst.push(quote! {
573-
bf.write_array(#target_accessor)?;
569+
(#accessor).encode(bf)?;
574570
});
575571
}
576572
Type::Array(_type_array) => {
577573
dst.push(quote! {
578-
bufferfish::Encodable::encode(#target_accessor, bf)?;
574+
(#accessor).encode(bf)?;
579575
});
580576
}
581577
Type::Path(TypePath { .. }) => {
578+
// Catch-all for other user-defined types (structs/enums)
579+
// These are assumed to implement Encodable.
582580
dst.push(quote! {
583-
bufferfish::Encodable::encode(#target_accessor, bf)?;
581+
(#accessor).encode(bf)?;
584582
});
585583
}
586584
_ => abort!(

rust/bufferfish/src/lib.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,4 +1023,65 @@ mod tests {
10231023
assert_eq!(bf.len(), 0);
10241024
assert_eq!(bf.as_ref(), &[]);
10251025
}
1026+
1027+
#[test]
1028+
fn test_encode_decode_complex_enums() {
1029+
use bufferfish_core as bufferfish;
1030+
use bufferfish_core::{Decodable, Encodable};
1031+
use bufferfish_derive::{Decode, Encode};
1032+
1033+
#[derive(Encode, Decode)]
1034+
enum Object {
1035+
Variant1 { a: u32, b: String },
1036+
Variant2 { c: i32, d: bool },
1037+
}
1038+
1039+
#[derive(Encode, Decode)]
1040+
enum Complex {
1041+
Object(Object),
1042+
Stringly(String),
1043+
Simple(u8),
1044+
Other,
1045+
}
1046+
1047+
#[derive(Encode, Decode)]
1048+
#[bufferfish(0_u16)]
1049+
enum ObjectId {
1050+
Test,
1051+
}
1052+
1053+
impl From<ObjectId> for u16 {
1054+
fn from(value: ObjectId) -> Self {
1055+
match value {
1056+
ObjectId::Test => 0,
1057+
}
1058+
}
1059+
}
1060+
1061+
#[derive(Encode, Decode)]
1062+
#[bufferfish(ObjectId::Test)]
1063+
pub struct ObjectContainer {
1064+
complex: Complex,
1065+
}
1066+
1067+
let mut bf = Bufferfish::new();
1068+
let complex = ObjectContainer {
1069+
complex: Complex::Object(Object::Variant1 {
1070+
a: 42,
1071+
b: "Hello".to_string(),
1072+
}),
1073+
};
1074+
1075+
complex.encode(&mut bf).unwrap();
1076+
1077+
let decoded = ObjectContainer::decode(&mut bf).unwrap();
1078+
1079+
match decoded.complex {
1080+
Complex::Object(Object::Variant1 { a, b }) => {
1081+
assert_eq!(a, 42);
1082+
assert_eq!(b, "Hello");
1083+
}
1084+
_ => panic!("Decoded complex type did not match expected variant"),
1085+
}
1086+
}
10261087
}

0 commit comments

Comments
 (0)