11use std:: ffi:: { CStr , CString } ;
22
3+ use prost:: Message ;
4+
35use crate :: bindings:: * ;
46use crate :: error:: * ;
7+ use crate :: protobuf;
58
6- /// An experimental API which parses a PLPGSQL function. This currently drops the returned
7- /// structure and returns only a Result<()>.
9+ /// Parses the given PL/pgSQL function into an abstract syntax tree.
810///
911/// # Example
1012///
1113/// ```rust
12- /// let result = pgls_query::parse_plpgsql("
14+ /// use pgls_query::parse_plpgsql;
15+ ///
16+ /// let result = parse_plpgsql("
1317/// CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
1418/// RETURNS varchar AS $$
1519/// BEGIN
@@ -21,18 +25,237 @@ use crate::error::*;
2125/// $$ LANGUAGE plpgsql;
2226/// ");
2327/// assert!(result.is_ok());
28+ /// let result = result.unwrap();
29+ /// assert_eq!(result.functions().len(), 1);
2430/// ```
25- pub fn parse_plpgsql ( stmt : & str ) -> Result < ( ) > {
31+ pub fn parse_plpgsql ( stmt : & str ) -> Result < PlpgsqlParseResult > {
2632 let input = CString :: new ( stmt) ?;
27- let result = unsafe { pg_query_parse_plpgsql ( input. as_ptr ( ) ) } ;
28- let structure = if !result. error . is_null ( ) {
33+ let result = unsafe { pg_query_parse_plpgsql_protobuf ( input. as_ptr ( ) ) } ;
34+ let parse_result = if !result. error . is_null ( ) {
2935 let message = unsafe { CStr :: from_ptr ( ( * result. error ) . message ) }
3036 . to_string_lossy ( )
3137 . to_string ( ) ;
3238 Err ( Error :: Parse ( message) )
3339 } else {
34- Ok ( ( ) )
40+ let data = unsafe {
41+ std:: slice:: from_raw_parts (
42+ result. parse_tree . data as * const u8 ,
43+ result. parse_tree . len as usize ,
44+ )
45+ } ;
46+ protobuf:: PLpgSqlParseResult :: decode ( data)
47+ . map_err ( Error :: Decode )
48+ . map ( PlpgsqlParseResult :: new)
3549 } ;
36- unsafe { pg_query_free_plpgsql_parse_result ( result) } ;
37- structure
50+ unsafe { pg_query_free_plpgsql_protobuf_parse_result ( result) } ;
51+ parse_result
52+ }
53+
54+ /// The result of parsing a PL/pgSQL function
55+ #[ derive( Debug ) ]
56+ pub struct PlpgsqlParseResult {
57+ /// The parsed protobuf result
58+ pub protobuf : protobuf:: PLpgSqlParseResult ,
59+ }
60+
61+ impl PlpgsqlParseResult {
62+ /// Create a new PlpgsqlParseResult
63+ pub fn new ( protobuf : protobuf:: PLpgSqlParseResult ) -> Self {
64+ Self { protobuf }
65+ }
66+
67+ /// Returns a reference to the list of parsed PL/pgSQL functions
68+ pub fn functions ( & self ) -> & [ protobuf:: PLpgSqlFunction ] {
69+ & self . protobuf . plpgsql_funcs
70+ }
71+
72+ /// Consumes the result and returns the list of parsed functions
73+ pub fn into_functions ( self ) -> Vec < protobuf:: PLpgSqlFunction > {
74+ self . protobuf . plpgsql_funcs
75+ }
76+ }
77+
78+ #[ cfg( test) ]
79+ mod tests {
80+ use super :: * ;
81+ use crate :: protobuf:: p_lpg_sql_stmt:: Stmt ;
82+
83+ #[ test]
84+ fn test_parse_plpgsql_simple ( ) {
85+ let result = parse_plpgsql (
86+ "
87+ CREATE OR REPLACE FUNCTION test_func()
88+ RETURNS void AS $$
89+ BEGIN
90+ NULL;
91+ END;
92+ $$ LANGUAGE plpgsql;
93+ " ,
94+ ) ;
95+ assert ! ( result. is_ok( ) ) ;
96+ let result = result. unwrap ( ) ;
97+ assert_eq ! ( result. functions( ) . len( ) , 1 ) ;
98+
99+ // Verify function has an action block
100+ let func = & result. functions ( ) [ 0 ] ;
101+ assert ! ( func. action. is_some( ) ) ;
102+
103+ // The body should contain statements
104+ let action = func. action . as_ref ( ) . unwrap ( ) ;
105+ assert ! ( !action. body. is_empty( ) ) ;
106+ }
107+
108+ #[ test]
109+ fn test_parse_plpgsql_with_assignment ( ) {
110+ let result = parse_plpgsql (
111+ "
112+ CREATE OR REPLACE FUNCTION add_numbers(a int, b int)
113+ RETURNS int AS $$
114+ DECLARE
115+ result int;
116+ BEGIN
117+ result := a + b;
118+ RETURN result;
119+ END;
120+ $$ LANGUAGE plpgsql;
121+ " ,
122+ ) ;
123+ assert ! ( result. is_ok( ) ) ;
124+ let result = result. unwrap ( ) ;
125+ assert_eq ! ( result. functions( ) . len( ) , 1 ) ;
126+
127+ let func = & result. functions ( ) [ 0 ] ;
128+ let action = func. action . as_ref ( ) . unwrap ( ) ;
129+
130+ // Should have assignment and return statements
131+ assert ! ( action. body. len( ) >= 2 ) ;
132+
133+ // First statement should be an assignment
134+ let first_stmt = & action. body [ 0 ] ;
135+ assert ! ( matches!(
136+ first_stmt. stmt,
137+ Some ( Stmt :: StmtAssign ( _) )
138+ ) ) ;
139+
140+ // Second statement should be a return
141+ let second_stmt = & action. body [ 1 ] ;
142+ assert ! ( matches!(
143+ second_stmt. stmt,
144+ Some ( Stmt :: StmtReturn ( _) )
145+ ) ) ;
146+
147+ // Verify the assignment expression contains the query
148+ if let Some ( Stmt :: StmtAssign ( assign) ) = & first_stmt. stmt {
149+ assert ! ( assign. expr. is_some( ) ) ;
150+ let expr = assign. expr . as_ref ( ) . unwrap ( ) ;
151+ assert ! ( expr. query. contains( "a + b" ) ) ;
152+ }
153+ }
154+
155+ #[ test]
156+ fn test_parse_plpgsql_with_if ( ) {
157+ let result = parse_plpgsql (
158+ "
159+ CREATE OR REPLACE FUNCTION cs_fmt_browser_version(v_name varchar, v_version varchar)
160+ RETURNS varchar AS $$
161+ BEGIN
162+ IF v_version IS NULL THEN
163+ RETURN v_name;
164+ END IF;
165+ RETURN v_name || '/' || v_version;
166+ END;
167+ $$ LANGUAGE plpgsql;
168+ " ,
169+ ) ;
170+ assert ! ( result. is_ok( ) ) ;
171+ let result = result. unwrap ( ) ;
172+ assert_eq ! ( result. functions( ) . len( ) , 1 ) ;
173+
174+ let func = & result. functions ( ) [ 0 ] ;
175+ let action = func. action . as_ref ( ) . unwrap ( ) ;
176+
177+ // Should have IF and RETURN statements
178+ assert ! ( action. body. len( ) >= 2 ) ;
179+
180+ // First statement should be IF
181+ let if_stmt = & action. body [ 0 ] ;
182+ assert ! ( matches!( if_stmt. stmt, Some ( Stmt :: StmtIf ( _) ) ) ) ;
183+
184+ // Verify the IF statement structure
185+ if let Some ( Stmt :: StmtIf ( if_node) ) = & if_stmt. stmt {
186+ // Should have a condition
187+ assert ! ( if_node. cond. is_some( ) ) ;
188+ let cond = if_node. cond . as_ref ( ) . unwrap ( ) ;
189+ assert ! ( cond. query. contains( "v_version IS NULL" ) ) ;
190+
191+ // Should have a then_body with RETURN statement
192+ assert ! ( !if_node. then_body. is_empty( ) ) ;
193+ assert ! ( matches!(
194+ if_node. then_body[ 0 ] . stmt,
195+ Some ( Stmt :: StmtReturn ( _) )
196+ ) ) ;
197+ }
198+
199+ // Second statement should be RETURN
200+ let return_stmt = & action. body [ 1 ] ;
201+ assert ! ( matches!( return_stmt. stmt, Some ( Stmt :: StmtReturn ( _) ) ) ) ;
202+ }
203+
204+ #[ test]
205+ fn test_parse_plpgsql_with_loop ( ) {
206+ let result = parse_plpgsql (
207+ "
208+ CREATE OR REPLACE FUNCTION count_down(n int)
209+ RETURNS void AS $$
210+ BEGIN
211+ WHILE n > 0 LOOP
212+ n := n - 1;
213+ END LOOP;
214+ END;
215+ $$ LANGUAGE plpgsql;
216+ " ,
217+ ) ;
218+ assert ! ( result. is_ok( ) ) ;
219+ let result = result. unwrap ( ) ;
220+
221+ let func = & result. functions ( ) [ 0 ] ;
222+ let action = func. action . as_ref ( ) . unwrap ( ) ;
223+
224+ // First statement should be WHILE loop
225+ let while_stmt = & action. body [ 0 ] ;
226+ assert ! ( matches!( while_stmt. stmt, Some ( Stmt :: StmtWhile ( _) ) ) ) ;
227+
228+ if let Some ( Stmt :: StmtWhile ( while_node) ) = & while_stmt. stmt {
229+ // Should have a condition
230+ assert ! ( while_node. cond. is_some( ) ) ;
231+ let cond = while_node. cond . as_ref ( ) . unwrap ( ) ;
232+ assert ! ( cond. query. contains( "n > 0" ) ) ;
233+
234+ // Should have a body with assignment
235+ assert ! ( !while_node. body. is_empty( ) ) ;
236+ }
237+ }
238+
239+ #[ test]
240+ fn test_parse_plpgsql_error ( ) {
241+ let result = parse_plpgsql ( "not valid plpgsql" ) ;
242+ assert ! ( result. is_err( ) ) ;
243+ }
244+
245+ #[ test]
246+ fn test_parse_plpgsql_multiple_functions ( ) {
247+ let result = parse_plpgsql (
248+ "
249+ CREATE FUNCTION foo() RETURNS void AS $$ BEGIN NULL; END; $$ LANGUAGE plpgsql;
250+ CREATE FUNCTION bar() RETURNS int AS $$ BEGIN RETURN 1; END; $$ LANGUAGE plpgsql;
251+ " ,
252+ ) ;
253+ assert ! ( result. is_ok( ) ) ;
254+ let result = result. unwrap ( ) ;
255+
256+ // Parser handles multiple CREATE FUNCTION statements
257+ assert_eq ! ( result. functions( ) . len( ) , 2 ) ;
258+ assert_eq ! ( result. functions( ) [ 0 ] . fn_signature, "foo" ) ;
259+ assert_eq ! ( result. functions( ) [ 1 ] . fn_signature, "bar" ) ;
260+ }
38261}
0 commit comments