diff --git a/pgconn/cancel_and_drain.go b/pgconn/cancel_and_drain.go new file mode 100644 index 000000000..005b18f71 --- /dev/null +++ b/pgconn/cancel_and_drain.go @@ -0,0 +1,131 @@ +package pgconn + +import ( + "context" + "time" + + "github.com/jackc/pgx/v5/pgconn/ctxwatch" + "github.com/jackc/pgx/v5/pgproto3" +) + +// CancelAndDrainContextWatcherHandler handles cancelled contexts by sending a cancel request to the server and then +// draining any pending SQLSTATE 57014 (query_canceled) with a single ";" round-trip. Unlike [CancelRequestContextWatcherHandler], +// no fixed sleep is used; the drain is deterministic. +type CancelAndDrainContextWatcherHandler struct { + Conn *PgConn + + // DeadlineDelay is the network deadline set on the connection when the context + // is cancelled, used as a fallback to unblock any blocked read. Defaults to 1s. + DeadlineDelay time.Duration + + // DrainTimeout is the maximum time to spend draining a cancelled query's + // in-flight results via SELECT 1 polling. Defaults to 5s. + DrainTimeout time.Duration + + cancelFinishedChan chan struct{} + stopFn context.CancelFunc +} + +var _ ctxwatch.Handler = (*CancelAndDrainContextWatcherHandler)(nil) + +func (h *CancelAndDrainContextWatcherHandler) deadlineDelay() time.Duration { + if h.DeadlineDelay == 0 { + return time.Second + } + return h.DeadlineDelay +} + +func (h *CancelAndDrainContextWatcherHandler) drainTimeout() time.Duration { + if h.DrainTimeout == 0 { + return 5 * time.Second + } + return h.DrainTimeout +} + +// HandleCancel is called when the context is cancelled. It sets a net.Conn deadline +// as a fallback and sends a PostgreSQL cancel request in a goroutine. +func (h *CancelAndDrainContextWatcherHandler) HandleCancel(_ context.Context) { + h.cancelFinishedChan = make(chan struct{}) + cancelCtx, stop := context.WithCancel(context.Background()) + h.stopFn = stop + + deadline := time.Now().Add(h.deadlineDelay()) + h.Conn.conn.SetDeadline(deadline) + + doneCh := h.cancelFinishedChan + go func() { + defer close(doneCh) + reqCtx, cancel := context.WithDeadline(cancelCtx, deadline) + defer cancel() + h.Conn.CancelRequest(reqCtx) + }() +} + +// HandleUnwatchAfterCancel is called after the cancelled query returns. It stops the cancel goroutine (if still +// running), clears the net.Conn deadline, and drains any in-flight cancel by polling SELECT 1. +func (h *CancelAndDrainContextWatcherHandler) HandleUnwatchAfterCancel() { + if h.stopFn != nil { + h.stopFn() + } + if h.cancelFinishedChan != nil { + <-h.cancelFinishedChan + } + h.Conn.conn.SetDeadline(time.Time{}) + h.cancelFinishedChan = nil + h.stopFn = nil + + if !h.Conn.IsClosed() { + ctx, cancel := context.WithTimeout(context.Background(), h.drainTimeout()) + defer cancel() + h.Conn.execInternalForDrain(ctx) + } +} + +// queryCanceledSQLStateCode is SQLSTATE 57014 (query_canceled). +const queryCanceledSQLStateCode = "57014" + +// execInternalForDrain sends a single ";" and reads until ReadyForQuery, absorbing any +// SQLSTATE 57014 (query_canceled). One round-trip is sufficient: PostgreSQL sets +// QueryCancelPending at most once per cancel signal, so at most one 57014 can arrive. +// On any failure the connection is asyncClosed. +// +// Called while the connection is still logically "busy" from pgconn's perspective +// (lock is held and contextWatcher.Unwatch has been called) but idle from the +// PostgreSQL server's perspective (ReadyForQuery was just received). This means +// it bypasses the normal lock/unlock and contextWatcher.Watch paths. +// +// The deadline from ctx is applied directly to the net.Conn. +func (pgConn *PgConn) execInternalForDrain(ctx context.Context) { + if deadline, ok := ctx.Deadline(); ok { + pgConn.conn.SetDeadline(deadline) + defer pgConn.conn.SetDeadline(time.Time{}) + } + + pgConn.frontend.Send(&pgproto3.Query{String: ";"}) + if err := pgConn.frontend.Flush(); err != nil { + pgConn.asyncClose() + return + } + + for { + msg, err := pgConn.receiveMessage() + if err != nil { + pgConn.asyncClose() + return + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + return + case *pgproto3.ErrorResponse: + pgErr := ErrorResponseToPgError(msg) + if pgErr.Code != queryCanceledSQLStateCode { + pgConn.asyncClose() + return + } + // 57014 absorbed — continue reading until ReadyForQuery + case *pgproto3.EmptyQueryResponse: + // Expected response for ";". + } + } +} diff --git a/pgconn/cancel_and_drain_test.go b/pgconn/cancel_and_drain_test.go new file mode 100644 index 000000000..fef8783a2 --- /dev/null +++ b/pgconn/cancel_and_drain_test.go @@ -0,0 +1,261 @@ +package pgconn_test + +import ( + "context" + "fmt" + "io" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgconn/ctxwatch" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func buildCancelAndDrainConfig(t *testing.T) *pgconn.Config { + t.Helper() + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelAndDrainContextWatcherHandler{Conn: conn} + } + config.ConnectTimeout = 5 * time.Second + return config +} + +func TestCancelAndDrainContextWatcherHandler(t *testing.T) { + t.Parallel() + + t.Run("connection reused after cancel", func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select pg_sleep(10)").ReadAll() + require.Error(t, err) + require.False(t, pgConn.IsClosed(), "connection should not be closed after cancel with drain handler") + + ensureConnValid(t, pgConn) + }) + + t.Run("no stale cancel bleed", func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + for i := range 50 { + func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + pgConn.Exec(ctx, "select pg_sleep(0.020)").ReadAll() + }() + + if pgConn.IsClosed() { + var err error + pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i) + } + + ensureConnValid(t, pgConn) + } + }) + + t.Run("stress", func(t *testing.T) { + t.Parallel() + + for i := range 10 { + t.Run(fmt.Sprintf("goroutine_%d", i), func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + for j := range 20 { + func() { + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) + defer cancel() + pgConn.Exec(ctx, "select pg_sleep(0.010)").ReadAll() + }() + + if pgConn.IsClosed() { + var err error + pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err, "goroutine %d iteration %d: failed to reconnect", i, j) + } + + ensureConnValid(t, pgConn) + } + }) + } + }) + + t.Run("ExecParams", func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + rr := pgConn.ExecParams(ctx, "select pg_sleep(10)", nil, nil, nil, nil) + rr.Read() + _, err = rr.Close() + assert.Error(t, err) + + if !pgConn.IsClosed() { + ensureConnValid(t, pgConn) + } + }) + + t.Run("CopyTo", func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.CopyTo(ctx, io.Discard, "COPY (SELECT pg_sleep(10)) TO STDOUT") + assert.Error(t, err) + + if !pgConn.IsClosed() { + ensureConnValid(t, pgConn) + } + }) + + t.Run("CopyFrom", func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(context.Background(), "CREATE TEMP TABLE drain_test_copyfrom (id int)").ReadAll() + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + + _, err = pgConn.CopyFrom(ctx, pr, "COPY drain_test_copyfrom FROM STDIN") + assert.Error(t, err) + + if !pgConn.IsClosed() { + ensureConnValid(t, pgConn) + } + }) + + t.Run("Pipeline", func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + pipeline := pgConn.StartPipeline(ctx) + + pipeline.SendQueryParams("select pg_sleep(10)", nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.Close() + + require.False(t, pgConn.IsClosed(), "connection should not be closed after cancelled pipeline with drain handler") + ensureConnValid(t, pgConn) + }) + + t.Run("Prepare", func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + for i := range 20 { + func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + pgConn.Prepare(ctx, "", "select pg_sleep(0.010)", nil) + }() + + if pgConn.IsClosed() { + var err error + pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i) + } + + ensureConnValid(t, pgConn) + } + }) + + t.Run("Deallocate", func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + for i := range 20 { + _, err := pgConn.Prepare(context.Background(), "drain_dealloc_test", "select 1", nil) + require.NoError(t, err, "iteration %d: prepare failed", i) + + func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + pgConn.Deallocate(ctx, "drain_dealloc_test") + }() + + if pgConn.IsClosed() { + var err error + pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i) + } + + ensureConnValid(t, pgConn) + } + }) + + t.Run("WaitForNotification", func(t *testing.T) { + t.Parallel() + + pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t)) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)") + } + + _, err = pgConn.Exec(context.Background(), "LISTEN drain_test_channel").ReadAll() + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err = pgConn.WaitForNotification(ctx) + require.Error(t, err) + + require.False(t, pgConn.IsClosed(), "connection should not be closed after cancelled WaitForNotification with drain handler") + ensureConnValid(t, pgConn) + }) +} diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index ca9a48cad..14f523cb6 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -16,6 +16,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/jackc/pgx/v5/internal/iobufpool" @@ -76,7 +77,7 @@ type NotificationHandler func(*PgConn, *Notification) // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn - pid uint32 // backend pid + pid atomic.Uint32 // backend pid; atomic because CancelRequest reads it from a separate goroutine secretKey []byte // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server txStatus byte @@ -410,8 +411,8 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo switch msg := msg.(type) { case *pgproto3.BackendKeyData: - pgConn.pid = msg.ProcessID pgConn.secretKey = msg.SecretKey + pgConn.pid.Store(msg.ProcessID) case *pgproto3.AuthenticationOk: case *pgproto3.AuthenticationCleartextPassword: @@ -656,7 +657,7 @@ func (pgConn *PgConn) Conn() net.Conn { // PID returns the backend PID. func (pgConn *PgConn) PID() uint32 { - return pgConn.pid + return pgConn.pid.Load() } // TxStatus returns the current TxStatus as reported by the server in the ReadyForQuery message. @@ -1037,6 +1038,14 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { // is no way to be sure a query was canceled. // See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-CANCELING-REQUESTS func (pgConn *PgConn) CancelRequest(ctx context.Context) error { + // Nothing to cancel if we haven't completed the handshake yet. The atomic load synchronizes with the + // store in connectOne, making the subsequent read of secretKey safe. + pid := pgConn.pid.Load() + if pid == 0 { + return nil + } + secretKey := pgConn.secretKey + // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. @@ -1072,11 +1081,11 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { defer contextWatcher.Unwatch() } - buf := make([]byte, 12+len(pgConn.secretKey)) + buf := make([]byte, 12+len(secretKey)) binary.BigEndian.PutUint32(buf[0:4], uint32(len(buf))) binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], pgConn.pid) - copy(buf[12:], pgConn.secretKey) + binary.BigEndian.PutUint32(buf[8:12], pid) + copy(buf[12:], secretKey) if _, err := cancelConn.Write(buf); err != nil { return fmt.Errorf("write to connection for cancellation: %w", err) @@ -2128,7 +2137,7 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { return &HijackedConn{ Conn: pgConn.conn, - PID: pgConn.pid, + PID: pgConn.pid.Load(), SecretKey: pgConn.secretKey, ParameterStatuses: pgConn.parameterStatuses, TxStatus: pgConn.txStatus, @@ -2148,7 +2157,6 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { func Construct(hc *HijackedConn) (*PgConn, error) { pgConn := &PgConn{ conn: hc.Conn, - pid: hc.PID, secretKey: hc.SecretKey, parameterStatuses: hc.ParameterStatuses, txStatus: hc.TxStatus, @@ -2160,6 +2168,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) { cleanupDone: make(chan struct{}), } + pgConn.pid.Store(hc.PID) pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn)) pgConn.bgReader = bgreader.New(pgConn.conn)