11use std:: ffi:: { CStr , CString } ;
22
3+ use prost:: Message ;
4+
35use crate :: bindings:: * ;
46use crate :: error:: * ;
7+ use crate :: protobuf;
58
6- /// An experimental API which parses a PLPGSQL function. This currently drops the returned
7- /// structure and returns only a Result<()>.
9+ /// Parses the given PL/pgSQL function into an abstract syntax tree.
810///
911/// # Example
1012///
1113/// ```rust
12- /// let result = pgls_query::parse_plpgsql("
14+ /// use pgls_query::parse_plpgsql;
15+ ///
16+ /// let result = parse_plpgsql("
1317/// CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
1418/// RETURNS varchar AS $$
1519/// BEGIN
@@ -21,18 +25,255 @@ use crate::error::*;
2125/// $$ LANGUAGE plpgsql;
2226/// ");
2327/// assert!(result.is_ok());
28+ /// let result = result.unwrap();
29+ /// assert_eq!(result.functions().len(), 1);
2430/// ```
25- pub fn parse_plpgsql ( stmt : & str ) -> Result < ( ) > {
31+ pub fn parse_plpgsql ( stmt : & str ) -> Result < PlpgsqlParseResult > {
2632 let input = CString :: new ( stmt) ?;
27- let result = unsafe { pg_query_parse_plpgsql ( input. as_ptr ( ) ) } ;
28- let structure = if !result. error . is_null ( ) {
33+ let result = unsafe { pg_query_parse_plpgsql_protobuf ( input. as_ptr ( ) ) } ;
34+ let parse_result = if !result. error . is_null ( ) {
2935 let message = unsafe { CStr :: from_ptr ( ( * result. error ) . message ) }
3036 . to_string_lossy ( )
3137 . to_string ( ) ;
3238 Err ( Error :: Parse ( message) )
3339 } else {
34- Ok ( ( ) )
40+ let data = unsafe {
41+ std:: slice:: from_raw_parts (
42+ result. parse_tree . data as * const u8 ,
43+ result. parse_tree . len as usize ,
44+ )
45+ } ;
46+ protobuf:: PLpgSqlParseResult :: decode ( data)
47+ . map_err ( Error :: Decode )
48+ . map ( PlpgsqlParseResult :: new)
3549 } ;
36- unsafe { pg_query_free_plpgsql_parse_result ( result) } ;
37- structure
50+ unsafe { pg_query_free_plpgsql_protobuf_parse_result ( result) } ;
51+ parse_result
52+ }
53+
54+ /// The result of parsing a PL/pgSQL function
55+ #[ derive( Debug ) ]
56+ pub struct PlpgsqlParseResult {
57+ /// The parsed protobuf result
58+ pub protobuf : protobuf:: PLpgSqlParseResult ,
59+ }
60+
61+ impl PlpgsqlParseResult {
62+ /// Create a new PlpgsqlParseResult
63+ pub fn new ( protobuf : protobuf:: PLpgSqlParseResult ) -> Self {
64+ Self { protobuf }
65+ }
66+
67+ /// Returns a reference to the single parsed PL/pgSQL function.
68+ ///
69+ /// Returns `None` if there is not exactly one function in the parse result.
70+ /// Use `functions()` if you need to handle multiple functions.
71+ pub fn function ( & self ) -> Option < & protobuf:: PLpgSqlFunction > {
72+ if self . protobuf . plpgsql_funcs . len ( ) != 1 {
73+ return None ;
74+ }
75+ self . protobuf . plpgsql_funcs . first ( )
76+ }
77+
78+ /// Consumes the result and returns the single parsed function.
79+ ///
80+ /// Returns `None` if there is not exactly one function in the parse result.
81+ /// Use `into_functions()` if you need to handle multiple functions.
82+ pub fn into_function ( self ) -> Option < protobuf:: PLpgSqlFunction > {
83+ if self . protobuf . plpgsql_funcs . len ( ) != 1 {
84+ return None ;
85+ }
86+ self . protobuf . plpgsql_funcs . into_iter ( ) . next ( )
87+ }
88+
89+ /// Returns a reference to the list of parsed PL/pgSQL functions
90+ pub fn functions ( & self ) -> & [ protobuf:: PLpgSqlFunction ] {
91+ & self . protobuf . plpgsql_funcs
92+ }
93+
94+ /// Consumes the result and returns the list of parsed functions
95+ pub fn into_functions ( self ) -> Vec < protobuf:: PLpgSqlFunction > {
96+ self . protobuf . plpgsql_funcs
97+ }
98+ }
99+
100+ #[ cfg( test) ]
101+ mod tests {
102+ use super :: * ;
103+ use crate :: protobuf:: p_lpg_sql_stmt:: Stmt ;
104+
105+ #[ test]
106+ fn test_parse_plpgsql_simple ( ) {
107+ let result = parse_plpgsql (
108+ "
109+ CREATE OR REPLACE FUNCTION test_func()
110+ RETURNS void AS $$
111+ BEGIN
112+ NULL;
113+ END;
114+ $$ LANGUAGE plpgsql;
115+ " ,
116+ ) ;
117+ assert ! ( result. is_ok( ) ) ;
118+ let result = result. unwrap ( ) ;
119+
120+ // Use function() for single function access
121+ let func = result. function ( ) . expect ( "should have exactly one function" ) ;
122+ assert ! ( func. action. is_some( ) ) ;
123+
124+ // The body should contain statements
125+ let action = func. action . as_ref ( ) . unwrap ( ) ;
126+ assert ! ( !action. body. is_empty( ) ) ;
127+ }
128+
129+ #[ test]
130+ fn test_parse_plpgsql_with_assignment ( ) {
131+ let result = parse_plpgsql (
132+ "
133+ CREATE OR REPLACE FUNCTION add_numbers(a int, b int)
134+ RETURNS int AS $$
135+ DECLARE
136+ result int;
137+ BEGIN
138+ result := a + b;
139+ RETURN result;
140+ END;
141+ $$ LANGUAGE plpgsql;
142+ " ,
143+ ) ;
144+ assert ! ( result. is_ok( ) ) ;
145+ let result = result. unwrap ( ) ;
146+ assert_eq ! ( result. functions( ) . len( ) , 1 ) ;
147+
148+ let func = & result. functions ( ) [ 0 ] ;
149+ let action = func. action . as_ref ( ) . unwrap ( ) ;
150+
151+ // Should have assignment and return statements
152+ assert ! ( action. body. len( ) >= 2 ) ;
153+
154+ // First statement should be an assignment
155+ let first_stmt = & action. body [ 0 ] ;
156+ assert ! ( matches!( first_stmt. stmt, Some ( Stmt :: StmtAssign ( _) ) ) ) ;
157+
158+ // Second statement should be a return
159+ let second_stmt = & action. body [ 1 ] ;
160+ assert ! ( matches!( second_stmt. stmt, Some ( Stmt :: StmtReturn ( _) ) ) ) ;
161+
162+ // Verify the assignment expression contains the query
163+ if let Some ( Stmt :: StmtAssign ( assign) ) = & first_stmt. stmt {
164+ assert ! ( assign. expr. is_some( ) ) ;
165+ let expr = assign. expr . as_ref ( ) . unwrap ( ) ;
166+ assert ! ( expr. query. contains( "a + b" ) ) ;
167+ }
168+ }
169+
170+ #[ test]
171+ fn test_parse_plpgsql_with_if ( ) {
172+ let result = parse_plpgsql (
173+ "
174+ CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
175+ RETURNS varchar AS $$
176+ BEGIN
177+ IF v_version IS NULL THEN
178+ RETURN v_name;
179+ END IF;
180+ RETURN v_name || '/' || v_version;
181+ END;
182+ $$ LANGUAGE plpgsql;
183+ " ,
184+ ) ;
185+ assert ! ( result. is_ok( ) ) ;
186+ let result = result. unwrap ( ) ;
187+ assert_eq ! ( result. functions( ) . len( ) , 1 ) ;
188+
189+ let func = & result. functions ( ) [ 0 ] ;
190+ let action = func. action . as_ref ( ) . unwrap ( ) ;
191+
192+ // Should have IF and RETURN statements
193+ assert ! ( action. body. len( ) >= 2 ) ;
194+
195+ // First statement should be IF
196+ let if_stmt = & action. body [ 0 ] ;
197+ assert ! ( matches!( if_stmt. stmt, Some ( Stmt :: StmtIf ( _) ) ) ) ;
198+
199+ // Verify the IF statement structure
200+ if let Some ( Stmt :: StmtIf ( if_node) ) = & if_stmt. stmt {
201+ // Should have a condition
202+ assert ! ( if_node. cond. is_some( ) ) ;
203+ let cond = if_node. cond . as_ref ( ) . unwrap ( ) ;
204+ assert ! ( cond. query. contains( "v_version IS NULL" ) ) ;
205+
206+ // Should have a then_body with RETURN statement
207+ assert ! ( !if_node. then_body. is_empty( ) ) ;
208+ assert ! ( matches!(
209+ if_node. then_body[ 0 ] . stmt,
210+ Some ( Stmt :: StmtReturn ( _) )
211+ ) ) ;
212+ }
213+
214+ // Second statement should be RETURN
215+ let return_stmt = & action. body [ 1 ] ;
216+ assert ! ( matches!( return_stmt. stmt, Some ( Stmt :: StmtReturn ( _) ) ) ) ;
217+ }
218+
219+ #[ test]
220+ fn test_parse_plpgsql_with_loop ( ) {
221+ let result = parse_plpgsql (
222+ "
223+ CREATE OR REPLACE FUNCTION count_down(n int)
224+ RETURNS void AS $$
225+ BEGIN
226+ WHILE n > 0 LOOP
227+ n := n - 1;
228+ END LOOP;
229+ END;
230+ $$ LANGUAGE plpgsql;
231+ " ,
232+ ) ;
233+ assert ! ( result. is_ok( ) ) ;
234+ let result = result. unwrap ( ) ;
235+
236+ let func = & result. functions ( ) [ 0 ] ;
237+ let action = func. action . as_ref ( ) . unwrap ( ) ;
238+
239+ // First statement should be WHILE loop
240+ let while_stmt = & action. body [ 0 ] ;
241+ assert ! ( matches!( while_stmt. stmt, Some ( Stmt :: StmtWhile ( _) ) ) ) ;
242+
243+ if let Some ( Stmt :: StmtWhile ( while_node) ) = & while_stmt. stmt {
244+ // Should have a condition
245+ assert ! ( while_node. cond. is_some( ) ) ;
246+ let cond = while_node. cond . as_ref ( ) . unwrap ( ) ;
247+ assert ! ( cond. query. contains( "n > 0" ) ) ;
248+
249+ // Should have a body with assignment
250+ assert ! ( !while_node. body. is_empty( ) ) ;
251+ }
252+ }
253+
254+ #[ test]
255+ fn test_parse_plpgsql_error ( ) {
256+ let result = parse_plpgsql ( "not valid plpgsql" ) ;
257+ assert ! ( result. is_err( ) ) ;
258+ }
259+
260+ #[ test]
261+ fn test_parse_plpgsql_multiple_functions ( ) {
262+ let result = parse_plpgsql (
263+ "
264+ CREATE FUNCTION foo() RETURNS void AS $$ BEGIN NULL; END; $$ LANGUAGE plpgsql;
265+ CREATE FUNCTION bar() RETURNS int AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql;
266+ " ,
267+ ) ;
268+ assert ! ( result. is_ok( ) ) ;
269+ let result = result. unwrap ( ) ;
270+
271+ // function() returns None when multiple functions present
272+ assert ! ( result. function( ) . is_none( ) ) ;
273+
274+ // Use functions() for multiple
275+ assert_eq ! ( result. functions( ) . len( ) , 2 ) ;
276+ assert_eq ! ( result. functions( ) [ 0 ] . fn_signature, "foo" ) ;
277+ assert_eq ! ( result. functions( ) [ 1 ] . fn_signature, "bar" ) ;
278+ }
38279}
0 commit comments