Skip to content

Commit 95f4d69

Browse files
committed
feat: implement typemap support
1 parent d46700d commit 95f4d69

File tree

5 files changed

+146
-10
lines changed

5 files changed

+146
-10
lines changed

crates/sqlx_gen/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "sqlx-gen"
3-
version = "0.4.2"
3+
version = "0.4.3"
44
edition = "2021"
55
description = "Generate Rust structs from database schema introspection"
66
license = "MIT"

crates/sqlx_gen/src/codegen/crud_gen.rs

Lines changed: 110 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,7 @@ pub fn generate_crud_from_parsed(
365365
let update_macro_args: Vec<TokenStream> = non_pk_fields
366366
.iter()
367367
.chain(pk_fields.iter())
368-
.map(|f| {
369-
let name = format_ident!("{}", f.rust_name);
370-
quote! { params.#name }
371-
})
368+
.map(|f| macro_arg_for_field(f))
372369
.collect();
373370

374371
let update_method = if use_macro {
@@ -586,6 +583,21 @@ fn build_where_clause_parsed(
586583
.join(" AND ")
587584
}
588585

586+
fn macro_arg_for_field(field: &ParsedField) -> TokenStream {
587+
let name = format_ident!("{}", field.rust_name);
588+
let check_type = if field.is_nullable {
589+
&field.inner_type
590+
} else {
591+
&field.rust_type
592+
};
593+
let normalized = check_type.replace(' ', "");
594+
if normalized.starts_with("Vec<") {
595+
quote! { params.#name.as_slice() }
596+
} else {
597+
quote! { params.#name }
598+
}
599+
}
600+
589601
fn build_where_clause_cast(
590602
pk_fields: &[&ParsedField],
591603
db_kind: DatabaseKind,
@@ -618,10 +630,7 @@ fn build_insert_method_parsed(
618630
if use_macro {
619631
let macro_args: Vec<TokenStream> = non_pk_fields
620632
.iter()
621-
.map(|f| {
622-
let name = format_ident!("{}", f.rust_name);
623-
quote! { params.#name }
624-
})
633+
.map(|f| macro_arg_for_field(f))
625634
.collect();
626635

627636
match db_kind {
@@ -1432,4 +1441,97 @@ mod tests {
14321441
// DELETE still uses query! macro
14331442
assert!(code.contains("query!"));
14341443
}
1444+
1445+
// --- Vec<String> native array uses .as_slice() in macro mode ---
1446+
1447+
fn entity_with_vec_string() -> ParsedEntity {
1448+
ParsedEntity {
1449+
struct_name: "PromptHistory".to_string(),
1450+
table_name: "prompt_history".to_string(),
1451+
schema_name: None,
1452+
is_view: false,
1453+
fields: vec![
1454+
ParsedField {
1455+
rust_name: "id".to_string(),
1456+
column_name: "id".to_string(),
1457+
rust_type: "Uuid".to_string(),
1458+
inner_type: "Uuid".to_string(),
1459+
is_nullable: false,
1460+
is_primary_key: true,
1461+
sql_type: None,
1462+
is_sql_array: false,
1463+
},
1464+
ParsedField {
1465+
rust_name: "content".to_string(),
1466+
column_name: "content".to_string(),
1467+
rust_type: "String".to_string(),
1468+
inner_type: "String".to_string(),
1469+
is_nullable: false,
1470+
is_primary_key: false,
1471+
sql_type: None,
1472+
is_sql_array: false,
1473+
},
1474+
ParsedField {
1475+
rust_name: "tags".to_string(),
1476+
column_name: "tags".to_string(),
1477+
rust_type: "Vec<String>".to_string(),
1478+
inner_type: "Vec<String>".to_string(),
1479+
is_nullable: false,
1480+
is_primary_key: false,
1481+
sql_type: None,
1482+
is_sql_array: false,
1483+
},
1484+
],
1485+
imports: vec!["use uuid::Uuid;".to_string()],
1486+
}
1487+
}
1488+
1489+
#[test]
1490+
fn test_vec_string_macro_insert_uses_as_slice() {
1491+
let skip = Methods::all();
1492+
let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true);
1493+
let code = parse_and_format(&tokens);
1494+
assert!(code.contains("as_slice()"));
1495+
}
1496+
1497+
#[test]
1498+
fn test_vec_string_macro_update_uses_as_slice() {
1499+
let skip = Methods::all();
1500+
let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true);
1501+
let code = parse_and_format(&tokens);
1502+
// Should have as_slice() for both insert and update
1503+
let count = code.matches("as_slice()").count();
1504+
assert!(count >= 2, "expected at least 2 as_slice() calls (insert + update), found {}", count);
1505+
}
1506+
1507+
#[test]
1508+
fn test_vec_string_non_macro_no_as_slice() {
1509+
let skip = Methods::all();
1510+
let (tokens, _) = generate_crud_from_parsed(&entity_with_vec_string(), DatabaseKind::Postgres, "crate::models::prompt_history", &skip, false);
1511+
let code = parse_and_format(&tokens);
1512+
// Runtime mode uses .bind() so no as_slice needed
1513+
assert!(!code.contains("as_slice()"));
1514+
}
1515+
1516+
#[test]
1517+
fn test_vec_string_parsed_from_source_uses_as_slice() {
1518+
use crate::codegen::entity_parser::parse_entity_source;
1519+
let source = r#"
1520+
use uuid::Uuid;
1521+
1522+
#[derive(Debug, Clone, sqlx::FromRow, SqlxGen)]
1523+
#[sqlx_gen(kind = "table", schema = "agent", table = "prompt_history")]
1524+
pub struct PromptHistory {
1525+
#[sqlx_gen(primary_key)]
1526+
pub id: Uuid,
1527+
pub content: String,
1528+
pub tags: Vec<String>,
1529+
}
1530+
"#;
1531+
let entity = parse_entity_source(source).unwrap();
1532+
let skip = Methods::all();
1533+
let (tokens, _) = generate_crud_from_parsed(&entity, DatabaseKind::Postgres, "crate::models::prompt_history", &skip, true);
1534+
let code = parse_and_format(&tokens);
1535+
assert!(code.contains("as_slice()"), "Expected as_slice() in generated code:\n{}", code);
1536+
}
14351537
}

crates/sqlx_gen/src/codegen/struct_gen.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,14 @@ fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<S
149149
return (Some(qualified), is_array);
150150
}
151151

152+
// Check if this is a non-builtin type that would hit the typemap fallback
153+
// (e.g. range types like "timerange", "tsrange", etc.)
154+
// Domains resolve to their base type, so they don't need marking.
155+
let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
156+
if !is_domain && !typemap::postgres::is_builtin(base_name) {
157+
return (Some(base_name.to_string()), is_array);
158+
}
159+
152160
(None, false)
153161
}
154162

crates/sqlx_gen/src/typemap/postgres.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,32 @@ use heck::ToUpperCamelCase;
33
use super::RustType;
44
use crate::introspect::SchemaInfo;
55

6+
/// Returns true if the udt_name is a known PostgreSQL builtin type
7+
/// (i.e., not a fallback to String).
8+
pub fn is_builtin(udt_name: &str) -> bool {
9+
matches!(
10+
udt_name,
11+
"bool"
12+
| "int2" | "smallint" | "smallserial"
13+
| "int4" | "int" | "integer" | "serial"
14+
| "int8" | "bigint" | "bigserial"
15+
| "float4" | "real"
16+
| "float8" | "double precision"
17+
| "numeric" | "decimal"
18+
| "varchar" | "text" | "bpchar" | "char" | "name" | "citext"
19+
| "bytea"
20+
| "timestamp" | "timestamp without time zone"
21+
| "timestamptz" | "timestamp with time zone"
22+
| "date"
23+
| "time" | "time without time zone"
24+
| "timetz" | "time with time zone"
25+
| "uuid"
26+
| "json" | "jsonb"
27+
| "inet" | "cidr"
28+
| "oid"
29+
)
30+
}
31+
632
pub fn map_type(udt_name: &str, schema_info: &SchemaInfo) -> RustType {
733
// Handle array types (prefixed with '_' in PG)
834
if let Some(inner) = udt_name.strip_prefix('_') {

crates/sqlx_gen_macros/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "sqlx-gen-macros"
3-
version = "0.4.2"
3+
version = "0.4.3"
44
edition = "2021"
55
description = "No-op attribute macros for sqlx-gen generated code"
66
license = "MIT"

0 commit comments

Comments
 (0)