Skip to content

Commit d51d3bf

Browse files
committed
fix(instrument): place signal before terminal stmts inside branches
When a segment wraps a terminal statement (RETURN, RAISE EXCEPTION) inside a control-flow keyword (IF/ELSIF/ELSE … THEN … RETURN …), the coverage signal was placed after the entire segment — past the RETURN — making it unreachable. Coverage for these branches was always reported as 0 regardless of test execution. Add findTerminalPos() to scan a segment for an embedded terminal statement and return its byte offset. In emitSegment, use a three-way dispatch: 1. Segment starts with terminal → signal before (existing logic) 2. Segment contains terminal inside → signal before the inner terminal (new logic) 3. No terminal → signal after (existing logic) This ensures every injected PERFORM pg_notify() is reachable by at least one execution path. Add indentOf() helper to extract indentation for the inner terminal line (may differ from the segment's first-line indent). Add tests: TestFindTerminalPos (8 cases), TestInstrumentBody_ReturnInBranches, TestInstrumentBody_RaiseExceptionInBranch, and TestInstrumentBody_MixedTerminalNonTerminalBranches.
1 parent 07bd2bd commit d51d3bf

2 files changed

Lines changed: 266 additions & 49 deletions

File tree

internal/instrument/instrumenter.go

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,21 @@ func instrumentBody(stmt *parser.Statement, filePath string, skipToBegin bool, n
165165
}
166166
}
167167

168-
notifyCall := fmt.Sprintf("%s%s pg_notify('pgcov', '%s');",
169-
indent, notifyCmd, strings.ReplaceAll(cp.SignalID, "'", "''"))
170-
171-
if isTerminalSegment(segText) {
172-
// Terminal statements (RETURN, RAISE EXCEPTION) exit the
173-
// current scope — the signal must fire before the statement.
168+
escapedSignal := strings.ReplaceAll(cp.SignalID, "'", "''")
169+
170+
if termPos := findTerminalPos(segText); termPos >= 0 {
171+
// Segment contains (or starts with) a terminal statement
172+
// (RETURN, RAISE EXCEPTION). Place the signal before the
173+
// terminal so it fires before the scope exits.
174+
// When termPos == 0 the segment starts with the terminal;
175+
// when termPos > 0 it is nested inside a control structure
176+
// (e.g. IF … THEN RETURN …).
177+
termIndent := indentOf(segText[termPos:])
178+
notifyCall := fmt.Sprintf("%s%s pg_notify('pgcov', '%s');",
179+
termIndent, notifyCmd, escapedSignal)
180+
instrumentedBody.WriteString(segText[:termPos])
174181
fmt.Fprintf(&instrumentedBody, "%s\n", notifyCall)
175-
instrumentedBody.WriteString(segText)
182+
instrumentedBody.WriteString(segText[termPos:])
176183
lastWrittenPos = segEnd
177184
} else {
178185
// Non-terminal statements: emit the signal after the
@@ -186,6 +193,8 @@ func instrumentBody(stmt *parser.Statement, filePath string, skipToBegin bool, n
186193
} else {
187194
lastWrittenPos = segEnd
188195
}
196+
notifyCall := fmt.Sprintf("%s%s pg_notify('pgcov', '%s');",
197+
indent, notifyCmd, escapedSignal)
189198
fmt.Fprintf(&instrumentedBody, "\n%s", notifyCall)
190199
}
191200
}
@@ -274,49 +283,52 @@ func isExecutableSegment(segmentContent string) bool {
274283
return true
275284
}
276285

277-
// isTerminalSegment checks whether a segment starts with a terminal statement
278-
// (RETURN or RAISE EXCEPTION / bare RAISE) that exits the current scope.
279-
// NOTIFY calls placed after such statements would be unreachable.
280-
// Non-fatal RAISE levels (NOTICE, WARNING, INFO, LOG, DEBUG) are not terminal.
281-
func isTerminalSegment(segmentContent string) bool {
286+
// findTerminalPos scans segmentContent for a terminal statement (RETURN or
287+
// fatal RAISE) and returns its byte position within the string. Returns -1
288+
// if no terminal statement is found. This is used for segments that wrap a
289+
// terminal inside a control-flow keyword (e.g. IF/ELSIF/ELSE … RETURN …).
290+
func findTerminalPos(segmentContent string) int {
282291
sc := pglex.NewScanner(segmentContent)
283-
284-
// Find the first non-comment token.
285-
var first pglex.Token
286292
for {
287-
first = sc.Scan()
288-
if first.Type == pglex.EOF {
289-
return false
293+
tok := sc.Scan()
294+
if tok.Type == pglex.EOF {
295+
return -1
290296
}
291-
if first.Type != pglex.Comment {
292-
break
297+
if tok.Type == pglex.KReturn {
298+
return tok.Pos
293299
}
294-
}
295-
296-
if first.Type == pglex.KReturn {
297-
return true
298-
}
299-
300-
if first.Type == pglex.KRaise {
301-
// RAISE is terminal unless followed by a non-fatal level.
302-
for {
303-
tok := sc.Scan()
304-
if tok.Type == pglex.EOF {
305-
return true // bare RAISE; — re-raise in exception handler
306-
}
307-
if tok.Type == pglex.Comment {
308-
continue
309-
}
310-
switch tok.Type {
311-
case pglex.KNotice, pglex.KWarning, pglex.KInfo, pglex.KLog, pglex.KDebug:
312-
return false
313-
default:
314-
return true // RAISE EXCEPTION, RAISE 'message', etc.
300+
if tok.Type == pglex.KRaise {
301+
pos := tok.Pos
302+
// Peek at the next non-comment token to decide fatality.
303+
for {
304+
next := sc.Scan()
305+
if next.Type == pglex.EOF {
306+
return pos // bare RAISE — re-raise
307+
}
308+
if next.Type == pglex.Comment {
309+
continue
310+
}
311+
switch next.Type {
312+
case pglex.KNotice, pglex.KWarning, pglex.KInfo, pglex.KLog, pglex.KDebug:
313+
// Non-fatal RAISE; continue scanning for a later terminal.
314+
break
315+
default:
316+
return pos // RAISE EXCEPTION, RAISE 'msg', etc.
317+
}
318+
break
315319
}
316320
}
317321
}
322+
}
318323

319-
return false
324+
// indentOf returns the leading whitespace of the first non-empty line in s.
325+
func indentOf(s string) string {
326+
for line := range strings.SplitSeq(s, "\n") {
327+
if strings.TrimSpace(line) != "" {
328+
return getIndentation(line)
329+
}
330+
}
331+
return ""
320332
}
321333

322334
// getIndentation returns the leading whitespace of a line.

internal/instrument/plpgsql_ast_test.go

Lines changed: 212 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,13 @@ $$ LANGUAGE plpgsql;`
238238
t.Logf("Coverage points: %d (may be 0 for malformed SQL)", len(instrumented.Locations))
239239
}
240240

241-
func TestIsTerminalSegment(t *testing.T) {
241+
func TestFindTerminalPos_StartingTerminals(t *testing.T) {
242+
// Segments that start with a terminal should return pos 0.
243+
// Segments without any terminal should return -1.
242244
tests := []struct {
243-
name string
244-
segment string
245-
terminal bool
245+
name string
246+
segment string
247+
wantFound bool
246248
}{
247249
{"RETURN value", "RETURN a + b", true},
248250
{"RETURN bare", "RETURN", true},
@@ -265,9 +267,11 @@ func TestIsTerminalSegment(t *testing.T) {
265267

266268
for _, tt := range tests {
267269
t.Run(tt.name, func(t *testing.T) {
268-
got := isTerminalSegment(tt.segment)
269-
if got != tt.terminal {
270-
t.Errorf("isTerminalSegment(%q) = %v, want %v", tt.segment, got, tt.terminal)
270+
got := findTerminalPos(tt.segment)
271+
if tt.wantFound && got < 0 {
272+
t.Errorf("findTerminalPos(%q) = -1, want >= 0", tt.segment)
273+
} else if !tt.wantFound && got >= 0 {
274+
t.Errorf("findTerminalPos(%q) = %d, want -1", tt.segment, got)
271275
}
272276
})
273277
}
@@ -367,3 +371,204 @@ $$ LANGUAGE plpgsql;`
367371

368372
t.Log(instrumentedSQL)
369373
}
374+
375+
func TestFindTerminalPos(t *testing.T) {
376+
tests := []struct {
377+
name string
378+
segment string
379+
wantFound bool // whether a terminal is expected
380+
wantText string // expected text at the found position (prefix check)
381+
}{
382+
{"bare RETURN", "RETURN x", true, "RETURN"},
383+
{"IF with RETURN", "IF x > 0 THEN\n RETURN 1", true, "RETURN"},
384+
{"ELSIF with RETURN", "\n ELSIF x > 5 THEN\n RETURN 2", true, "RETURN"},
385+
{"ELSE with RETURN", "\n ELSE\n RETURN 3", true, "RETURN"},
386+
{"no terminal", "x := x + 1", false, ""},
387+
{"RAISE NOTICE", "RAISE NOTICE 'hello'", false, ""},
388+
{"IF with RAISE EXCEPTION", "IF x < 0 THEN\n RAISE EXCEPTION 'bad'", true, "RAISE"},
389+
{"IF with non-terminal", "IF x > 0 THEN\n x := 1", false, ""},
390+
}
391+
392+
for _, tt := range tests {
393+
t.Run(tt.name, func(t *testing.T) {
394+
got := findTerminalPos(tt.segment)
395+
if tt.wantFound {
396+
if got < 0 {
397+
t.Fatalf("findTerminalPos(%q) = -1, want >= 0", tt.segment)
398+
}
399+
if !strings.HasPrefix(tt.segment[got:], tt.wantText) {
400+
t.Errorf("findTerminalPos(%q) at %d: got %q, want prefix %q",
401+
tt.segment, got, tt.segment[got:], tt.wantText)
402+
}
403+
} else {
404+
if got >= 0 {
405+
t.Errorf("findTerminalPos(%q) = %d, want -1", tt.segment, got)
406+
}
407+
}
408+
})
409+
}
410+
}
411+
412+
func TestInstrumentBody_ReturnInBranches(t *testing.T) {
413+
// B2 scenario: IF/ELSIF/ELSE with RETURN in each branch.
414+
// All signals must be reachable (placed before the RETURN inside each branch).
415+
sql := `CREATE OR REPLACE FUNCTION check_stock(v_stock INT)
416+
RETURNS TEXT AS $$
417+
BEGIN
418+
IF v_stock = 0 THEN
419+
RETURN 'out_of_stock';
420+
ELSIF v_stock <= 10 THEN
421+
RETURN 'low_stock';
422+
ELSE
423+
RETURN 'in_stock';
424+
END IF;
425+
END;
426+
$$ LANGUAGE plpgsql;`
427+
428+
stmts := parser.ParseStatements(sql)
429+
if len(stmts) == 0 {
430+
t.Fatal("ParseStatements() returned no statements")
431+
}
432+
433+
instrumentedSQL, coveragePoints := instrumentBody(stmts[0], "test.sql", true, "PERFORM")
434+
if len(coveragePoints) != 3 {
435+
t.Fatalf("expected 3 coverage points, got %d", len(coveragePoints))
436+
}
437+
438+
// For each coverage point, verify the NOTIFY comes BEFORE the RETURN
439+
// inside the branch—not after it (which would be unreachable).
440+
returns := []string{"RETURN 'out_of_stock'", "RETURN 'low_stock'", "RETURN 'in_stock'"}
441+
for i, cp := range coveragePoints {
442+
notify := fmt.Sprintf("PERFORM pg_notify('pgcov', '%s');", cp.SignalID)
443+
notifyIdx := strings.Index(instrumentedSQL, notify)
444+
returnIdx := strings.Index(instrumentedSQL, returns[i])
445+
if notifyIdx < 0 || returnIdx < 0 {
446+
t.Fatalf("cp %d: could not find notify or %s", i, returns[i])
447+
}
448+
if notifyIdx > returnIdx {
449+
t.Errorf("cp %d (%s): NOTIFY at %d after RETURN at %d — signal is unreachable",
450+
i, returns[i], notifyIdx, returnIdx)
451+
}
452+
}
453+
454+
// Also verify the instrumented SQL does NOT have a PERFORM between
455+
// a RETURN and the next ELSIF/ELSE/END (that would be unreachable code).
456+
for _, kw := range []string{"ELSIF", "ELSE", "END IF"} {
457+
kwIdx := strings.Index(instrumentedSQL, kw)
458+
if kwIdx < 0 {
459+
continue
460+
}
461+
// Check a narrow window before the keyword for a rogue PERFORM.
462+
before := instrumentedSQL[max(0, kwIdx-80):kwIdx]
463+
// There should be a RETURN between the PERFORM and the keyword boundary.
464+
lastPerform := strings.LastIndex(before, "PERFORM pg_notify")
465+
lastReturn := strings.LastIndex(before, "RETURN")
466+
if lastPerform >= 0 && lastReturn >= 0 && lastPerform > lastReturn {
467+
t.Errorf("unreachable PERFORM found between RETURN and %s", kw)
468+
}
469+
}
470+
471+
t.Log(instrumentedSQL)
472+
}
473+
474+
func TestInstrumentBody_RaiseExceptionInBranch(t *testing.T) {
475+
// Segment: IF ... THEN RAISE EXCEPTION ... — terminal inside control structure.
476+
sql := `CREATE OR REPLACE FUNCTION validate(x INT)
477+
RETURNS VOID AS $$
478+
BEGIN
479+
IF x < 0 THEN
480+
RAISE EXCEPTION 'negative: %', x;
481+
ELSIF x = 0 THEN
482+
RAISE EXCEPTION 'zero';
483+
END IF;
484+
RAISE NOTICE 'ok: %', x;
485+
END;
486+
$$ LANGUAGE plpgsql;`
487+
488+
stmts := parser.ParseStatements(sql)
489+
if len(stmts) == 0 {
490+
t.Fatal("ParseStatements() returned no statements")
491+
}
492+
493+
instrumentedSQL, coveragePoints := instrumentBody(stmts[0], "test.sql", true, "PERFORM")
494+
if len(coveragePoints) != 3 {
495+
t.Fatalf("expected 3 coverage points, got %d", len(coveragePoints))
496+
}
497+
498+
// First two are in IF/ELSIF branches with RAISE EXCEPTION (terminal).
499+
// Signals must appear before the RAISE EXCEPTION.
500+
for i, target := range []string{"RAISE EXCEPTION 'negative", "RAISE EXCEPTION 'zero"} {
501+
notify := fmt.Sprintf("PERFORM pg_notify('pgcov', '%s');", coveragePoints[i].SignalID)
502+
notifyIdx := strings.Index(instrumentedSQL, notify)
503+
stmtIdx := strings.Index(instrumentedSQL, target)
504+
if notifyIdx < 0 || stmtIdx < 0 {
505+
t.Fatalf("cp %d: could not find notify or %q", i, target)
506+
}
507+
if notifyIdx > stmtIdx {
508+
t.Errorf("cp %d: NOTIFY at %d after %q at %d — unreachable", i, notifyIdx, target, stmtIdx)
509+
}
510+
}
511+
512+
// Third is RAISE NOTICE (non-terminal, standalone). Signal should come after.
513+
notify2 := fmt.Sprintf("PERFORM pg_notify('pgcov', '%s');", coveragePoints[2].SignalID)
514+
notifyIdx2 := strings.Index(instrumentedSQL, notify2)
515+
noticeIdx := strings.Index(instrumentedSQL, "RAISE NOTICE")
516+
if notifyIdx2 < 0 || noticeIdx < 0 {
517+
t.Fatal("could not find RAISE NOTICE or its notify")
518+
}
519+
if notifyIdx2 <= noticeIdx {
520+
t.Errorf("RAISE NOTICE: NOTIFY at %d should come after statement at %d", notifyIdx2, noticeIdx)
521+
}
522+
523+
t.Log(instrumentedSQL)
524+
}
525+
526+
func TestInstrumentBody_MixedTerminalNonTerminalBranches(t *testing.T) {
527+
// Branch with RETURN vs branch with assignment — signal placement differs.
528+
sql := `CREATE OR REPLACE FUNCTION classify(x INT)
529+
RETURNS TEXT AS $$
530+
DECLARE
531+
result TEXT;
532+
BEGIN
533+
IF x > 0 THEN
534+
result := 'positive';
535+
ELSE
536+
RETURN 'non-positive';
537+
END IF;
538+
RETURN result;
539+
END;
540+
$$ LANGUAGE plpgsql;`
541+
542+
stmts := parser.ParseStatements(sql)
543+
if len(stmts) == 0 {
544+
t.Fatal("ParseStatements() returned no statements")
545+
}
546+
547+
instrumentedSQL, coveragePoints := instrumentBody(stmts[0], "test.sql", true, "PERFORM")
548+
if len(coveragePoints) != 3 {
549+
t.Fatalf("expected 3 coverage points, got %d", len(coveragePoints))
550+
}
551+
552+
// cp0: IF ... result := 'positive' — no terminal, signal after.
553+
assign := "result := 'positive'"
554+
assignNotify := fmt.Sprintf("PERFORM pg_notify('pgcov', '%s');", coveragePoints[0].SignalID)
555+
if strings.Index(instrumentedSQL, assignNotify) < strings.Index(instrumentedSQL, assign) {
556+
t.Error("assignment branch: NOTIFY should come AFTER the assignment")
557+
}
558+
559+
// cp1: ELSE RETURN 'non-positive' — terminal inside branch, signal before.
560+
ret1 := "RETURN 'non-positive'"
561+
retNotify1 := fmt.Sprintf("PERFORM pg_notify('pgcov', '%s');", coveragePoints[1].SignalID)
562+
if strings.Index(instrumentedSQL, retNotify1) > strings.Index(instrumentedSQL, ret1) {
563+
t.Error("ELSE RETURN branch: NOTIFY should come BEFORE the RETURN")
564+
}
565+
566+
// cp2: standalone RETURN result — terminal at start, signal before.
567+
ret2 := "RETURN result"
568+
retNotify2 := fmt.Sprintf("PERFORM pg_notify('pgcov', '%s');", coveragePoints[2].SignalID)
569+
if strings.Index(instrumentedSQL, retNotify2) > strings.Index(instrumentedSQL, ret2) {
570+
t.Error("standalone RETURN: NOTIFY should come BEFORE the RETURN")
571+
}
572+
573+
t.Log(instrumentedSQL)
574+
}

0 commit comments

Comments
 (0)