Skip to content

Commit 2ec3058

Browse files
committed
pgconn: add CancelAndDrainContextWatcherHandler
Add a context watcher handler that deterministically drains stale cancel signals instead of relying on a fixed 100ms sleep. When a context is cancelled, HandleCancel sends the cancel request immediately. After the query returns, HandleUnwatchAfterCancel polls SELECT 1 to absorb any in-flight SQLSTATE 57014 before releasing the connection. The drain runs inside Unwatch via HandleUnwatchAfterCancel, so every pgconn operation that watches a context (Exec, ExecParams, Prepare, Deallocate, CopyTo, CopyFrom, Pipeline, WaitForNotification) gets drain-after-cancel automatically with no per-call-site wiring.
1 parent da20f82 commit 2ec3058

2 files changed

Lines changed: 395 additions & 0 deletions

File tree

pgconn/cancel_and_drain.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package pgconn
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
8+
"github.com/jackc/pgx/v5/pgproto3"
9+
)
10+
11+
// CancelAndDrainContextWatcherHandler handles cancelled contexts by sending a cancel request to the server and then
12+
// draining any in-flight SQLSTATE 57014 (query_canceled) by polling SELECT 1. Unlike [CancelRequestContextWatcherHandler],
13+
// no fixed sleep is used; the drain is deterministic.
14+
type CancelAndDrainContextWatcherHandler struct {
15+
Conn *PgConn
16+
17+
// DeadlineDelay is the network deadline set on the connection when the context
18+
// is cancelled, used as a fallback to unblock any blocked read. Defaults to 1s.
19+
DeadlineDelay time.Duration
20+
21+
// DrainTimeout is the maximum time to spend draining a cancelled query's
22+
// in-flight results via SELECT 1 polling. Defaults to 5s.
23+
DrainTimeout time.Duration
24+
25+
cancelFinishedChan chan struct{}
26+
stopFn context.CancelFunc
27+
}
28+
29+
var _ ctxwatch.Handler = (*CancelAndDrainContextWatcherHandler)(nil)
30+
31+
func (h *CancelAndDrainContextWatcherHandler) deadlineDelay() time.Duration {
32+
if h.DeadlineDelay == 0 {
33+
return time.Second
34+
}
35+
return h.DeadlineDelay
36+
}
37+
38+
func (h *CancelAndDrainContextWatcherHandler) drainTimeout() time.Duration {
39+
if h.DrainTimeout == 0 {
40+
return 5 * time.Second
41+
}
42+
return h.DrainTimeout
43+
}
44+
45+
// HandleCancel is called when the context is cancelled. It sets a net.Conn deadline
46+
// as a fallback and sends a PostgreSQL cancel request in a goroutine.
47+
func (h *CancelAndDrainContextWatcherHandler) HandleCancel(_ context.Context) {
48+
h.cancelFinishedChan = make(chan struct{})
49+
cancelCtx, stop := context.WithCancel(context.Background())
50+
h.stopFn = stop
51+
52+
deadline := time.Now().Add(h.deadlineDelay())
53+
h.Conn.conn.SetDeadline(deadline)
54+
55+
doneCh := h.cancelFinishedChan
56+
go func() {
57+
defer close(doneCh)
58+
reqCtx, cancel := context.WithDeadline(cancelCtx, deadline)
59+
defer cancel()
60+
h.Conn.CancelRequest(reqCtx)
61+
}()
62+
}
63+
64+
// HandleUnwatchAfterCancel is called after the cancelled query returns. It stops the cancel goroutine (if still
65+
// running), clears the net.Conn deadline, and drains any in-flight cancel by polling SELECT 1.
66+
func (h *CancelAndDrainContextWatcherHandler) HandleUnwatchAfterCancel() {
67+
if h.stopFn != nil {
68+
h.stopFn()
69+
}
70+
if h.cancelFinishedChan != nil {
71+
<-h.cancelFinishedChan
72+
}
73+
h.Conn.conn.SetDeadline(time.Time{})
74+
h.cancelFinishedChan = nil
75+
h.stopFn = nil
76+
77+
if !h.Conn.IsClosed() {
78+
ctx, cancel := context.WithTimeout(context.Background(), h.drainTimeout())
79+
defer cancel()
80+
h.Conn.execInternalForDrain(ctx)
81+
}
82+
}
83+
84+
// queryCanceledSQLStateCode is SQLSTATE 57014 (query_canceled).
85+
const queryCanceledSQLStateCode = "57014"
86+
87+
// execInternalForDrain sends SELECT 1 in a loop, absorbing any SQLSTATE 57014
88+
// responses, until the connection is confirmed clean or a non-57014 error occurs.
89+
// On any failure the connection is asyncClosed.
90+
//
91+
// Called while the connection is still logically "busy" from pgconn's perspective
92+
// (lock is held and contextWatcher.Unwatch has been called) but idle from the
93+
// PostgreSQL server's perspective (ReadyForQuery was just received). This means
94+
// it bypasses the normal lock/unlock and contextWatcher.Watch paths.
95+
//
96+
// The deadline from ctx is applied directly to the net.Conn.
97+
func (pgConn *PgConn) execInternalForDrain(ctx context.Context) {
98+
if deadline, ok := ctx.Deadline(); ok {
99+
pgConn.conn.SetDeadline(deadline)
100+
defer pgConn.conn.SetDeadline(time.Time{})
101+
}
102+
103+
outer:
104+
for {
105+
pgConn.frontend.Send(&pgproto3.Query{String: "SELECT 1"})
106+
if err := pgConn.frontend.Flush(); err != nil {
107+
pgConn.asyncClose()
108+
return
109+
}
110+
clean := true
111+
for {
112+
msg, err := pgConn.receiveMessage()
113+
if err != nil {
114+
pgConn.asyncClose()
115+
return
116+
}
117+
118+
switch msg := msg.(type) {
119+
case *pgproto3.ReadyForQuery:
120+
if !clean {
121+
clean = true
122+
continue outer // absorbed 57014; send another SELECT 1 to confirm clean
123+
}
124+
return // clean ReadyForQuery — done
125+
case *pgproto3.ErrorResponse:
126+
pgErr := ErrorResponseToPgError(msg)
127+
if pgErr.Code == queryCanceledSQLStateCode {
128+
clean = false // cancel hit this SELECT 1; will confirm after ReadyForQuery
129+
} else {
130+
pgConn.asyncClose()
131+
return
132+
}
133+
case *pgproto3.RowDescription, *pgproto3.DataRow, *pgproto3.CommandComplete:
134+
// Normal result messages for SELECT 1.
135+
}
136+
}
137+
}
138+
}

pgconn/cancel_and_drain_test.go

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
package pgconn_test
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"os"
8+
"testing"
9+
"time"
10+
11+
"github.com/jackc/pgx/v5/pgconn"
12+
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func buildCancelAndDrainConfig(t *testing.T) *pgconn.Config {
18+
t.Helper()
19+
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
20+
require.NoError(t, err)
21+
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
22+
return &pgconn.CancelAndDrainContextWatcherHandler{Conn: conn}
23+
}
24+
config.ConnectTimeout = 5 * time.Second
25+
return config
26+
}
27+
28+
func TestCancelAndDrainContextWatcherHandler(t *testing.T) {
29+
t.Parallel()
30+
31+
t.Run("connection reused after cancel", func(t *testing.T) {
32+
t.Parallel()
33+
34+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
35+
require.NoError(t, err)
36+
defer closeConn(t, pgConn)
37+
38+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
39+
defer cancel()
40+
41+
_, err = pgConn.Exec(ctx, "select pg_sleep(10)").ReadAll()
42+
require.Error(t, err)
43+
require.False(t, pgConn.IsClosed(), "connection should not be closed after cancel with drain handler")
44+
45+
ensureConnValid(t, pgConn)
46+
})
47+
48+
t.Run("no stale cancel bleed", func(t *testing.T) {
49+
t.Parallel()
50+
51+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
52+
require.NoError(t, err)
53+
defer closeConn(t, pgConn)
54+
55+
for i := range 50 {
56+
func() {
57+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
58+
defer cancel()
59+
pgConn.Exec(ctx, "select pg_sleep(0.020)").ReadAll()
60+
}()
61+
62+
if pgConn.IsClosed() {
63+
var err error
64+
pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
65+
require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i)
66+
}
67+
68+
ensureConnValid(t, pgConn)
69+
}
70+
})
71+
72+
t.Run("stress", func(t *testing.T) {
73+
t.Parallel()
74+
75+
for i := range 10 {
76+
t.Run(fmt.Sprintf("goroutine_%d", i), func(t *testing.T) {
77+
t.Parallel()
78+
79+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
80+
require.NoError(t, err)
81+
defer closeConn(t, pgConn)
82+
83+
for j := range 20 {
84+
func() {
85+
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond)
86+
defer cancel()
87+
pgConn.Exec(ctx, "select pg_sleep(0.010)").ReadAll()
88+
}()
89+
90+
if pgConn.IsClosed() {
91+
var err error
92+
pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
93+
require.NoError(t, err, "goroutine %d iteration %d: failed to reconnect", i, j)
94+
}
95+
96+
ensureConnValid(t, pgConn)
97+
}
98+
})
99+
}
100+
})
101+
102+
t.Run("ExecParams", func(t *testing.T) {
103+
t.Parallel()
104+
105+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
106+
require.NoError(t, err)
107+
defer closeConn(t, pgConn)
108+
109+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
110+
defer cancel()
111+
112+
rr := pgConn.ExecParams(ctx, "select pg_sleep(10)", nil, nil, nil, nil)
113+
rr.Read()
114+
_, err = rr.Close()
115+
assert.Error(t, err)
116+
117+
if !pgConn.IsClosed() {
118+
ensureConnValid(t, pgConn)
119+
}
120+
})
121+
122+
t.Run("CopyTo", func(t *testing.T) {
123+
t.Parallel()
124+
125+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
126+
require.NoError(t, err)
127+
defer closeConn(t, pgConn)
128+
129+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
130+
defer cancel()
131+
132+
_, err = pgConn.CopyTo(ctx, io.Discard, "COPY (SELECT pg_sleep(10)) TO STDOUT")
133+
assert.Error(t, err)
134+
135+
if !pgConn.IsClosed() {
136+
ensureConnValid(t, pgConn)
137+
}
138+
})
139+
140+
t.Run("CopyFrom", func(t *testing.T) {
141+
t.Parallel()
142+
143+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
144+
require.NoError(t, err)
145+
defer closeConn(t, pgConn)
146+
147+
_, err = pgConn.Exec(context.Background(), "CREATE TEMP TABLE drain_test_copyfrom (id int)").ReadAll()
148+
require.NoError(t, err)
149+
150+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
151+
defer cancel()
152+
153+
pr, pw := io.Pipe()
154+
defer pr.Close()
155+
defer pw.Close()
156+
157+
_, err = pgConn.CopyFrom(ctx, pr, "COPY drain_test_copyfrom FROM STDIN")
158+
assert.Error(t, err)
159+
160+
if !pgConn.IsClosed() {
161+
ensureConnValid(t, pgConn)
162+
}
163+
})
164+
165+
t.Run("Pipeline", func(t *testing.T) {
166+
t.Parallel()
167+
168+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
169+
require.NoError(t, err)
170+
defer closeConn(t, pgConn)
171+
172+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
173+
defer cancel()
174+
175+
pipeline := pgConn.StartPipeline(ctx)
176+
177+
pipeline.SendQueryParams("select pg_sleep(10)", nil, nil, nil, nil)
178+
err = pipeline.Sync()
179+
require.NoError(t, err)
180+
181+
pipeline.Close()
182+
183+
require.False(t, pgConn.IsClosed(), "connection should not be closed after cancelled pipeline with drain handler")
184+
ensureConnValid(t, pgConn)
185+
})
186+
187+
t.Run("Prepare", func(t *testing.T) {
188+
t.Parallel()
189+
190+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
191+
require.NoError(t, err)
192+
defer closeConn(t, pgConn)
193+
194+
for i := range 20 {
195+
func() {
196+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
197+
defer cancel()
198+
pgConn.Prepare(ctx, "", "select pg_sleep(0.010)", nil)
199+
}()
200+
201+
if pgConn.IsClosed() {
202+
var err error
203+
pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
204+
require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i)
205+
}
206+
207+
ensureConnValid(t, pgConn)
208+
}
209+
})
210+
211+
t.Run("Deallocate", func(t *testing.T) {
212+
t.Parallel()
213+
214+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
215+
require.NoError(t, err)
216+
defer closeConn(t, pgConn)
217+
218+
for i := range 20 {
219+
_, err := pgConn.Prepare(context.Background(), "drain_dealloc_test", "select 1", nil)
220+
require.NoError(t, err, "iteration %d: prepare failed", i)
221+
222+
func() {
223+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
224+
defer cancel()
225+
pgConn.Deallocate(ctx, "drain_dealloc_test")
226+
}()
227+
228+
if pgConn.IsClosed() {
229+
var err error
230+
pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
231+
require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i)
232+
}
233+
234+
ensureConnValid(t, pgConn)
235+
}
236+
})
237+
238+
t.Run("WaitForNotification", func(t *testing.T) {
239+
t.Parallel()
240+
241+
pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
242+
require.NoError(t, err)
243+
defer closeConn(t, pgConn)
244+
245+
_, err = pgConn.Exec(context.Background(), "LISTEN drain_test_channel").ReadAll()
246+
require.NoError(t, err)
247+
248+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
249+
defer cancel()
250+
251+
err = pgConn.WaitForNotification(ctx)
252+
require.Error(t, err)
253+
254+
require.False(t, pgConn.IsClosed(), "connection should not be closed after cancelled WaitForNotification with drain handler")
255+
ensureConnValid(t, pgConn)
256+
})
257+
}

0 commit comments

Comments
 (0)