Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions pgconn/cancel_and_drain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package pgconn

import (
"context"
"time"

"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3"
)

// CancelAndDrainContextWatcherHandler handles cancelled contexts by sending a cancel request to the server and then
// draining any pending SQLSTATE 57014 (query_canceled) with a single ";" round-trip. Unlike [CancelRequestContextWatcherHandler],
// no fixed sleep is used; the drain is deterministic.
type CancelAndDrainContextWatcherHandler struct {
Conn *PgConn

// DeadlineDelay is the network deadline set on the connection when the context
// is cancelled, used as a fallback to unblock any blocked read. Defaults to 1s.
DeadlineDelay time.Duration

// DrainTimeout is the maximum time to spend draining a cancelled query's
// in-flight results via SELECT 1 polling. Defaults to 5s.
DrainTimeout time.Duration

cancelFinishedChan chan struct{}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use the context.Context created in HandleCancel? (e.g. shutdownCtx)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're suggesting something like this:

type CancelAndDrainContextWatcherHandler struct {
	Conn          *PgConn
	DeadlineDelay time.Duration
	DrainTimeout  time.Duration

	doneFn context.CancelFunc   // replaces cancelFinishedChan
	doneCtx context.Context     // replaces cancelFinishedChan
	stopFn context.CancelFunc
}

func (h *CancelAndDrainContextWatcherHandler) HandleCancel(_ context.Context) {
	h.doneCtx, h.doneFn = context.WithCancel(context.Background())
	cancelCtx, stop := context.WithCancel(context.Background())
	h.stopFn = stop

	deadline := time.Now().Add(h.deadlineDelay())
	h.Conn.conn.SetDeadline(deadline)

	go func() {
		defer h.doneFn()
		reqCtx, cancel := context.WithDeadline(cancelCtx, deadline)
		defer cancel()
		h.Conn.CancelRequest(reqCtx)
	}()
}

func (h *CancelAndDrainContextWatcherHandler) HandleUnwatchAfterCancel() {
	if h.stopFn != nil {
		h.stopFn()
	}
	if h.doneCtx != nil {
		<-h.doneCtx.Done()
	}
	// ... rest unchanged
}

this would work fine but is more allocation and diverges from the pattern in the other handler...I don't personally see much reason to prefer it, but I'm not against making that change.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to what I said below, this should all be handled in pgxpool or puddle (somehow), and not done at the PgConn layer.

stopFn context.CancelFunc
}

var _ ctxwatch.Handler = (*CancelAndDrainContextWatcherHandler)(nil)

func (h *CancelAndDrainContextWatcherHandler) deadlineDelay() time.Duration {
if h.DeadlineDelay == 0 {
return time.Second
}
return h.DeadlineDelay
}

func (h *CancelAndDrainContextWatcherHandler) drainTimeout() time.Duration {
if h.DrainTimeout == 0 {
return 5 * time.Second
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const defaultDrainTimeout = 5 * time.Second
return defaultDrainTimeout

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is how I would do it if I owned this library, but I'm matching the existing style - there are no duration constants anywhere else.

}
return h.DrainTimeout
}

// HandleCancel is called when the context is cancelled. It sets a net.Conn deadline
// as a fallback and sends a PostgreSQL cancel request in a goroutine.
func (h *CancelAndDrainContextWatcherHandler) HandleCancel(_ context.Context) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is HandleCancel()'s arg being ignored and not passed to context.WithCancel()? Is this a different lifetime scope? It looks like we intend to inherit from the passed in context.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ctx passed here is already cancelled, this is the same pattern used by the existing handler:

pgx/pgconn/pgconn.go

Lines 2885 to 2887 in a5680bc

func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) {
h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay))
}

h.cancelFinishedChan = make(chan struct{})
cancelCtx, stop := context.WithCancel(context.Background())
h.stopFn = stop

deadline := time.Now().Add(h.deadlineDelay())
h.Conn.conn.SetDeadline(deadline)

doneCh := h.cancelFinishedChan
go func() {
defer close(doneCh)
reqCtx, cancel := context.WithDeadline(cancelCtx, deadline)
defer cancel()
h.Conn.CancelRequest(reqCtx)
}()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be replaced with a errgroup.Group?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this pattern matches the pre-existing handler, there's only 1 goroutine here, and no meaningful error to return

}

// HandleUnwatchAfterCancel is called after the cancelled query returns. It stops the cancel goroutine (if still
// running), clears the net.Conn deadline, and drains any in-flight cancel by polling SELECT 1.
func (h *CancelAndDrainContextWatcherHandler) HandleUnwatchAfterCancel() {
if h.stopFn != nil {
h.stopFn()
}
if h.cancelFinishedChan != nil {
<-h.cancelFinishedChan
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can block indefinitely

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

look at https://github.com/jackc/pgx/pull/2534/changes#diff-361e75a80ec6958cfb843b7e092b63e5bbf6c6b3a20734e1a43f1f722a35c063R57-R58:

defer close(doneCh) // doneCh here is h.cancelFinishedChan
reqCtx, cancel := context.WithDeadline(cancelCtx, deadline) // cancelCtx is cancelled by stopFn

so if stopFn gets called, this channel is closed, and if that doesn't happen for some reason, it's still closed by the deadline

}
h.Conn.conn.SetDeadline(time.Time{})
h.cancelFinishedChan = nil
h.stopFn = nil

if !h.Conn.IsClosed() {
ctx, cancel := context.WithTimeout(context.Background(), h.drainTimeout())
defer cancel()
h.Conn.execInternalForDrain(ctx)
}
}

// queryCanceledSQLStateCode is SQLSTATE 57014 (query_canceled).
const queryCanceledSQLStateCode = "57014"

// execInternalForDrain sends a single ";" and reads until ReadyForQuery, absorbing any
// SQLSTATE 57014 (query_canceled). One round-trip is sufficient: PostgreSQL sets
// QueryCancelPending at most once per cancel signal, so at most one 57014 can arrive.
// On any failure the connection is asyncClosed.
//
// Called while the connection is still logically "busy" from pgconn's perspective
// (lock is held and contextWatcher.Unwatch has been called) but idle from the
// PostgreSQL server's perspective (ReadyForQuery was just received). This means
// it bypasses the normal lock/unlock and contextWatcher.Watch paths.
//
// The deadline from ctx is applied directly to the net.Conn.
func (pgConn *PgConn) execInternalForDrain(ctx context.Context) {
if deadline, ok := ctx.Deadline(); ok {
pgConn.conn.SetDeadline(deadline)
defer pgConn.conn.SetDeadline(time.Time{})
}

pgConn.frontend.Send(&pgproto3.Query{String: ";"})
if err := pgConn.frontend.Flush(); err != nil {
pgConn.asyncClose()
return
}

for {
msg, err := pgConn.receiveMessage()
if err != nil {
pgConn.asyncClose()
return
}

switch msg := msg.(type) {
case *pgproto3.ReadyForQuery:
return
case *pgproto3.ErrorResponse:
pgErr := ErrorResponseToPgError(msg)
if pgErr.Code != queryCanceledSQLStateCode {
pgConn.asyncClose()
return
}
// 57014 absorbed — continue reading until ReadyForQuery
case *pgproto3.EmptyQueryResponse:
// Expected response for ";".
}
}
}
261 changes: 261 additions & 0 deletions pgconn/cancel_and_drain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
package pgconn_test

import (
"context"
"fmt"
"io"
"os"
"testing"
"time"

"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func buildCancelAndDrainConfig(t *testing.T) *pgconn.Config {
t.Helper()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
return &pgconn.CancelAndDrainContextWatcherHandler{Conn: conn}
}
config.ConnectTimeout = 5 * time.Second
return config
}

func TestCancelAndDrainContextWatcherHandler(t *testing.T) {
t.Parallel()

t.Run("connection reused after cancel", func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

_, err = pgConn.Exec(ctx, "select pg_sleep(10)").ReadAll()
require.Error(t, err)
require.False(t, pgConn.IsClosed(), "connection should not be closed after cancel with drain handler")

ensureConnValid(t, pgConn)
})

t.Run("no stale cancel bleed", func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

for i := range 50 {
func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer cancel()
pgConn.Exec(ctx, "select pg_sleep(0.020)").ReadAll()
}()

if pgConn.IsClosed() {
var err error
pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i)
}

ensureConnValid(t, pgConn)
}
})

t.Run("stress", func(t *testing.T) {
t.Parallel()

for i := range 10 {
t.Run(fmt.Sprintf("goroutine_%d", i), func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

for j := range 20 {
func() {
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond)
defer cancel()
pgConn.Exec(ctx, "select pg_sleep(0.010)").ReadAll()
}()

if pgConn.IsClosed() {
var err error
pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err, "goroutine %d iteration %d: failed to reconnect", i, j)
}

ensureConnValid(t, pgConn)
}
})
}
})

t.Run("ExecParams", func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

rr := pgConn.ExecParams(ctx, "select pg_sleep(10)", nil, nil, nil, nil)
rr.Read()
_, err = rr.Close()
assert.Error(t, err)

if !pgConn.IsClosed() {
ensureConnValid(t, pgConn)
}
})

t.Run("CopyTo", func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

_, err = pgConn.CopyTo(ctx, io.Discard, "COPY (SELECT pg_sleep(10)) TO STDOUT")
assert.Error(t, err)

if !pgConn.IsClosed() {
ensureConnValid(t, pgConn)
}
})

t.Run("CopyFrom", func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

_, err = pgConn.Exec(context.Background(), "CREATE TEMP TABLE drain_test_copyfrom (id int)").ReadAll()
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()

_, err = pgConn.CopyFrom(ctx, pr, "COPY drain_test_copyfrom FROM STDIN")
assert.Error(t, err)

if !pgConn.IsClosed() {
ensureConnValid(t, pgConn)
}
})

t.Run("Pipeline", func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

pipeline := pgConn.StartPipeline(ctx)

pipeline.SendQueryParams("select pg_sleep(10)", nil, nil, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)

pipeline.Close()

require.False(t, pgConn.IsClosed(), "connection should not be closed after cancelled pipeline with drain handler")
ensureConnValid(t, pgConn)
})

t.Run("Prepare", func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

for i := range 20 {
func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer cancel()
pgConn.Prepare(ctx, "", "select pg_sleep(0.010)", nil)
}()

if pgConn.IsClosed() {
var err error
pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i)
}

ensureConnValid(t, pgConn)
}
})

t.Run("Deallocate", func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

for i := range 20 {
_, err := pgConn.Prepare(context.Background(), "drain_dealloc_test", "select 1", nil)
require.NoError(t, err, "iteration %d: prepare failed", i)

func() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer cancel()
pgConn.Deallocate(ctx, "drain_dealloc_test")
}()

if pgConn.IsClosed() {
var err error
pgConn, err = pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err, "iteration %d: failed to reconnect after closed connection", i)
}

ensureConnValid(t, pgConn)
}
})

t.Run("WaitForNotification", func(t *testing.T) {
t.Parallel()

pgConn, err := pgconn.ConnectConfig(context.Background(), buildCancelAndDrainConfig(t))
require.NoError(t, err)
defer closeConn(t, pgConn)

if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
}

_, err = pgConn.Exec(context.Background(), "LISTEN drain_test_channel").ReadAll()
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

err = pgConn.WaitForNotification(ctx)
require.Error(t, err)

require.False(t, pgConn.IsClosed(), "connection should not be closed after cancelled WaitForNotification with drain handler")
ensureConnValid(t, pgConn)
})
}
Loading
Loading