@@ -128,66 +128,62 @@ func isDOBlock(stmt *parser.Statement) bool {
128128 return stmt .Node != nil && stmt .Node .GetDoStmt () != nil
129129}
130130
131- func instrumentWithLexer (stmt * parser.Statement , filePath string ) (string , []CoveragePoint ) {
131+ // extractFunctionBody extracts the function/DO-block body text from the AST node.
132+ // Returns the body text or "" if not found.
133+ func extractFunctionBody (stmt * parser.Statement ) string {
132134 if stmt .Node == nil {
133- return stmt . RawSQL , nil
135+ return ""
134136 }
135137
136- // Extract function body from the parsed AST node
137- var functionBodyText string
138-
139138 switch node := stmt .Node .Node ; node .(type ) {
140139 case * pgquery.Node_CreateFunctionStmt :
141- // Get the function body from the "as" option
142140 createFunc := stmt .Node .GetCreateFunctionStmt ()
143141 for _ , opt := range createFunc .Options {
144142 if defElem := opt .GetDefElem (); defElem != nil && defElem .Defname == "as" {
145143 if defElem .Arg != nil {
146144 if strList := defElem .Arg .GetList (); strList != nil && len (strList .Items ) > 0 {
147145 if strNode := strList .Items [0 ].GetString_ (); strNode != nil {
148- functionBodyText = strNode .Sval
149- break
146+ return strNode .Sval
150147 }
151148 } else if strNode := defElem .Arg .GetString_ (); strNode != nil {
152- functionBodyText = strNode .Sval
153- break
149+ return strNode .Sval
154150 }
155151 }
156152 }
157153 }
158154
159155 case * pgquery.Node_DoStmt :
160- // For DO blocks
161156 if doStmt := stmt .Node .GetDoStmt (); len (doStmt .Args ) > 0 {
162157 if strNode := doStmt .Args [0 ].GetString_ (); strNode != nil {
163- functionBodyText = strNode .Sval
158+ return strNode .Sval
164159 }
165160 }
166-
167- default :
168- return stmt .RawSQL , nil
169161 }
170162
171- // Scan the function body content with lexer
172- ScanRes , err := pgquery .Scan (functionBodyText )
173- if err != nil {
174- // Return original on scan error
163+ return ""
164+ }
165+
166+ func instrumentWithLexer (stmt * parser.Statement , filePath string ) (string , []CoveragePoint ) {
167+ functionBodyText := extractFunctionBody (stmt )
168+ if functionBodyText == "" {
175169 return stmt .RawSQL , nil
176170 }
177171
178- tokens := ScanRes .GetTokens ()
172+ // Scan the function body content with our PL/pgSQL lexer
173+ scanner := parser .NewScanner (functionBodyText )
174+ tokens := scanner .ScanAll ()
179175 if len (tokens ) == 0 {
180176 return stmt .RawSQL , nil
181177 }
182178
183- // Find executable statement boundaries in the body content
184- executableSegments := findExecutableSegments (functionBodyText , tokens )
179+ // Find executable statement boundaries in the body content (skip to BEGIN for PL/pgSQL)
180+ executableSegments := findExecutableSegments (functionBodyText , tokens , true )
185181 if len (executableSegments ) == 0 {
186182 return stmt .RawSQL , nil
187183 }
188184
189185 // Create coverage points and inject PERFORM calls
190- return instrumentFunctionBodyFromAST (stmt , filePath , functionBodyText , executableSegments )
186+ return instrumentFunctionBodyFromAST (stmt , filePath , functionBodyText , executableSegments , "PERFORM" )
191187}
192188
193189type executableSegment struct {
@@ -197,31 +193,35 @@ type executableSegment struct {
197193 lineEnd int // Line number in body content (0-based)
198194}
199195
200- // findExecutableSegments finds executable statement segments in function body
201- func findExecutableSegments (bodyContent string , tokens []* pgquery.ScanToken ) []executableSegment {
196+ // findExecutableSegments finds executable statement segments in function body.
197+ // When skipToBegin is true (PL/pgSQL), tokens before the first BEGIN are skipped.
198+ // When skipToBegin is false (SQL functions), all tokens are considered immediately.
199+ func findExecutableSegments (bodyContent string , tokens []parser.Token , skipToBegin bool ) []executableSegment {
202200 var segments []executableSegment
203201
204202 hasExecutableContent := false
205203 firstExecutableTokenPos := - 1
206204
207- for idx , token := range tokens {
208- if token .Token == pgquery .Token_BEGIN_P { // Skip until BEGIN token
209- tokens = tokens [idx + 1 :]
210- break
205+ if skipToBegin {
206+ for idx , token := range tokens {
207+ if token .Type == parser .KBegin { // Skip until BEGIN token
208+ tokens = tokens [idx + 1 :]
209+ break
210+ }
211211 }
212212 }
213213
214214 for _ , token := range tokens {
215215 // Skip comment tokens
216- if isCommentToken ( int32 ( token .Token )) {
216+ if token .Type == parser . Comment {
217217 continue
218218 }
219219
220220 // Check if this is a semicolon (statement separator)
221- if token .Token == pgquery . Token_ASCII_59 { // Token_ASCII_59 - ";"
221+ if token .Type == parser . TokenType ( ';' ) {
222222 if hasExecutableContent && firstExecutableTokenPos >= 0 {
223223 // Check if this segment represents an executable statement
224- segmentEnd := int ( token .Start )
224+ segmentEnd := token .Pos
225225 segmentContent := bodyContent [firstExecutableTokenPos :segmentEnd ]
226226
227227 if isExecutableSegment (segmentContent ) {
@@ -241,7 +241,7 @@ func findExecutableSegments(bodyContent string, tokens []*pgquery.ScanToken) []e
241241 } else {
242242 // This is some non-comment token, so we have content
243243 if ! hasExecutableContent {
244- firstExecutableTokenPos = int ( token .Start )
244+ firstExecutableTokenPos = token .Pos
245245 }
246246 hasExecutableContent = true
247247 }
@@ -321,8 +321,9 @@ func convertByteOffsetToLine(sql string, byteOffset int) int {
321321 return lineIdx
322322}
323323
324- // instrumentFunctionBodyFromAST injects PERFORM calls using AST-extracted function body
325- func instrumentFunctionBodyFromAST (stmt * parser.Statement , filePath string , bodyContent string , segments []executableSegment ) (string , []CoveragePoint ) {
324+ // instrumentFunctionBodyFromAST injects coverage-tracking calls using AST-extracted function body.
325+ // notifyCmd is the SQL command used for the pg_notify call: "PERFORM" for PL/pgSQL, "SELECT" for SQL functions.
326+ func instrumentFunctionBodyFromAST (stmt * parser.Statement , filePath string , bodyContent string , segments []executableSegment , notifyCmd string ) (string , []CoveragePoint ) {
326327 var locations []CoveragePoint
327328
328329 // Find where the function body content actually starts in the original SQL
@@ -381,9 +382,9 @@ func instrumentFunctionBodyFromAST(stmt *parser.Statement, filePath string, body
381382 }
382383 }
383384
384- // Inject PERFORM pg_notify call before the segment
385- notifyCall := fmt .Sprintf ("%sPERFORM pg_notify('pgcov', '%s');\n " ,
386- indent , strings .ReplaceAll (cp .SignalID , "'" , "''" ))
385+ // Inject coverage-tracking pg_notify call before the segment
386+ notifyCall := fmt .Sprintf ("%s%s pg_notify('pgcov', '%s');\n " ,
387+ indent , notifyCmd , strings .ReplaceAll (cp .SignalID , "'" , "''" ))
387388 instrumentedBody .WriteString (notifyCall )
388389
389390 // Write the original segment content
@@ -403,33 +404,35 @@ func instrumentFunctionBodyFromAST(stmt *parser.Statement, filePath string, body
403404 return result , locations
404405}
405406
406- // isCommentToken checks if a token is a comment token that should be excluded
407- func isCommentToken (tokenType int32 ) bool {
408- return tokenType == 275 || tokenType == 276 // Token_SQL_COMMENT || Token_C_COMMENT
409- }
410-
411407// getIndentation returns the leading whitespace of a line
412408func getIndentation (line string ) string {
413409 return line [:len (line )- len (strings .TrimLeft (line , " \t " ))]
414410}
415411
416- // instrumentSQLFunction instruments a SQL function
412+ // instrumentSQLFunction instruments a SQL-language function.
413+ // SQL functions have no DECLARE/BEGIN block, so we scan the body immediately.
414+ // Since PERFORM is not valid in SQL functions, we use SELECT pg_notify(...) instead.
417415func instrumentSQLFunction (stmt * parser.Statement , filePath string ) (string , []CoveragePoint ) {
418- // For SQL functions, we can't inject PERFORM, so we mark the function definition
419- // Use the byte position from the parsed statement
420- bytePos := stmt .StartPos
421- stmtLength := len (stmt .RawSQL )
422- cp := CoveragePoint {
423- File : filePath ,
424- StartPos : bytePos ,
425- Length : stmtLength ,
426- Branch : "" ,
416+ functionBodyText := extractFunctionBody (stmt )
417+ if functionBodyText == "" {
418+ return stmt .RawSQL , nil
419+ }
420+
421+ // Scan the function body with our PL/pgSQL lexer (works for plain SQL too)
422+ scanner := parser .NewScanner (functionBodyText )
423+ tokens := scanner .ScanAll ()
424+ if len (tokens ) == 0 {
425+ return stmt .RawSQL , nil
426+ }
427+
428+ // Find executable segments without skipping to BEGIN (SQL functions have no BEGIN)
429+ executableSegments := findExecutableSegments (functionBodyText , tokens , false )
430+ if len (executableSegments ) == 0 {
431+ return stmt .RawSQL , nil
427432 }
428- cp .SignalID = FormatSignalID (cp .File , cp .StartPos , cp .Length , cp .Branch )
429433
430- // SQL functions are harder to instrument - for now, just track the function call
431- // This would require wrapping the SQL expression which is complex
432- return stmt .RawSQL , []CoveragePoint {cp }
434+ // Inject SELECT pg_notify calls
435+ return instrumentFunctionBodyFromAST (stmt , filePath , functionBodyText , executableSegments , "SELECT" )
433436}
434437
435438// markStatementLinesAsCovered creates coverage points for all non-comment lines
0 commit comments