11use std:: ffi:: { CStr , CString } ;
22
3+ use prost:: Message ;
4+
35use crate :: bindings:: * ;
46use crate :: error:: * ;
7+ use crate :: protobuf;
58
6- /// An experimental API which parses a PLPGSQL function. This currently drops the returned
7- /// structure and returns only a Result<()>.
9+ /// Parses the given PL/pgSQL function into an abstract syntax tree.
810///
911/// # Example
1012///
1113/// ```rust
12- /// let result = pgls_query::parse_plpgsql("
14+ /// use pgls_query::parse_plpgsql;
15+ ///
16+ /// let result = parse_plpgsql("
1317/// CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
1418/// RETURNS varchar AS $$
1519/// BEGIN
@@ -21,18 +25,261 @@ use crate::error::*;
2125/// $$ LANGUAGE plpgsql;
2226/// ");
2327/// assert!(result.is_ok());
28+ /// let result = result.unwrap();
29+ /// assert_eq!(result.functions().len(), 1);
2430/// ```
25- pub fn parse_plpgsql ( stmt : & str ) -> Result < ( ) > {
31+ pub fn parse_plpgsql ( stmt : & str ) -> Result < PlpgsqlParseResult > {
2632 let input = CString :: new ( stmt) ?;
27- let result = unsafe { pg_query_parse_plpgsql ( input. as_ptr ( ) ) } ;
28- let structure = if !result. error . is_null ( ) {
33+ let result = unsafe { pg_query_parse_plpgsql_protobuf ( input. as_ptr ( ) ) } ;
34+ let parse_result = if !result. error . is_null ( ) {
2935 let message = unsafe { CStr :: from_ptr ( ( * result. error ) . message ) }
3036 . to_string_lossy ( )
3137 . to_string ( ) ;
3238 Err ( Error :: Parse ( message) )
3339 } else {
34- Ok ( ( ) )
40+ let data = unsafe {
41+ std:: slice:: from_raw_parts (
42+ result. parse_tree . data as * const u8 ,
43+ result. parse_tree . len as usize ,
44+ )
45+ } ;
46+ protobuf:: PLpgSqlParseResult :: decode ( data)
47+ . map_err ( Error :: Decode )
48+ . map ( PlpgsqlParseResult :: new)
3549 } ;
36- unsafe { pg_query_free_plpgsql_parse_result ( result) } ;
37- structure
50+ unsafe { pg_query_free_plpgsql_protobuf_parse_result ( result) } ;
51+ parse_result
52+ }
53+
54+ /// The result of parsing a PL/pgSQL function
55+ #[ derive( Debug ) ]
56+ pub struct PlpgsqlParseResult {
57+ /// The parsed protobuf result
58+ pub protobuf : protobuf:: PLpgSqlParseResult ,
59+ }
60+
61+ impl PlpgsqlParseResult {
62+ /// Create a new PlpgsqlParseResult
63+ pub fn new ( protobuf : protobuf:: PLpgSqlParseResult ) -> Self {
64+ Self { protobuf }
65+ }
66+
67+ /// Returns a reference to the single parsed PL/pgSQL function.
68+ ///
69+ /// Returns `None` if there is not exactly one function in the parse result.
70+ /// Use `functions()` if you need to handle multiple functions.
71+ pub fn function ( & self ) -> Option < & protobuf:: PLpgSqlFunction > {
72+ if self . protobuf . plpgsql_funcs . len ( ) != 1 {
73+ return None ;
74+ }
75+ self . protobuf . plpgsql_funcs . first ( )
76+ }
77+
78+ /// Consumes the result and returns the single parsed function.
79+ ///
80+ /// Returns `None` if there is not exactly one function in the parse result.
81+ /// Use `into_functions()` if you need to handle multiple functions.
82+ pub fn into_function ( self ) -> Option < protobuf:: PLpgSqlFunction > {
83+ if self . protobuf . plpgsql_funcs . len ( ) != 1 {
84+ return None ;
85+ }
86+ self . protobuf . plpgsql_funcs . into_iter ( ) . next ( )
87+ }
88+
89+ /// Returns a reference to the list of parsed PL/pgSQL functions
90+ pub fn functions ( & self ) -> & [ protobuf:: PLpgSqlFunction ] {
91+ & self . protobuf . plpgsql_funcs
92+ }
93+
94+ /// Consumes the result and returns the list of parsed functions
95+ pub fn into_functions ( self ) -> Vec < protobuf:: PLpgSqlFunction > {
96+ self . protobuf . plpgsql_funcs
97+ }
98+ }
99+
100+ #[ cfg( test) ]
101+ mod tests {
102+ use super :: * ;
103+ use crate :: protobuf:: p_lpg_sql_stmt:: Stmt ;
104+
105+ #[ test]
106+ fn test_parse_plpgsql_simple ( ) {
107+ let result = parse_plpgsql (
108+ "
109+ CREATE OR REPLACE FUNCTION test_func()
110+ RETURNS void AS $$
111+ BEGIN
112+ NULL;
113+ END;
114+ $$ LANGUAGE plpgsql;
115+ " ,
116+ ) ;
117+ assert ! ( result. is_ok( ) ) ;
118+ let result = result. unwrap ( ) ;
119+
120+ // Use function() for single function access
121+ let func = result. function ( ) . expect ( "should have exactly one function" ) ;
122+ assert ! ( func. action. is_some( ) ) ;
123+
124+ // The body should contain statements
125+ let action = func. action . as_ref ( ) . unwrap ( ) ;
126+ assert ! ( !action. body. is_empty( ) ) ;
127+ }
128+
129+ #[ test]
130+ fn test_parse_plpgsql_with_assignment ( ) {
131+ let result = parse_plpgsql (
132+ "
133+ CREATE OR REPLACE FUNCTION add_numbers(a int, b int)
134+ RETURNS int AS $$
135+ DECLARE
136+ result int;
137+ BEGIN
138+ result := a + b;
139+ RETURN result;
140+ END;
141+ $$ LANGUAGE plpgsql;
142+ " ,
143+ ) ;
144+ assert ! ( result. is_ok( ) ) ;
145+ let result = result. unwrap ( ) ;
146+ assert_eq ! ( result. functions( ) . len( ) , 1 ) ;
147+
148+ let func = & result. functions ( ) [ 0 ] ;
149+ let action = func. action . as_ref ( ) . unwrap ( ) ;
150+
151+ // Should have assignment and return statements
152+ assert ! ( action. body. len( ) >= 2 ) ;
153+
154+ // First statement should be an assignment
155+ let first_stmt = & action. body [ 0 ] ;
156+ assert ! ( matches!(
157+ first_stmt. stmt,
158+ Some ( Stmt :: StmtAssign ( _) )
159+ ) ) ;
160+
161+ // Second statement should be a return
162+ let second_stmt = & action. body [ 1 ] ;
163+ assert ! ( matches!(
164+ second_stmt. stmt,
165+ Some ( Stmt :: StmtReturn ( _) )
166+ ) ) ;
167+
168+ // Verify the assignment expression contains the query
169+ if let Some ( Stmt :: StmtAssign ( assign) ) = & first_stmt. stmt {
170+ assert ! ( assign. expr. is_some( ) ) ;
171+ let expr = assign. expr . as_ref ( ) . unwrap ( ) ;
172+ assert ! ( expr. query. contains( "a + b" ) ) ;
173+ }
174+ }
175+
176+ #[ test]
177+ fn test_parse_plpgsql_with_if ( ) {
178+ let result = parse_plpgsql (
179+ "
180+ CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
181+ RETURNS varchar AS $$
182+ BEGIN
183+ IF v_version IS NULL THEN
184+ RETURN v_name;
185+ END IF;
186+ RETURN v_name || '/' || v_version;
187+ END;
188+ $$ LANGUAGE plpgsql;
189+ " ,
190+ ) ;
191+ assert ! ( result. is_ok( ) ) ;
192+ let result = result. unwrap ( ) ;
193+ assert_eq ! ( result. functions( ) . len( ) , 1 ) ;
194+
195+ let func = & result. functions ( ) [ 0 ] ;
196+ let action = func. action . as_ref ( ) . unwrap ( ) ;
197+
198+ // Should have IF and RETURN statements
199+ assert ! ( action. body. len( ) >= 2 ) ;
200+
201+ // First statement should be IF
202+ let if_stmt = & action. body [ 0 ] ;
203+ assert ! ( matches!( if_stmt. stmt, Some ( Stmt :: StmtIf ( _) ) ) ) ;
204+
205+ // Verify the IF statement structure
206+ if let Some ( Stmt :: StmtIf ( if_node) ) = & if_stmt. stmt {
207+ // Should have a condition
208+ assert ! ( if_node. cond. is_some( ) ) ;
209+ let cond = if_node. cond . as_ref ( ) . unwrap ( ) ;
210+ assert ! ( cond. query. contains( "v_version IS NULL" ) ) ;
211+
212+ // Should have a then_body with RETURN statement
213+ assert ! ( !if_node. then_body. is_empty( ) ) ;
214+ assert ! ( matches!(
215+ if_node. then_body[ 0 ] . stmt,
216+ Some ( Stmt :: StmtReturn ( _) )
217+ ) ) ;
218+ }
219+
220+ // Second statement should be RETURN
221+ let return_stmt = & action. body [ 1 ] ;
222+ assert ! ( matches!( return_stmt. stmt, Some ( Stmt :: StmtReturn ( _) ) ) ) ;
223+ }
224+
225+ #[ test]
226+ fn test_parse_plpgsql_with_loop ( ) {
227+ let result = parse_plpgsql (
228+ "
229+ CREATE OR REPLACE FUNCTION count_down(n int)
230+ RETURNS void AS $$
231+ BEGIN
232+ WHILE n > 0 LOOP
233+ n := n - 1;
234+ END LOOP;
235+ END;
236+ $$ LANGUAGE plpgsql;
237+ " ,
238+ ) ;
239+ assert ! ( result. is_ok( ) ) ;
240+ let result = result. unwrap ( ) ;
241+
242+ let func = & result. functions ( ) [ 0 ] ;
243+ let action = func. action . as_ref ( ) . unwrap ( ) ;
244+
245+ // First statement should be WHILE loop
246+ let while_stmt = & action. body [ 0 ] ;
247+ assert ! ( matches!( while_stmt. stmt, Some ( Stmt :: StmtWhile ( _) ) ) ) ;
248+
249+ if let Some ( Stmt :: StmtWhile ( while_node) ) = & while_stmt. stmt {
250+ // Should have a condition
251+ assert ! ( while_node. cond. is_some( ) ) ;
252+ let cond = while_node. cond . as_ref ( ) . unwrap ( ) ;
253+ assert ! ( cond. query. contains( "n > 0" ) ) ;
254+
255+ // Should have a body with assignment
256+ assert ! ( !while_node. body. is_empty( ) ) ;
257+ }
258+ }
259+
260+ #[ test]
261+ fn test_parse_plpgsql_error ( ) {
262+ let result = parse_plpgsql ( "not valid plpgsql" ) ;
263+ assert ! ( result. is_err( ) ) ;
264+ }
265+
266+ #[ test]
267+ fn test_parse_plpgsql_multiple_functions ( ) {
268+ let result = parse_plpgsql (
269+ "
270+ CREATE FUNCTION foo() RETURNS void AS $$ BEGIN NULL; END; $$ LANGUAGE plpgsql;
271+ CREATE FUNCTION bar() RETURNS int AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql;
272+ " ,
273+ ) ;
274+ assert ! ( result. is_ok( ) ) ;
275+ let result = result. unwrap ( ) ;
276+
277+ // function() returns None when multiple functions present
278+ assert ! ( result. function( ) . is_none( ) ) ;
279+
280+ // Use functions() for multiple
281+ assert_eq ! ( result. functions( ) . len( ) , 2 ) ;
282+ assert_eq ! ( result. functions( ) [ 0 ] . fn_signature, "foo" ) ;
283+ assert_eq ! ( result. functions( ) [ 1 ] . fn_signature, "bar" ) ;
284+ }
38285}
0 commit comments