@@ -18,13 +18,20 @@ type (
1818 FilterColumns []FilterColumn `json:"filter_columns"`
1919 FuncWrappedColumns []FuncWrappedColumn `json:"func_wrapped_columns,omitempty"`
2020 UpdateTargets []string `json:"update_targets,omitempty"`
21+ ProceduralBodies []ProceduralBody `json:"procedural_bodies,omitempty"`
2122 HasSelectStar bool `json:"has_select_star"`
2223 HasLimit bool `json:"has_limit"`
2324 HasWhere bool `json:"has_where"`
2425 HasJoin bool `json:"has_join"`
2526 StatementType string `json:"statement_type"`
2627 }
2728
29+ // body content is opaque to pg_query, so it escapes static validation.
30+ ProceduralBody struct {
31+ Kind string `json:"kind"` // "DO", "CREATE FUNCTION", "CREATE PROCEDURE"
32+ Language string `json:"language"` // e.g. "plpgsql"
33+ }
34+
2835 ReferencedTable struct {
2936 Schema * string `json:"schema,omitempty"`
3037 Name string `json:"name"`
@@ -54,6 +61,7 @@ func ParseSQL(sql string) (*ParsedQuery, error) {
5461 tables []ReferencedTable
5562 filterColumns []FilterColumn
5663 funcWrappedColumns []FuncWrappedColumn
64+ proceduralBodies []ProceduralBody
5765 updateTargets []string
5866 hasSelectStar bool
5967 hasJoin bool
@@ -101,6 +109,26 @@ func ParseSQL(sql string) (*ParsedQuery, error) {
101109 if n .DeleteStmt .WhereClause != nil {
102110 hasWhere = true
103111 }
112+ case * pg_query.Node_DoStmt :
113+ if stmtType == "" {
114+ stmtType = "DO"
115+ }
116+ proceduralBodies = append (proceduralBodies , ProceduralBody {
117+ Kind : "DO" ,
118+ Language : doStmtLanguage (n .DoStmt ),
119+ })
120+ case * pg_query.Node_CreateFunctionStmt :
121+ kind := "CREATE FUNCTION"
122+ if n .CreateFunctionStmt .IsProcedure {
123+ kind = "CREATE PROCEDURE"
124+ }
125+ if stmtType == "" {
126+ stmtType = kind
127+ }
128+ proceduralBodies = append (proceduralBodies , ProceduralBody {
129+ Kind : kind ,
130+ Language : createFunctionLanguage (n .CreateFunctionStmt ),
131+ })
104132 }
105133
106134 // WHERE for func-wrapped columns (date_trunc(col), col::date, ...)
@@ -175,6 +203,7 @@ func ParseSQL(sql string) (*ParsedQuery, error) {
175203 HasJoin : hasJoin ,
176204 FuncWrappedColumns : funcWrappedColumns ,
177205 UpdateTargets : updateTargets ,
206+ ProceduralBodies : proceduralBodies ,
178207 StatementType : stmtType ,
179208 },
180209 }, nil
@@ -446,6 +475,37 @@ func extractTypeName(tn *pg_query.TypeName) string {
446475 return ""
447476}
448477
478+ // DO defaults to plpgsql when no LANGUAGE is given.
479+ func doStmtLanguage (s * pg_query.DoStmt ) string {
480+ if s == nil {
481+ return "plpgsql"
482+ }
483+ if lang := defElemLanguage (s .Args ); lang != "" {
484+ return lang
485+ }
486+ return "plpgsql"
487+ }
488+
489+ func createFunctionLanguage (s * pg_query.CreateFunctionStmt ) string {
490+ if s == nil {
491+ return ""
492+ }
493+ return defElemLanguage (s .Options )
494+ }
495+
496+ func defElemLanguage (opts []* pg_query.Node ) string {
497+ for _ , opt := range opts {
498+ de , ok := opt .Node .(* pg_query.Node_DefElem )
499+ if ! ok || de .DefElem == nil || de .DefElem .Defname != "language" {
500+ continue
501+ }
502+ if s , ok := de .DefElem .Arg .Node .(* pg_query.Node_String_ ); ok {
503+ return strings .ToLower (s .String_ .Sval )
504+ }
505+ }
506+ return ""
507+ }
508+
449509func strp (s string ) * string { return & s }
450510
451511func splitQualified (name string ) (* string , string ) {
0 commit comments