Skip to content

Commit 8d09eb8

Browse files
authored
bidirectional params fix (#268)
* fix * fmt * fix * add tests cases
1 parent 4857691 commit 8d09eb8

File tree

7 files changed

+572
-33
lines changed

7 files changed

+572
-33
lines changed

src/ts_generator/sql_parser/expressions/translate_expr.rs

Lines changed: 90 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -123,31 +123,53 @@ pub fn translate_column_name_assignment(assignment: &Assignment) -> Option<Strin
123123
///
124124
/// some_field = $1
125125
/// some_table.some_field = $1
126+
///
127+
/// Also handles the reversed case where the placeholder is on the left:
128+
/// ? >= some_field
129+
/// $1 >= some_table.some_field
126130
pub async fn get_sql_query_param(
127131
left: &Expr,
128132
right: &Expr,
129133
single_table_name: &Option<&str>,
130134
table_with_joins: &Option<Vec<TableWithJoins>>,
131135
db_conn: &DBConn,
132136
cte_columns: &std::collections::HashMap<String, std::collections::HashMap<String, TsFieldType>>,
137+
) -> Result<Option<(TsFieldType, bool, Option<String>)>, TsGeneratorError> {
138+
// Try the standard order first: left=column, right=placeholder
139+
let result =
140+
get_sql_query_param_directed(left, right, single_table_name, table_with_joins, db_conn, cte_columns).await?;
141+
if result.is_some() {
142+
return Ok(result);
143+
}
144+
145+
// Try the reversed order: left=placeholder, right=column
146+
get_sql_query_param_directed(right, left, single_table_name, table_with_joins, db_conn, cte_columns).await
147+
}
148+
149+
/// Internal helper that checks a specific direction: column_expr for column name, placeholder_expr for placeholder
150+
async fn get_sql_query_param_directed(
151+
column_expr: &Expr,
152+
placeholder_expr: &Expr,
153+
single_table_name: &Option<&str>,
154+
table_with_joins: &Option<Vec<TableWithJoins>>,
155+
db_conn: &DBConn,
156+
cte_columns: &std::collections::HashMap<String, std::collections::HashMap<String, TsFieldType>>,
133157
) -> Result<Option<(TsFieldType, bool, Option<String>)>, TsGeneratorError> {
134158
let table_name: Option<String>;
135159

136160
if table_with_joins.is_some() {
137-
table_name = translate_table_from_expr(table_with_joins, &left.clone()).ok();
161+
table_name = translate_table_from_expr(table_with_joins, &column_expr.clone()).ok();
138162
} else if single_table_name.is_some() {
139163
table_name = single_table_name.map(|x| x.to_string());
140164
} else {
141-
return Err(TsGeneratorError::TableNameInferenceFailedInWhere {
142-
query: left.to_string(),
143-
});
165+
return Ok(None);
144166
}
145167

146-
let column_name = translate_column_name_expr(left);
168+
let column_name = translate_column_name_expr(column_expr);
147169

148-
// If the right side of the expression is a placeholder `?` or `$n`
170+
// If the placeholder side of the expression is a placeholder `?` or `$n`
149171
// they are valid query parameter to process
150-
let expr_placeholder = get_expr_placeholder(right);
172+
let expr_placeholder = get_expr_placeholder(placeholder_expr);
151173

152174
match (column_name, expr_placeholder, table_name) {
153175
(Some(column_name), Some(expr_placeholder), Some(table_name)) => {
@@ -407,30 +429,67 @@ pub async fn translate_expr(
407429
low,
408430
high,
409431
} => {
410-
let low = get_sql_query_param(
411-
expr,
412-
low,
413-
single_table_name,
414-
table_with_joins,
415-
db_conn,
416-
&ts_query.table_valued_function_columns,
417-
)
418-
.await?;
419-
let high = get_sql_query_param(
420-
expr,
421-
high,
422-
single_table_name,
423-
table_with_joins,
424-
db_conn,
425-
&ts_query.table_valued_function_columns,
426-
)
427-
.await?;
428-
if let Some((value, is_nullable, placeholder)) = low {
429-
ts_query.insert_param(&value, &is_nullable, &placeholder)?;
430-
}
431-
432-
if let Some((value, is_nullable, placeholder)) = high {
433-
ts_query.insert_param(&value, &is_nullable, &placeholder)?;
432+
// BETWEEN has two forms:
433+
// 1. `column BETWEEN ? AND ?` — expr is column, low/high are placeholders
434+
// 2. `? BETWEEN low_col AND high_col` — expr is placeholder, low/high are columns
435+
let expr_is_placeholder = get_expr_placeholder(expr).is_some();
436+
437+
if expr_is_placeholder {
438+
// Case 2: `? BETWEEN low_col AND high_col`
439+
// The placeholder is the expr itself, infer its type from low (or high) column
440+
let result = get_sql_query_param_directed(
441+
low,
442+
expr,
443+
single_table_name,
444+
table_with_joins,
445+
db_conn,
446+
&ts_query.table_valued_function_columns,
447+
)
448+
.await?;
449+
if let Some((value, is_nullable, placeholder)) = result {
450+
ts_query.insert_param(&value, &is_nullable, &placeholder)?;
451+
} else {
452+
// Fallback: try high column
453+
let result = get_sql_query_param_directed(
454+
high,
455+
expr,
456+
single_table_name,
457+
table_with_joins,
458+
db_conn,
459+
&ts_query.table_valued_function_columns,
460+
)
461+
.await?;
462+
if let Some((value, is_nullable, placeholder)) = result {
463+
ts_query.insert_param(&value, &is_nullable, &placeholder)?;
464+
}
465+
}
466+
} else {
467+
// Case 1: `column BETWEEN ? AND ?`
468+
// The expr is a column, low and high may be placeholders
469+
let low_result = get_sql_query_param_directed(
470+
expr,
471+
low,
472+
single_table_name,
473+
table_with_joins,
474+
db_conn,
475+
&ts_query.table_valued_function_columns,
476+
)
477+
.await?;
478+
let high_result = get_sql_query_param_directed(
479+
expr,
480+
high,
481+
single_table_name,
482+
table_with_joins,
483+
db_conn,
484+
&ts_query.table_valued_function_columns,
485+
)
486+
.await?;
487+
if let Some((value, is_nullable, placeholder)) = low_result {
488+
ts_query.insert_param(&value, &is_nullable, &placeholder)?;
489+
}
490+
if let Some((value, is_nullable, placeholder)) = high_result {
491+
ts_query.insert_param(&value, &is_nullable, &placeholder)?;
492+
}
434493
}
435494
Ok(())
436495
}

test-utils/src/sandbox.rs

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,139 @@ impl TestConfig {
9696
}
9797
}
9898

99+
/// Checks if the MySQL server at the given host:port is above the specified major.minor version.
100+
/// Uses `docker exec` to query the version. Returns true if version cannot be determined.
101+
pub fn is_mysql_version_above(db_host: &str, db_port: i32, major: u32, minor: u32) -> bool {
102+
let output = std::process::Command::new("docker")
103+
.args(["exec", "sqlx-ts-mysql-1", "mysql", "-u", "root", "-N", "-e", "SELECT VERSION();"])
104+
.output();
105+
match output {
106+
Ok(out) => {
107+
let version = String::from_utf8_lossy(&out.stdout);
108+
let version = version.trim();
109+
let parts: Vec<&str> = version.split('.').collect();
110+
if parts.len() >= 2 {
111+
let srv_major: u32 = parts[0].parse().unwrap_or(0);
112+
let srv_minor: u32 = parts[1].parse().unwrap_or(0);
113+
srv_major > major || (srv_major == major && srv_minor > minor)
114+
} else {
115+
true
116+
}
117+
}
118+
Err(_) => true,
119+
}
120+
}
121+
99122
#[macro_export]
100123
macro_rules! run_test {
124+
// Arm with minimum MySQL version requirement: (major, minor)
125+
($($name: ident, $test_config: expr, $ts_content: expr, $generated_types: expr, min_mysql: ($maj:expr, $min:expr))*) => {
126+
$(
127+
#[test]
128+
fn $name() -> Result<(), Box<dyn std::error::Error>> {
129+
use assert_cmd::cargo::cargo_bin_cmd;
130+
let ts_content = $ts_content;
131+
let test_config: TestConfig = $test_config;
132+
133+
// Check minimum MySQL version requirement
134+
if test_config.db_type == "mysql" {
135+
if !test_utils::sandbox::is_mysql_version_above(&test_config.db_host, test_config.db_port, $maj, $min) {
136+
eprintln!("Skipping test {}: requires MySQL > {}.{}", stringify!($name), $maj, $min);
137+
return Ok(());
138+
}
139+
}
140+
141+
println!("checking test config {:?}", test_config);
142+
let file_extension = test_config.file_extension;
143+
let db_type = test_config.db_type;
144+
let db_host = test_config.db_host;
145+
let db_port = test_config.db_port;
146+
let db_user = test_config.db_user;
147+
let db_pass = test_config.db_pass;
148+
let db_name = test_config.db_name;
149+
let config_file_name = test_config.config_file_name;
150+
let generate_path = test_config.generate_path;
151+
152+
// SETUP
153+
let dir = tempdir()?;
154+
let parent_path = dir.path();
155+
let file_path = parent_path.join(format!("index.{file_extension}"));
156+
157+
let mut temp_file = fs::File::create(&file_path)?;
158+
writeln!(temp_file, "{}", ts_content)?;
159+
let file_result = fs::read_to_string(&file_path)?;
160+
161+
// EXECUTE
162+
let mut cmd = cargo_bin_cmd!("sqlx-ts");
163+
164+
cmd.arg(parent_path.to_str().unwrap())
165+
.arg(format!("--ext={file_extension}"))
166+
.arg(format!("--db-type={db_type}"))
167+
.arg(format!("--db-host={db_host}"))
168+
.arg(format!("--db-port={db_port}"))
169+
.arg(format!("--db-user={db_user}"))
170+
.arg(format!("--db-name={db_name}"));
171+
172+
if &generate_path.is_some() == &true {
173+
let generate_path = generate_path.clone();
174+
let generate_path = generate_path.unwrap();
175+
let generate_path = generate_path.as_path();
176+
let generate_path = parent_path.join(generate_path);
177+
let generate_path = generate_path.display();
178+
cmd.arg(format!("--generate-path={generate_path}"));
179+
}
180+
181+
if (test_config.generate_types) {
182+
cmd.arg("-g");
183+
}
184+
185+
if (config_file_name.is_some()) {
186+
let cwd = env::current_dir()?;
187+
let config_file_name = format!("{}", config_file_name.unwrap());
188+
let config_path = cwd.join(format!("tests/configs/{config_file_name}"));
189+
let config_path = config_path.display();
190+
cmd.arg(format!("--config={config_path}"));
191+
}
192+
193+
if (db_pass.is_some()) {
194+
let db_pass = db_pass.unwrap();
195+
cmd.arg(format!("--db-pass={db_pass}"));
196+
} else {
197+
cmd.arg("--db-pass=");
198+
}
199+
200+
cmd.assert()
201+
.success()
202+
.stdout(predicates::str::contains("No SQL errors detected!"));
203+
204+
let generated_types: &str = $generated_types.clone();
205+
206+
if generate_path.is_some() {
207+
let generate_path = parent_path.join(generate_path.unwrap().as_path());
208+
let type_file = fs::read_to_string(generate_path);
209+
let type_file = type_file.unwrap();
210+
211+
assert_eq!(
212+
generated_types.trim().to_string().flatten(),
213+
type_file.trim().to_string().flatten()
214+
);
215+
return Ok(());
216+
}
217+
218+
let type_file = fs::read_to_string(parent_path.join("index.queries.ts"));
219+
if type_file.is_ok() {
220+
let type_file = type_file.unwrap().clone();
221+
let type_file = type_file.trim();
222+
assert_eq!(
223+
generated_types.trim().to_string().flatten(),
224+
type_file.to_string().flatten()
225+
);
226+
}
227+
Ok(())
228+
}
229+
)*};
230+
231+
// Original arm without version requirement
101232
($($name: ident, $test_config: expr, $ts_content: expr, $generated_types: expr)*) => {
102233
$(
103234
// MACRO STARTS
@@ -117,7 +248,7 @@ $(
117248
let db_name = test_config.db_name;
118249
let config_file_name = test_config.config_file_name;
119250
let generate_path = test_config.generate_path;
120-
251+
121252
// SETUP
122253
let dir = tempdir()?;
123254
let parent_path = dir.path();
@@ -126,7 +257,7 @@ $(
126257
let mut temp_file = fs::File::create(&file_path)?;
127258
writeln!(temp_file, "{}", ts_content)?;
128259
let file_result = fs::read_to_string(&file_path)?;
129-
260+
130261
// EXECUTE
131262
let mut cmd = cargo_bin_cmd!("sqlx-ts");
132263

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
export type PlaceholderBeforeComparisonParams = [number | null, number | null];
2+
3+
export interface IPlaceholderBeforeComparisonResult {
4+
character_id: number | null;
5+
id: number;
6+
quantity: number | null;
7+
}
8+
9+
export interface IPlaceholderBeforeComparisonQuery {
10+
params: PlaceholderBeforeComparisonParams;
11+
result: IPlaceholderBeforeComparisonResult;
12+
}
13+
14+
export type PlaceholderBetweenExprParams = [number | null];
15+
16+
export interface IPlaceholderBetweenExprResult {
17+
character_id: number | null;
18+
id: number;
19+
quantity: number | null;
20+
}
21+
22+
export interface IPlaceholderBetweenExprQuery {
23+
params: PlaceholderBetweenExprParams;
24+
result: IPlaceholderBetweenExprResult;
25+
}
26+
27+
export type BetweenPlaceholderBoundsParams = [number | null, number | null];
28+
29+
export interface IBetweenPlaceholderBoundsResult {
30+
character_id: number | null;
31+
id: number;
32+
quantity: number | null;
33+
}
34+
35+
export interface IBetweenPlaceholderBoundsQuery {
36+
params: BetweenPlaceholderBoundsParams;
37+
result: IBetweenPlaceholderBoundsResult;
38+
}
39+
40+
export type MixedParamPositionsParams = [number | null, number | null];
41+
42+
export interface IMixedParamPositionsResult {
43+
character_id: number | null;
44+
id: number;
45+
quantity: number | null;
46+
}
47+
48+
export interface IMixedParamPositionsQuery {
49+
params: MixedParamPositionsParams;
50+
result: IMixedParamPositionsResult;
51+
}

0 commit comments

Comments
 (0)