diff --git a/CHANGELOG.md b/CHANGELOG.md index 1308c8185..ccd7c89ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# Unreleased + +* Add CancelAndDrainContextWatcherHandler that replaces the racy 100ms sleep in CancelRequestContextWatcherHandler with a deterministic single-";" drain, preventing a cancel request from producing a 57014 (query_canceled) on the wrong query (Sean Chittenden) +* Fix data race on pid and secretKey in CancelRequest by grouping both into an atomically-published backendKeyData struct (Sean Chittenden) + # 5.9.1 (March 22, 2026) * Fix: batch result format corruption when using cached prepared statements (reported by Dirkjan Bussink) diff --git a/pgconn/cancel_and_drain.go b/pgconn/cancel_and_drain.go new file mode 100644 index 000000000..816995f1b --- /dev/null +++ b/pgconn/cancel_and_drain.go @@ -0,0 +1,154 @@ +package pgconn + +import ( + "context" + "time" + + "github.com/jackc/pgx/v5/pgconn/ctxwatch" + "github.com/jackc/pgx/v5/pgproto3" +) + +const ( + defaultDeadlineDelay = time.Second + defaultDrainTimeout = 5 * time.Second + + queryCanceledSQLStateCode = "57014" + + cancelStateIdle = 0 + cancelStateInFlight = 1 + cancelStateSent = 2 +) + +// CancelAndDrainContextWatcherHandler handles cancelled contexts by first sending a cancel request, then draining any +// pending SQLSTATE 57014 (query_canceled) with a single ";" round-trip. +// +// Correctness depends on at most one cancel request being in flight per connection at any time. Each cancel request +// causes the server to set QueryCancelPending, which produces exactly one 57014. If two cancel requests were sent, +// two 57014s could arrive -- the first absorbed by the drain, the second bleeding into the next real query. This +// invariant is enforced by [PgConn.CancelRequest]'s mutex-guarded state machine, which blocks concurrent callers +// until the in-flight cancel completes. +type CancelAndDrainContextWatcherHandler struct { + Conn *PgConn + + // DeadlineDelay is a net.Conn deadline set when the context is cancelled, used as a fallback to unblock blocked + // reads. Defaults to defaultDeadlineDelay (1s). + DeadlineDelay time.Duration + + // DrainTimeout caps the single drain round-trip. Defaults to defaultDrainTimeout (5s). + DrainTimeout time.Duration + + doneCtx context.Context //nolint:containedctx // synchronization primitive, not a request-scoped context + doneFn context.CancelFunc + stopFn context.CancelFunc +} + +var _ ctxwatch.Handler = (*CancelAndDrainContextWatcherHandler)(nil) + +func (h *CancelAndDrainContextWatcherHandler) deadlineDelay() time.Duration { + if h.DeadlineDelay == 0 { + return defaultDeadlineDelay + } + return h.DeadlineDelay +} + +func (h *CancelAndDrainContextWatcherHandler) drainTimeout() time.Duration { + if h.DrainTimeout == 0 { + return defaultDrainTimeout + } + return h.DrainTimeout +} + +// HandleCancel is called when the watched context is cancelled. It applies a net.Conn deadline as a fallback and fires +// a cancel request in a goroutine. Mutual exclusion (at most one cancel in flight) is enforced by +// [PgConn.CancelRequest], not here -- the ctxwatch.Handler interface does not permit a return value, but CancelRequest +// will block if another cancel is already in progress. +// +// The parent context is inherited (via WithoutCancel) so that values like trace IDs propagate into the cancel request +// without inheriting its already-fired cancellation. +func (h *CancelAndDrainContextWatcherHandler) HandleCancel(ctx context.Context) { + baseCtx := context.WithoutCancel(ctx) + cancelCtx, stop := context.WithCancel(baseCtx) + h.stopFn = stop + + h.doneCtx, h.doneFn = context.WithCancel(context.Background()) + + deadline := time.Now().Add(h.deadlineDelay()) + h.Conn.conn.SetDeadline(deadline) + + go func() { + defer h.doneFn() + reqCtx, cancel := context.WithDeadline(cancelCtx, deadline) + defer cancel() + h.Conn.CancelRequest(reqCtx) + }() +} + +// HandleUnwatchAfterCancel is called after the cancelled query returns. It waits for the cancel goroutine, clears the +// deadline, and -- if the cancel was successfully sent (cancelStateSent) -- sends exactly one ";" to absorb any pending +// 57014. Finally it transitions back to idle. +func (h *CancelAndDrainContextWatcherHandler) HandleUnwatchAfterCancel() { + if h.stopFn != nil { + h.stopFn() + } + if h.doneCtx != nil { + <-h.doneCtx.Done() + } + h.Conn.conn.SetDeadline(time.Time{}) + h.doneCtx = nil + h.doneFn = nil + h.stopFn = nil + + h.Conn.cancelMu.Lock() + needsDrain := h.Conn.cancelMu.state == cancelStateSent + if needsDrain { + h.Conn.cancelMu.state = cancelStateIdle + } + h.Conn.cancelMu.Unlock() + + if !h.Conn.IsClosed() && needsDrain { + ctx, cancel := context.WithTimeout(context.Background(), h.drainTimeout()) + defer cancel() + h.Conn.drainOnce(ctx) + } +} + +// drainOnce sends a single ";" and reads the response. If the server returns 57014, the cancel was still pending and is +// now consumed. If the server returns a clean EmptyQueryResponse, the cancel was already consumed by the original query. +// Either way the connection is clean after one round-trip -- no loop required. +// +// This design assumes at most one cancel is in flight per connection (enforced by [PgConn.CancelRequest]). A single +// cancel produces at most one QueryCancelPending flag on the server, which yields at most one 57014. +func (pgConn *PgConn) drainOnce(ctx context.Context) { + if deadline, ok := ctx.Deadline(); ok { + pgConn.conn.SetDeadline(deadline) + defer pgConn.conn.SetDeadline(time.Time{}) + } + + pgConn.frontend.Send(&pgproto3.Query{String: ";"}) + if err := pgConn.frontend.Flush(); err != nil { + pgConn.asyncClose() + return + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + return + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + if pgErr.Code != queryCanceledSQLStateCode { + pgConn.asyncClose() + return + } + // 57014 absorbed -- continue reading until ReadyForQuery + case *pgproto3.EmptyQueryResponse: + // Expected response for ";" -- continue reading until ReadyForQuery + } + } +} diff --git a/pgconn/cancel_and_drain_test.go b/pgconn/cancel_and_drain_test.go new file mode 100644 index 000000000..60e719ea7 --- /dev/null +++ b/pgconn/cancel_and_drain_test.go @@ -0,0 +1,560 @@ +package pgconn_test + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" +) + +const pgSleepBlock = "pg_sleep(10)" + +func buildCancelAndDrainConfig(t *testing.T) *pgconn.Config { + t.Helper() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelAndDrainContextWatcherHandler{ + Conn: conn, + DeadlineDelay: 5 * time.Second, + DrainTimeout: 5 * time.Second, + } + } + config.ConnectTimeout = 5 * time.Second + return config +} + +// waitUntilActive polls pg_stat_activity from observer until targetPID is in "active" state, then +// returns. If the poll fails unexpectedly (e.g. PID vanishes), t.Errorf surfaces the diagnostic. +func waitUntilActive(t *testing.T, ctx context.Context, observer *pgconn.PgConn, targetPID []byte) { + t.Helper() + var polls int + for { + result := observer.ExecParams(ctx, + "SELECT state FROM pg_stat_activity WHERE pid = $1", + [][]byte{targetPID}, nil, nil, nil, + ).Read() + polls++ + if result.Err != nil { + if ctx.Err() == nil { + t.Errorf("waitUntilActive: poll failed for pid %s after %d polls: %v", targetPID, polls, result.Err) + } + return + } + if len(result.Rows) == 0 { + t.Errorf("waitUntilActive: pid %s not found in pg_stat_activity after %d polls", targetPID, polls) + return + } + if string(result.Rows[0][0]) == "active" { + return + } + time.Sleep(time.Millisecond) + } +} + +// cancelOnActive creates a cancellable child of t.Context(), starts polling pg_stat_activity in a +// goroutine, and returns the child context and a cleanup function. +// +// The cleanup function cancels the context (stopping the poller) AND waits for the goroutine to +// finish its last ExecParams on the observer. Callers MUST call cleanup before reusing the observer. +// +// <-ctx.Done() is NOT a safe synchronization point here because the caller may also call cancel +// (via cleanup) to break out of a deadlock when Exec returns before the poller fires. In that case, +// ctx.Done() closes immediately -- before the goroutine finishes with the observer. The WaitGroup +// inside cleanup is what actually signals goroutine completion. +func cancelOnActive(t *testing.T, observer *pgconn.PgConn, targetPID []byte) (ctx context.Context, cleanup func()) { + t.Helper() + ctx, cancel := context.WithCancel(t.Context()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + defer cancel() + waitUntilActive(t, ctx, observer, targetPID) + }() + return ctx, func() { + cancel() + wg.Wait() + } +} + +func getBackendPID(t *testing.T, conn *pgconn.PgConn) []byte { + t.Helper() + result := conn.ExecParams(t.Context(), + "SELECT pg_backend_pid()::TEXT", nil, nil, nil, nil, + ).Read() + require.NoError(t, result.Err) + require.Equal(t, 1, len(result.Rows)) + return result.Rows[0][0] +} + +func newObserver(t *testing.T) *pgconn.PgConn { + t.Helper() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.ConnectTimeout = 5 * time.Second + conn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + t.Cleanup(func() { conn.Close(context.Background()) }) + return conn +} + +// Scenario 1: Cancel arrives before query completes +// +// The 57014 is consumed by the original query. Connection is clean afterward. The drain sends ";" +// anyway (harmless -- one extra round-trip). +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │───Query(sql)─────────▶│ │ │ +// │ │ [executing] │ │ +// │ [ctx cancelled] │ │ │ +// │ │ │──CancelReq──────▶│ +// │ │ │ │──SIGINT──▶│ +// │ │ [interrupted] │◀──close──────────│ +// │◀──ErrorResponse(57014)│ │ │ +// │◀──ReadyForQuery───────│ │ │ +// │ │ │ │ +// │ [HandleUnwatchAfterCancel] │ │ +// │ [cancelState: sent -> idle] │ │ +// │───Query(;)───────────▶│ [drain] │ │ +// │◀──EmptyQueryResponse──│ │ │ +// │◀──ReadyForQuery───────│ │ │ +// │ │ │ │ +// ▼ connection clean (ok) ▼ ▼ ▼ +func TestCancelAndDrainExecCanceled(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + observer := newObserver(t) + pid := getBackendPID(t, pgConn) + + ctx, cleanup := cancelOnActive(t, observer, pid) + defer cleanup() + + _, err = pgConn.Exec(ctx, "SELECT 1, "+pgSleepBlock).ReadAll() + require.Error(t, err) + + ensureConnValid(t, pgConn) +} + +// Scenario 1 variant: same flow as TestCancelAndDrainExecCanceled but exercises the extended query +// protocol path (ExecParams -> Parse/Bind/Describe/Execute instead of simple Query). +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │───Parse/Bind/Exec────▶│ │ │ +// │ │ [executing] │ │ +// │ [ctx cancelled] │ │ │ +// │ │ │──CancelReq──────▶│ +// │ │ │ │──SIGINT──▶│ +// │ │ [interrupted] │◀──close──────────│ +// │◀──ErrorResponse(57014)│ │ │ +// │◀──ReadyForQuery───────│ │ │ +// │ │ │ │ +// │ [HandleUnwatchAfterCancel] │ │ +// │ [cancelState: sent -> idle] │ │ +// │───Query(;)───────────▶│ [drain] │ │ +// │◀──EmptyQueryResponse──│ │ │ +// │◀──ReadyForQuery───────│ │ │ +// │ │ │ │ +// ▼ connection clean (ok) ▼ ▼ ▼ +func TestCancelAndDrainExecParamsCanceled(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + observer := newObserver(t) + pid := getBackendPID(t, pgConn) + + ctx, cleanup := cancelOnActive(t, observer, pid) + defer cleanup() + + result := pgConn.ExecParams(ctx, "SELECT 1, "+pgSleepBlock, nil, nil, nil, nil) + _, err = result.Close() + require.Error(t, err) + + ensureConnValid(t, pgConn) +} + +// Scenario 1 variant: same flow exercised through the COPY TO protocol path. +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │───Query(COPY...)─────▶│ │ │ +// │ │ [executing] │ │ +// │ [ctx cancelled] │ │ │ +// │ │ │──CancelReq──────▶│ +// │ │ │ │──SIGINT──▶│ +// │ │ [interrupted] │◀──close──────────│ +// │◀──ErrorResponse(57014)│ │ │ +// │◀──ReadyForQuery───────│ │ │ +// │ │ │ │ +// │ [HandleUnwatchAfterCancel] │ │ +// │ [cancelState: sent -> idle] │ │ +// │───Query(;)───────────▶│ [drain] │ │ +// │◀──EmptyQueryResponse──│ │ │ +// │◀──ReadyForQuery───────│ │ │ +// │ │ │ │ +// ▼ connection clean (ok) ▼ ▼ ▼ +func TestCancelAndDrainCopyToCanceled(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + observer := newObserver(t) + pid := getBackendPID(t, pgConn) + + ctx, cleanup := cancelOnActive(t, observer, pid) + defer cleanup() + + _, err = pgConn.CopyTo(ctx, nil, "COPY (SELECT "+pgSleepBlock+") TO STDOUT") + require.Error(t, err) + + ensureConnValid(t, pgConn) +} + +// Scenario 1 followed by a Prepare: after the cancel+drain cycle cleans up, the extended query +// protocol (Prepare -> Parse + DescribeStatement) works on the same connection. +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │───Query(sql)─────────▶│ │ │ +// │ [ctx cancelled, cancel, drain] │ │ +// ▼ connection clean (ok) ▼ ▼ ▼ +// │───Parse/Describe─────▶│ │ │ +// │◀──ParseComplete───────│ │ │ +// │◀──ParameterDescription│ │ │ +// │◀──RowDescription──────│ │ │ +// │◀──ReadyForQuery───────│ │ │ +// ▼ prepare succeeded (ok)▼ ▼ ▼ +func TestCancelAndDrainPrepareSurvivesCancelCycle(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + observer := newObserver(t) + pid := getBackendPID(t, pgConn) + + ctx, cleanup := cancelOnActive(t, observer, pid) + defer cleanup() + + _, err = pgConn.Exec(ctx, "SELECT "+pgSleepBlock).ReadAll() + require.Error(t, err) + + ensureConnValid(t, pgConn) + + sd, err := pgConn.Prepare(t.Context(), "test_stmt", "SELECT 1", nil) + require.NoError(t, err) + require.NotNil(t, sd) +} + +// Scenario 3: Single-";" drain absorbs the stale 57014 +// +// Same race as scenario 2 (the bug), but HandleUnwatchAfterCancel sends exactly one ";" to flush +// the pending cancel before the connection is reused. This test runs 50 cancel+query cycles and +// verifies that no 57014 ever bleeds into the subsequent query. +// +// One ";" is sufficient because PostgreSQL sets QueryCancelPending at most once per cancel signal. +// After the 57014 is raised and sent, the flag is cleared. There is no mechanism for a second +// 57014 from the same cancel. +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │───Query(sql)─────────▶│ │ │ +// │ │ [executing] │ │ +// │ [ctx cancelled] │ │ │ +// │◀──CommandComplete─────│ │ │ +// │◀──ReadyForQuery───────│ [idle] │ │ +// │ │ │──CancelReq──────▶│ +// │ │ │ │──SIGINT──▶│ +// │ │ [QueryCancelPending]│◀──close──────────│ +// │ │ │ │ +// │ [HandleUnwatchAfterCancel] │ │ +// │ [cancelState: sent -> idle] │ │ +// │───Query(;)───────────▶│ [drain] │ │ +// │◀──ErrorResponse(57014)│ [flag consumed] │ │ +// │◀──ReadyForQuery───────│ │ │ +// │ │ │ │ +// │───Query(next sql)────▶│ [clean] │ │ +// │◀──CommandComplete─────│ │ │ +// │◀──ReadyForQuery───────│ │ │ +// │ │ │ │ +// ▼ connection clean (ok) ▼ ▼ ▼ +func TestCancelAndDrainNoStale57014Bleed(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + observer := newObserver(t) + pid := getBackendPID(t, pgConn) + + for i := range 50 { + ctx, cleanup := cancelOnActive(t, observer, pid) + pgConn.Exec(ctx, "SELECT "+pgSleepBlock).ReadAll() + cleanup() + + result := pgConn.ExecParams( + t.Context(), + "SELECT $1::TEXT", + [][]byte{[]byte(fmt.Sprintf("iter_%d", i))}, + nil, nil, nil, + ).Read() + require.NoError(t, result.Err, "iteration %d: stale cancel leaked into next query", i) + require.Equal(t, 1, len(result.Rows)) + require.Equal(t, fmt.Sprintf("iter_%d", i), string(result.Rows[0][0])) + } +} + +// Scenario 1 / Scenario 3 repeated: exercises multiple cancel+drain cycles on the same connection +// to verify that the state machine resets cleanly each time. +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │ │ │ │ +// │ [repeat 20x: ctx cancelled -> cancel -> drain -> ensureConnValid] +// │ │ │ │ +// ▼ connection clean (ok) ▼ ▼ ▼ +func TestCancelAndDrainConnectionReuseCycles(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + observer := newObserver(t) + pid := getBackendPID(t, pgConn) + + for range 20 { + ctx, cleanup := cancelOnActive(t, observer, pid) + pgConn.Exec(ctx, "SELECT "+pgSleepBlock).ReadAll() + cleanup() + ensureConnValid(t, pgConn) + } +} + +// No-cancel path: the query completes before the context deadline. The context watcher never fires +// HandleCancel. No cancel request is sent, no drain is needed. This is the steady-state happy path +// and must not regress. +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │───Query(sql)─────────▶│ │ │ +// │ │ [executing] │ │ +// │◀──RowDescription──────│ │ │ +// │◀──DataRow─────────────│ │ │ +// │◀──CommandComplete─────│ │ │ +// │◀──ReadyForQuery───────│ │ │ +// │ │ │ │ +// │ [ctx not cancelled -- no cancel, no drain] │ │ +// ▼ connection clean (ok) ▼ ▼ ▼ +func TestCancelAndDrainQueryCompletesBeforeCancel(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + results, err := pgConn.Exec(t.Context(), "SELECT 42").ReadAll() + require.NoError(t, err) + require.Len(t, results, 1) + require.Equal(t, "42", string(results[0].Rows[0][0])) + + ensureConnValid(t, pgConn) +} + +// Scenario 5: Duplicate cancel (mutex prevents double-send) +// +// CancelRequest is called twice on the same connection while a query is running. The first call +// transitions idle -> inFlight -> sent. The second call sees cancelStateSent and returns nil +// immediately -- no second cancel packet is sent, so at most one 57014 is produced. +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │───Query(sql)─────────▶│ │ │ +// │ │ [executing] │ │ +// │ │ │ │ +// │ [CancelRequest #1: idle -> inFlight (ok)] │ │ +// │ │ │──CancelReq──────▶│ +// │ [CancelRequest #1: inFlight -> sent] │ │ +// │ │ │ │ +// │ [CancelRequest #2: state == sent -> no-op] │ │ +// │ │ │ │ +// ▼ only one cancel sent ▼ ▼ ▼ +func TestCancelAndDrainCancelRequestIdempotent(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + observer := newObserver(t) + pid := getBackendPID(t, pgConn) + + multiResult := pgConn.Exec(t.Context(), "SELECT "+pgSleepBlock) + + waitUntilActive(t, t.Context(), observer, pid) + + err = pgConn.CancelRequest(t.Context()) + require.NoError(t, err) + + err = pgConn.CancelRequest(t.Context()) + require.NoError(t, err) + + for multiResult.NextResult() { + } + err = multiResult.Close() + require.Error(t, err) + + ensureConnValid(t, pgConn) +} + +// Scenario 5 variant: multiple goroutines race to call CancelRequest concurrently. The mutex +// ensures only one caller transitions idle -> inFlight -> sent. All others either block on the +// in-flight done context and return nil, or see cancelStateSent and return nil immediately. +// Either way, exactly one cancel packet is sent. +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │───Query(sql)─────────▶│ │ │ +// │ │ [executing] │ │ +// │ [ctx cancelled] │ │ │ +// │ [goroutine 1: idle -> inFlight (ok)] │ │ +// │ │ │──CancelReq──────▶│ +// │ │ │ │ +// │ [goroutines 2-5: inFlight -> block on done] │ │ +// │ │ │ │ +// │ [goroutine 1 completes -> sent, doneFn()] │ │ +// │ [goroutines 2-5 unblock -> return nil] │ │ +// │ │ │ │ +// ▼ only one cancel sent ▼ ▼ ▼ +func TestCancelAndDrainConcurrentCancelRequest(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + observer := newObserver(t) + pid := getBackendPID(t, pgConn) + + multiResult := pgConn.Exec(t.Context(), "SELECT "+pgSleepBlock) + + waitUntilActive(t, t.Context(), observer, pid) + + const goroutines = 5 + var wg sync.WaitGroup + wg.Add(goroutines) + errs := make([]error, goroutines) + for i := range goroutines { + go func(idx int) { + defer wg.Done() + errs[idx] = pgConn.CancelRequest(t.Context()) + }(i) + } + wg.Wait() + + for i, err := range errs { + assert.NoError(t, err, "goroutine %d", i) + } + + for multiResult.NextResult() { + } + err = multiResult.Close() + require.Error(t, err) + + ensureConnValid(t, pgConn) +} + +// Stress test: 10 parallel connections x 20 cancel+drain cycles each. Exercises the full +// cancel -> drain -> reuse path under contention, covering scenarios 1 and 3 in aggregate. +// The connection must remain valid after every cycle. +// +// Client (conn A) Server Backend Client (conn B) Postmaster +// │ │ │ │ +// │ [repeat 20x: cancel on active -> drain -> ensureConnValid] +// │ │ │ │ +// ▼ connection clean (ok) ▼ ▼ ▼ +func TestCancelAndDrainStress(t *testing.T) { + t.Parallel() + + for i := range 10 { + t.Run(fmt.Sprintf("Worker %d", i), func(t *testing.T) { + t.Parallel() + + config := buildCancelAndDrainConfig(t) + pgConn, err := pgconn.ConnectConfig(t.Context(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("CockroachDB incompatible with PostgreSQL: pg_stat_activity") + } + + observer := newObserver(t) + pid := getBackendPID(t, pgConn) + + for range 20 { + ctx, cleanup := cancelOnActive(t, observer, pid) + pgConn.Exec(ctx, "SELECT 1, "+pgSleepBlock).ReadAll() + cleanup() + ensureConnValid(t, pgConn) + } + }) + } +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index d6587cef8..d41eb92c4 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -16,6 +16,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/jackc/pgx/v5/internal/iobufpool" @@ -34,6 +35,24 @@ const ( connStatusBusy ) +// PostgreSQL protocol negotiation codes from src/include/libpq/pqcomm.h. +// Each is PG_PROTOCOL(1234, N) = (1234 << 16) | N. +const ( + cancelRequestCode = 80877102 // PG_PROTOCOL(1234, 5678) -- CANCEL_REQUEST_CODE + negotiateSSLCode = 80877103 // PG_PROTOCOL(1234, 5679) -- NEGOTIATE_SSL_CODE +) + +// CancelRequestPacket layout from src/include/libpq/pqcomm.h. All fixed fields are uint32 (4 bytes, network order). +// The cancel auth key is variable-length in pgx (historically 4 bytes, but stored as []byte for forward compatibility). +const ( + cancelPacketFieldSize = 4 // sizeof(uint32) -- each fixed field in the cancel packet + cancelPacketHeaderLen = 3 * cancelPacketFieldSize // packet length + request code + backend PID + cancelPacketLenOffset = 0 + cancelPacketCodeOffset = cancelPacketFieldSize + cancelPacketPIDOffset = 2 * cancelPacketFieldSize + cancelPacketKeyOffset = cancelPacketHeaderLen +) + // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError @@ -73,11 +92,18 @@ type NoticeHandler func(*PgConn, *Notice) // notice event. type NotificationHandler func(*PgConn, *Notification) +// backendKeyData holds the PID and secret key received during the connection handshake. These are published +// atomically because CancelRequest reads them from a goroutine spawned by the context watcher, which may race +// with the handshake if a context is cancelled during connection setup. +type backendKeyData struct { + pid uint32 + secretKey []byte +} + // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn - pid uint32 // backend pid - secretKey []byte // key to use to send a cancel query message to the server + backendKey atomic.Pointer[backendKeyData] parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte frontend *pgproto3.Frontend @@ -106,6 +132,16 @@ type PgConn struct { fieldDescriptions [16]FieldDescription cleanupDone chan struct{} + + // Cancel coordination. CancelRequest is the only PgConn method designed to be called from another goroutine + // (via the context watcher, asyncClose, or direct user calls). The mutex protects the state + done context + // pair so that concurrent callers either proceed (idle), block (inFlight), or no-op (sent). + cancelMu struct { + sync.Mutex + state uint32 // one of cancelState* constants + done context.Context //nolint:containedctx // synchronization primitive, not a request-scoped context + doneFn context.CancelFunc + } } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value @@ -410,8 +446,10 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo switch msg := msg.(type) { case *pgproto3.BackendKeyData: - pgConn.pid = msg.ProcessID - pgConn.secretKey = msg.SecretKey + pgConn.backendKey.Store(&backendKeyData{ + pid: msg.ProcessID, + secretKey: msg.SecretKey, + }) case *pgproto3.AuthenticationOk: case *pgproto3.AuthenticationCleartextPassword: @@ -491,7 +529,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo } func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { - err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) + err := binary.Write(conn, binary.BigEndian, []int32{8, negotiateSSLCode}) if err != nil { return nil, err } @@ -656,7 +694,10 @@ func (pgConn *PgConn) Conn() net.Conn { // PID returns the backend PID. func (pgConn *PgConn) PID() uint32 { - return pgConn.pid + if bk := pgConn.backendKey.Load(); bk != nil { + return bk.pid + } + return 0 } // TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. @@ -674,7 +715,10 @@ func (pgConn *PgConn) TxStatus() byte { // SecretKey returns the backend secret key used to send a cancel query message to the server. func (pgConn *PgConn) SecretKey() []byte { - return pgConn.secretKey + if bk := pgConn.backendKey.Load(); bk != nil { + return bk.secretKey + } + return nil } // Frontend returns the underlying *pgproto3.Frontend. This rarely necessary. @@ -1036,7 +1080,62 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. // See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-CANCELING-REQUESTS +// +// CancelRequest is safe to call from multiple goroutines concurrently. If a cancel is already in flight, the caller +// blocks until it completes and returns nil (the cancel was handled). This prevents multiple cancel requests from +// producing multiple 57014 responses, which the single-";" drain cannot reconcile. func (pgConn *PgConn) CancelRequest(ctx context.Context) error { + pgConn.cancelMu.Lock() + switch pgConn.cancelMu.state { + case cancelStateInFlight: + // Another cancel is in progress -- grab the done context and wait for it. + done := pgConn.cancelMu.done + pgConn.cancelMu.Unlock() + if done != nil { + select { + case <-done.Done(): + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + return nil + + case cancelStateSent: + // A cancel was already sent and is pending drain -- nothing more to do. + pgConn.cancelMu.Unlock() + return nil + } + + // cancelStateIdle -- we own the cancel. + pgConn.cancelMu.done, pgConn.cancelMu.doneFn = context.WithCancel(context.Background()) + pgConn.cancelMu.state = cancelStateInFlight + pgConn.cancelMu.Unlock() + + err := pgConn.sendCancelRequest(ctx) + + pgConn.cancelMu.Lock() + if err != nil { + pgConn.cancelMu.state = cancelStateIdle + } else { + pgConn.cancelMu.state = cancelStateSent + } + pgConn.cancelMu.doneFn() + pgConn.cancelMu.done = nil + pgConn.cancelMu.doneFn = nil + pgConn.cancelMu.Unlock() + + return err +} + +// sendCancelRequest performs the actual network I/O for a cancel request: dial a new TCP connection to the server, +// write the CancelRequestPacket, and wait for acknowledgement. +func (pgConn *PgConn) sendCancelRequest(ctx context.Context) error { + bk := pgConn.backendKey.Load() + if bk == nil { + return nil + } + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. @@ -1072,11 +1171,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { defer contextWatcher.Unwatch() } - buf := make([]byte, 12+len(pgConn.secretKey)) - binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf))) - binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], pgConn.pid) - copy(buf[12:], pgConn.secretKey) + buf := make([]byte, cancelPacketHeaderLen+len(bk.secretKey)) + binary.BigEndian.PutUint32(buf[cancelPacketLenOffset:], uint32(len(buf))) + binary.BigEndian.PutUint32(buf[cancelPacketCodeOffset:], cancelRequestCode) + binary.BigEndian.PutUint32(buf[cancelPacketPIDOffset:], bk.pid) + copy(buf[cancelPacketKeyOffset:], bk.secretKey) if _, err := cancelConn.Write(buf); err != nil { return fmt.Errorf("write to connection for cancellation: %w", err) @@ -2126,16 +2225,19 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { } pgConn.status = connStatusClosed - return &HijackedConn{ + hc := &HijackedConn{ Conn: pgConn.conn, - PID: pgConn.pid, - SecretKey: pgConn.secretKey, ParameterStatuses: pgConn.parameterStatuses, TxStatus: pgConn.txStatus, Frontend: pgConn.frontend, Config: pgConn.config, CustomData: pgConn.customData, - }, nil + } + if bk := pgConn.backendKey.Load(); bk != nil { + hc.PID = bk.pid + hc.SecretKey = bk.secretKey + } + return hc, nil } // Construct created a PgConn from an already established connection to a PostgreSQL server. This is the inverse of @@ -2148,8 +2250,6 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { func Construct(hc *HijackedConn) (*PgConn, error) { pgConn := &PgConn{ conn: hc.Conn, - pid: hc.PID, - secretKey: hc.SecretKey, parameterStatuses: hc.ParameterStatuses, txStatus: hc.TxStatus, frontend: hc.Frontend, @@ -2160,6 +2260,10 @@ func Construct(hc *HijackedConn) (*PgConn, error) { cleanupDone: make(chan struct{}), } + pgConn.backendKey.Store(&backendKeyData{ + pid: hc.PID, + secretKey: hc.SecretKey, + }) pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn)) pgConn.bgReader = bgreader.New(pgConn.conn)