|
| 1 | +package pgconn |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "time" |
| 6 | + |
| 7 | + "github.com/jackc/pgx/v5/pgconn/ctxwatch" |
| 8 | + "github.com/jackc/pgx/v5/pgproto3" |
| 9 | +) |
| 10 | + |
| 11 | +const ( |
| 12 | + defaultDeadlineDelay = time.Second |
| 13 | + defaultDrainTimeout = 5 * time.Second |
| 14 | + |
| 15 | + queryCanceledSQLStateCode = "57014" |
| 16 | + |
| 17 | + cancelStateIdle = 0 |
| 18 | + cancelStateInFlight = 1 |
| 19 | + cancelStateSent = 2 |
| 20 | +) |
| 21 | + |
| 22 | +// CancelAndDrainContextWatcherHandler handles cancelled contexts by first sending a cancel request, then draining any |
| 23 | +// pending SQLSTATE 57014 (query_canceled) with a single ";" round-trip. |
| 24 | +// |
| 25 | +// Correctness depends on at most one cancel request being in flight per connection at any time. Each cancel request |
| 26 | +// causes the server to set QueryCancelPending, which produces exactly one 57014. If two cancel requests were sent, |
| 27 | +// two 57014s could arrive -- the first absorbed by the drain, the second bleeding into the next real query. This |
| 28 | +// invariant is enforced by [PgConn.CancelRequest]'s mutex-guarded state machine, which blocks concurrent callers |
| 29 | +// until the in-flight cancel completes. |
| 30 | +type CancelAndDrainContextWatcherHandler struct { |
| 31 | + Conn *PgConn |
| 32 | + |
| 33 | + // DeadlineDelay is a net.Conn deadline set when the context is cancelled, used as a fallback to unblock blocked |
| 34 | + // reads. Defaults to defaultDeadlineDelay (1s). |
| 35 | + DeadlineDelay time.Duration |
| 36 | + |
| 37 | + // DrainTimeout caps the single drain round-trip. Defaults to defaultDrainTimeout (5s). |
| 38 | + DrainTimeout time.Duration |
| 39 | + |
| 40 | + doneCtx context.Context //nolint:containedctx // synchronization primitive, not a request-scoped context |
| 41 | + doneFn context.CancelFunc |
| 42 | + stopFn context.CancelFunc |
| 43 | +} |
| 44 | + |
| 45 | +var _ ctxwatch.Handler = (*CancelAndDrainContextWatcherHandler)(nil) |
| 46 | + |
| 47 | +func (h *CancelAndDrainContextWatcherHandler) deadlineDelay() time.Duration { |
| 48 | + if h.DeadlineDelay == 0 { |
| 49 | + return defaultDeadlineDelay |
| 50 | + } |
| 51 | + return h.DeadlineDelay |
| 52 | +} |
| 53 | + |
| 54 | +func (h *CancelAndDrainContextWatcherHandler) drainTimeout() time.Duration { |
| 55 | + if h.DrainTimeout == 0 { |
| 56 | + return defaultDrainTimeout |
| 57 | + } |
| 58 | + return h.DrainTimeout |
| 59 | +} |
| 60 | + |
| 61 | +// HandleCancel is called when the watched context is cancelled. It applies a net.Conn deadline as a fallback and fires |
| 62 | +// a cancel request in a goroutine. Mutual exclusion (at most one cancel in flight) is enforced by |
| 63 | +// [PgConn.CancelRequest], not here -- the ctxwatch.Handler interface does not permit a return value, but CancelRequest |
| 64 | +// will block if another cancel is already in progress. |
| 65 | +// |
| 66 | +// The parent context is inherited (via WithoutCancel) so that values like trace IDs propagate into the cancel request |
| 67 | +// without inheriting its already-fired cancellation. |
| 68 | +func (h *CancelAndDrainContextWatcherHandler) HandleCancel(ctx context.Context) { |
| 69 | + baseCtx := context.WithoutCancel(ctx) |
| 70 | + cancelCtx, stop := context.WithCancel(baseCtx) |
| 71 | + h.stopFn = stop |
| 72 | + |
| 73 | + h.doneCtx, h.doneFn = context.WithCancel(context.Background()) |
| 74 | + |
| 75 | + deadline := time.Now().Add(h.deadlineDelay()) |
| 76 | + h.Conn.conn.SetDeadline(deadline) |
| 77 | + |
| 78 | + go func() { |
| 79 | + defer h.doneFn() |
| 80 | + reqCtx, cancel := context.WithDeadline(cancelCtx, deadline) |
| 81 | + defer cancel() |
| 82 | + h.Conn.CancelRequest(reqCtx) |
| 83 | + }() |
| 84 | +} |
| 85 | + |
| 86 | +// HandleUnwatchAfterCancel is called after the cancelled query returns. It waits for the cancel goroutine, clears the |
| 87 | +// deadline, and -- if the cancel was successfully sent (cancelStateSent) -- sends exactly one ";" to absorb any pending |
| 88 | +// 57014. Finally it transitions back to idle. |
| 89 | +func (h *CancelAndDrainContextWatcherHandler) HandleUnwatchAfterCancel() { |
| 90 | + if h.stopFn != nil { |
| 91 | + h.stopFn() |
| 92 | + } |
| 93 | + if h.doneCtx != nil { |
| 94 | + <-h.doneCtx.Done() |
| 95 | + } |
| 96 | + h.Conn.conn.SetDeadline(time.Time{}) |
| 97 | + h.doneCtx = nil |
| 98 | + h.doneFn = nil |
| 99 | + h.stopFn = nil |
| 100 | + |
| 101 | + h.Conn.cancelMu.Lock() |
| 102 | + needsDrain := h.Conn.cancelMu.state == cancelStateSent |
| 103 | + if needsDrain { |
| 104 | + h.Conn.cancelMu.state = cancelStateIdle |
| 105 | + } |
| 106 | + h.Conn.cancelMu.Unlock() |
| 107 | + |
| 108 | + if !h.Conn.IsClosed() && needsDrain { |
| 109 | + ctx, cancel := context.WithTimeout(context.Background(), h.drainTimeout()) |
| 110 | + defer cancel() |
| 111 | + h.Conn.drainOnce(ctx) |
| 112 | + } |
| 113 | +} |
| 114 | + |
| 115 | +// drainOnce sends a single ";" and reads the response. If the server returns 57014, the cancel was still pending and is |
| 116 | +// now consumed. If the server returns a clean EmptyQueryResponse, the cancel was already consumed by the original query. |
| 117 | +// Either way the connection is clean after one round-trip -- no loop required. |
| 118 | +// |
| 119 | +// This design assumes at most one cancel is in flight per connection (enforced by [PgConn.CancelRequest]). A single |
| 120 | +// cancel produces at most one QueryCancelPending flag on the server, which yields at most one 57014. |
| 121 | +func (pgConn *PgConn) drainOnce(ctx context.Context) { |
| 122 | + if deadline, ok := ctx.Deadline(); ok { |
| 123 | + pgConn.conn.SetDeadline(deadline) |
| 124 | + defer pgConn.conn.SetDeadline(time.Time{}) |
| 125 | + } |
| 126 | + |
| 127 | + pgConn.frontend.Send(&pgproto3.Query{String: ";"}) |
| 128 | + if err := pgConn.frontend.Flush(); err != nil { |
| 129 | + pgConn.asyncClose() |
| 130 | + return |
| 131 | + } |
| 132 | + |
| 133 | + for { |
| 134 | + msg, err := pgConn.receiveMessage() |
| 135 | + if err != nil { |
| 136 | + pgConn.asyncClose() |
| 137 | + return |
| 138 | + } |
| 139 | + |
| 140 | + switch msg := msg.(type) { |
| 141 | + case *pgproto3.ReadyForQuery: |
| 142 | + return |
| 143 | + case *pgproto3.ErrorResponse: |
| 144 | + pgErr := ErrorResponseToPgError(msg) |
| 145 | + if pgErr.Code != queryCanceledSQLStateCode { |
| 146 | + pgConn.asyncClose() |
| 147 | + return |
| 148 | + } |
| 149 | + // 57014 absorbed -- continue reading until ReadyForQuery |
| 150 | + case *pgproto3.EmptyQueryResponse: |
| 151 | + // Expected response for ";" -- continue reading until ReadyForQuery |
| 152 | + } |
| 153 | + } |
| 154 | +} |
0 commit comments