Skip to content

Commit 83a3883

Browse files
committed
feat typed multipart
1 parent 37a7465 commit 83a3883

2 files changed

Lines changed: 315 additions & 16 deletions

File tree

crates/vespera_macro/src/multipart_impl.rs

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use proc_macro2::TokenStream;
2222
use quote::quote;
2323
use syn::{DeriveInput, Fields, Type};
2424

25-
use crate::parser::{extract_field_rename, extract_rename_all, rename_field};
25+
use crate::parser::{extract_default, extract_field_rename, extract_rename_all, rename_field};
2626

2727
/// Collected codegen fragments for each struct field.
2828
struct FieldCodegen<'a> {
@@ -32,11 +32,22 @@ struct FieldCodegen<'a> {
3232
idents: Vec<&'a syn::Ident>,
3333
}
3434

35+
/// How a missing field should be handled.
36+
enum DefaultKind {
37+
/// No default — field is required; emit `MissingField` error.
38+
None,
39+
/// Use `Default::default()` — from `#[serde(default)]` or `#[form_data(default)]`.
40+
Trait,
41+
/// Call a custom function — from `#[serde(default = "path::to::fn")]`.
42+
Function(String),
43+
}
44+
3545
/// Process all named fields into codegen fragments.
3646
fn process_fields<'a>(
3747
fields: impl Iterator<Item = &'a syn::Field>,
3848
rename_all: Option<&str>,
3949
strict: bool,
50+
struct_default: bool,
4051
) -> FieldCodegen<'a> {
4152
let mut cg = FieldCodegen {
4253
declarations: Vec::new(),
@@ -52,7 +63,7 @@ fn process_fields<'a>(
5263
let is_option = is_option_type(ty);
5364
let field_name = resolve_field_name(ident, &field.attrs, rename_all);
5465
let limit_tokens = extract_limit_tokens(&field.attrs);
55-
let has_default = extract_default_flag(&field.attrs);
66+
let default_kind = resolve_default_kind(&field.attrs, struct_default);
5667

5768
// The concrete type for TryFromFieldWithState turbofish. For Option<T>
5869
// and Vec<T> the derive wraps the parsed value, so the trait Self is T.
@@ -110,18 +121,28 @@ fn process_fields<'a>(
110121

111122
// Post-loop: required field checks / defaults
112123
if !is_option && !is_vec {
113-
if has_default {
114-
cg.post_loop.push(quote! {
115-
let #ident: #ty = #ident.unwrap_or_default();
116-
});
117-
} else {
118-
cg.post_loop.push(quote! {
119-
let #ident = #ident.ok_or(
120-
vespera::multipart::TypedMultipartError::MissingField {
121-
field_name: std::string::String::from(#field_name)
122-
}
123-
)?;
124-
});
124+
match &default_kind {
125+
DefaultKind::Trait => {
126+
cg.post_loop.push(quote! {
127+
let #ident: #ty = #ident.unwrap_or_default();
128+
});
129+
}
130+
DefaultKind::Function(fn_path) => {
131+
let path: syn::ExprPath =
132+
syn::parse_str(fn_path).expect("invalid default function path");
133+
cg.post_loop.push(quote! {
134+
let #ident: #ty = #ident.unwrap_or_else(#path);
135+
});
136+
}
137+
DefaultKind::None => {
138+
cg.post_loop.push(quote! {
139+
let #ident = #ident.ok_or(
140+
vespera::multipart::TypedMultipartError::MissingField {
141+
field_name: std::string::String::from(#field_name)
142+
}
143+
)?;
144+
});
145+
}
125146
}
126147
}
127148

@@ -136,6 +157,7 @@ pub fn process_derive(input: &DeriveInput) -> TokenStream {
136157
let struct_name = &input.ident;
137158
let rename_all = extract_rename_all(&input.attrs);
138159
let strict = extract_strict(&input.attrs);
160+
let struct_default = extract_struct_default(&input.attrs);
139161

140162
let fields = match &input.data {
141163
syn::Data::Struct(data) => match &data.fields {
@@ -157,7 +179,7 @@ pub fn process_derive(input: &DeriveInput) -> TokenStream {
157179
}
158180
};
159181

160-
let mut cg = process_fields(fields.iter(), rename_all.as_deref(), strict);
182+
let mut cg = process_fields(fields.iter(), rename_all.as_deref(), strict, struct_default);
161183

162184
if strict {
163185
cg.assignments.push(quote! {
@@ -317,8 +339,35 @@ fn extract_limit_tokens(attrs: &[syn::Attribute]) -> TokenStream {
317339
quote! { std::option::Option::None }
318340
}
319341

342+
/// Resolve the default behavior for a field.
343+
///
344+
/// Priority:
345+
/// 1. `#[form_data(default)]` — explicit form_data override (bare default)
346+
/// 2. `#[serde(default)]` — bare default via `Default::default()`
347+
/// 3. `#[serde(default = "fn_path")]` — custom default function
348+
/// 4. Struct-level `#[serde(default)]` — all fields get `Default::default()`
349+
/// 5. No default — field is required
350+
fn resolve_default_kind(attrs: &[syn::Attribute], struct_default: bool) -> DefaultKind {
351+
// 1. Check #[form_data(default)]
352+
if extract_form_data_default(attrs) {
353+
return DefaultKind::Trait;
354+
}
355+
356+
// 2-3. Check #[serde(default)] or #[serde(default = "fn")]
357+
if let Some(serde_default) = extract_default(attrs) {
358+
return serde_default.map_or(DefaultKind::Trait, DefaultKind::Function);
359+
}
360+
361+
// 4. Struct-level #[serde(default)]
362+
if struct_default {
363+
return DefaultKind::Trait;
364+
}
365+
366+
DefaultKind::None
367+
}
368+
320369
/// Extract `default` flag from `#[form_data(default)]`.
321-
fn extract_default_flag(attrs: &[syn::Attribute]) -> bool {
370+
fn extract_form_data_default(attrs: &[syn::Attribute]) -> bool {
322371
for attr in attrs {
323372
if attr.path().is_ident("form_data") {
324373
let mut has_default = false;
@@ -336,6 +385,13 @@ fn extract_default_flag(attrs: &[syn::Attribute]) -> bool {
336385
false
337386
}
338387

388+
/// Check if the struct has `#[serde(default)]` at the struct level.
389+
fn extract_struct_default(attrs: &[syn::Attribute]) -> bool {
390+
// Reuse extract_default — if it returns Some(None), it's bare #[serde(default)]
391+
// For struct-level, we only support bare default (no custom function)
392+
extract_default(attrs).is_some()
393+
}
394+
339395
// ─── Type Utilities ─────────────────────────────────────────────────────────
340396

341397
/// Extract the first generic type argument from a type like `Option<T>` or `Vec<T>`.

examples/axum-example/tests/integration_test.rs

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,3 +1140,246 @@ async fn test_form_data_limit_unlimited_keyword() {
11401140
let response = server.post("/limit-test").multipart(form).await;
11411141
response.assert_status_ok();
11421142
}
1143+
1144+
// ============== #[serde(rename)] and #[serde(default)] tests ==============
1145+
//
1146+
// These tests verify that `#[derive(Multipart)]` correctly handles serde
1147+
// attributes for field renaming and default values.
1148+
1149+
fn default_greeting() -> String {
1150+
"hello".to_string()
1151+
}
1152+
1153+
/// Test struct with serde rename and default attributes.
1154+
#[derive(Debug, Multipart)]
1155+
#[serde(rename_all = "camelCase")]
1156+
#[allow(dead_code)]
1157+
struct SerdeAttrTestRequest {
1158+
/// Uses camelCase rename from struct-level rename_all.
1159+
pub user_name: String,
1160+
/// Explicit field rename overrides rename_all.
1161+
#[serde(rename = "customTag")]
1162+
pub tag_value: String,
1163+
/// `#[serde(default)]` uses `Default::default()` when missing.
1164+
#[serde(default)]
1165+
pub score: i32,
1166+
/// `#[serde(default = "fn")]` calls custom function when missing.
1167+
#[serde(default = "default_greeting")]
1168+
pub greeting: String,
1169+
}
1170+
1171+
async fn serde_attr_handler(
1172+
TypedMultipart(req): TypedMultipart<SerdeAttrTestRequest>,
1173+
) -> axum::Json<serde_json::Value> {
1174+
axum::Json(serde_json::json!({
1175+
"userName": req.user_name,
1176+
"tagValue": req.tag_value,
1177+
"score": req.score,
1178+
"greeting": req.greeting,
1179+
}))
1180+
}
1181+
1182+
/// Test struct with struct-level `#[serde(default)]`.
1183+
#[derive(Debug, Multipart)]
1184+
#[serde(default)]
1185+
#[allow(dead_code)]
1186+
struct StructDefaultTestRequest {
1187+
pub name: String,
1188+
pub count: i32,
1189+
pub active: bool,
1190+
}
1191+
1192+
async fn struct_default_handler(
1193+
TypedMultipart(req): TypedMultipart<StructDefaultTestRequest>,
1194+
) -> axum::Json<serde_json::Value> {
1195+
axum::Json(serde_json::json!({
1196+
"name": req.name,
1197+
"count": req.count,
1198+
"active": req.active,
1199+
}))
1200+
}
1201+
1202+
fn create_serde_test_app() -> axum::Router {
1203+
axum::Router::new()
1204+
.route("/serde-test", axum::routing::post(serde_attr_handler))
1205+
.route(
1206+
"/struct-default-test",
1207+
axum::routing::post(struct_default_handler),
1208+
)
1209+
}
1210+
1211+
// ─── serde(rename_all) tests ────────────────────────────────────────────────
1212+
1213+
#[tokio::test]
1214+
async fn test_serde_rename_all_camel_case() {
1215+
let server = TestServer::new(create_serde_test_app());
1216+
1217+
// Field "user_name" is renamed to "userName" by rename_all = "camelCase"
1218+
let form = MultipartForm::new()
1219+
.add_text("userName", "Alice")
1220+
.add_text("customTag", "rust");
1221+
1222+
let response = server.post("/serde-test").multipart(form).await;
1223+
response.assert_status_ok();
1224+
1225+
let result: serde_json::Value = response.json();
1226+
assert_eq!(result["userName"], "Alice");
1227+
assert_eq!(result["tagValue"], "rust");
1228+
}
1229+
1230+
#[tokio::test]
1231+
async fn test_serde_rename_all_rust_name_rejected() {
1232+
let server = TestServer::new(create_serde_test_app());
1233+
1234+
// Using Rust field name "user_name" instead of "userName" should fail
1235+
let form = MultipartForm::new()
1236+
.add_text("user_name", "Alice")
1237+
.add_text("customTag", "rust");
1238+
1239+
let response = server.post("/serde-test").multipart(form).await;
1240+
// "userName" is missing → MissingField error
1241+
response.assert_status(axum::http::StatusCode::BAD_REQUEST);
1242+
}
1243+
1244+
// ─── serde(rename = "...") tests ────────────────────────────────────────────
1245+
1246+
#[tokio::test]
1247+
async fn test_serde_rename_explicit() {
1248+
let server = TestServer::new(create_serde_test_app());
1249+
1250+
// "tag_value" is renamed to "customTag" by #[serde(rename = "customTag")]
1251+
let form = MultipartForm::new()
1252+
.add_text("userName", "Alice")
1253+
.add_text("customTag", "explicit");
1254+
1255+
let response = server.post("/serde-test").multipart(form).await;
1256+
response.assert_status_ok();
1257+
1258+
let result: serde_json::Value = response.json();
1259+
assert_eq!(result["tagValue"], "explicit");
1260+
}
1261+
1262+
#[tokio::test]
1263+
async fn test_serde_rename_camel_case_of_field_rejected() {
1264+
let server = TestServer::new(create_serde_test_app());
1265+
1266+
// "tagValue" (camelCase of Rust name) should NOT work — explicit rename takes priority
1267+
let form = MultipartForm::new()
1268+
.add_text("userName", "Alice")
1269+
.add_text("tagValue", "wrong");
1270+
1271+
let response = server.post("/serde-test").multipart(form).await;
1272+
// "customTag" is missing → MissingField error
1273+
response.assert_status(axum::http::StatusCode::BAD_REQUEST);
1274+
}
1275+
1276+
// ─── serde(default) field-level tests ───────────────────────────────────────
1277+
1278+
#[tokio::test]
1279+
async fn test_serde_default_uses_default_trait() {
1280+
let server = TestServer::new(create_serde_test_app());
1281+
1282+
// Omit "score" (has #[serde(default)]) — should get i32::default() = 0
1283+
let form = MultipartForm::new()
1284+
.add_text("userName", "Alice")
1285+
.add_text("customTag", "test");
1286+
1287+
let response = server.post("/serde-test").multipart(form).await;
1288+
response.assert_status_ok();
1289+
1290+
let result: serde_json::Value = response.json();
1291+
assert_eq!(result["score"], 0, "score should default to 0");
1292+
}
1293+
1294+
#[tokio::test]
1295+
async fn test_serde_default_fn_uses_custom_function() {
1296+
let server = TestServer::new(create_serde_test_app());
1297+
1298+
// Omit "greeting" (has #[serde(default = "default_greeting")])
1299+
// Should get "hello" from the custom function
1300+
let form = MultipartForm::new()
1301+
.add_text("userName", "Alice")
1302+
.add_text("customTag", "test");
1303+
1304+
let response = server.post("/serde-test").multipart(form).await;
1305+
response.assert_status_ok();
1306+
1307+
let result: serde_json::Value = response.json();
1308+
assert_eq!(
1309+
result["greeting"], "hello",
1310+
"greeting should default to 'hello' from default_greeting()"
1311+
);
1312+
}
1313+
1314+
#[tokio::test]
1315+
async fn test_serde_default_overridden_when_provided() {
1316+
let server = TestServer::new(create_serde_test_app());
1317+
1318+
// Provide both default fields — explicit values should win
1319+
let form = MultipartForm::new()
1320+
.add_text("userName", "Alice")
1321+
.add_text("customTag", "test")
1322+
.add_text("score", "42")
1323+
.add_text("greeting", "world");
1324+
1325+
let response = server.post("/serde-test").multipart(form).await;
1326+
response.assert_status_ok();
1327+
1328+
let result: serde_json::Value = response.json();
1329+
assert_eq!(result["score"], 42);
1330+
assert_eq!(result["greeting"], "world");
1331+
}
1332+
1333+
// ─── serde(default) struct-level tests ──────────────────────────────────────
1334+
1335+
#[tokio::test]
1336+
async fn test_struct_level_serde_default_all_omitted() {
1337+
let server = TestServer::new(create_serde_test_app());
1338+
1339+
// No recognized fields — struct has #[serde(default)], all get Default::default().
1340+
// Send an unrecognized field to produce a valid multipart body (non-strict ignores it).
1341+
let form = MultipartForm::new().add_text("_ignored", "");
1342+
1343+
let response = server.post("/struct-default-test").multipart(form).await;
1344+
response.assert_status_ok();
1345+
1346+
let result: serde_json::Value = response.json();
1347+
assert_eq!(result["name"], "", "String::default() is empty string");
1348+
assert_eq!(result["count"], 0, "i32::default() is 0");
1349+
assert_eq!(result["active"], false, "bool::default() is false");
1350+
}
1351+
1352+
#[tokio::test]
1353+
async fn test_struct_level_serde_default_partial() {
1354+
let server = TestServer::new(create_serde_test_app());
1355+
1356+
// Only provide "name" — other fields should get defaults
1357+
let form = MultipartForm::new().add_text("name", "Bob");
1358+
1359+
let response = server.post("/struct-default-test").multipart(form).await;
1360+
response.assert_status_ok();
1361+
1362+
let result: serde_json::Value = response.json();
1363+
assert_eq!(result["name"], "Bob");
1364+
assert_eq!(result["count"], 0);
1365+
assert_eq!(result["active"], false);
1366+
}
1367+
1368+
#[tokio::test]
1369+
async fn test_struct_level_serde_default_all_provided() {
1370+
let server = TestServer::new(create_serde_test_app());
1371+
1372+
// Provide all fields — explicit values should win
1373+
let form = MultipartForm::new()
1374+
.add_text("name", "Charlie")
1375+
.add_text("count", "99")
1376+
.add_text("active", "true");
1377+
1378+
let response = server.post("/struct-default-test").multipart(form).await;
1379+
response.assert_status_ok();
1380+
1381+
let result: serde_json::Value = response.json();
1382+
assert_eq!(result["name"], "Charlie");
1383+
assert_eq!(result["count"], 99);
1384+
assert_eq!(result["active"], true);
1385+
}

0 commit comments

Comments
 (0)