Skip to content

Commit ea26c15

Browse files
committed
feat(pgls_query): add PLpgSQL protobuf parsing support
- Update libpg_query submodule to fork with PLpgSQL protobuf API - Replace JSON-based parse_plpgsql with protobuf-based implementation - Add PlpgsqlParseResult struct with typed access to parsed functions - Export PLpgSQL protobuf types (PLpgSqlFunction, PLpgSqlStmt, etc.) - Update bindgen allowlist for new FFI functions - Filter PLpgSQL and SummaryResult types from Node iteration macros - Add comprehensive tests verifying parsed structure (IF, WHILE, assignments)
1 parent d9964c1 commit ea26c15

File tree

9 files changed

+2084
-15
lines changed

9 files changed

+2084
-15
lines changed

.gitmodules

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[submodule "crates/pgls_query/vendor/libpg_query"]
22
path = crates/pgls_query/vendor/libpg_query
3-
url = https://github.com/pganalyze/libpg_query.git
4-
branch = 17-latest
3+
url = https://github.com/psteinroe/libpg_query.git
4+
branch = feat/plpgsql-protobuf

crates/pgls_query/build.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ fn run_bindgen(
100100
.allowlist_function("pg_query_split_with_parser")
101101
.allowlist_function("pg_query_split_with_scanner")
102102
.allowlist_function("pg_query_parse_plpgsql")
103+
.allowlist_function("pg_query_parse_plpgsql_protobuf")
103104
.allowlist_function("pg_query_free_protobuf_parse_result")
105+
.allowlist_function("pg_query_free_plpgsql_protobuf_parse_result")
104106
.allowlist_function("pg_query_free_scan_result")
105107
.allowlist_function("pg_query_free_deparse_result")
106108
.allowlist_function("pg_query_free_normalize_result")
@@ -117,6 +119,7 @@ fn run_bindgen(
117119
.allowlist_type("PgQueryFingerprintResult")
118120
.allowlist_type("PgQuerySplitResult")
119121
.allowlist_type("PgQuerySplitStmt")
122+
.allowlist_type("PgQueryPlpgsqlProtobufParseResult")
120123
// Also generate bindings for size_t since it's used in PgQueryProtobuf
121124
.allowlist_type("size_t")
122125
.allowlist_var("PG_VERSION_NUM");
@@ -269,6 +272,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
269272
bindings_content.push_str(" pub fn pg_query_scan(input: *const ::std::os::raw::c_char) -> PgQueryScanResult;\n");
270273
bindings_content.push_str(" pub fn pg_query_parse_protobuf(input: *const ::std::os::raw::c_char) -> PgQueryProtobufParseResult;\n");
271274
bindings_content.push_str(" pub fn pg_query_parse_plpgsql(input: *const ::std::os::raw::c_char) -> PgQueryPlpgsqlParseResult;\n");
275+
bindings_content.push_str(" pub fn pg_query_parse_plpgsql_protobuf(input: *const ::std::os::raw::c_char) -> PgQueryPlpgsqlProtobufParseResult;\n");
272276
bindings_content.push_str(" pub fn pg_query_deparse_protobuf(protobuf: PgQueryProtobuf) -> PgQueryDeparseResult;\n");
273277
bindings_content.push_str(" pub fn pg_query_normalize(input: *const ::std::os::raw::c_char) -> PgQueryNormalizeResult;\n");
274278
bindings_content.push_str(" pub fn pg_query_fingerprint(input: *const ::std::os::raw::c_char) -> PgQueryFingerprintResult;\n");
@@ -278,6 +282,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
278282
.push_str(" pub fn pg_query_free_scan_result(result: PgQueryScanResult);\n");
279283
bindings_content.push_str(" pub fn pg_query_free_protobuf_parse_result(result: PgQueryProtobufParseResult);\n");
280284
bindings_content.push_str(" pub fn pg_query_free_plpgsql_parse_result(result: PgQueryPlpgsqlParseResult);\n");
285+
bindings_content.push_str(" pub fn pg_query_free_plpgsql_protobuf_parse_result(result: PgQueryPlpgsqlProtobufParseResult);\n");
281286
bindings_content.push_str(
282287
" pub fn pg_query_free_deparse_result(result: PgQueryDeparseResult);\n",
283288
);

crates/pgls_query/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ pub use scan::*;
2828
pub use split::*;
2929

3030
pub use protobuf::Node;
31+
pub use protobuf::PLpgSqlFunction;
32+
pub use protobuf::PLpgSqlStmt;
33+
pub use protobuf::PLpgSqlParseResult;
3134

3235
// Include the generated bindings with 2024 edition compatibility
3336
#[allow(non_upper_case_globals)]

crates/pgls_query/src/plpgsql.rs

Lines changed: 256 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
use std::ffi::{CStr, CString};
22

3+
use prost::Message;
4+
35
use crate::bindings::*;
46
use crate::error::*;
7+
use crate::protobuf;
58

6-
/// An experimental API which parses a PLPGSQL function. This currently drops the returned
7-
/// structure and returns only a Result<()>.
9+
/// Parses the given PL/pgSQL function into an abstract syntax tree.
810
///
911
/// # Example
1012
///
1113
/// ```rust
12-
/// let result = pgls_query::parse_plpgsql("
14+
/// use pgls_query::parse_plpgsql;
15+
///
16+
/// let result = parse_plpgsql("
1317
/// CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
1418
/// RETURNS varchar AS $$
1519
/// BEGIN
@@ -21,18 +25,261 @@ use crate::error::*;
2125
/// $$ LANGUAGE plpgsql;
2226
/// ");
2327
/// assert!(result.is_ok());
28+
/// let result = result.unwrap();
29+
/// assert_eq!(result.functions().len(), 1);
2430
/// ```
25-
pub fn parse_plpgsql(stmt: &str) -> Result<()> {
31+
pub fn parse_plpgsql(stmt: &str) -> Result<PlpgsqlParseResult> {
2632
let input = CString::new(stmt)?;
27-
let result = unsafe { pg_query_parse_plpgsql(input.as_ptr()) };
28-
let structure = if !result.error.is_null() {
33+
let result = unsafe { pg_query_parse_plpgsql_protobuf(input.as_ptr()) };
34+
let parse_result = if !result.error.is_null() {
2935
let message = unsafe { CStr::from_ptr((*result.error).message) }
3036
.to_string_lossy()
3137
.to_string();
3238
Err(Error::Parse(message))
3339
} else {
34-
Ok(())
40+
let data = unsafe {
41+
std::slice::from_raw_parts(
42+
result.parse_tree.data as *const u8,
43+
result.parse_tree.len as usize,
44+
)
45+
};
46+
protobuf::PLpgSqlParseResult::decode(data)
47+
.map_err(Error::Decode)
48+
.map(PlpgsqlParseResult::new)
3549
};
36-
unsafe { pg_query_free_plpgsql_parse_result(result) };
37-
structure
50+
unsafe { pg_query_free_plpgsql_protobuf_parse_result(result) };
51+
parse_result
52+
}
53+
54+
/// The result of parsing a PL/pgSQL function
55+
#[derive(Debug)]
56+
pub struct PlpgsqlParseResult {
57+
/// The parsed protobuf result
58+
pub protobuf: protobuf::PLpgSqlParseResult,
59+
}
60+
61+
impl PlpgsqlParseResult {
62+
/// Create a new PlpgsqlParseResult
63+
pub fn new(protobuf: protobuf::PLpgSqlParseResult) -> Self {
64+
Self { protobuf }
65+
}
66+
67+
/// Returns a reference to the single parsed PL/pgSQL function.
68+
///
69+
/// Returns `None` if there is not exactly one function in the parse result.
70+
/// Use `functions()` if you need to handle multiple functions.
71+
pub fn function(&self) -> Option<&protobuf::PLpgSqlFunction> {
72+
if self.protobuf.plpgsql_funcs.len() != 1 {
73+
return None;
74+
}
75+
self.protobuf.plpgsql_funcs.first()
76+
}
77+
78+
/// Consumes the result and returns the single parsed function.
79+
///
80+
/// Returns `None` if there is not exactly one function in the parse result.
81+
/// Use `into_functions()` if you need to handle multiple functions.
82+
pub fn into_function(self) -> Option<protobuf::PLpgSqlFunction> {
83+
if self.protobuf.plpgsql_funcs.len() != 1 {
84+
return None;
85+
}
86+
self.protobuf.plpgsql_funcs.into_iter().next()
87+
}
88+
89+
/// Returns a reference to the list of parsed PL/pgSQL functions
90+
pub fn functions(&self) -> &[protobuf::PLpgSqlFunction] {
91+
&self.protobuf.plpgsql_funcs
92+
}
93+
94+
/// Consumes the result and returns the list of parsed functions
95+
pub fn into_functions(self) -> Vec<protobuf::PLpgSqlFunction> {
96+
self.protobuf.plpgsql_funcs
97+
}
98+
}
99+
100+
#[cfg(test)]
101+
mod tests {
102+
use super::*;
103+
use crate::protobuf::p_lpg_sql_stmt::Stmt;
104+
105+
#[test]
106+
fn test_parse_plpgsql_simple() {
107+
let result = parse_plpgsql(
108+
"
109+
CREATE OR REPLACE FUNCTION test_func()
110+
RETURNS void AS $$
111+
BEGIN
112+
NULL;
113+
END;
114+
$$ LANGUAGE plpgsql;
115+
",
116+
);
117+
assert!(result.is_ok());
118+
let result = result.unwrap();
119+
120+
// Use function() for single function access
121+
let func = result.function().expect("should have exactly one function");
122+
assert!(func.action.is_some());
123+
124+
// The body should contain statements
125+
let action = func.action.as_ref().unwrap();
126+
assert!(!action.body.is_empty());
127+
}
128+
129+
#[test]
130+
fn test_parse_plpgsql_with_assignment() {
131+
let result = parse_plpgsql(
132+
"
133+
CREATE OR REPLACE FUNCTION add_numbers(a int, b int)
134+
RETURNS int AS $$
135+
DECLARE
136+
result int;
137+
BEGIN
138+
result := a + b;
139+
RETURN result;
140+
END;
141+
$$ LANGUAGE plpgsql;
142+
",
143+
);
144+
assert!(result.is_ok());
145+
let result = result.unwrap();
146+
assert_eq!(result.functions().len(), 1);
147+
148+
let func = &result.functions()[0];
149+
let action = func.action.as_ref().unwrap();
150+
151+
// Should have assignment and return statements
152+
assert!(action.body.len() >= 2);
153+
154+
// First statement should be an assignment
155+
let first_stmt = &action.body[0];
156+
assert!(matches!(
157+
first_stmt.stmt,
158+
Some(Stmt::StmtAssign(_))
159+
));
160+
161+
// Second statement should be a return
162+
let second_stmt = &action.body[1];
163+
assert!(matches!(
164+
second_stmt.stmt,
165+
Some(Stmt::StmtReturn(_))
166+
));
167+
168+
// Verify the assignment expression contains the query
169+
if let Some(Stmt::StmtAssign(assign)) = &first_stmt.stmt {
170+
assert!(assign.expr.is_some());
171+
let expr = assign.expr.as_ref().unwrap();
172+
assert!(expr.query.contains("a + b"));
173+
}
174+
}
175+
176+
#[test]
177+
fn test_parse_plpgsql_with_if() {
178+
let result = parse_plpgsql(
179+
"
180+
CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
181+
RETURNS varchar AS $$
182+
BEGIN
183+
IF v_version IS NULL THEN
184+
RETURN v_name;
185+
END IF;
186+
RETURN v_name || '/' || v_version;
187+
END;
188+
$$ LANGUAGE plpgsql;
189+
",
190+
);
191+
assert!(result.is_ok());
192+
let result = result.unwrap();
193+
assert_eq!(result.functions().len(), 1);
194+
195+
let func = &result.functions()[0];
196+
let action = func.action.as_ref().unwrap();
197+
198+
// Should have IF and RETURN statements
199+
assert!(action.body.len() >= 2);
200+
201+
// First statement should be IF
202+
let if_stmt = &action.body[0];
203+
assert!(matches!(if_stmt.stmt, Some(Stmt::StmtIf(_))));
204+
205+
// Verify the IF statement structure
206+
if let Some(Stmt::StmtIf(if_node)) = &if_stmt.stmt {
207+
// Should have a condition
208+
assert!(if_node.cond.is_some());
209+
let cond = if_node.cond.as_ref().unwrap();
210+
assert!(cond.query.contains("v_version IS NULL"));
211+
212+
// Should have a then_body with RETURN statement
213+
assert!(!if_node.then_body.is_empty());
214+
assert!(matches!(
215+
if_node.then_body[0].stmt,
216+
Some(Stmt::StmtReturn(_))
217+
));
218+
}
219+
220+
// Second statement should be RETURN
221+
let return_stmt = &action.body[1];
222+
assert!(matches!(return_stmt.stmt, Some(Stmt::StmtReturn(_))));
223+
}
224+
225+
#[test]
226+
fn test_parse_plpgsql_with_loop() {
227+
let result = parse_plpgsql(
228+
"
229+
CREATE OR REPLACE FUNCTION count_down(n int)
230+
RETURNS void AS $$
231+
BEGIN
232+
WHILE n > 0 LOOP
233+
n := n - 1;
234+
END LOOP;
235+
END;
236+
$$ LANGUAGE plpgsql;
237+
",
238+
);
239+
assert!(result.is_ok());
240+
let result = result.unwrap();
241+
242+
let func = &result.functions()[0];
243+
let action = func.action.as_ref().unwrap();
244+
245+
// First statement should be WHILE loop
246+
let while_stmt = &action.body[0];
247+
assert!(matches!(while_stmt.stmt, Some(Stmt::StmtWhile(_))));
248+
249+
if let Some(Stmt::StmtWhile(while_node)) = &while_stmt.stmt {
250+
// Should have a condition
251+
assert!(while_node.cond.is_some());
252+
let cond = while_node.cond.as_ref().unwrap();
253+
assert!(cond.query.contains("n > 0"));
254+
255+
// Should have a body with assignment
256+
assert!(!while_node.body.is_empty());
257+
}
258+
}
259+
260+
#[test]
261+
fn test_parse_plpgsql_error() {
262+
let result = parse_plpgsql("not valid plpgsql");
263+
assert!(result.is_err());
264+
}
265+
266+
#[test]
267+
fn test_parse_plpgsql_multiple_functions() {
268+
let result = parse_plpgsql(
269+
"
270+
CREATE FUNCTION foo() RETURNS void AS $$ BEGIN NULL; END; $$ LANGUAGE plpgsql;
271+
CREATE FUNCTION bar() RETURNS int AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql;
272+
",
273+
);
274+
assert!(result.is_ok());
275+
let result = result.unwrap();
276+
277+
// function() returns None when multiple functions present
278+
assert!(result.function().is_none());
279+
280+
// Use functions() for multiple
281+
assert_eq!(result.functions().len(), 2);
282+
assert_eq!(result.functions()[0].fn_signature, "foo");
283+
assert_eq!(result.functions()[1].fn_signature, "bar");
284+
}
38285
}

0 commit comments

Comments
 (0)