Skip to content

Commit 65af30a

Browse files
committed
feat(pgls_query): add PLpgSQL protobuf parsing support
- Update libpg_query submodule to fork with PLpgSQL protobuf API - Replace JSON-based parse_plpgsql with protobuf-based implementation - Add PlpgsqlParseResult struct with typed access to parsed functions - Export PLpgSQL protobuf types (PLpgSqlFunction, PLpgSqlStmt, etc.) - Update bindgen allowlist for new FFI functions - Filter PLpgSQL and SummaryResult types from Node iteration macros - Add comprehensive tests verifying parsed structure (IF, WHILE, assignments)
1 parent d9964c1 commit 65af30a

10 files changed

Lines changed: 2088 additions & 15 deletions

File tree

.gitmodules

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[submodule "crates/pgls_query/vendor/libpg_query"]
22
path = crates/pgls_query/vendor/libpg_query
3-
url = https://github.com/pganalyze/libpg_query.git
4-
branch = 17-latest
3+
url = https://github.com/psteinroe/libpg_query.git
4+
branch = feat/plpgsql-protobuf

crates/pgls_query/build.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ fn run_bindgen(
100100
.allowlist_function("pg_query_split_with_parser")
101101
.allowlist_function("pg_query_split_with_scanner")
102102
.allowlist_function("pg_query_parse_plpgsql")
103+
.allowlist_function("pg_query_parse_plpgsql_protobuf")
103104
.allowlist_function("pg_query_free_protobuf_parse_result")
105+
.allowlist_function("pg_query_free_plpgsql_protobuf_parse_result")
104106
.allowlist_function("pg_query_free_scan_result")
105107
.allowlist_function("pg_query_free_deparse_result")
106108
.allowlist_function("pg_query_free_normalize_result")
@@ -117,6 +119,7 @@ fn run_bindgen(
117119
.allowlist_type("PgQueryFingerprintResult")
118120
.allowlist_type("PgQuerySplitResult")
119121
.allowlist_type("PgQuerySplitStmt")
122+
.allowlist_type("PgQueryPlpgsqlProtobufParseResult")
120123
// Also generate bindings for size_t since it's used in PgQueryProtobuf
121124
.allowlist_type("size_t")
122125
.allowlist_var("PG_VERSION_NUM");
@@ -269,6 +272,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
269272
bindings_content.push_str(" pub fn pg_query_scan(input: *const ::std::os::raw::c_char) -> PgQueryScanResult;\n");
270273
bindings_content.push_str(" pub fn pg_query_parse_protobuf(input: *const ::std::os::raw::c_char) -> PgQueryProtobufParseResult;\n");
271274
bindings_content.push_str(" pub fn pg_query_parse_plpgsql(input: *const ::std::os::raw::c_char) -> PgQueryPlpgsqlParseResult;\n");
275+
bindings_content.push_str(" pub fn pg_query_parse_plpgsql_protobuf(input: *const ::std::os::raw::c_char) -> PgQueryPlpgsqlProtobufParseResult;\n");
272276
bindings_content.push_str(" pub fn pg_query_deparse_protobuf(protobuf: PgQueryProtobuf) -> PgQueryDeparseResult;\n");
273277
bindings_content.push_str(" pub fn pg_query_normalize(input: *const ::std::os::raw::c_char) -> PgQueryNormalizeResult;\n");
274278
bindings_content.push_str(" pub fn pg_query_fingerprint(input: *const ::std::os::raw::c_char) -> PgQueryFingerprintResult;\n");
@@ -278,6 +282,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
278282
.push_str(" pub fn pg_query_free_scan_result(result: PgQueryScanResult);\n");
279283
bindings_content.push_str(" pub fn pg_query_free_protobuf_parse_result(result: PgQueryProtobufParseResult);\n");
280284
bindings_content.push_str(" pub fn pg_query_free_plpgsql_parse_result(result: PgQueryPlpgsqlParseResult);\n");
285+
bindings_content.push_str(" pub fn pg_query_free_plpgsql_protobuf_parse_result(result: PgQueryPlpgsqlProtobufParseResult);\n");
281286
bindings_content.push_str(
282287
" pub fn pg_query_free_deparse_result(result: PgQueryDeparseResult);\n",
283288
);

crates/pgls_query/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ pub use scan::*;
2828
pub use split::*;
2929

3030
pub use protobuf::Node;
31+
pub use protobuf::PLpgSqlFunction;
32+
pub use protobuf::PLpgSqlParseResult;
33+
pub use protobuf::PLpgSqlStmt;
3134

3235
// Include the generated bindings with 2024 edition compatibility
3336
#[allow(non_upper_case_globals)]

crates/pgls_query/src/plpgsql.rs

Lines changed: 250 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
use std::ffi::{CStr, CString};
22

3+
use prost::Message;
4+
35
use crate::bindings::*;
46
use crate::error::*;
7+
use crate::protobuf;
58

6-
/// An experimental API which parses a PLPGSQL function. This currently drops the returned
7-
/// structure and returns only a Result<()>.
9+
/// Parses the given PL/pgSQL function into an abstract syntax tree.
810
///
911
/// # Example
1012
///
1113
/// ```rust
12-
/// let result = pgls_query::parse_plpgsql("
14+
/// use pgls_query::parse_plpgsql;
15+
///
16+
/// let result = parse_plpgsql("
1317
/// CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
1418
/// RETURNS varchar AS $$
1519
/// BEGIN
@@ -21,18 +25,255 @@ use crate::error::*;
2125
/// $$ LANGUAGE plpgsql;
2226
/// ");
2327
/// assert!(result.is_ok());
28+
/// let result = result.unwrap();
29+
/// assert_eq!(result.functions().len(), 1);
2430
/// ```
25-
pub fn parse_plpgsql(stmt: &str) -> Result<()> {
31+
pub fn parse_plpgsql(stmt: &str) -> Result<PlpgsqlParseResult> {
2632
let input = CString::new(stmt)?;
27-
let result = unsafe { pg_query_parse_plpgsql(input.as_ptr()) };
28-
let structure = if !result.error.is_null() {
33+
let result = unsafe { pg_query_parse_plpgsql_protobuf(input.as_ptr()) };
34+
let parse_result = if !result.error.is_null() {
2935
let message = unsafe { CStr::from_ptr((*result.error).message) }
3036
.to_string_lossy()
3137
.to_string();
3238
Err(Error::Parse(message))
3339
} else {
34-
Ok(())
40+
let data = unsafe {
41+
std::slice::from_raw_parts(
42+
result.parse_tree.data as *const u8,
43+
result.parse_tree.len as usize,
44+
)
45+
};
46+
protobuf::PLpgSqlParseResult::decode(data)
47+
.map_err(Error::Decode)
48+
.map(PlpgsqlParseResult::new)
3549
};
36-
unsafe { pg_query_free_plpgsql_parse_result(result) };
37-
structure
50+
unsafe { pg_query_free_plpgsql_protobuf_parse_result(result) };
51+
parse_result
52+
}
53+
54+
/// The result of parsing a PL/pgSQL function
55+
#[derive(Debug)]
56+
pub struct PlpgsqlParseResult {
57+
/// The parsed protobuf result
58+
pub protobuf: protobuf::PLpgSqlParseResult,
59+
}
60+
61+
impl PlpgsqlParseResult {
62+
/// Create a new PlpgsqlParseResult
63+
pub fn new(protobuf: protobuf::PLpgSqlParseResult) -> Self {
64+
Self { protobuf }
65+
}
66+
67+
/// Returns a reference to the single parsed PL/pgSQL function.
68+
///
69+
/// Returns `None` if there is not exactly one function in the parse result.
70+
/// Use `functions()` if you need to handle multiple functions.
71+
pub fn function(&self) -> Option<&protobuf::PLpgSqlFunction> {
72+
if self.protobuf.plpgsql_funcs.len() != 1 {
73+
return None;
74+
}
75+
self.protobuf.plpgsql_funcs.first()
76+
}
77+
78+
/// Consumes the result and returns the single parsed function.
79+
///
80+
/// Returns `None` if there is not exactly one function in the parse result.
81+
/// Use `into_functions()` if you need to handle multiple functions.
82+
pub fn into_function(self) -> Option<protobuf::PLpgSqlFunction> {
83+
if self.protobuf.plpgsql_funcs.len() != 1 {
84+
return None;
85+
}
86+
self.protobuf.plpgsql_funcs.into_iter().next()
87+
}
88+
89+
/// Returns a reference to the list of parsed PL/pgSQL functions
90+
pub fn functions(&self) -> &[protobuf::PLpgSqlFunction] {
91+
&self.protobuf.plpgsql_funcs
92+
}
93+
94+
/// Consumes the result and returns the list of parsed functions
95+
pub fn into_functions(self) -> Vec<protobuf::PLpgSqlFunction> {
96+
self.protobuf.plpgsql_funcs
97+
}
98+
}
99+
100+
#[cfg(test)]
101+
mod tests {
102+
use super::*;
103+
use crate::protobuf::p_lpg_sql_stmt::Stmt;
104+
105+
#[test]
106+
fn test_parse_plpgsql_simple() {
107+
let result = parse_plpgsql(
108+
"
109+
CREATE OR REPLACE FUNCTION test_func()
110+
RETURNS void AS $$
111+
BEGIN
112+
NULL;
113+
END;
114+
$$ LANGUAGE plpgsql;
115+
",
116+
);
117+
assert!(result.is_ok());
118+
let result = result.unwrap();
119+
120+
// Use function() for single function access
121+
let func = result.function().expect("should have exactly one function");
122+
assert!(func.action.is_some());
123+
124+
// The body should contain statements
125+
let action = func.action.as_ref().unwrap();
126+
assert!(!action.body.is_empty());
127+
}
128+
129+
#[test]
130+
fn test_parse_plpgsql_with_assignment() {
131+
let result = parse_plpgsql(
132+
"
133+
CREATE OR REPLACE FUNCTION add_numbers(a int, b int)
134+
RETURNS int AS $$
135+
DECLARE
136+
result int;
137+
BEGIN
138+
result := a + b;
139+
RETURN result;
140+
END;
141+
$$ LANGUAGE plpgsql;
142+
",
143+
);
144+
assert!(result.is_ok());
145+
let result = result.unwrap();
146+
assert_eq!(result.functions().len(), 1);
147+
148+
let func = &result.functions()[0];
149+
let action = func.action.as_ref().unwrap();
150+
151+
// Should have assignment and return statements
152+
assert!(action.body.len() >= 2);
153+
154+
// First statement should be an assignment
155+
let first_stmt = &action.body[0];
156+
assert!(matches!(first_stmt.stmt, Some(Stmt::StmtAssign(_))));
157+
158+
// Second statement should be a return
159+
let second_stmt = &action.body[1];
160+
assert!(matches!(second_stmt.stmt, Some(Stmt::StmtReturn(_))));
161+
162+
// Verify the assignment expression contains the query
163+
if let Some(Stmt::StmtAssign(assign)) = &first_stmt.stmt {
164+
assert!(assign.expr.is_some());
165+
let expr = assign.expr.as_ref().unwrap();
166+
assert!(expr.query.contains("a + b"));
167+
}
168+
}
169+
170+
#[test]
171+
fn test_parse_plpgsql_with_if() {
172+
let result = parse_plpgsql(
173+
"
174+
CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
175+
RETURNS varchar AS $$
176+
BEGIN
177+
IF v_version IS NULL THEN
178+
RETURN v_name;
179+
END IF;
180+
RETURN v_name || '/' || v_version;
181+
END;
182+
$$ LANGUAGE plpgsql;
183+
",
184+
);
185+
assert!(result.is_ok());
186+
let result = result.unwrap();
187+
assert_eq!(result.functions().len(), 1);
188+
189+
let func = &result.functions()[0];
190+
let action = func.action.as_ref().unwrap();
191+
192+
// Should have IF and RETURN statements
193+
assert!(action.body.len() >= 2);
194+
195+
// First statement should be IF
196+
let if_stmt = &action.body[0];
197+
assert!(matches!(if_stmt.stmt, Some(Stmt::StmtIf(_))));
198+
199+
// Verify the IF statement structure
200+
if let Some(Stmt::StmtIf(if_node)) = &if_stmt.stmt {
201+
// Should have a condition
202+
assert!(if_node.cond.is_some());
203+
let cond = if_node.cond.as_ref().unwrap();
204+
assert!(cond.query.contains("v_version IS NULL"));
205+
206+
// Should have a then_body with RETURN statement
207+
assert!(!if_node.then_body.is_empty());
208+
assert!(matches!(
209+
if_node.then_body[0].stmt,
210+
Some(Stmt::StmtReturn(_))
211+
));
212+
}
213+
214+
// Second statement should be RETURN
215+
let return_stmt = &action.body[1];
216+
assert!(matches!(return_stmt.stmt, Some(Stmt::StmtReturn(_))));
217+
}
218+
219+
#[test]
220+
fn test_parse_plpgsql_with_loop() {
221+
let result = parse_plpgsql(
222+
"
223+
CREATE OR REPLACE FUNCTION count_down(n int)
224+
RETURNS void AS $$
225+
BEGIN
226+
WHILE n > 0 LOOP
227+
n := n - 1;
228+
END LOOP;
229+
END;
230+
$$ LANGUAGE plpgsql;
231+
",
232+
);
233+
assert!(result.is_ok());
234+
let result = result.unwrap();
235+
236+
let func = &result.functions()[0];
237+
let action = func.action.as_ref().unwrap();
238+
239+
// First statement should be WHILE loop
240+
let while_stmt = &action.body[0];
241+
assert!(matches!(while_stmt.stmt, Some(Stmt::StmtWhile(_))));
242+
243+
if let Some(Stmt::StmtWhile(while_node)) = &while_stmt.stmt {
244+
// Should have a condition
245+
assert!(while_node.cond.is_some());
246+
let cond = while_node.cond.as_ref().unwrap();
247+
assert!(cond.query.contains("n > 0"));
248+
249+
// Should have a body with assignment
250+
assert!(!while_node.body.is_empty());
251+
}
252+
}
253+
254+
#[test]
255+
fn test_parse_plpgsql_error() {
256+
let result = parse_plpgsql("not valid plpgsql");
257+
assert!(result.is_err());
258+
}
259+
260+
#[test]
261+
fn test_parse_plpgsql_multiple_functions() {
262+
let result = parse_plpgsql(
263+
"
264+
CREATE FUNCTION foo() RETURNS void AS $$ BEGIN NULL; END; $$ LANGUAGE plpgsql;
265+
CREATE FUNCTION bar() RETURNS int AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql;
266+
",
267+
);
268+
assert!(result.is_ok());
269+
let result = result.unwrap();
270+
271+
// function() returns None when multiple functions present
272+
assert!(result.function().is_none());
273+
274+
// Use functions() for multiple
275+
assert_eq!(result.functions().len(), 2);
276+
assert_eq!(result.functions()[0].fn_signature, "foo");
277+
assert_eq!(result.functions()[1].fn_signature, "bar");
278+
}
38279
}

0 commit comments

Comments
 (0)