Skip to content

Commit a49018a

Browse files
authored
CTE support (#251)
* cte support * fmt
1 parent f8dc398 commit a49018a

6 files changed

Lines changed: 231 additions & 7 deletions

File tree

src/ts_generator/sql_parser/expressions/translate_expr.rs

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ pub async fn get_sql_query_param(
124124
single_table_name: &Option<&str>,
125125
table_with_joins: &Option<Vec<TableWithJoins>>,
126126
db_conn: &DBConn,
127+
cte_columns: &std::collections::HashMap<String, std::collections::HashMap<String, TsFieldType>>,
127128
) -> Result<Option<(TsFieldType, bool, Option<String>)>, TsGeneratorError> {
128129
let table_name: Option<String>;
129130

@@ -145,6 +146,15 @@ pub async fn get_sql_query_param(
145146

146147
match (column_name, expr_placeholder, table_name) {
147148
(Some(column_name), Some(expr_placeholder), Some(table_name)) => {
149+
// First check if the table is a CTE or table-valued function
150+
if let Some(cte_table_columns) = cte_columns.get(table_name.as_str()) {
151+
if let Some(ts_type) = cte_table_columns.get(column_name.as_str()) {
152+
return Ok(Some((ts_type.clone(), false, Some(expr_placeholder))));
153+
}
154+
// Column not found in CTE columns — return None to allow fallback handling
155+
return Ok(None);
156+
}
157+
148158
let table_names = vec![table_name.as_str()];
149159
let columns = DB_SCHEMA
150160
.lock()
@@ -312,7 +322,15 @@ pub async fn translate_expr(
312322
// OPERATORS START //
313323
/////////////////////
314324
Expr::BinaryOp { left, op: _, right } => {
315-
let param = get_sql_query_param(left, right, single_table_name, table_with_joins, db_conn).await?;
325+
let param = get_sql_query_param(
326+
left,
327+
right,
328+
single_table_name,
329+
table_with_joins,
330+
db_conn,
331+
&ts_query.table_valued_function_columns,
332+
)
333+
.await?;
316334
if let Some((value, is_nullable, index)) = param {
317335
let _ = ts_query.insert_param(&value, &is_nullable, &index);
318336
Ok(())
@@ -354,6 +372,7 @@ pub async fn translate_expr(
354372
single_table_name,
355373
table_with_joins,
356374
db_conn,
375+
&ts_query.table_valued_function_columns,
357376
)
358377
.await?;
359378

@@ -383,8 +402,24 @@ pub async fn translate_expr(
383402
low,
384403
high,
385404
} => {
386-
let low = get_sql_query_param(expr, low, single_table_name, table_with_joins, db_conn).await?;
387-
let high = get_sql_query_param(expr, high, single_table_name, table_with_joins, db_conn).await?;
405+
let low = get_sql_query_param(
406+
expr,
407+
low,
408+
single_table_name,
409+
table_with_joins,
410+
db_conn,
411+
&ts_query.table_valued_function_columns,
412+
)
413+
.await?;
414+
let high = get_sql_query_param(
415+
expr,
416+
high,
417+
single_table_name,
418+
table_with_joins,
419+
db_conn,
420+
&ts_query.table_valued_function_columns,
421+
)
422+
.await?;
388423
if let Some((value, is_nullable, placeholder)) = low {
389424
ts_query.insert_param(&value, &is_nullable, &placeholder)?;
390425
}

src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,25 @@ pub async fn translate_wildcard_expr(
6464
ts_query: &mut TsQuery,
6565
db_conn: &DBConn,
6666
) -> Result<(), TsGeneratorError> {
67-
let table_with_joins = get_all_table_names_from_select(select)?;
67+
let table_names = get_all_table_names_from_select(select)?;
6868

69-
if table_with_joins.len() > 1 {
69+
// Check if the table is a CTE or table-valued function registered in table_valued_function_columns.
70+
// CTEs are processed before the main query body and their columns are stored there.
71+
for table_name in &table_names {
72+
if let Some(tvf_columns) = ts_query.table_valued_function_columns.get(table_name).cloned() {
73+
for (col_name, ts_type) in tvf_columns {
74+
ts_query.result.insert(col_name, vec![ts_type]);
75+
}
76+
return Ok(());
77+
}
78+
}
79+
80+
if table_names.len() > 1 {
7081
warning!("Impossible to calculate appropriate field names of a wildcard query with multiple tables. Please use explicit field names instead. Query: {}", select.to_string());
7182
}
7283

73-
let table_with_joins = table_with_joins.iter().map(|s| s.as_ref()).collect();
74-
let all_fields = DB_SCHEMA.lock().await.fetch_table(&table_with_joins, db_conn).await;
84+
let table_refs = table_names.iter().map(|s| s.as_ref()).collect();
85+
let all_fields = DB_SCHEMA.lock().await.fetch_table(&table_refs, db_conn).await;
7586

7687
if let Some(all_fields) = all_fields {
7788
for key in all_fields.keys() {

src/ts_generator/sql_parser/translate_query.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use async_recursion::async_recursion;
12
use sqlparser::ast::{FunctionArg, FunctionArgExpr, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins};
23
use std::collections::HashMap;
34

@@ -195,6 +196,7 @@ pub async fn translate_select(
195196
}
196197

197198
/// Translates a query and workout ts_query's results and params
199+
#[async_recursion]
198200
pub async fn translate_query(
199201
ts_query: &mut TsQuery,
200202
// this parameter is used to stack table_with_joins while recursing through subqueries
@@ -205,6 +207,40 @@ pub async fn translate_query(
205207
alias: Option<&str>,
206208
is_selection: bool,
207209
) -> Result<(), TsGeneratorError> {
210+
// Process CTEs (WITH clause) before the main query body.
211+
// Each CTE is processed to extract its output columns, which are then registered
212+
// as virtual table columns so the main query can reference them.
213+
if let Some(with) = &query.with {
214+
for cte in &with.cte_tables {
215+
let cte_name = DisplayIndent(&cte.alias.name).to_string();
216+
let mut cte_ts_query = TsQuery::new(cte_name.clone());
217+
translate_query(&mut cte_ts_query, &None, &cte.query, db_conn, None, true).await?;
218+
219+
// Merge parameters from the CTE body into the outer query's params.
220+
// Parameters inside CTE bodies ($1, $2, ?) belong to the overall query's parameter list.
221+
for (idx, types) in &cte_ts_query.params {
222+
ts_query.params.insert(*idx, types.clone());
223+
}
224+
225+
// Extract the CTE's output columns and register them as virtual table columns
226+
// so the outer query can look them up just like table-valued function columns
227+
let cte_columns: HashMap<String, TsFieldType> = cte_ts_query
228+
.result
229+
.into_iter()
230+
.map(|(col_name, types)| {
231+
// Take the first non-null type; fall back to Any if only Null is present
232+
let ts_type = types
233+
.into_iter()
234+
.find(|t| *t != TsFieldType::Null)
235+
.unwrap_or(TsFieldType::Any);
236+
(col_name, ts_type)
237+
})
238+
.collect();
239+
240+
ts_query.table_valued_function_columns.insert(cte_name, cte_columns);
241+
}
242+
}
243+
208244
let body = *query.body.clone();
209245
match body {
210246
SetExpr::Select(select) => {

tests/demo/cte/cte.queries.ts

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
export type SimpleCteParams = [];
2+
3+
export interface ISimpleCteResult {
4+
id: number;
5+
name: string;
6+
}
7+
8+
export interface ISimpleCteQuery {
9+
params: SimpleCteParams;
10+
result: ISimpleCteResult;
11+
}
12+
13+
export type RankWithCteParams = [];
14+
15+
export interface IRankWithCteResult {
16+
id: number;
17+
name: string;
18+
rk: any;
19+
}
20+
21+
export interface IRankWithCteQuery {
22+
params: RankWithCteParams;
23+
result: IRankWithCteResult;
24+
}
25+
26+
export type MultipleCtesParams = [];
27+
28+
export interface IMultipleCtesResult {
29+
id: number;
30+
name: string;
31+
}
32+
33+
export interface IMultipleCtesQuery {
34+
params: MultipleCtesParams;
35+
result: IMultipleCtesResult;
36+
}
37+
38+
export type CteWithWildcardParams = [];
39+
40+
export interface ICteWithWildcardResult {
41+
id: number;
42+
name: string;
43+
}
44+
45+
export interface ICteWithWildcardQuery {
46+
params: CteWithWildcardParams;
47+
result: ICteWithWildcardResult;
48+
}

tests/demo/cte/cte.snapshot.ts

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
export type SimpleCteParams = [];
2+
3+
export interface ISimpleCteResult {
4+
id: number;
5+
name: string;
6+
}
7+
8+
export interface ISimpleCteQuery {
9+
params: SimpleCteParams;
10+
result: ISimpleCteResult;
11+
}
12+
13+
export type RankWithCteParams = [];
14+
15+
export interface IRankWithCteResult {
16+
id: number;
17+
name: string;
18+
rk: any;
19+
}
20+
21+
export interface IRankWithCteQuery {
22+
params: RankWithCteParams;
23+
result: IRankWithCteResult;
24+
}
25+
26+
export type MultipleCtesParams = [];
27+
28+
export interface IMultipleCtesResult {
29+
id: number;
30+
name: string;
31+
}
32+
33+
export interface IMultipleCtesQuery {
34+
params: MultipleCtesParams;
35+
result: IMultipleCtesResult;
36+
}
37+
38+
export type CteWithWildcardParams = [];
39+
40+
export interface ICteWithWildcardResult {
41+
id: number;
42+
name: string;
43+
}
44+
45+
export interface ICteWithWildcardQuery {
46+
params: CteWithWildcardParams;
47+
result: ICteWithWildcardResult;
48+
}

tests/demo/cte/cte.ts

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import { sql } from 'sqlx-ts'
2+
3+
// Simple CTE with explicit column selection
4+
const simpleCte = sql`
5+
-- @name: simple cte
6+
WITH filtered_items AS (
7+
SELECT id, name FROM items
8+
)
9+
SELECT id, name FROM filtered_items
10+
`
11+
12+
// CTE with window function (RANK) — original issue #104
13+
const rankWithCte = sql`
14+
-- @name: rank with cte
15+
WITH ranked_items AS (
16+
SELECT
17+
id,
18+
name,
19+
rarity,
20+
RANK() OVER (PARTITION BY rarity ORDER BY id) AS rk
21+
FROM items
22+
)
23+
SELECT id, name, rk FROM ranked_items WHERE rk = 1
24+
`
25+
26+
// Multiple CTEs
27+
const multipleCtes = sql`
28+
-- @name: multiple ctes
29+
WITH
30+
popular AS (
31+
SELECT id, name FROM items WHERE id > 10
32+
),
33+
rare AS (
34+
SELECT id, name FROM items WHERE rarity = 'rare'
35+
)
36+
SELECT id, name FROM popular
37+
`
38+
39+
// CTE with wildcard in outer query
40+
const cteWithWildcard = sql`
41+
-- @name: cte with wildcard
42+
WITH base AS (
43+
SELECT id, name FROM items
44+
)
45+
SELECT * FROM base
46+
`

0 commit comments

Comments
 (0)