Skip to content

Commit 6cc978f

Browse files
authored
fix(formatter): properly handle child statements (#678)
1 parent d89118d commit 6cc978f

3 files changed

Lines changed: 177 additions & 29 deletions

File tree

crates/pgls_workspace/src/workspace/server.rs

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,34 @@ impl Workspace for WorkspaceServer {
889889
let mut formatted_output = String::new();
890890
let path_str = params.path.as_path().display().to_string();
891891

892-
for (_id, stmt_range, text, ast_result) in doc.iter(FormatStatementMapper) {
892+
let format_candidates: Vec<_> = doc.iter(FormatStatementMapper).collect();
893+
let mut formatted_sql_fn_bodies = HashMap::new();
894+
895+
for (id, _stmt_range, _text, ast_result) in &format_candidates {
896+
if !id.is_child() {
897+
continue;
898+
}
899+
900+
let Some(parent_id) = id.parent() else {
901+
continue;
902+
};
903+
904+
let Ok(ast) = ast_result else {
905+
continue;
906+
};
907+
908+
let Ok(result) = pgls_pretty_print::format_statement(ast, &config) else {
909+
continue;
910+
};
911+
912+
formatted_sql_fn_bodies.insert(parent_id, result.formatted);
913+
}
914+
915+
for (id, stmt_range, text, ast_result) in format_candidates {
916+
if id.is_child() {
917+
continue;
918+
}
919+
893920
if let Some(filter_range) = params.range
894921
&& stmt_range.intersect(filter_range).is_none()
895922
{
@@ -901,35 +928,42 @@ impl Workspace for WorkspaceServer {
901928
}
902929

903930
match ast_result {
904-
Ok(ast) => match pgls_pretty_print::format_statement(&ast, &config) {
905-
Ok(result) => {
906-
if text != result.formatted {
907-
statements.push(StatementFormatResult {
908-
original: text.clone(),
909-
formatted: result.formatted.clone(),
910-
range: stmt_range,
911-
});
912-
}
913-
if !formatted_output.is_empty() {
914-
formatted_output.push_str("\n\n");
915-
}
916-
formatted_output.push_str(&result.formatted);
931+
Ok(ast) => {
932+
let mut ast = ast;
933+
if let Some(formatted_sql_fn_body) = formatted_sql_fn_bodies.get(&id) {
934+
sql_function::set_sql_fn_body(&mut ast, formatted_sql_fn_body);
917935
}
918-
Err(err) => {
919-
diagnostics.push(SDiagnostic::new(
920-
pgls_diagnostics::Error::from(WorkspaceError::format_error(
921-
err.to_string(),
922-
))
923-
.with_file_path(&path_str)
924-
.with_file_span(stmt_range),
925-
));
926-
927-
if !formatted_output.is_empty() {
928-
formatted_output.push_str("\n\n");
936+
937+
match pgls_pretty_print::format_statement(&ast, &config) {
938+
Ok(result) => {
939+
if text != result.formatted {
940+
statements.push(StatementFormatResult {
941+
original: text.clone(),
942+
formatted: result.formatted.clone(),
943+
range: stmt_range,
944+
});
945+
}
946+
if !formatted_output.is_empty() {
947+
formatted_output.push_str("\n\n");
948+
}
949+
formatted_output.push_str(&result.formatted);
950+
}
951+
Err(err) => {
952+
diagnostics.push(SDiagnostic::new(
953+
pgls_diagnostics::Error::from(WorkspaceError::format_error(
954+
err.to_string(),
955+
))
956+
.with_file_path(&path_str)
957+
.with_file_span(stmt_range),
958+
));
959+
960+
if !formatted_output.is_empty() {
961+
formatted_output.push_str("\n\n");
962+
}
963+
formatted_output.push_str(&text);
929964
}
930-
formatted_output.push_str(&text);
931965
}
932-
},
966+
}
933967
Err(syntax_err) => {
934968
diagnostics.push(SDiagnostic::new(
935969
pgls_diagnostics::Error::from(syntax_err).with_file_path(&path_str),

crates/pgls_workspace/src/workspace/server.tests.rs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::sync::Arc;
33
use biome_deserialize::{Merge, StringSet};
44
use pgls_analyse::RuleCategories;
55
use pgls_configuration::{
6-
PartialConfiguration, PartialTypecheckConfiguration, database::PartialDatabaseConfiguration,
7-
files::PartialFilesConfiguration,
6+
PartialConfiguration, PartialFormatConfiguration, PartialTypecheckConfiguration,
7+
database::PartialDatabaseConfiguration, files::PartialFilesConfiguration,
88
};
99

1010
#[cfg(not(target_os = "windows"))]
@@ -17,6 +17,7 @@ use sqlx::{Executor, PgPool};
1717
use crate::{
1818
Workspace, WorkspaceError,
1919
features::code_actions::ExecuteStatementResult,
20+
features::format::PullFileFormattingParams,
2021
workspace::{
2122
OpenFileParams, RegisterProjectFolderParams, StatementId, UpdateSettingsParams,
2223
server::WorkspaceServer,
@@ -629,6 +630,52 @@ FOR NO KEY UPDATE;
629630
);
630631
}
631632

633+
#[tokio::test]
634+
async fn test_format_keeps_sql_function_body_intact() {
635+
let mut conf = PartialConfiguration::init();
636+
conf.merge_with(PartialConfiguration {
637+
format: Some(PartialFormatConfiguration {
638+
enabled: Some(true),
639+
..Default::default()
640+
}),
641+
..Default::default()
642+
});
643+
644+
let workspace = get_test_workspace(Some(conf)).expect("Unable to create test workspace");
645+
646+
let path = PgLSPath::new("test.sql");
647+
let content =
648+
"create function add(a int, b int) returns int as 'SELECT 424242+$1+$2;' language sql;";
649+
650+
workspace
651+
.open_file(OpenFileParams {
652+
path: path.clone(),
653+
content: content.into(),
654+
version: 1,
655+
})
656+
.expect("Unable to open test file");
657+
658+
let result = workspace
659+
.pull_file_formatting(PullFileFormattingParams {
660+
path: path.clone(),
661+
range: None,
662+
})
663+
.expect("Unable to pull formatting");
664+
665+
assert_eq!(
666+
result.formatted.matches("424242").count(),
667+
1,
668+
"SQL function body should not be emitted as a second standalone statement:\n{}",
669+
result.formatted,
670+
);
671+
672+
assert!(
673+
result.formatted.contains("select 424242 + $1 + $2;"),
674+
"SQL function body should be formatted inline:\n{}",
675+
result.formatted,
676+
);
677+
}
678+
632679
#[sqlx::test(migrator = "pgls_test_utils::MIGRATIONS")]
633680
async fn test_cstyle_comments(test_db: PgPool) {
634681
let mut conf = PartialConfiguration::init();

crates/pgls_workspace/src/workspace/server/sql_function.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,52 @@ pub fn get_sql_fn_body(ast: &pgls_query::NodeEnum, content: &str) -> Option<SQLF
106106
})
107107
}
108108

109+
/// Replaces the SQL function body in CREATE FUNCTION ... AS '<body>' for LANGUAGE SQL functions.
110+
/// Returns true when a body was replaced.
111+
pub fn set_sql_fn_body(ast: &mut pgls_query::NodeEnum, new_body: &str) -> bool {
112+
let create_fn = match ast {
113+
pgls_query::NodeEnum::CreateFunctionStmt(cf) => cf,
114+
_ => return false,
115+
};
116+
117+
let language = pgls_query_ext::utils::find_option_value(create_fn, "language");
118+
if language.as_deref() != Some("sql") {
119+
return false;
120+
}
121+
122+
for option in &mut create_fn.options {
123+
let Some(pgls_query::NodeEnum::DefElem(def_elem)) = option.node.as_mut() else {
124+
continue;
125+
};
126+
127+
if def_elem.defname != "as" {
128+
continue;
129+
}
130+
131+
let Some(arg) = def_elem.arg.as_mut() else {
132+
continue;
133+
};
134+
135+
match arg.node.as_mut() {
136+
Some(pgls_query::NodeEnum::String(s)) => {
137+
s.sval = new_body.to_string();
138+
return true;
139+
}
140+
Some(pgls_query::NodeEnum::List(list)) => {
141+
if list.items.len() == 1
142+
&& let Some(pgls_query::NodeEnum::String(s)) = list.items[0].node.as_mut()
143+
{
144+
s.sval = new_body.to_string();
145+
return true;
146+
}
147+
}
148+
_ => {}
149+
}
150+
}
151+
152+
false
153+
}
154+
109155
#[cfg(test)]
110156
mod tests {
111157
use super::*;
@@ -163,4 +209,25 @@ mod tests {
163209
.is_some()
164210
);
165211
}
212+
213+
#[test]
214+
fn set_sql_function_body() {
215+
let input = "CREATE FUNCTION add(a integer, b integer) RETURNS integer
216+
AS 'SELECT $1+$2;'
217+
LANGUAGE SQL
218+
IMMUTABLE;";
219+
220+
let mut ast = pgls_query::parse(input).unwrap().into_root().unwrap();
221+
222+
let changed = set_sql_fn_body(&mut ast, "select $1 + $2;");
223+
assert!(changed);
224+
225+
let create_fn = match &ast {
226+
pgls_query::NodeEnum::CreateFunctionStmt(stmt) => stmt,
227+
_ => panic!("Expected CreateFunctionStmt"),
228+
};
229+
230+
let body = pgls_query_ext::utils::find_option_value(create_fn, "as");
231+
assert_eq!(body, Some("select $1 + $2;".to_string()));
232+
}
166233
}

0 commit comments

Comments
 (0)