Skip to content

Commit e7eeeed

Browse files
committed
feat: add plpgsql_scanner
this scanner is fully compatible to the core Postgres scanner and allows to scan tokens faster and more flexible
1 parent cbd0d42 commit e7eeeed

3 files changed

Lines changed: 2356 additions & 57 deletions

File tree

internal/instrument/instrumenter.go

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -128,66 +128,62 @@ func isDOBlock(stmt *parser.Statement) bool {
128128
return stmt.Node != nil && stmt.Node.GetDoStmt() != nil
129129
}
130130

131-
func instrumentWithLexer(stmt *parser.Statement, filePath string) (string, []CoveragePoint) {
131+
// extractFunctionBody extracts the function/DO-block body text from the AST node.
132+
// Returns the body text or "" if not found.
133+
func extractFunctionBody(stmt *parser.Statement) string {
132134
if stmt.Node == nil {
133-
return stmt.RawSQL, nil
135+
return ""
134136
}
135137

136-
// Extract function body from the parsed AST node
137-
var functionBodyText string
138-
139138
switch node := stmt.Node.Node; node.(type) {
140139
case *pgquery.Node_CreateFunctionStmt:
141-
// Get the function body from the "as" option
142140
createFunc := stmt.Node.GetCreateFunctionStmt()
143141
for _, opt := range createFunc.Options {
144142
if defElem := opt.GetDefElem(); defElem != nil && defElem.Defname == "as" {
145143
if defElem.Arg != nil {
146144
if strList := defElem.Arg.GetList(); strList != nil && len(strList.Items) > 0 {
147145
if strNode := strList.Items[0].GetString_(); strNode != nil {
148-
functionBodyText = strNode.Sval
149-
break
146+
return strNode.Sval
150147
}
151148
} else if strNode := defElem.Arg.GetString_(); strNode != nil {
152-
functionBodyText = strNode.Sval
153-
break
149+
return strNode.Sval
154150
}
155151
}
156152
}
157153
}
158154

159155
case *pgquery.Node_DoStmt:
160-
// For DO blocks
161156
if doStmt := stmt.Node.GetDoStmt(); len(doStmt.Args) > 0 {
162157
if strNode := doStmt.Args[0].GetString_(); strNode != nil {
163-
functionBodyText = strNode.Sval
158+
return strNode.Sval
164159
}
165160
}
166-
167-
default:
168-
return stmt.RawSQL, nil
169161
}
170162

171-
// Scan the function body content with lexer
172-
ScanRes, err := pgquery.Scan(functionBodyText)
173-
if err != nil {
174-
// Return original on scan error
163+
return ""
164+
}
165+
166+
func instrumentWithLexer(stmt *parser.Statement, filePath string) (string, []CoveragePoint) {
167+
functionBodyText := extractFunctionBody(stmt)
168+
if functionBodyText == "" {
175169
return stmt.RawSQL, nil
176170
}
177171

178-
tokens := ScanRes.GetTokens()
172+
// Scan the function body content with our PL/pgSQL lexer
173+
scanner := parser.NewScanner(functionBodyText)
174+
tokens := scanner.ScanAll()
179175
if len(tokens) == 0 {
180176
return stmt.RawSQL, nil
181177
}
182178

183-
// Find executable statement boundaries in the body content
184-
executableSegments := findExecutableSegments(functionBodyText, tokens)
179+
// Find executable statement boundaries in the body content (skip to BEGIN for PL/pgSQL)
180+
executableSegments := findExecutableSegments(functionBodyText, tokens, true)
185181
if len(executableSegments) == 0 {
186182
return stmt.RawSQL, nil
187183
}
188184

189185
// Create coverage points and inject PERFORM calls
190-
return instrumentFunctionBodyFromAST(stmt, filePath, functionBodyText, executableSegments)
186+
return instrumentFunctionBodyFromAST(stmt, filePath, functionBodyText, executableSegments, "PERFORM")
191187
}
192188

193189
type executableSegment struct {
@@ -197,31 +193,35 @@ type executableSegment struct {
197193
lineEnd int // Line number in body content (0-based)
198194
}
199195

200-
// findExecutableSegments finds executable statement segments in function body
201-
func findExecutableSegments(bodyContent string, tokens []*pgquery.ScanToken) []executableSegment {
196+
// findExecutableSegments finds executable statement segments in function body.
197+
// When skipToBegin is true (PL/pgSQL), tokens before the first BEGIN are skipped.
198+
// When skipToBegin is false (SQL functions), all tokens are considered immediately.
199+
func findExecutableSegments(bodyContent string, tokens []parser.Token, skipToBegin bool) []executableSegment {
202200
var segments []executableSegment
203201

204202
hasExecutableContent := false
205203
firstExecutableTokenPos := -1
206204

207-
for idx, token := range tokens {
208-
if token.Token == pgquery.Token_BEGIN_P { // Skip until BEGIN token
209-
tokens = tokens[idx+1:]
210-
break
205+
if skipToBegin {
206+
for idx, token := range tokens {
207+
if token.Type == parser.KBegin { // Skip until BEGIN token
208+
tokens = tokens[idx+1:]
209+
break
210+
}
211211
}
212212
}
213213

214214
for _, token := range tokens {
215215
// Skip comment tokens
216-
if isCommentToken(int32(token.Token)) {
216+
if token.Type == parser.Comment {
217217
continue
218218
}
219219

220220
// Check if this is a semicolon (statement separator)
221-
if token.Token == pgquery.Token_ASCII_59 { // Token_ASCII_59 - ";"
221+
if token.Type == parser.TokenType(';') {
222222
if hasExecutableContent && firstExecutableTokenPos >= 0 {
223223
// Check if this segment represents an executable statement
224-
segmentEnd := int(token.Start)
224+
segmentEnd := token.Pos
225225
segmentContent := bodyContent[firstExecutableTokenPos:segmentEnd]
226226

227227
if isExecutableSegment(segmentContent) {
@@ -241,7 +241,7 @@ func findExecutableSegments(bodyContent string, tokens []*pgquery.ScanToken) []e
241241
} else {
242242
// This is some non-comment token, so we have content
243243
if !hasExecutableContent {
244-
firstExecutableTokenPos = int(token.Start)
244+
firstExecutableTokenPos = token.Pos
245245
}
246246
hasExecutableContent = true
247247
}
@@ -321,8 +321,9 @@ func convertByteOffsetToLine(sql string, byteOffset int) int {
321321
return lineIdx
322322
}
323323

324-
// instrumentFunctionBodyFromAST injects PERFORM calls using AST-extracted function body
325-
func instrumentFunctionBodyFromAST(stmt *parser.Statement, filePath string, bodyContent string, segments []executableSegment) (string, []CoveragePoint) {
324+
// instrumentFunctionBodyFromAST injects coverage-tracking calls using AST-extracted function body.
325+
// notifyCmd is the SQL command used for the pg_notify call: "PERFORM" for PL/pgSQL, "SELECT" for SQL functions.
326+
func instrumentFunctionBodyFromAST(stmt *parser.Statement, filePath string, bodyContent string, segments []executableSegment, notifyCmd string) (string, []CoveragePoint) {
326327
var locations []CoveragePoint
327328

328329
// Find where the function body content actually starts in the original SQL
@@ -381,9 +382,9 @@ func instrumentFunctionBodyFromAST(stmt *parser.Statement, filePath string, body
381382
}
382383
}
383384

384-
// Inject PERFORM pg_notify call before the segment
385-
notifyCall := fmt.Sprintf("%sPERFORM pg_notify('pgcov', '%s');\n",
386-
indent, strings.ReplaceAll(cp.SignalID, "'", "''"))
385+
// Inject coverage-tracking pg_notify call before the segment
386+
notifyCall := fmt.Sprintf("%s%s pg_notify('pgcov', '%s');\n",
387+
indent, notifyCmd, strings.ReplaceAll(cp.SignalID, "'", "''"))
387388
instrumentedBody.WriteString(notifyCall)
388389

389390
// Write the original segment content
@@ -403,33 +404,35 @@ func instrumentFunctionBodyFromAST(stmt *parser.Statement, filePath string, body
403404
return result, locations
404405
}
405406

406-
// isCommentToken checks if a token is a comment token that should be excluded
407-
func isCommentToken(tokenType int32) bool {
408-
return tokenType == 275 || tokenType == 276 // Token_SQL_COMMENT || Token_C_COMMENT
409-
}
410-
411407
// getIndentation returns the leading whitespace of a line
412408
func getIndentation(line string) string {
413409
return line[:len(line)-len(strings.TrimLeft(line, " \t"))]
414410
}
415411

416-
// instrumentSQLFunction instruments a SQL function
412+
// instrumentSQLFunction instruments a SQL-language function.
413+
// SQL functions have no DECLARE/BEGIN block, so we scan the body immediately.
414+
// Since PERFORM is not valid in SQL functions, we use SELECT pg_notify(...) instead.
417415
func instrumentSQLFunction(stmt *parser.Statement, filePath string) (string, []CoveragePoint) {
418-
// For SQL functions, we can't inject PERFORM, so we mark the function definition
419-
// Use the byte position from the parsed statement
420-
bytePos := stmt.StartPos
421-
stmtLength := len(stmt.RawSQL)
422-
cp := CoveragePoint{
423-
File: filePath,
424-
StartPos: bytePos,
425-
Length: stmtLength,
426-
Branch: "",
416+
functionBodyText := extractFunctionBody(stmt)
417+
if functionBodyText == "" {
418+
return stmt.RawSQL, nil
419+
}
420+
421+
// Scan the function body with our PL/pgSQL lexer (works for plain SQL too)
422+
scanner := parser.NewScanner(functionBodyText)
423+
tokens := scanner.ScanAll()
424+
if len(tokens) == 0 {
425+
return stmt.RawSQL, nil
426+
}
427+
428+
// Find executable segments without skipping to BEGIN (SQL functions have no BEGIN)
429+
executableSegments := findExecutableSegments(functionBodyText, tokens, false)
430+
if len(executableSegments) == 0 {
431+
return stmt.RawSQL, nil
427432
}
428-
cp.SignalID = FormatSignalID(cp.File, cp.StartPos, cp.Length, cp.Branch)
429433

430-
// SQL functions are harder to instrument - for now, just track the function call
431-
// This would require wrapping the SQL expression which is complex
432-
return stmt.RawSQL, []CoveragePoint{cp}
434+
// Inject SELECT pg_notify calls
435+
return instrumentFunctionBodyFromAST(stmt, filePath, functionBodyText, executableSegments, "SELECT")
433436
}
434437

435438
// markStatementLinesAsCovered creates coverage points for all non-comment lines

0 commit comments

Comments
 (0)