Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/ts_generator/sql_parser/expressions/translate_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ pub async fn get_sql_query_param(
single_table_name: &Option<&str>,
table_with_joins: &Option<Vec<TableWithJoins>>,
db_conn: &DBConn,
cte_columns: &std::collections::HashMap<String, std::collections::HashMap<String, TsFieldType>>,
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameter is named cte_columns, but call sites pass &ts_query.table_valued_function_columns, and translate_query also stores CTE columns in table_valued_function_columns. This naming mismatch makes the API harder to understand and maintain. Consider renaming the parameter (and related comments/locals) to something like virtual_table_columns (or splitting CTE vs TVF maps if they’re intended to be distinct concepts).

Copilot uses AI. Check for mistakes.
) -> Result<Option<(TsFieldType, bool, Option<String>)>, TsGeneratorError> {
let table_name: Option<String>;

Expand All @@ -145,6 +146,15 @@ pub async fn get_sql_query_param(

match (column_name, expr_placeholder, table_name) {
(Some(column_name), Some(expr_placeholder), Some(table_name)) => {
// First check if the table is a CTE or table-valued function
if let Some(cte_table_columns) = cte_columns.get(table_name.as_str()) {
if let Some(ts_type) = cte_table_columns.get(column_name.as_str()) {
return Ok(Some((ts_type.clone(), false, Some(expr_placeholder))));
}
// Column not found in CTE columns — return None to allow fallback handling
return Ok(None);
Comment thread
JasonShin marked this conversation as resolved.
}

let table_names = vec![table_name.as_str()];
let columns = DB_SCHEMA
.lock()
Expand Down Expand Up @@ -312,7 +322,15 @@ pub async fn translate_expr(
// OPERATORS START //
/////////////////////
Expr::BinaryOp { left, op: _, right } => {
let param = get_sql_query_param(left, right, single_table_name, table_with_joins, db_conn).await?;
let param = get_sql_query_param(
left,
right,
single_table_name,
table_with_joins,
db_conn,
&ts_query.table_valued_function_columns,
)
.await?;
if let Some((value, is_nullable, index)) = param {
let _ = ts_query.insert_param(&value, &is_nullable, &index);
Ok(())
Expand Down Expand Up @@ -354,6 +372,7 @@ pub async fn translate_expr(
single_table_name,
table_with_joins,
db_conn,
&ts_query.table_valued_function_columns,
)
.await?;

Expand Down Expand Up @@ -383,8 +402,24 @@ pub async fn translate_expr(
low,
high,
} => {
let low = get_sql_query_param(expr, low, single_table_name, table_with_joins, db_conn).await?;
let high = get_sql_query_param(expr, high, single_table_name, table_with_joins, db_conn).await?;
let low = get_sql_query_param(
expr,
low,
single_table_name,
table_with_joins,
db_conn,
&ts_query.table_valued_function_columns,
)
.await?;
let high = get_sql_query_param(
expr,
high,
single_table_name,
table_with_joins,
db_conn,
&ts_query.table_valued_function_columns,
)
.await?;
if let Some((value, is_nullable, placeholder)) = low {
ts_query.insert_param(&value, &is_nullable, &placeholder)?;
}
Expand Down
19 changes: 15 additions & 4 deletions src/ts_generator/sql_parser/expressions/translate_wildcard_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,25 @@ pub async fn translate_wildcard_expr(
ts_query: &mut TsQuery,
db_conn: &DBConn,
) -> Result<(), TsGeneratorError> {
let table_with_joins = get_all_table_names_from_select(select)?;
let table_names = get_all_table_names_from_select(select)?;

if table_with_joins.len() > 1 {
// Check if the table is a CTE or table-valued function registered in table_valued_function_columns.
// CTEs are processed before the main query body and their columns are stored there.
for table_name in &table_names {
if let Some(tvf_columns) = ts_query.table_valued_function_columns.get(table_name).cloned() {
for (col_name, ts_type) in tvf_columns {
ts_query.result.insert(col_name, vec![ts_type]);
}
return Ok(());
}
}

if table_names.len() > 1 {
warning!("Impossible to calculate appropriate field names of a wildcard query with multiple tables. Please use explicit field names instead. Query: {}", select.to_string());
}
Comment on lines +71 to 82
Copy link

Copilot AI Feb 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns early as soon as any table name matches a registered virtual table, which drops columns from other tables in FROM (CTE + real table, or multiple CTEs). A safer behavior is: only short-circuit when table_names.len() == 1, otherwise merge virtual-table columns into the result and continue with the existing multi-table handling (including schema fetch + warning).

Copilot uses AI. Check for mistakes.

let table_with_joins = table_with_joins.iter().map(|s| s.as_ref()).collect();
let all_fields = DB_SCHEMA.lock().await.fetch_table(&table_with_joins, db_conn).await;
let table_refs = table_names.iter().map(|s| s.as_ref()).collect();
let all_fields = DB_SCHEMA.lock().await.fetch_table(&table_refs, db_conn).await;

if let Some(all_fields) = all_fields {
for key in all_fields.keys() {
Expand Down
36 changes: 36 additions & 0 deletions src/ts_generator/sql_parser/translate_query.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use async_recursion::async_recursion;
use sqlparser::ast::{FunctionArg, FunctionArgExpr, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins};
use std::collections::HashMap;

Expand Down Expand Up @@ -195,6 +196,7 @@ pub async fn translate_select(
}

/// Translates a query and workout ts_query's results and params
#[async_recursion]
pub async fn translate_query(
ts_query: &mut TsQuery,
// this parameter is used to stack table_with_joins while recursing through subqueries
Expand All @@ -205,6 +207,40 @@ pub async fn translate_query(
alias: Option<&str>,
is_selection: bool,
) -> Result<(), TsGeneratorError> {
// Process CTEs (WITH clause) before the main query body.
// Each CTE is processed to extract its output columns, which are then registered
// as virtual table columns so the main query can reference them.
if let Some(with) = &query.with {
for cte in &with.cte_tables {
let cte_name = DisplayIndent(&cte.alias.name).to_string();
let mut cte_ts_query = TsQuery::new(cte_name.clone());
translate_query(&mut cte_ts_query, &None, &cte.query, db_conn, None, true).await?;

// Merge parameters from the CTE body into the outer query's params.
// Parameters inside CTE bodies ($1, $2, ?) belong to the overall query's parameter list.
for (idx, types) in &cte_ts_query.params {
ts_query.params.insert(*idx, types.clone());
Comment thread
JasonShin marked this conversation as resolved.
}

// Extract the CTE's output columns and register them as virtual table columns
// so the outer query can look them up just like table-valued function columns
let cte_columns: HashMap<String, TsFieldType> = cte_ts_query
.result
.into_iter()
.map(|(col_name, types)| {
// Take the first non-null type; fall back to Any if only Null is present
let ts_type = types
.into_iter()
.find(|t| *t != TsFieldType::Null)
.unwrap_or(TsFieldType::Any);
(col_name, ts_type)
})
.collect();

ts_query.table_valued_function_columns.insert(cte_name, cte_columns);
}
}

let body = *query.body.clone();
match body {
SetExpr::Select(select) => {
Expand Down
48 changes: 48 additions & 0 deletions tests/demo/cte/cte.queries.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
export type SimpleCteParams = [];

export interface ISimpleCteResult {
id: number;
name: string;
}

export interface ISimpleCteQuery {
params: SimpleCteParams;
result: ISimpleCteResult;
}

export type RankWithCteParams = [];

export interface IRankWithCteResult {
id: number;
name: string;
rk: any;
}

export interface IRankWithCteQuery {
params: RankWithCteParams;
result: IRankWithCteResult;
}

export type MultipleCtesParams = [];

export interface IMultipleCtesResult {
id: number;
name: string;
}

export interface IMultipleCtesQuery {
params: MultipleCtesParams;
result: IMultipleCtesResult;
}

export type CteWithWildcardParams = [];

export interface ICteWithWildcardResult {
id: number;
name: string;
}

export interface ICteWithWildcardQuery {
params: CteWithWildcardParams;
result: ICteWithWildcardResult;
}
48 changes: 48 additions & 0 deletions tests/demo/cte/cte.snapshot.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
export type SimpleCteParams = [];

export interface ISimpleCteResult {
id: number;
name: string;
}

export interface ISimpleCteQuery {
params: SimpleCteParams;
result: ISimpleCteResult;
}

export type RankWithCteParams = [];

export interface IRankWithCteResult {
id: number;
name: string;
rk: any;
}

export interface IRankWithCteQuery {
params: RankWithCteParams;
result: IRankWithCteResult;
}

export type MultipleCtesParams = [];

export interface IMultipleCtesResult {
id: number;
name: string;
}

export interface IMultipleCtesQuery {
params: MultipleCtesParams;
result: IMultipleCtesResult;
}

export type CteWithWildcardParams = [];

export interface ICteWithWildcardResult {
id: number;
name: string;
}

export interface ICteWithWildcardQuery {
params: CteWithWildcardParams;
result: ICteWithWildcardResult;
}
46 changes: 46 additions & 0 deletions tests/demo/cte/cte.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { sql } from 'sqlx-ts'

// Simple CTE with explicit column selection
const simpleCte = sql`
-- @name: simple cte
WITH filtered_items AS (
SELECT id, name FROM items
)
SELECT id, name FROM filtered_items
`

// CTE with window function (RANK) — original issue #104
const rankWithCte = sql`
-- @name: rank with cte
WITH ranked_items AS (
SELECT
id,
name,
rarity,
RANK() OVER (PARTITION BY rarity ORDER BY id) AS rk
FROM items
)
SELECT id, name, rk FROM ranked_items WHERE rk = 1
`

// Multiple CTEs
const multipleCtes = sql`
-- @name: multiple ctes
WITH
popular AS (
SELECT id, name FROM items WHERE id > 10
),
rare AS (
SELECT id, name FROM items WHERE rarity = 'rare'
)
SELECT id, name FROM popular
`

// CTE with wildcard in outer query
const cteWithWildcard = sql`
-- @name: cte with wildcard
WITH base AS (
SELECT id, name FROM items
)
SELECT * FROM base
`
Loading