Skip to content

Commit 55f4e2c

Browse files
committed
feat: generate type source and add default macros
1 parent d8ea5e6 commit 55f4e2c

7 files changed

Lines changed: 174 additions & 10 deletions

File tree

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Supports **PostgreSQL**, **MySQL**, and **SQLite**. Introspects tables, views, e
2020
- Custom derives (`--derives Serialize,Deserialize`)
2121
- Type overrides (`--type-overrides jsonb=MyType`)
2222
- SQL views support (`--views`)
23-
- Table filtering (`--tables users,orders`)
23+
- Table filtering (`--tables users,orders`) and exclusion (`--exclude-tables _migrations`)
2424
- Single-file or multi-file output
2525
- Dry-run mode (preview on stdout)
2626

@@ -52,6 +52,11 @@ sqlx-gen -u sqlite:./local.db -o src/models
5252
sqlx-gen -u postgres://... --derives Serialize,Deserialize -o src/models
5353
```
5454

55+
### Exclude specific tables
56+
```sh
57+
sqlx-gen -u postgres://... --exclude-tables _migrations,schema_versions -o src/models
58+
```
59+
5560
### Include SQL views
5661
```sh
5762
sqlx-gen -u postgres://... --views -o src/models
@@ -72,6 +77,7 @@ sqlx-gen -u postgres://... --dry-run
7277
| `--derives` | | Additional derive macros (comma-separated) | none |
7378
| `--type-overrides` | | Type overrides `sql_type=RustType` (comma-separated) | none |
7479
| `--tables` | | Only generate these tables (comma-separated) | all |
80+
| `--exclude-tables` | | Exclude these tables/views (comma-separated) | none |
7581
| `--views` | | Also generate structs for SQL views | false |
7682
| `--single-file` | | Write everything to a single `models.rs` | false |
7783
| `--dry-run` | | Print to stdout, don't write files | false |

src/cli.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ pub struct Args {
3333
#[arg(long, value_delimiter = ',')]
3434
pub tables: Option<Vec<String>>,
3535

36+
/// Exclude these tables/views from generation (comma-separated)
37+
#[arg(long, value_delimiter = ',')]
38+
pub exclude_tables: Option<Vec<String>>,
39+
3640
/// Also generate structs for SQL views
3741
#[arg(long)]
3842
pub views: bool,
@@ -89,6 +93,7 @@ mod tests {
8993
type_overrides: vec![],
9094
single_file: false,
9195
tables: None,
96+
exclude_tables: None,
9297
views: false,
9398
dry_run: false,
9499
}
@@ -103,6 +108,7 @@ mod tests {
103108
type_overrides: overrides.into_iter().map(|s| s.to_string()).collect(),
104109
single_file: false,
105110
tables: None,
111+
exclude_tables: None,
106112
views: false,
107113
dry_run: false,
108114
}
@@ -240,4 +246,20 @@ mod tests {
240246
let map = args.parse_type_overrides();
241247
assert_eq!(map.get("key").unwrap(), "");
242248
}
249+
250+
// ========== exclude_tables ==========
251+
252+
#[test]
253+
fn test_exclude_tables_default_none() {
254+
let args = make_args("postgres://localhost/db");
255+
assert!(args.exclude_tables.is_none());
256+
}
257+
258+
#[test]
259+
fn test_exclude_tables_set() {
260+
let mut args = make_args("postgres://localhost/db");
261+
args.exclude_tables = Some(vec!["_migrations".to_string(), "schema_versions".to_string()]);
262+
assert_eq!(args.exclude_tables.as_ref().unwrap().len(), 2);
263+
assert!(args.exclude_tables.as_ref().unwrap().contains(&"_migrations".to_string()));
264+
}
243265
}

src/codegen/composite_gen.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,27 @@ pub fn generate_composite(
2727
composite.schema_name, composite.name
2828
);
2929

30+
imports.insert("use serde::{Serialize, Deserialize};".to_string());
3031
let mut derive_tokens = vec![
3132
quote! { Debug },
3233
quote! { Clone },
34+
quote! { PartialEq },
35+
quote! { Eq },
36+
quote! { Serialize },
37+
quote! { Deserialize },
3338
quote! { sqlx::Type },
3439
];
3540
for d in extra_derives {
3641
let ident = format_ident!("{}", d);
3742
derive_tokens.push(quote! { #ident });
3843
}
3944

40-
let pg_name = &composite.name;
45+
// Schema-qualify the type name for non-public schemas so sqlx can find the type
46+
let pg_name = if composite.schema_name != "public" {
47+
format!("{}.{}", composite.schema_name, composite.name)
48+
} else {
49+
composite.name.clone()
50+
};
4151
let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
4252

4353
let fields: Vec<TokenStream> = composite
@@ -169,6 +179,28 @@ mod tests {
169179
assert!(code.contains("sqlx(type_name = \"geo_point\")"));
170180
}
171181

182+
#[test]
183+
fn test_non_public_schema_qualified_type_name() {
184+
let c = CompositeTypeInfo {
185+
schema_name: "geo".to_string(),
186+
name: "point".to_string(),
187+
fields: vec![make_field("x", "float8", false)],
188+
};
189+
let schema = SchemaInfo::default();
190+
let (tokens, _) = generate_composite(&c, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
191+
let code = parse_and_format(&tokens);
192+
assert!(code.contains("sqlx(type_name = \"geo.point\")"));
193+
}
194+
195+
#[test]
196+
fn test_public_schema_not_qualified() {
197+
let c = make_composite("address", vec![make_field("x", "text", false)]);
198+
let code = gen(&c);
199+
assert!(code.contains("sqlx(type_name = \"address\")"));
200+
// type_name should NOT be schema-qualified for public schema
201+
assert!(!code.contains("type_name = \"public.address\""));
202+
}
203+
172204
// --- fields ---
173205

174206
#[test]

src/codegen/enum_gen.rs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@ pub fn generate_enum(
1717
for imp in imports_for_derives(extra_derives) {
1818
imports.insert(imp);
1919
}
20-
let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
2120

21+
let enum_name = format_ident!("{}", enum_info.name.to_upper_camel_case());
2222
let doc = format!("Enum: {}.{}", enum_info.schema_name, enum_info.name);
2323

24+
imports.insert("use serde::{Serialize, Deserialize};".to_string());
2425
let mut derive_tokens = vec![
2526
quote! { Debug },
2627
quote! { Clone },
2728
quote! { PartialEq },
29+
quote! { Eq },
30+
quote! { Serialize },
31+
quote! { Deserialize },
2832
quote! { sqlx::Type },
2933
];
3034
for d in extra_derives {
@@ -33,8 +37,13 @@ pub fn generate_enum(
3337
}
3438

3539
// For PG, add #[sqlx(type_name = "...")]
40+
// Schema-qualify the type name for non-public schemas so sqlx can find the type
3641
let type_attr = if db_kind == DatabaseKind::Postgres {
37-
let pg_name = &enum_info.name;
42+
let pg_name = if enum_info.schema_name != "public" {
43+
format!("{}.{}", enum_info.schema_name, enum_info.name)
44+
} else {
45+
enum_info.name.clone()
46+
};
3847
quote! { #[sqlx(type_name = #pg_name)] }
3948
} else {
4049
quote! {}
@@ -90,7 +99,11 @@ mod tests {
9099
parse_and_format(&tokens)
91100
}
92101

93-
fn gen_with_derives(info: &EnumInfo, db: DatabaseKind, derives: &[String]) -> (String, BTreeSet<String>) {
102+
fn gen_with_derives(
103+
info: &EnumInfo,
104+
db: DatabaseKind,
105+
derives: &[String],
106+
) -> (String, BTreeSet<String>) {
94107
let (tokens, imports) = generate_enum(info, db, derives);
95108
(parse_and_format(&tokens), imports)
96109
}
@@ -128,6 +141,27 @@ mod tests {
128141
assert!(code.contains("sqlx(type_name = \"user_status\")"));
129142
}
130143

144+
#[test]
145+
fn test_postgres_non_public_schema_qualified_type_name() {
146+
let e = EnumInfo {
147+
schema_name: "auth".to_string(),
148+
name: "role".to_string(),
149+
variants: vec!["admin".to_string(), "user".to_string()],
150+
};
151+
let (tokens, _) = generate_enum(&e, DatabaseKind::Postgres, &[]);
152+
let code = parse_and_format(&tokens);
153+
assert!(code.contains("sqlx(type_name = \"auth.role\")"));
154+
}
155+
156+
#[test]
157+
fn test_postgres_public_schema_not_qualified() {
158+
let e = make_enum("status", vec!["a"]);
159+
let code = gen(&e, DatabaseKind::Postgres);
160+
assert!(code.contains("sqlx(type_name = \"status\")"));
161+
// type_name should NOT be schema-qualified for public schema
162+
assert!(!code.contains("type_name = \"public.status\""));
163+
}
164+
131165
#[test]
132166
fn test_mysql_no_type_name() {
133167
let e = make_enum("status", vec!["a"]);
@@ -207,10 +241,10 @@ mod tests {
207241
// --- imports ---
208242

209243
#[test]
210-
fn test_no_derives_empty_imports() {
244+
fn test_no_extra_derives_has_serde_import() {
211245
let e = make_enum("status", vec!["a"]);
212246
let (_, imports) = gen_with_derives(&e, DatabaseKind::Postgres, &[]);
213-
assert!(imports.is_empty());
247+
assert!(imports.iter().any(|i| i.contains("serde")));
214248
}
215249

216250
#[test]

src/codegen/struct_gen.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,14 @@ pub fn generate_struct(
2323
let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
2424

2525
// Build derive list
26+
imports.insert("use serde::{Serialize, Deserialize};".to_string());
2627
let mut derive_tokens = vec![
2728
quote! { Debug },
2829
quote! { Clone },
30+
quote! { PartialEq },
31+
quote! { Eq },
32+
quote! { Serialize },
33+
quote! { Deserialize },
2934
quote! { sqlx::FromRow },
3035
];
3136
for d in extra_derives {
@@ -315,11 +320,12 @@ mod tests {
315320
}
316321

317322
#[test]
318-
fn test_int4_no_import() {
323+
fn test_int4_only_serde_import() {
319324
let table = make_table("users", vec![make_col("id", "int4", false)]);
320325
let schema = SchemaInfo::default();
321326
let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
322-
assert!(imports.is_empty());
327+
assert_eq!(imports.len(), 1);
328+
assert!(imports.iter().any(|i| i.contains("serde")));
323329
}
324330

325331
#[test]

src/main.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,17 @@ async fn main() -> Result<()> {
4040
}
4141
};
4242

43-
// Filter tables if requested
43+
// Filter tables if requested (whitelist)
4444
if let Some(ref filter) = args.tables {
4545
schema_info.tables.retain(|t| filter.contains(&t.name));
4646
}
4747

48+
// Exclude tables/views if requested (blacklist)
49+
if let Some(ref exclude) = args.exclude_tables {
50+
schema_info.tables.retain(|t| !exclude.contains(&t.name));
51+
schema_info.views.retain(|v| !exclude.contains(&v.name));
52+
}
53+
4854
let table_count = schema_info.tables.len();
4955
let view_count = schema_info.views.len();
5056
let enum_count = schema_info.enums.len();

tests/e2e_sqlite.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,61 @@ async fn test_view_pascal_case_name() {
135135
let view_file = files.iter().find(|f| f.filename == "all_active_users.rs").unwrap();
136136
assert!(view_file.code.contains("pub struct AllActiveUsers"));
137137
}
138+
139+
// --- exclude tables ---
140+
141+
#[tokio::test]
142+
async fn test_exclude_table() {
143+
let pool = setup_pool().await;
144+
exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL)").await;
145+
exec(&pool, "CREATE TABLE _migrations (id INTEGER NOT NULL)").await;
146+
let mut schema = introspect(&pool, false).await.unwrap();
147+
let exclude = vec!["_migrations".to_string()];
148+
schema.tables.retain(|t| !exclude.contains(&t.name));
149+
let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false);
150+
assert_eq!(files.len(), 1);
151+
assert_eq!(files[0].filename, "users.rs");
152+
}
153+
154+
#[tokio::test]
155+
async fn test_exclude_nonexistent_table() {
156+
let pool = setup_pool().await;
157+
exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL)").await;
158+
exec(&pool, "CREATE TABLE posts (id INTEGER NOT NULL)").await;
159+
let mut schema = introspect(&pool, false).await.unwrap();
160+
let exclude = vec!["nonexistent".to_string()];
161+
schema.tables.retain(|t| !exclude.contains(&t.name));
162+
assert_eq!(schema.tables.len(), 2);
163+
}
164+
165+
#[tokio::test]
166+
async fn test_tables_include_then_exclude() {
167+
let pool = setup_pool().await;
168+
exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL)").await;
169+
exec(&pool, "CREATE TABLE posts (id INTEGER NOT NULL)").await;
170+
exec(&pool, "CREATE TABLE comments (id INTEGER NOT NULL)").await;
171+
let mut schema = introspect(&pool, false).await.unwrap();
172+
// Simulate --tables users,posts
173+
let include = vec!["users".to_string(), "posts".to_string()];
174+
schema.tables.retain(|t| include.contains(&t.name));
175+
// Simulate --exclude-tables posts
176+
let exclude = vec!["posts".to_string()];
177+
schema.tables.retain(|t| !exclude.contains(&t.name));
178+
assert_eq!(schema.tables.len(), 1);
179+
assert_eq!(schema.tables[0].name, "users");
180+
}
181+
182+
#[tokio::test]
183+
async fn test_exclude_view() {
184+
let pool = setup_pool().await;
185+
exec(&pool, "CREATE TABLE users (id INTEGER NOT NULL)").await;
186+
exec(&pool, "CREATE VIEW v1 AS SELECT id FROM users").await;
187+
exec(&pool, "CREATE VIEW v2 AS SELECT id FROM users").await;
188+
let mut schema = introspect(&pool, true).await.unwrap();
189+
let exclude = vec!["v1".to_string()];
190+
schema.views.retain(|v| !exclude.contains(&v.name));
191+
let files = codegen::generate(&schema, DatabaseKind::Sqlite, &[], &HashMap::new(), false);
192+
let view_files: Vec<_> = files.iter().filter(|f| f.origin.as_ref().is_some_and(|o| o.starts_with("View:"))).collect();
193+
assert_eq!(view_files.len(), 1);
194+
assert_eq!(view_files[0].filename, "v2.rs");
195+
}

0 commit comments

Comments
 (0)