Skip to content

Commit a763b5b

Browse files
committed
feat: make view properties type as same as origin
1 parent 09e92e1 commit a763b5b

File tree

5 files changed

+540
-13
lines changed

5 files changed

+540
-13
lines changed

crates/sqlx_gen/src/codegen/mod.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,9 @@ pub fn generate(
124124
let imports = filter_imports(&imports, single_file);
125125
let code = format_tokens_with_imports(&tokens, &imports);
126126
let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str()));
127-
let origin = format!("Table: {}.{}", table.schema_name, table.name);
128127
files.push(GeneratedFile {
129128
filename: format!("{}.rs", module_name),
130-
origin: Some(origin),
129+
origin: None,
131130
code,
132131
});
133132
}
@@ -139,10 +138,9 @@ pub fn generate(
139138
let imports = filter_imports(&imports, single_file);
140139
let code = format_tokens_with_imports(&tokens, &imports);
141140
let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str()));
142-
let origin = format!("View: {}.{}", view.schema_name, view.name);
143141
files.push(GeneratedFile {
144142
filename: format!("{}.rs", module_name),
145-
origin: Some(origin),
143+
origin: None,
146144
code,
147145
});
148146
}
@@ -805,13 +803,13 @@ mod tests {
805803
}
806804

807805
#[test]
808-
fn test_generate_origin_correct() {
806+
fn test_generate_no_origin_for_tables() {
809807
let schema = SchemaInfo {
810808
tables: vec![make_table("users", vec![make_col("id", "int4")])],
811809
..Default::default()
812810
};
813811
let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
814-
assert_eq!(files[0].origin, Some("Table: public.users".to_string()));
812+
assert_eq!(files[0].origin, None);
815813
}
816814

817815
#[test]
@@ -949,13 +947,13 @@ mod tests {
949947
}
950948

951949
#[test]
952-
fn test_generate_view_origin() {
950+
fn test_generate_no_origin_for_views() {
953951
let schema = SchemaInfo {
954952
views: vec![make_view("active_users", vec![make_col("id", "int4")])],
955953
..Default::default()
956954
};
957955
let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
958-
assert_eq!(files[0].origin, Some("View: public.active_users".to_string()));
956+
assert_eq!(files[0].origin, None);
959957
}
960958

961959
#[test]

crates/sqlx_gen/src/introspect/mysql.rs

Lines changed: 227 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::collections::HashMap;
2+
13
use crate::error::Result;
24
use sqlx::MySqlPool;
35

@@ -9,11 +11,17 @@ pub async fn introspect(
911
include_views: bool,
1012
) -> Result<SchemaInfo> {
1113
let tables = fetch_tables(pool, schemas).await?;
12-
let views = if include_views {
14+
let mut views = if include_views {
1315
fetch_views(pool, schemas).await?
1416
} else {
1517
Vec::new()
1618
};
19+
20+
if !views.is_empty() {
21+
let sources = fetch_view_column_sources(pool, schemas).await?;
22+
resolve_view_nullability(&mut views, &sources, &tables);
23+
}
24+
1725
let enums = extract_enums(&tables);
1826

1927
Ok(SchemaInfo {
@@ -141,6 +149,105 @@ async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableIn
141149
Ok(views)
142150
}
143151

152+
struct ViewColumnSource {
153+
view_schema: String,
154+
view_name: String,
155+
table_schema: String,
156+
table_name: String,
157+
column_name: String,
158+
}
159+
160+
async fn fetch_view_column_sources(
161+
pool: &MySqlPool,
162+
schemas: &[String],
163+
) -> Result<Vec<ViewColumnSource>> {
164+
let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
165+
let query = format!(
166+
r#"
167+
SELECT
168+
vcu.VIEW_SCHEMA,
169+
vcu.VIEW_NAME,
170+
vcu.TABLE_SCHEMA,
171+
vcu.TABLE_NAME,
172+
vcu.COLUMN_NAME
173+
FROM INFORMATION_SCHEMA.VIEW_COLUMN_USAGE vcu
174+
WHERE vcu.VIEW_SCHEMA IN ({})
175+
"#,
176+
placeholders.join(",")
177+
);
178+
179+
let mut q = sqlx::query_as::<_, (String, String, String, String, String)>(&query);
180+
for schema in schemas {
181+
q = q.bind(schema);
182+
}
183+
184+
match q.fetch_all(pool).await {
185+
Ok(rows) => Ok(rows
186+
.into_iter()
187+
.map(
188+
|(view_schema, view_name, table_schema, table_name, column_name)| {
189+
ViewColumnSource {
190+
view_schema,
191+
view_name,
192+
table_schema,
193+
table_name,
194+
column_name,
195+
}
196+
},
197+
)
198+
.collect()),
199+
Err(_) => {
200+
// VIEW_COLUMN_USAGE may not exist on older MySQL versions
201+
Ok(Vec::new())
202+
}
203+
}
204+
}
205+
206+
fn resolve_view_nullability(
207+
views: &mut [TableInfo],
208+
sources: &[ViewColumnSource],
209+
tables: &[TableInfo],
210+
) {
211+
// Build table column lookup: (schema, table, column) -> is_nullable
212+
let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
213+
for table in tables {
214+
for col in &table.columns {
215+
table_lookup.insert(
216+
(&table.schema_name, &table.name, &col.name),
217+
col.is_nullable,
218+
);
219+
}
220+
}
221+
222+
// Build view column source lookup: (view_schema, view_name, column_name) -> Vec<is_nullable>
223+
let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
224+
for src in sources {
225+
if let Some(&is_nullable) =
226+
table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
227+
{
228+
view_lookup
229+
.entry((&src.view_schema, &src.view_name, &src.column_name))
230+
.or_default()
231+
.push(is_nullable);
232+
}
233+
}
234+
235+
for view in views.iter_mut() {
236+
for col in view.columns.iter_mut() {
237+
if let Some(nullable_flags) = view_lookup.get(&(
238+
view.schema_name.as_str(),
239+
view.name.as_str(),
240+
col.name.as_str(),
241+
)) {
242+
// Only mark as non-nullable if ALL sources are NOT nullable
243+
if !nullable_flags.is_empty() && nullable_flags.iter().all(|&n| !n) {
244+
col.is_nullable = false;
245+
}
246+
}
247+
}
248+
}
249+
}
250+
144251
/// Extract inline ENUMs from column types.
145252
/// MySQL ENUM('a','b','c') in COLUMN_TYPE gets extracted to an EnumInfo
146253
/// keyed by table_name + column_name.
@@ -339,4 +446,123 @@ mod tests {
339446
let enums = extract_enums(&tables);
340447
assert_eq!(enums.len(), 1);
341448
}
449+
450+
// ========== resolve_view_nullability ==========
451+
452+
fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
453+
TableInfo {
454+
schema_name: schema.to_string(),
455+
name: name.to_string(),
456+
columns: columns
457+
.into_iter()
458+
.enumerate()
459+
.map(|(i, col)| ColumnInfo {
460+
name: col.to_string(),
461+
data_type: "varchar".to_string(),
462+
udt_name: "varchar(255)".to_string(),
463+
is_nullable: true,
464+
is_primary_key: false,
465+
ordinal_position: i as i32,
466+
schema_name: schema.to_string(),
467+
column_default: None,
468+
})
469+
.collect(),
470+
}
471+
}
472+
473+
fn make_table_with_nullability(
474+
schema: &str,
475+
name: &str,
476+
columns: Vec<(&str, bool)>,
477+
) -> TableInfo {
478+
TableInfo {
479+
schema_name: schema.to_string(),
480+
name: name.to_string(),
481+
columns: columns
482+
.into_iter()
483+
.enumerate()
484+
.map(|(i, (col, nullable))| ColumnInfo {
485+
name: col.to_string(),
486+
data_type: "varchar".to_string(),
487+
udt_name: "varchar(255)".to_string(),
488+
is_nullable: nullable,
489+
is_primary_key: false,
490+
ordinal_position: i as i32,
491+
schema_name: schema.to_string(),
492+
column_default: None,
493+
})
494+
.collect(),
495+
}
496+
}
497+
498+
fn make_source(
499+
view_schema: &str,
500+
view_name: &str,
501+
table_schema: &str,
502+
table_name: &str,
503+
column_name: &str,
504+
) -> ViewColumnSource {
505+
ViewColumnSource {
506+
view_schema: view_schema.to_string(),
507+
view_name: view_name.to_string(),
508+
table_schema: table_schema.to_string(),
509+
table_name: table_name.to_string(),
510+
column_name: column_name.to_string(),
511+
}
512+
}
513+
514+
#[test]
515+
fn test_resolve_not_null_column() {
516+
let tables = vec![make_table_with_nullability(
517+
"db",
518+
"users",
519+
vec![("id", false), ("name", false)],
520+
)];
521+
let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
522+
let sources = vec![
523+
make_source("db", "my_view", "db", "users", "id"),
524+
make_source("db", "my_view", "db", "users", "name"),
525+
];
526+
resolve_view_nullability(&mut views, &sources, &tables);
527+
assert!(!views[0].columns[0].is_nullable);
528+
assert!(!views[0].columns[1].is_nullable);
529+
}
530+
531+
#[test]
532+
fn test_resolve_nullable_source() {
533+
let tables = vec![make_table_with_nullability(
534+
"db",
535+
"users",
536+
vec![("id", false), ("name", true)],
537+
)];
538+
let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
539+
let sources = vec![
540+
make_source("db", "my_view", "db", "users", "id"),
541+
make_source("db", "my_view", "db", "users", "name"),
542+
];
543+
resolve_view_nullability(&mut views, &sources, &tables);
544+
assert!(!views[0].columns[0].is_nullable);
545+
assert!(views[0].columns[1].is_nullable);
546+
}
547+
548+
#[test]
549+
fn test_resolve_no_match_stays_nullable() {
550+
let tables = vec![make_table_with_nullability(
551+
"db",
552+
"users",
553+
vec![("id", false)],
554+
)];
555+
let mut views = vec![make_view("db", "my_view", vec!["computed"])];
556+
let sources = vec![];
557+
resolve_view_nullability(&mut views, &sources, &tables);
558+
assert!(views[0].columns[0].is_nullable);
559+
}
560+
561+
#[test]
562+
fn test_resolve_empty_sources() {
563+
let tables = vec![];
564+
let mut views = vec![make_view("db", "my_view", vec!["id"])];
565+
resolve_view_nullability(&mut views, &[], &tables);
566+
assert!(views[0].columns[0].is_nullable);
567+
}
342568
}

0 commit comments

Comments
 (0)