diff --git a/src/ts_generator/sql_parser/expressions/translate_expr.rs b/src/ts_generator/sql_parser/expressions/translate_expr.rs index 7351fe1d..d47ea135 100644 --- a/src/ts_generator/sql_parser/expressions/translate_expr.rs +++ b/src/ts_generator/sql_parser/expressions/translate_expr.rs @@ -124,6 +124,7 @@ pub async fn get_sql_query_param( single_table_name: &Option<&str>, table_with_joins: &Option>, db_conn: &DBConn, + cte_columns: &std::collections::HashMap>, ) -> Result)>, TsGeneratorError> { let table_name: Option; @@ -145,6 +146,15 @@ pub async fn get_sql_query_param( match (column_name, expr_placeholder, table_name) { (Some(column_name), Some(expr_placeholder), Some(table_name)) => { + // First check if the table is a CTE or table-valued function + if let Some(cte_table_columns) = cte_columns.get(table_name.as_str()) { + if let Some(ts_type) = cte_table_columns.get(column_name.as_str()) { + return Ok(Some((ts_type.clone(), false, Some(expr_placeholder)))); + } + // Column not found in CTE columns — return None to allow fallback handling + return Ok(None); + } + let table_names = vec![table_name.as_str()]; let columns = DB_SCHEMA .lock() @@ -312,7 +322,15 @@ pub async fn translate_expr( // OPERATORS START // ///////////////////// Expr::BinaryOp { left, op: _, right } => { - let param = get_sql_query_param(left, right, single_table_name, table_with_joins, db_conn).await?; + let param = get_sql_query_param( + left, + right, + single_table_name, + table_with_joins, + db_conn, + &ts_query.table_valued_function_columns, + ) + .await?; if let Some((value, is_nullable, index)) = param { let _ = ts_query.insert_param(&value, &is_nullable, &index); Ok(()) @@ -354,6 +372,7 @@ pub async fn translate_expr( single_table_name, table_with_joins, db_conn, + &ts_query.table_valued_function_columns, ) .await?; @@ -383,8 +402,24 @@ pub async fn translate_expr( low, high, } => { - let low = get_sql_query_param(expr, low, single_table_name, table_with_joins, db_conn).await?; - let high = get_sql_query_param(expr, high, single_table_name, table_with_joins, db_conn).await?; + let low = get_sql_query_param( + expr, + low, + single_table_name, + table_with_joins, + db_conn, + &ts_query.table_valued_function_columns, + ) + .await?; + let high = get_sql_query_param( + expr, + high, + single_table_name, + table_with_joins, + db_conn, + &ts_query.table_valued_function_columns, + ) + .await?; if let Some((value, is_nullable, placeholder)) = low { ts_query.insert_param(&value, &is_nullable, &placeholder)?; } diff --git a/src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs b/src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs index 768f4aee..74cfe7aa 100644 --- a/src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs +++ b/src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs @@ -64,14 +64,25 @@ pub async fn translate_wildcard_expr( ts_query: &mut TsQuery, db_conn: &DBConn, ) -> Result<(), TsGeneratorError> { - let table_with_joins = get_all_table_names_from_select(select)?; + let table_names = get_all_table_names_from_select(select)?; - if table_with_joins.len() > 1 { + // Check if the table is a CTE or table-valued function registered in table_valued_function_columns. + // CTEs are processed before the main query body and their columns are stored there. + for table_name in &table_names { + if let Some(tvf_columns) = ts_query.table_valued_function_columns.get(table_name).cloned() { + for (col_name, ts_type) in tvf_columns { + ts_query.result.insert(col_name, vec![ts_type]); + } + return Ok(()); + } + } + + if table_names.len() > 1 { warning!("Impossible to calculate appropriate field names of a wildcard query with multiple tables. Please use explicit field names instead. Query: {}", select.to_string()); } - let table_with_joins = table_with_joins.iter().map(|s| s.as_ref()).collect(); - let all_fields = DB_SCHEMA.lock().await.fetch_table(&table_with_joins, db_conn).await; + let table_refs = table_names.iter().map(|s| s.as_ref()).collect(); + let all_fields = DB_SCHEMA.lock().await.fetch_table(&table_refs, db_conn).await; if let Some(all_fields) = all_fields { for key in all_fields.keys() { diff --git a/src/ts_generator/sql_parser/translate_query.rs b/src/ts_generator/sql_parser/translate_query.rs index 8abc76d0..a295b18b 100644 --- a/src/ts_generator/sql_parser/translate_query.rs +++ b/src/ts_generator/sql_parser/translate_query.rs @@ -1,3 +1,4 @@ +use async_recursion::async_recursion; use sqlparser::ast::{FunctionArg, FunctionArgExpr, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins}; use std::collections::HashMap; @@ -195,6 +196,7 @@ pub async fn translate_select( } /// Translates a query and workout ts_query's results and params +#[async_recursion] pub async fn translate_query( ts_query: &mut TsQuery, // this parameter is used to stack table_with_joins while recursing through subqueries @@ -205,6 +207,40 @@ pub async fn translate_query( alias: Option<&str>, is_selection: bool, ) -> Result<(), TsGeneratorError> { + // Process CTEs (WITH clause) before the main query body. + // Each CTE is processed to extract its output columns, which are then registered + // as virtual table columns so the main query can reference them. + if let Some(with) = &query.with { + for cte in &with.cte_tables { + let cte_name = DisplayIndent(&cte.alias.name).to_string(); + let mut cte_ts_query = TsQuery::new(cte_name.clone()); + translate_query(&mut cte_ts_query, &None, &cte.query, db_conn, None, true).await?; + + // Merge parameters from the CTE body into the outer query's params. + // Parameters inside CTE bodies ($1, $2, ?) belong to the overall query's parameter list. + for (idx, types) in &cte_ts_query.params { + ts_query.params.insert(*idx, types.clone()); + } + + // Extract the CTE's output columns and register them as virtual table columns + // so the outer query can look them up just like table-valued function columns + let cte_columns: HashMap = cte_ts_query + .result + .into_iter() + .map(|(col_name, types)| { + // Take the first non-null type; fall back to Any if only Null is present + let ts_type = types + .into_iter() + .find(|t| *t != TsFieldType::Null) + .unwrap_or(TsFieldType::Any); + (col_name, ts_type) + }) + .collect(); + + ts_query.table_valued_function_columns.insert(cte_name, cte_columns); + } + } + let body = *query.body.clone(); match body { SetExpr::Select(select) => { diff --git a/tests/demo/cte/cte.queries.ts b/tests/demo/cte/cte.queries.ts new file mode 100644 index 00000000..2af81abb --- /dev/null +++ b/tests/demo/cte/cte.queries.ts @@ -0,0 +1,48 @@ +export type SimpleCteParams = []; + +export interface ISimpleCteResult { + id: number; + name: string; +} + +export interface ISimpleCteQuery { + params: SimpleCteParams; + result: ISimpleCteResult; +} + +export type RankWithCteParams = []; + +export interface IRankWithCteResult { + id: number; + name: string; + rk: any; +} + +export interface IRankWithCteQuery { + params: RankWithCteParams; + result: IRankWithCteResult; +} + +export type MultipleCtesParams = []; + +export interface IMultipleCtesResult { + id: number; + name: string; +} + +export interface IMultipleCtesQuery { + params: MultipleCtesParams; + result: IMultipleCtesResult; +} + +export type CteWithWildcardParams = []; + +export interface ICteWithWildcardResult { + id: number; + name: string; +} + +export interface ICteWithWildcardQuery { + params: CteWithWildcardParams; + result: ICteWithWildcardResult; +} diff --git a/tests/demo/cte/cte.snapshot.ts b/tests/demo/cte/cte.snapshot.ts new file mode 100644 index 00000000..2af81abb --- /dev/null +++ b/tests/demo/cte/cte.snapshot.ts @@ -0,0 +1,48 @@ +export type SimpleCteParams = []; + +export interface ISimpleCteResult { + id: number; + name: string; +} + +export interface ISimpleCteQuery { + params: SimpleCteParams; + result: ISimpleCteResult; +} + +export type RankWithCteParams = []; + +export interface IRankWithCteResult { + id: number; + name: string; + rk: any; +} + +export interface IRankWithCteQuery { + params: RankWithCteParams; + result: IRankWithCteResult; +} + +export type MultipleCtesParams = []; + +export interface IMultipleCtesResult { + id: number; + name: string; +} + +export interface IMultipleCtesQuery { + params: MultipleCtesParams; + result: IMultipleCtesResult; +} + +export type CteWithWildcardParams = []; + +export interface ICteWithWildcardResult { + id: number; + name: string; +} + +export interface ICteWithWildcardQuery { + params: CteWithWildcardParams; + result: ICteWithWildcardResult; +} diff --git a/tests/demo/cte/cte.ts b/tests/demo/cte/cte.ts new file mode 100644 index 00000000..fec2de17 --- /dev/null +++ b/tests/demo/cte/cte.ts @@ -0,0 +1,46 @@ +import { sql } from 'sqlx-ts' + +// Simple CTE with explicit column selection +const simpleCte = sql` +-- @name: simple cte +WITH filtered_items AS ( + SELECT id, name FROM items +) +SELECT id, name FROM filtered_items +` + +// CTE with window function (RANK) — original issue #104 +const rankWithCte = sql` +-- @name: rank with cte +WITH ranked_items AS ( + SELECT + id, + name, + rarity, + RANK() OVER (PARTITION BY rarity ORDER BY id) AS rk + FROM items +) +SELECT id, name, rk FROM ranked_items WHERE rk = 1 +` + +// Multiple CTEs +const multipleCtes = sql` +-- @name: multiple ctes +WITH + popular AS ( + SELECT id, name FROM items WHERE id > 10 + ), + rare AS ( + SELECT id, name FROM items WHERE rarity = 'rare' + ) +SELECT id, name FROM popular +` + +// CTE with wildcard in outer query +const cteWithWildcard = sql` +-- @name: cte with wildcard +WITH base AS ( + SELECT id, name FROM items +) +SELECT * FROM base +`