Skip to content

Commit 78c2f2a

Browse files
committed
fix(instrument): use CTE-based pg_notify for SQL-language functions
SELECT pg_notify(...) injected as a standalone statement inside SQL-language function bodies produces an extra result set that conflicts with the function's declared return type, causing runtime errors. Replace standalone SELECT pg_notify(...) with a CTE prefix: WITH _pgcov_signal AS (SELECT pg_notify('pgcov', '<signal>')) <stmt> This wraps each statement without adding extra result rows, preserving the function's return type semantics. Changes: - instrumentBody() gains a useCTE parameter; when true, coverage signals are injected as CTE prefixes instead of separate statements - instrumentStatement() passes useCTE=true for sql-language functions - Add unit tests for CTE instrumentation (single and multi-statement) - Add testdata/sqlfunc/ with SQL-language function sources and tests - Add TestSQLFunctionInstrumentation e2e test verifying instrumented SQL functions deploy and return correct values against real Postgres
1 parent 812e5f0 commit 78c2f2a

6 files changed

Lines changed: 913 additions & 585 deletions

File tree

internal/instrument/instrumenter.go

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ func instrumentStatement(stmt *parser.Statement, filePath string) (string, []Cov
7474
case parser.StmtFunction, parser.StmtProcedure, parser.StmtDO:
7575
switch stmt.Language {
7676
case "plpgsql":
77-
instrumented, locs := instrumentBody(stmt, filePath, true, "PERFORM")
77+
instrumented, locs := instrumentBody(stmt, filePath, true, false)
7878
return instrumented, locs
7979
case "sql":
80-
instrumented, locs := instrumentBody(stmt, filePath, false, "SELECT")
80+
instrumented, locs := instrumentBody(stmt, filePath, false, true)
8181
return instrumented, locs
8282
default:
8383
// Unknown language, mark as implicitly covered
@@ -101,8 +101,12 @@ func instrumentStatement(stmt *parser.Statement, filePath string) (string, []Cov
101101
//
102102
// For PL/pgSQL (skipToBegin=true), tokens before the first BEGIN are skipped.
103103
// For SQL functions (skipToBegin=false), instrumentation starts immediately.
104-
// notifyCmd is "PERFORM" for PL/pgSQL or "SELECT" for SQL functions.
105-
func instrumentBody(stmt *parser.Statement, filePath string, skipToBegin bool, notifyCmd string) (string, []CoveragePoint) {
104+
// When useCTE is true, coverage signals are injected as a CTE prefix
105+
// (WITH _pgcov_signal AS (SELECT pg_notify(...)) <original statement>)
106+
// instead of a standalone statement, avoiding extra result sets that break
107+
// SQL-language function return types.
108+
// When useCTE is false, signals use PERFORM pg_notify(...) (PL/pgSQL).
109+
func instrumentBody(stmt *parser.Statement, filePath string, skipToBegin bool, useCTE bool) (string, []CoveragePoint) {
106110
bodyContent := stmt.Body
107111
if bodyContent == "" {
108112
return stmt.RawSQL, nil
@@ -132,7 +136,7 @@ func instrumentBody(stmt *parser.Statement, filePath string, skipToBegin bool, n
132136
// For terminal statements (RETURN, RAISE EXCEPTION) the signal is
133137
// emitted *before* the statement because nothing executes after them.
134138
// For all other statements the signal is emitted *after* the statement
135-
// so that coverage is recorded only on successful execution (B1 fix).
139+
// so that coverage is recorded only on successful execution.
136140
emitSegment := func(segEnd int) {
137141
segText := bodyContent[segStart:segEnd]
138142
if !isExecutableSegment(segText) {
@@ -167,16 +171,25 @@ func instrumentBody(stmt *parser.Statement, filePath string, skipToBegin bool, n
167171

168172
escapedSignal := strings.ReplaceAll(cp.SignalID, "'", "''")
169173

170-
if termPos := findTerminalPos(segText); termPos >= 0 {
174+
if useCTE {
175+
// SQL-language functions: inject coverage signal as a CTE
176+
// prefix so we don't produce an extra result set that would
177+
// conflict with the function's declared return type (B6).
178+
ctePrefix := fmt.Sprintf("WITH _pgcov_signal AS (SELECT pg_notify('pgcov', '%s')) ",
179+
escapedSignal)
180+
instrumentedBody.WriteString(ctePrefix)
181+
instrumentedBody.WriteString(segText)
182+
lastWrittenPos = segEnd
183+
} else if termPos := findTerminalPos(segText); termPos >= 0 {
171184
// Segment contains (or starts with) a terminal statement
172185
// (RETURN, RAISE EXCEPTION). Place the signal before the
173186
// terminal so it fires before the scope exits.
174187
// When termPos == 0 the segment starts with the terminal;
175188
// when termPos > 0 it is nested inside a control structure
176189
// (e.g. IF … THEN RETURN …).
177190
termIndent := indentOf(segText[termPos:])
178-
notifyCall := fmt.Sprintf("%s%s pg_notify('pgcov', '%s');",
179-
termIndent, notifyCmd, escapedSignal)
191+
notifyCall := fmt.Sprintf("%sPERFORM pg_notify('pgcov', '%s');",
192+
termIndent, escapedSignal)
180193
instrumentedBody.WriteString(segText[:termPos])
181194
fmt.Fprintf(&instrumentedBody, "%s\n", notifyCall)
182195
instrumentedBody.WriteString(segText[termPos:])
@@ -193,8 +206,8 @@ func instrumentBody(stmt *parser.Statement, filePath string, skipToBegin bool, n
193206
} else {
194207
lastWrittenPos = segEnd
195208
}
196-
notifyCall := fmt.Sprintf("%s%s pg_notify('pgcov', '%s');",
197-
indent, notifyCmd, escapedSignal)
209+
notifyCall := fmt.Sprintf("%sPERFORM pg_notify('pgcov', '%s');",
210+
indent, escapedSignal)
198211
fmt.Fprintf(&instrumentedBody, "\n%s", notifyCall)
199212
}
200213
}

internal/instrument/instrumenter_test.go

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ $$ LANGUAGE plpgsql;`
3232
}
3333
stmt := stmts[0]
3434

35-
instrumentedSQL, coveragePoints := instrumentBody(stmt, "test.sql", true, "PERFORM")
35+
instrumentedSQL, coveragePoints := instrumentBody(stmt, "test.sql", true, false)
3636
if instrumentedSQL == "" {
3737
t.Error("instrumentWithLexer() returned empty instrumented SQL")
3838
}
@@ -127,6 +127,99 @@ $$ LANGUAGE plpgsql;`
127127
}
128128
}
129129

130+
func TestInstrument_SQLFunction_UsesCTE(t *testing.T) {
131+
sql := `CREATE OR REPLACE FUNCTION double_val(x INT)
132+
RETURNS INT AS $$
133+
SELECT x * 2;
134+
$$ LANGUAGE sql;`
135+
136+
tmpDir := t.TempDir()
137+
tmpFile := filepath.Join(tmpDir, "sqlfunc.sql")
138+
if err := os.WriteFile(tmpFile, []byte(sql), 0644); err != nil {
139+
t.Fatalf("failed to write temp file: %v", err)
140+
}
141+
142+
file := &discovery.DiscoveredFile{
143+
Path: tmpFile,
144+
RelativePath: "sqlfunc.sql",
145+
Type: discovery.FileTypeSource,
146+
}
147+
148+
parsed, err := parser.Parse(file)
149+
if err != nil {
150+
t.Fatalf("Parse() error = %v", err)
151+
}
152+
153+
instrumented, err := GenerateCoverageInstrument(parsed)
154+
if err != nil {
155+
t.Fatalf("Instrument() error = %v", err)
156+
}
157+
158+
// Should have coverage points
159+
if len(instrumented.Locations) == 0 {
160+
t.Fatal("Instrument() produced no coverage points for SQL function")
161+
}
162+
163+
// Should use CTE-based instrumentation, not standalone SELECT pg_notify
164+
if !strings.Contains(instrumented.InstrumentedText, "WITH _pgcov_signal AS (SELECT pg_notify(") {
165+
t.Error("SQL function should use CTE-based instrumentation")
166+
}
167+
168+
// Should NOT have a standalone SELECT pg_notify (without CTE wrapper)
169+
// Check that 'SELECT pg_notify' only appears inside CTE definitions
170+
text := instrumented.InstrumentedText
171+
cteRemoved := strings.ReplaceAll(text, "WITH _pgcov_signal AS (SELECT pg_notify(", "")
172+
if strings.Contains(cteRemoved, "SELECT pg_notify(") {
173+
t.Error("SQL function should not have standalone SELECT pg_notify calls")
174+
}
175+
176+
// Should NOT use PERFORM (that's for PL/pgSQL)
177+
if strings.Contains(instrumented.InstrumentedText, "PERFORM pg_notify") {
178+
t.Error("SQL function should not use PERFORM")
179+
}
180+
181+
t.Logf("Instrumented SQL:\n%s", instrumented.InstrumentedText)
182+
}
183+
184+
func TestInstrument_SQLFunction_MultipleStatements(t *testing.T) {
185+
// SQL functions can have multiple statements; the last one determines the return value
186+
sql := `CREATE OR REPLACE FUNCTION insert_and_count()
187+
RETURNS BIGINT AS $$
188+
INSERT INTO log(msg) VALUES ('hello');
189+
SELECT count(*) FROM log;
190+
$$ LANGUAGE sql;`
191+
192+
tmpDir := t.TempDir()
193+
tmpFile := filepath.Join(tmpDir, "sqlfunc_multi.sql")
194+
if err := os.WriteFile(tmpFile, []byte(sql), 0644); err != nil {
195+
t.Fatalf("failed to write temp file: %v", err)
196+
}
197+
198+
file := &discovery.DiscoveredFile{
199+
Path: tmpFile,
200+
RelativePath: "sqlfunc_multi.sql",
201+
Type: discovery.FileTypeSource,
202+
}
203+
204+
parsed, err := parser.Parse(file)
205+
if err != nil {
206+
t.Fatalf("Parse() error = %v", err)
207+
}
208+
209+
instrumented, err := GenerateCoverageInstrument(parsed)
210+
if err != nil {
211+
t.Fatalf("Instrument() error = %v", err)
212+
}
213+
214+
// Both statements should get CTE-based instrumentation
215+
cteCount := strings.Count(instrumented.InstrumentedText, "WITH _pgcov_signal AS (SELECT pg_notify(")
216+
if cteCount < 2 {
217+
t.Errorf("Expected at least 2 CTE injections for multi-statement SQL function, got %d", cteCount)
218+
}
219+
220+
t.Logf("Instrumented SQL:\n%s", instrumented.InstrumentedText)
221+
}
222+
130223
func TestInstrument_MultipleStatements(t *testing.T) {
131224
sql := `SELECT 1;
132225
SELECT 2;

0 commit comments

Comments
 (0)