Skip to content

Commit 0a9caa2

Browse files
committed
feat(pgls_query): add PLpgSQL protobuf parsing support
- Update libpg_query submodule to fork with PLpgSQL protobuf API - Replace JSON-based parse_plpgsql with protobuf-based implementation - Add PlpgsqlParseResult struct with typed access to parsed functions - Export PLpgSQL protobuf types (PLpgSqlFunction, PLpgSqlStmt, etc.) - Update bindgen allowlist for new FFI functions - Filter PLpgSQL and SummaryResult types from Node iteration macros - Add comprehensive tests verifying parsed structure (IF, WHILE, assignments)
1 parent d9964c1 commit 0a9caa2

File tree

9 files changed

+2060
-15
lines changed

9 files changed

+2060
-15
lines changed

.gitmodules

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[submodule "crates/pgls_query/vendor/libpg_query"]
22
path = crates/pgls_query/vendor/libpg_query
3-
url = https://github.com/pganalyze/libpg_query.git
4-
branch = 17-latest
3+
url = https://github.com/psteinroe/libpg_query.git
4+
branch = feat/plpgsql-protobuf

crates/pgls_query/build.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ fn run_bindgen(
100100
.allowlist_function("pg_query_split_with_parser")
101101
.allowlist_function("pg_query_split_with_scanner")
102102
.allowlist_function("pg_query_parse_plpgsql")
103+
.allowlist_function("pg_query_parse_plpgsql_protobuf")
103104
.allowlist_function("pg_query_free_protobuf_parse_result")
105+
.allowlist_function("pg_query_free_plpgsql_protobuf_parse_result")
104106
.allowlist_function("pg_query_free_scan_result")
105107
.allowlist_function("pg_query_free_deparse_result")
106108
.allowlist_function("pg_query_free_normalize_result")
@@ -117,6 +119,7 @@ fn run_bindgen(
117119
.allowlist_type("PgQueryFingerprintResult")
118120
.allowlist_type("PgQuerySplitResult")
119121
.allowlist_type("PgQuerySplitStmt")
122+
.allowlist_type("PgQueryPlpgsqlProtobufParseResult")
120123
// Also generate bindings for size_t since it's used in PgQueryProtobuf
121124
.allowlist_type("size_t")
122125
.allowlist_var("PG_VERSION_NUM");
@@ -269,6 +272,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
269272
bindings_content.push_str(" pub fn pg_query_scan(input: *const ::std::os::raw::c_char) -> PgQueryScanResult;\n");
270273
bindings_content.push_str(" pub fn pg_query_parse_protobuf(input: *const ::std::os::raw::c_char) -> PgQueryProtobufParseResult;\n");
271274
bindings_content.push_str(" pub fn pg_query_parse_plpgsql(input: *const ::std::os::raw::c_char) -> PgQueryPlpgsqlParseResult;\n");
275+
bindings_content.push_str(" pub fn pg_query_parse_plpgsql_protobuf(input: *const ::std::os::raw::c_char) -> PgQueryPlpgsqlProtobufParseResult;\n");
272276
bindings_content.push_str(" pub fn pg_query_deparse_protobuf(protobuf: PgQueryProtobuf) -> PgQueryDeparseResult;\n");
273277
bindings_content.push_str(" pub fn pg_query_normalize(input: *const ::std::os::raw::c_char) -> PgQueryNormalizeResult;\n");
274278
bindings_content.push_str(" pub fn pg_query_fingerprint(input: *const ::std::os::raw::c_char) -> PgQueryFingerprintResult;\n");
@@ -278,6 +282,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
278282
.push_str(" pub fn pg_query_free_scan_result(result: PgQueryScanResult);\n");
279283
bindings_content.push_str(" pub fn pg_query_free_protobuf_parse_result(result: PgQueryProtobufParseResult);\n");
280284
bindings_content.push_str(" pub fn pg_query_free_plpgsql_parse_result(result: PgQueryPlpgsqlParseResult);\n");
285+
bindings_content.push_str(" pub fn pg_query_free_plpgsql_protobuf_parse_result(result: PgQueryPlpgsqlProtobufParseResult);\n");
281286
bindings_content.push_str(
282287
" pub fn pg_query_free_deparse_result(result: PgQueryDeparseResult);\n",
283288
);

crates/pgls_query/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ pub use scan::*;
2828
pub use split::*;
2929

3030
pub use protobuf::Node;
31+
pub use protobuf::PLpgSqlFunction;
32+
pub use protobuf::PLpgSqlStmt;
33+
pub use protobuf::PLpgSqlParseResult;
3134

3235
// Include the generated bindings with 2024 edition compatibility
3336
#[allow(non_upper_case_globals)]

crates/pgls_query/src/plpgsql.rs

Lines changed: 232 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
use std::ffi::{CStr, CString};
22

3+
use prost::Message;
4+
35
use crate::bindings::*;
46
use crate::error::*;
7+
use crate::protobuf;
58

6-
/// An experimental API which parses a PLPGSQL function. This currently drops the returned
7-
/// structure and returns only a Result<()>.
9+
/// Parses the given PL/pgSQL function into an abstract syntax tree.
810
///
911
/// # Example
1012
///
1113
/// ```rust
12-
/// let result = pgls_query::parse_plpgsql("
14+
/// use pgls_query::parse_plpgsql;
15+
///
16+
/// let result = parse_plpgsql("
1317
/// CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
1418
/// RETURNS varchar AS $$
1519
/// BEGIN
@@ -21,18 +25,237 @@ use crate::error::*;
2125
/// $$ LANGUAGE plpgsql;
2226
/// ");
2327
/// assert!(result.is_ok());
28+
/// let result = result.unwrap();
29+
/// assert_eq!(result.functions().len(), 1);
2430
/// ```
25-
pub fn parse_plpgsql(stmt: &str) -> Result<()> {
31+
pub fn parse_plpgsql(stmt: &str) -> Result<PlpgsqlParseResult> {
2632
let input = CString::new(stmt)?;
27-
let result = unsafe { pg_query_parse_plpgsql(input.as_ptr()) };
28-
let structure = if !result.error.is_null() {
33+
let result = unsafe { pg_query_parse_plpgsql_protobuf(input.as_ptr()) };
34+
let parse_result = if !result.error.is_null() {
2935
let message = unsafe { CStr::from_ptr((*result.error).message) }
3036
.to_string_lossy()
3137
.to_string();
3238
Err(Error::Parse(message))
3339
} else {
34-
Ok(())
40+
let data = unsafe {
41+
std::slice::from_raw_parts(
42+
result.parse_tree.data as *const u8,
43+
result.parse_tree.len as usize,
44+
)
45+
};
46+
protobuf::PLpgSqlParseResult::decode(data)
47+
.map_err(Error::Decode)
48+
.map(PlpgsqlParseResult::new)
3549
};
36-
unsafe { pg_query_free_plpgsql_parse_result(result) };
37-
structure
50+
unsafe { pg_query_free_plpgsql_protobuf_parse_result(result) };
51+
parse_result
52+
}
53+
54+
/// The result of parsing a PL/pgSQL function
55+
#[derive(Debug)]
56+
pub struct PlpgsqlParseResult {
57+
/// The parsed protobuf result
58+
pub protobuf: protobuf::PLpgSqlParseResult,
59+
}
60+
61+
impl PlpgsqlParseResult {
62+
/// Create a new PlpgsqlParseResult
63+
pub fn new(protobuf: protobuf::PLpgSqlParseResult) -> Self {
64+
Self { protobuf }
65+
}
66+
67+
/// Returns a reference to the list of parsed PL/pgSQL functions
68+
pub fn functions(&self) -> &[protobuf::PLpgSqlFunction] {
69+
&self.protobuf.plpgsql_funcs
70+
}
71+
72+
/// Consumes the result and returns the list of parsed functions
73+
pub fn into_functions(self) -> Vec<protobuf::PLpgSqlFunction> {
74+
self.protobuf.plpgsql_funcs
75+
}
76+
}
77+
78+
#[cfg(test)]
79+
mod tests {
80+
use super::*;
81+
use crate::protobuf::p_lpg_sql_stmt::Stmt;
82+
83+
#[test]
84+
fn test_parse_plpgsql_simple() {
85+
let result = parse_plpgsql(
86+
"
87+
CREATE OR REPLACE FUNCTION test_func()
88+
RETURNS void AS $$
89+
BEGIN
90+
NULL;
91+
END;
92+
$$ LANGUAGE plpgsql;
93+
",
94+
);
95+
assert!(result.is_ok());
96+
let result = result.unwrap();
97+
assert_eq!(result.functions().len(), 1);
98+
99+
// Verify function has an action block
100+
let func = &result.functions()[0];
101+
assert!(func.action.is_some());
102+
103+
// The body should contain statements
104+
let action = func.action.as_ref().unwrap();
105+
assert!(!action.body.is_empty());
106+
}
107+
108+
#[test]
109+
fn test_parse_plpgsql_with_assignment() {
110+
let result = parse_plpgsql(
111+
"
112+
CREATE OR REPLACE FUNCTION add_numbers(a int, b int)
113+
RETURNS int AS $$
114+
DECLARE
115+
result int;
116+
BEGIN
117+
result := a + b;
118+
RETURN result;
119+
END;
120+
$$ LANGUAGE plpgsql;
121+
",
122+
);
123+
assert!(result.is_ok());
124+
let result = result.unwrap();
125+
assert_eq!(result.functions().len(), 1);
126+
127+
let func = &result.functions()[0];
128+
let action = func.action.as_ref().unwrap();
129+
130+
// Should have assignment and return statements
131+
assert!(action.body.len() >= 2);
132+
133+
// First statement should be an assignment
134+
let first_stmt = &action.body[0];
135+
assert!(matches!(
136+
first_stmt.stmt,
137+
Some(Stmt::StmtAssign(_))
138+
));
139+
140+
// Second statement should be a return
141+
let second_stmt = &action.body[1];
142+
assert!(matches!(
143+
second_stmt.stmt,
144+
Some(Stmt::StmtReturn(_))
145+
));
146+
147+
// Verify the assignment expression contains the query
148+
if let Some(Stmt::StmtAssign(assign)) = &first_stmt.stmt {
149+
assert!(assign.expr.is_some());
150+
let expr = assign.expr.as_ref().unwrap();
151+
assert!(expr.query.contains("a + b"));
152+
}
153+
}
154+
155+
#[test]
156+
fn test_parse_plpgsql_with_if() {
157+
let result = parse_plpgsql(
158+
"
159+
CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
160+
RETURNS varchar AS $$
161+
BEGIN
162+
IF v_version IS NULL THEN
163+
RETURN v_name;
164+
END IF;
165+
RETURN v_name || '/' || v_version;
166+
END;
167+
$$ LANGUAGE plpgsql;
168+
",
169+
);
170+
assert!(result.is_ok());
171+
let result = result.unwrap();
172+
assert_eq!(result.functions().len(), 1);
173+
174+
let func = &result.functions()[0];
175+
let action = func.action.as_ref().unwrap();
176+
177+
// Should have IF and RETURN statements
178+
assert!(action.body.len() >= 2);
179+
180+
// First statement should be IF
181+
let if_stmt = &action.body[0];
182+
assert!(matches!(if_stmt.stmt, Some(Stmt::StmtIf(_))));
183+
184+
// Verify the IF statement structure
185+
if let Some(Stmt::StmtIf(if_node)) = &if_stmt.stmt {
186+
// Should have a condition
187+
assert!(if_node.cond.is_some());
188+
let cond = if_node.cond.as_ref().unwrap();
189+
assert!(cond.query.contains("v_version IS NULL"));
190+
191+
// Should have a then_body with RETURN statement
192+
assert!(!if_node.then_body.is_empty());
193+
assert!(matches!(
194+
if_node.then_body[0].stmt,
195+
Some(Stmt::StmtReturn(_))
196+
));
197+
}
198+
199+
// Second statement should be RETURN
200+
let return_stmt = &action.body[1];
201+
assert!(matches!(return_stmt.stmt, Some(Stmt::StmtReturn(_))));
202+
}
203+
204+
#[test]
205+
fn test_parse_plpgsql_with_loop() {
206+
let result = parse_plpgsql(
207+
"
208+
CREATE OR REPLACE FUNCTION count_down(n int)
209+
RETURNS void AS $$
210+
BEGIN
211+
WHILE n > 0 LOOP
212+
n := n - 1;
213+
END LOOP;
214+
END;
215+
$$ LANGUAGE plpgsql;
216+
",
217+
);
218+
assert!(result.is_ok());
219+
let result = result.unwrap();
220+
221+
let func = &result.functions()[0];
222+
let action = func.action.as_ref().unwrap();
223+
224+
// First statement should be WHILE loop
225+
let while_stmt = &action.body[0];
226+
assert!(matches!(while_stmt.stmt, Some(Stmt::StmtWhile(_))));
227+
228+
if let Some(Stmt::StmtWhile(while_node)) = &while_stmt.stmt {
229+
// Should have a condition
230+
assert!(while_node.cond.is_some());
231+
let cond = while_node.cond.as_ref().unwrap();
232+
assert!(cond.query.contains("n > 0"));
233+
234+
// Should have a body with assignment
235+
assert!(!while_node.body.is_empty());
236+
}
237+
}
238+
239+
#[test]
240+
fn test_parse_plpgsql_error() {
241+
let result = parse_plpgsql("not valid plpgsql");
242+
assert!(result.is_err());
243+
}
244+
245+
#[test]
246+
fn test_parse_plpgsql_multiple_functions() {
247+
let result = parse_plpgsql(
248+
"
249+
CREATE FUNCTION foo() RETURNS void AS $$ BEGIN NULL; END; $$ LANGUAGE plpgsql;
250+
CREATE FUNCTION bar() RETURNS int AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql;
251+
",
252+
);
253+
assert!(result.is_ok());
254+
let result = result.unwrap();
255+
256+
// Parser handles multiple CREATE FUNCTION statements
257+
assert_eq!(result.functions().len(), 2);
258+
assert_eq!(result.functions()[0].fn_signature, "foo");
259+
assert_eq!(result.functions()[1].fn_signature, "bar");
260+
}
38261
}

0 commit comments

Comments
 (0)