Skip to content
Merged
9 changes: 9 additions & 0 deletions src/core/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::common::lazy::CONFIG;
use crate::common::types::DatabaseType;
use crate::common::SQL;
use crate::core::mysql::prepare as mysql_explain;
use crate::core::postgres::prepare as postgres_explain;
Expand Down Expand Up @@ -34,6 +35,14 @@ impl DBConn {

Ok((explain_failed, ts_query))
}

/// Get the database type for this connection
pub fn get_db_type(&self) -> DatabaseType {
match self {
DBConn::MySQLPooledConn(_) => DatabaseType::Mysql,
DBConn::PostgresConn(_) => DatabaseType::Postgres,
}
}
}

pub struct DBConnections<'a> {
Expand Down
14 changes: 11 additions & 3 deletions src/ts_generator/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ use crate::ts_generator::annotations::extract_result_annotations;
use crate::ts_generator::sql_parser::translate_stmt::translate_stmt;
use crate::ts_generator::types::ts_query::TsQuery;

use crate::common::types::DatabaseType;
use color_eyre::eyre::eyre;
use color_eyre::eyre::Result;
use convert_case::{Case, Casing};
use regex::Regex;
use sqlparser::{dialect::GenericDialect, parser::Parser};
use sqlparser::{
dialect::{Dialect, MySqlDialect, PostgreSqlDialect},
parser::Parser,
};

use super::errors::TsGeneratorError;

Expand Down Expand Up @@ -117,9 +121,13 @@ pub fn clear_single_ts_file_if_exists() -> Result<()> {
}

pub async fn generate_ts_interface(sql: &SQL, db_conn: &DBConn) -> Result<TsQuery> {
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
// Use the appropriate SQL dialect based on the database type
let dialect: Box<dyn Dialect> = match db_conn.get_db_type() {
DatabaseType::Postgres => Box::new(PostgreSqlDialect {}),
DatabaseType::Mysql => Box::new(MySqlDialect {}),
};

let sql_ast = Parser::parse_sql(&dialect, &sql.query)?;
let sql_ast = Parser::parse_sql(&*dialect, &sql.query)?;
let mut ts_query = TsQuery::new(get_query_name(sql)?);

let annotated_result_types = extract_result_annotations(sql.query.as_str());
Expand Down
133 changes: 101 additions & 32 deletions src/ts_generator/sql_parser/expressions/translate_expr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::functions::{is_date_function, is_numeric_function, is_type_polymorphic_function};
use crate::common::lazy::DB_SCHEMA;
use crate::common::logger::warning;
use crate::common::logger::{error, warning};
use crate::core::connection::DBConn;
use crate::ts_generator::errors::TsGeneratorError;
use crate::ts_generator::sql_parser::expressions::translate_data_type::translate_value;
Expand Down Expand Up @@ -180,20 +180,46 @@ pub async fn translate_expr(
Expr::Identifier(ident) => {
let column_name = DisplayIndent(ident).to_string();
let table_name = single_table_name.expect("Missing table name for identifier");

// First check if this is a table-valued function column
if let Some(tvf_columns) = ts_query.table_valued_function_columns.get(table_name) {
if let Some(ts_type) = tvf_columns.get(&column_name) {
let field_name = alias.unwrap_or(column_name.as_str());
ts_query.insert_result(
Some(field_name),
&[ts_type.to_owned()],
is_selection,
false, // Table-valued function columns are not nullable by default
expr_for_logging,
)?;
return Ok(());
}
}

// Fall back to database schema
let table_details = &DB_SCHEMA.lock().await.fetch_table(&vec![table_name], db_conn).await;

// TODO: We can also memoize this method
if let Some(table_details) = table_details {
let field = table_details.get(&column_name).unwrap();

let field_name = alias.unwrap_or(column_name.as_str());
ts_query.insert_result(
Some(field_name),
&[field.field_type.to_owned()],
is_selection,
field.is_nullable,
expr_for_logging,
)?
if let Some(field) = table_details.get(&column_name) {
Copy link

Copilot AI Nov 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The nested if-let chain (lines 202-217) should use if-let-else chaining or early returns to reduce nesting depth and improve readability.

Copilot uses AI. Check for mistakes.
let field_name = alias.unwrap_or(column_name.as_str());
ts_query.insert_result(
Some(field_name),
&[field.field_type.to_owned()],
is_selection,
field.is_nullable,
expr_for_logging,
)?
} else {
error!(
"Column '{}' not found in table '{}'. If '{}' is a table-valued function, verify that the column is defined in its alias. Otherwise, the column may not exist in the table.",
column_name, table_name, table_name
);
}
} else {
error!(
"Table '{}' not found in schema. This may be a table-valued function.",
Comment thread
JasonShin marked this conversation as resolved.
table_name
);
}
Ok(())
}
Expand All @@ -203,32 +229,67 @@ pub async fn translate_expr(

let table_name = translate_table_from_expr(table_with_joins, expr)?;

// First check if this is a table-valued function column
if let Some(tvf_columns) = ts_query.table_valued_function_columns.get(&table_name) {
if let Some(ts_type) = tvf_columns.get(&ident) {
// if the select item is a compound identifier and does not has an alias, we should use `table_name.ident` as the key name
let key_name = format!("{table_name}_{ident}");
let key_name = &alias.unwrap_or_else(|| {
warning!(
"Missing an alias for a compound identifier, using {} as the key name. Prefer adding an alias for example: `{} AS {}`",
key_name, expr, ident
);
key_name.as_str()
});

ts_query.insert_result(
Some(key_name),
&[ts_type.to_owned()],
is_selection,
false, // Table-valued function columns are not nullable by default
expr_for_logging,
)?;
return Ok(());
}
}

// Fall back to database schema
let table_details = &DB_SCHEMA
.lock()
.await
.fetch_table(&vec![table_name.as_str()], db_conn)
.await;

if let Some(table_details) = table_details {
let field = table_details.get(&ident).unwrap();

// if the select item is a compound identifier and does not has an alias, we should use `table_name.ident` as the key name
let key_name = format!("{table_name}_{ident}");
let key_name = &alias.unwrap_or_else(|| {
warning!(
"Missing an alias for a compound identifier, using {} as the key name. Prefer adding an alias for example: `{} AS {}`",
key_name, expr, ident
);
key_name.as_str()
});

ts_query.insert_result(
Some(key_name),
&[field.field_type.to_owned()],
is_selection,
field.is_nullable,
expr_for_logging,
)?;
if let Some(field) = table_details.get(&ident) {
Copy link

Copilot AI Nov 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The nested if-let chain (lines 263-292) mirrors the pattern at lines 202-223 with duplicated error messages. Consider extracting this logic into a helper function to reduce duplication.

Copilot uses AI. Check for mistakes.
// if the select item is a compound identifier and does not has an alias, we should use `table_name.ident` as the key name
let key_name = format!("{table_name}_{ident}");
let key_name = &alias.unwrap_or_else(|| {
warning!(
"Missing an alias for a compound identifier, using {} as the key name. Prefer adding an alias for example: `{} AS {}`",
key_name, expr, ident
);
key_name.as_str()
});

ts_query.insert_result(
Some(key_name),
&[field.field_type.to_owned()],
is_selection,
field.is_nullable,
expr_for_logging,
)?;
} else {
error!(
"Column '{}' not found in table '{}' for compound identifier '{}.{}'. This may be a table-valued function.",
Copy link

Copilot AI Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to previous comment, this error message suggests 'may be a table-valued function' but is reached after checking table-valued function columns. The message should be updated to reflect this is an unexpected condition.

Suggested change
"Column '{}' not found in table '{}' for compound identifier '{}.{}'. This may be a table-valued function.",
"Column '{}' not found in table '{}' for compound identifier '{}.{}'. This is unexpected: column not found in either table-valued function columns or database schema.",

Copilot uses AI. Check for mistakes.
ident, table_name, table_name, ident
);
}
} else {
error!(
"Table '{}' not found in schema for compound identifier '{}.{}'. This may be a table-valued function.",
table_name, table_name, ident
Comment thread
JasonShin marked this conversation as resolved.
);
}
}
Ok(())
Expand Down Expand Up @@ -359,7 +420,15 @@ pub async fn translate_expr(
if let Some(ts_field_type) = ts_field_type {
return ts_query.insert_result(alias, &[ts_field_type], is_selection, false, expr_for_logging);
}
ts_query.insert_param(&TsFieldType::Boolean, &false, &Some(placeholder.to_string()))
// For placeholders where we can't infer the type:
// - If we're in a WHERE clause (is_selection is false AND we have a table context), infer as Boolean
// - Otherwise, use Any for flexibility (e.g., for table-valued function arguments)
let inferred_type = if !is_selection && single_table_name.is_some() {
TsFieldType::Boolean
} else {
TsFieldType::Any
};
ts_query.insert_param(&inferred_type, &false, &Some(placeholder.to_string()))
}
Expr::JsonAccess { value: _, path: _ } => {
ts_query.insert_result(alias, &[TsFieldType::Any], is_selection, false, expr_for_logging)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,56 @@ use crate::ts_generator::sql_parser::quoted_strings::*;
use color_eyre::eyre::Result;
use sqlparser::ast::{Assignment, AssignmentTarget, Expr, Join, SelectItem, TableFactor, TableWithJoins};

/// Check if the given table name corresponds to a table-valued function alias
/// by examining the table_with_joins to see if it's a TableFactor::Function
#[allow(dead_code)]
pub fn is_table_function(table_name: &str, table_with_joins: &[TableWithJoins]) -> bool {
for twj in table_with_joins {
if let TableFactor::Function { alias: Some(alias), .. } = &twj.relation {
if DisplayTableAlias(alias).to_string() == table_name {
return true;
}
}
// Also check joins
for join in &twj.joins {
if let TableFactor::Function { alias: Some(alias), .. } = &join.relation {
if DisplayTableAlias(alias).to_string() == table_name {
return true;
}
}
}
}
false
}

Comment thread
JasonShin marked this conversation as resolved.
pub fn get_default_table(table_with_joins: &[TableWithJoins]) -> String {
table_with_joins
.first()
.and_then(|x| match &x.relation {
TableFactor::Table {
name,
alias: _,
args: _,
alias,
args,
with_hints: _,
version: _,
partitions: _,
with_ordinality: _,
json_path: _,
sample: _,
index_hints: _,
} => Some(DisplayObjectName(name).to_string()),
} => {
// If args is Some, it's a table-valued function (e.g., jsonb_to_recordset($1))
// In that case, use the alias name if available
if args.is_some() {
alias.as_ref().map(|a| DisplayTableAlias(a).to_string())
} else {
Some(DisplayObjectName(name).to_string())
}
}
TableFactor::Function { alias, .. } => {
// For LATERAL functions, use the alias name as the table name
alias.as_ref().map(|a| DisplayTableAlias(a).to_string())
}
_ => None,
})
.expect("The query does not have a default table, impossible to generate types")
Expand All @@ -41,7 +75,7 @@ pub fn find_table_name_from_identifier(
TableFactor::Table {
name,
alias,
args: _,
args,
with_hints: _,
version: _,
partitions: _,
Expand All @@ -50,11 +84,30 @@ pub fn find_table_name_from_identifier(
sample: _,
index_hints: _,
} => {
let alias = alias.clone().map(|alias| DisplayTableAlias(&alias).to_string());
let name = DisplayObjectName(name).to_string();
if Some(left.to_string()) == alias || left == name {
// If the identifier matches the alias, then return the table name
return Ok(name.to_owned());
let alias_str = alias.clone().map(|alias| DisplayTableAlias(&alias).to_string());
let name_str = DisplayObjectName(name).to_string();

// If this is a table-valued function (args is Some), use alias as the effective name
if args.is_some() {
if let Some(alias) = alias_str {
if left == alias {
return Ok(alias);
}
}
} else {
// Regular table
if Some(left.to_string()) == alias_str || left == name_str {
return Ok(name_str.to_owned());
}
}
}
TableFactor::Function { alias, .. } => {
// For LATERAL functions, the alias is the effective table name
if let Some(alias) = alias {
let alias_name = DisplayTableAlias(alias).to_string();
if left == alias_name {
return Ok(alias_name);
}
}
}
_ => {
Expand Down Expand Up @@ -89,6 +142,15 @@ pub fn find_table_name_from_identifier(
return Ok(name);
}
}
TableFactor::Function { alias, .. } => {
// For table-valued functions in joins, the alias is the effective table name
if let Some(alias) = alias {
let alias_name = DisplayTableAlias(alias).to_string();
if left == alias_name {
return Ok(alias_name);
}
}
}
_ => {
return Err(TsGeneratorError::TableFactorWhileProcessingTableWithJoins(
join.to_string(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ pub fn get_all_table_names_from_select(select: &Select) -> Result<Vec<String>, T
let name = DisplayObjectName(&name).to_string();
Ok(name)
}
TableFactor::Function { .. } => {
// Wildcard queries with table-valued functions are not supported
// because we cannot query the database schema for function result types
Err(TsGeneratorError::WildcardStatementUnsupportedTableExpr(
select.to_string(),
))
}
_ => Err(TsGeneratorError::WildcardStatementUnsupportedTableExpr(
select.to_string(),
)),
Expand Down
Loading
Loading