diff --git a/pkg/util/admission/BUILD.bazel b/pkg/util/admission/BUILD.bazel index cdeab54741e6..f1d6d3fc1290 100644 --- a/pkg/util/admission/BUILD.bazel +++ b/pkg/util/admission/BUILD.bazel @@ -89,6 +89,7 @@ go_test( "scheduler_latency_listener_test.go", "sequencer_test.go", "snapshot_queue_test.go", + "sql_cpu_handle_test.go", "store_token_estimation_test.go", "tokens_linear_model_test.go", "work_queue_test.go", @@ -99,6 +100,7 @@ go_test( "//pkg/cli/exit", "//pkg/roachpb", "//pkg/settings/cluster", + "//pkg/testutils", "//pkg/testutils/datapathutils", "//pkg/testutils/echotest", "//pkg/testutils/skip", diff --git a/pkg/util/admission/sql_cpu_handle.go b/pkg/util/admission/sql_cpu_handle.go index fa16a99eeb5e..fe8bdba996ed 100644 --- a/pkg/util/admission/sql_cpu_handle.go +++ b/pkg/util/admission/sql_cpu_handle.go @@ -96,9 +96,63 @@ type SQLCPUHandle struct { p *sqlCPUProviderImpl wq *WorkQueue + // admitTurn gates the blocking WorkQueue.Admit path for this handle. A + // statement can run many DistSQL goroutines; when the token reservation is + // short, each of them could otherwise enter Admit at once. Serializing that + // path keeps at most one blocking Admit per handle at a time. + // + // Any goroutine that needs that path must take the turn first: buffer size + // 1, send to acquire and deferred receive to release. The send blocks while + // another goroutine holds the turn (so only one runs the serialized section + // at a time). A channel is used so that wait can select on <-ctx.Done(); + // Mutex.Lock cannot. + // + // Holding the turn is a precondition for blocking Admit, not a + // commitment to call it. After the turn is acquired, two conditions + // can prevent blocking Admit from running: + // 1. The handle was closed while waiting for the turn. The + // remaining deficit is accounted via BypassAdmission. + // 2. The previous turn-holder refilled reservation while this + // goroutine waited, and the second deductFromReservation + // covers the shortfall entirely. No Admit call is needed. + admitTurn chan struct{} + mu struct { syncutil.Mutex - closed bool + + // Once true, cannot be set to false. + closed bool + + // reservation holds tokens obtained from WorkQueue.Admit in + // excess of the current checkpoint's deficit. These surplus + // tokens cover future checkpoints without re-acquiring + // WorkQueue.mu. Without the reservation, every measureAndAdmit + // call would call WorkQueue.Admit directly, acquiring + // WorkQueue.mu on every checkpoint. The reservation allows + // lock-free CAS deductions that skip the Admit call entirely. + // + // reservation is modified in three places: + // + // - CAS decrement (tryDeductReservation): does not need + // mu. Goroutines race only with each other and with + // Close's Swap(0); a failed CAS retries and eventually + // sees 0. No tokens are lost. + // + // - Swap(0) drain (Close): does not need mu. Close sets + // closed=true under mu first, which prevents any future + // Add. After that, Swap(0) races only with CAS + // decrements, which is safe (see above). + // + // - Add increment (after Admit): needs mu. The caller + // must check closed before adding — if closed, tokens + // go to AdmittedSQLWorkDone instead. Without mu, Close + // could set closed=true and Swap(0) between the check + // and the Add, leaking tokens. + // + // INVARIANT: reservation >= 0. + // INVARIANT: closed == true => reservation == 0. + reservation atomic.Int64 + gHandles []*GoroutineCPUHandle // Backing for up to 2 goroutine handles, to avoid allocations in // gHandles when there are 2 or fewer goroutines. @@ -114,45 +168,190 @@ func newSQLCPUAdmissionHandle( atGateway: atGateway, p: p, wq: wq, + admitTurn: make(chan struct{}, 1), } h.mu.gHandles = h.mu.handlesBacking[:0] return h } -// reportAndAcquireConsumedCPU updates cumulative CPU counters and, if a CTT -// WorkQueue is attached, calls Admit to deduct the consumed CPU from the token -// bucket. This may block until tokens are available unless noWait is true. -func (h *SQLCPUHandle) reportAndAcquireConsumedCPU( - ctx context.Context, diff time.Duration, noWait bool, -) error { +// reportCPU atomically adds the CPU time difference to the appropriate +// cumulative counter. +func (h *SQLCPUHandle) reportCPU(diff time.Duration) { if h.atGateway { h.p.cumulativeGatewayCPUNanos.Add(diff.Nanoseconds()) } else { h.p.cumulativeDistSQLCPUNanos.Add(diff.Nanoseconds()) } +} - if h.wq == nil { - return nil +// tryDeductReservation deducts up to diffNanos from reservation via +// CAS. Returns the amount grabbed (may be less than diffNanos). +func (h *SQLCPUHandle) tryDeductReservation(diffNanos int64) int64 { + for { + current := h.mu.reservation.Load() + if current <= 0 { + return 0 + } + grab := min(current, diffNanos) + if h.mu.reservation.CompareAndSwap(current, current-grab) { + return grab + } } +} - // RequestedCount is set to the exact CPU consumed (from grunning), so the - // WorkQueue's CPU time token estimator is skipped (see Admit). Because the - // exact amount is deducted at Admit time, there is no estimate to correct, - // so AdmittedWorkDone is not called. This also avoids training the KV - // estimator with SQL CPU data, which would corrupt its estimates. - // - // TODO(wenyi): Currently we call Admit on every measureAndAdmit invocation, - // which happens every ~1024 rows. This means each SQL goroutine takes the - // WorkQueue mutex on every check. Consider reserving more tokens than the - // exact amount consumed (e.g., 2x the last diff, or a smoothed estimate of - // upcoming usage) and tracking remaining reservation locally. This would - // allow subsequent measureAndAdmit calls to deduct from the local - // reservation without calling Admit, reducing contention on the WorkQueue. +// maxRefillBuffer caps the reservation buffer per Admit call to +// prevent large checkpoints from holding excessive tokens idle. +const maxRefillBuffer = int64(1 * time.Millisecond) + +// refillHeuristic determines how many tokens to request from the WorkQueue when +// the reservation runs out. It requests the deficit (to cover the current +// shortfall) plus a buffer (to pre-pay future fast-path CAS deductions). A +// larger buffer means fewer blocking Admit calls and less contention on +// WorkQueue.mu, but it also means more tokens are held in this handle's +// reservation instead of the shared pool. Tokens sitting in reservation are +// unavailable to other tenants and other handles within the same tenant, which +// can reduce fairness. The buffer is capped at maxRefillBuffer to bound +// this unfairness. +// +// TODO(wenyihu6): replace this simple 2x heuristic with an adaptive scheme +// (e.g. exponential growth) that grows the buffer when Admit calls are too +// frequent and shrinks it when they are infrequent. +func (h *SQLCPUHandle) refillHeuristic(deficit int64) int64 { + buffer := min(deficit, maxRefillBuffer) + return deficit + buffer +} + +// constructWorkInfo returns a WorkInfo copy with the given +// RequestedCount and BypassAdmission. +// +// NB: setting RequestedCount > 0 causes WorkQueue.Admit to skip the +// cpuTimeTokenEstimator (see callerSetRequestedCount in Admit). This is +// required for SQL CPU admission, which already knows the exact CPU consumed +// from grunning and does not need/want estimation. +// +// REQUIRES: reqCount > 0 +func (h *SQLCPUHandle) constructWorkInfo(reqCount int64, noWait bool) WorkInfo { workInfo := h.workInfo - workInfo.RequestedCount = diff.Nanoseconds() + workInfo.RequestedCount = reqCount workInfo.BypassAdmission = noWait - _, err := h.wq.Admit(ctx, workInfo) - return err + return workInfo +} + +func (h *SQLCPUHandle) isClosed() bool { + h.mu.Lock() + defer h.mu.Unlock() + return h.mu.closed +} + +// deductFromReservation deducts what it can from reservation via CAS. +// Returns the shortfall still needed after the deduction. +func (h *SQLCPUHandle) deductFromReservation(needed int64) (shortfall int64) { + grabbed := h.tryDeductReservation(needed) + return needed - grabbed +} + +// reportAndAcquireConsumedCPU acquires tokens for consumed CPU. +// +// 1. Fast path: reservation covers the deficit via CAS. No Admit. +// 2. noWait path: deduct what is available, account the rest via +// BypassAdmission. +// 3. Slow path: take a turn via admitTurn, call Admit for the +// deficit plus a buffer, store the buffer in reservation. +// +// In winding-down cases (noWait, context cancellation, Close having +// run), the goroutine deducts what it can and accounts the rest via +// BypassAdmission. It never blocks and never refills the reservation. +// +// INVARIANT: Every token obtained from WorkQueue.Admit must be accounted for: +// tokens are either +// (a) consumed to pay for measured CPU usage, +// (b) held in reservation for future usage, or +// (c) returned via AdmittedSQLWorkDone when the handle is closed. +// +// When SQLCPUHandle is closed, any remaining reservation is returned +// via AdmittedSQLWorkDone, so this code does not leak tokens. +func (h *SQLCPUHandle) reportAndAcquireConsumedCPU( + ctx context.Context, diff time.Duration, noWait bool, +) error { + h.reportCPU(diff) + + if h.wq == nil { + return nil + } + + diffNanos := diff.Nanoseconds() + + // Deduct from reservation (lock-free CAS). + remaining := h.deductFromReservation(diffNanos) + + if noWait { + // Winding down: account the deficit without blocking. + if remaining > 0 { + _, _ = h.wq.Admit(ctx, h.constructWorkInfo(remaining, true /*noWait*/)) + } + return nil + } + + // Fast path: reservation covered the deficit. + if remaining == 0 { + return nil + } + + // Slow path: serialize blocking Admit calls via admitTurn. + select { + case h.admitTurn <- struct{}{}: + defer func() { <-h.admitTurn }() + case <-ctx.Done(): + // Winding down: account the deficit without blocking. + _, _ = h.wq.Admit(ctx, h.constructWorkInfo(remaining, true /*noWait*/)) + return ctx.Err() + } + + // Close may have run while waiting for the turn. + if h.isClosed() { + // Winding down: account the deficit without blocking. + _, _ = h.wq.Admit(ctx, h.constructWorkInfo(remaining, true /*noWait*/)) + return nil + } + + // The previous turn-holder may have refilled the reservation. + remaining = h.deductFromReservation(remaining) + if remaining == 0 { + return nil + } + + // Request the deficit plus a buffer (see refillHeuristic). + resp, err := h.wq.Admit(ctx, h.constructWorkInfo(h.refillHeuristic(remaining), false /*noWait*/)) + if err != nil { + // Error is only due to context cancellation. Account the deficit + // without blocking. + _, _ = h.wq.Admit(ctx, h.constructWorkInfo(remaining, true /*noWait*/)) + return err + } + if resp.Enabled { + buffer := resp.requestedCount - remaining + if buffer > 0 { + closed := func() bool { + h.mu.Lock() + defer h.mu.Unlock() + if h.mu.closed { + return true + } + // NB: closed check and reservation.Add must be atomic under mu. + // Without the lock, this goroutine could read closed=false, then + // Close sets closed=true and Swap(0) drains reservation, then this + // goroutine does Add(buffer), leaking tokens. + h.mu.reservation.Add(buffer) + return false + }() + // Close already ran and drained reservation. Return the buffer + // directly. + if closed { + h.wq.AdmittedSQLWorkDone(h.workInfo.TenantID, buffer) + } + } + } + return nil } // TODO(sumeer): see the comment @@ -204,21 +403,45 @@ func (h *SQLCPUHandle) RegisterGoroutine() *GoroutineCPUHandle { return gh } -// Close is called when no more reporting is needed. It pools -// GoroutineCPUHandles that have been closed. GoroutineCPUHandles that are not -// yet closed are left for GC. +// Close sets closed=true under mu, drains reservation via Swap(0), +// and returns any remaining tokens via AdmittedSQLWorkDone. +// +// Close returns even if there exist goroutines in reportAndAcquireConsumedCPU +// blocked on Admit. reportAndAcquireConsumedCPU that raced with Close are +// handled in two ways: +// - Before taking the turn: they see closed=true after acquiring +// the turn and fall back to BypassAdmission. +// - After Admit returns: they see closed=true under mu and return +// the surplus buffer via AdmittedSQLWorkDone instead of adding +// it to reservation. +// +// Closed GoroutineCPUHandles are pooled; unclosed ones are left for GC. func (h *SQLCPUHandle) Close() { - h.mu.Lock() - defer h.mu.Unlock() - h.mu.closed = true - for i, gh := range h.mu.gHandles { - if gh.closed.Load() { - gh.reset() - goroutineCPUHandlePool.Put(gh) + // After this, goroutines in reportAndAcquireConsumedCPU observe + // closed=true when they acquire mu. + func() { + h.mu.Lock() + defer h.mu.Unlock() + h.mu.closed = true + for i, gh := range h.mu.gHandles { + if gh.closed.Load() { + gh.reset() + goroutineCPUHandlePool.Put(gh) + } + h.mu.gHandles[i] = nil } - h.mu.gHandles[i] = nil + h.mu.gHandles = nil + }() + if h.wq == nil { + return + } + // Drain reservation outside the lock. Swap(0) races safely with CAS + // deductions (CAS retries on conflict and finds 0). NB: No new tokens should + // be added to mu.reservation after this. + remaining := h.mu.reservation.Swap(0) + if remaining > 0 { + h.wq.AdmittedSQLWorkDone(h.workInfo.TenantID, remaining) } - h.mu.gHandles = nil } // GoroutineCPUHandle is used for CPU accounting on a single goroutine. It diff --git a/pkg/util/admission/sql_cpu_handle_test.go b/pkg/util/admission/sql_cpu_handle_test.go new file mode 100644 index 000000000000..e6e5d299eb83 --- /dev/null +++ b/pkg/util/admission/sql_cpu_handle_test.go @@ -0,0 +1,601 @@ +// Copyright 2026 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package admission + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +// TestSQLCPUHandleFastAndSlowPath walks through the full reservation +// lifecycle: slow path (Admit) → fast path (CAS) → reservation +// exhausted → slow path again. +func TestSQLCPUHandleFastAndSlowPath(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + q, tg, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + // 1) Slow path: reservation is 0, must call Admit. + // heuristic(1ms) = 1ms + min(1ms, 1ms) = 2ms requested. + // 1ms consumed, 1ms goes to reservation. + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 1*time.Millisecond, false)) + require.Equal(t, int64(1*time.Millisecond), h.mu.reservation.Load()) + require.Contains(t, tg.buf.stringAndReset(), "tryGet", + "first call should go through Admit") + + // 2) Fast path: 500us < 1ms reservation, CAS covers it. + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 500*time.Microsecond, false)) + require.Equal(t, int64(500*time.Microsecond), h.mu.reservation.Load()) + require.Empty(t, tg.buf.stringAndReset(), + "fast path should not call Admit") + + // 3) Slow path again: 2ms > 500us reservation. + // CAS grabs 500us, remaining=1.5ms, slow path. + // heuristic(1.5ms) = 1.5ms + min(1.5ms, 1ms) = 2.5ms. + // buffer = 2.5ms - 1.5ms = 1ms added to reservation. + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 2*time.Millisecond, false)) + require.Equal(t, int64(1*time.Millisecond), h.mu.reservation.Load()) + require.Contains(t, tg.buf.stringAndReset(), "tryGet", + "exhausted reservation should fall back to Admit") + + // CPU should be fully reported across all three calls. + gw, _ := provider.GetCumulativeSQLCPUNanos() + require.Equal(t, int64(3500*time.Microsecond), gw) +} + +// TestSQLCPUHandleCloseReturnsTokens verifies that Close drains +// reservation and returns tokens via AdmittedSQLWorkDone only when +// there are tokens to return. +func TestSQLCPUHandleCloseReturnsTokens(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + + tests := []struct { + name string + seedReservation bool + expectedReturn bool + }{ + { + name: "non-zero reservation returns tokens", + seedReservation: true, + expectedReturn: true, + }, + { + name: "zero reservation skips return", + seedReservation: false, + expectedReturn: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + q, tg, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + if tc.seedReservation { + require.NoError(t, h.reportAndAcquireConsumedCPU( + ctx, 1*time.Millisecond, false)) + require.Equal(t, int64(1*time.Millisecond), + h.mu.reservation.Load()) + } + + _ = tg.buf.stringAndReset() + h.Close() + + require.True(t, h.isClosed()) + require.Equal(t, int64(0), h.mu.reservation.Load()) + + output := tg.buf.String() + if tc.expectedReturn { + require.Contains(t, output, "returnGrant") + } else { + require.NotContains(t, output, "returnGrant") + } + }) + } +} + +// TestSQLCPUHandleNoWaitBypassAdmission verifies that the noWait path +// uses BypassAdmission and does not block. +func TestSQLCPUHandleNoWaitBypassAdmission(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + q, tg, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + // Make tryGet return false — a blocking Admit would hang, but + // noWait should bypass admission entirely via BypassAdmission. + tg.mu.Lock() + tg.mu.returnValueFromTryGet = false + tg.mu.Unlock() + + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 1*time.Millisecond, true /*noWait*/)) + + gw, _ := provider.GetCumulativeSQLCPUNanos() + require.Equal(t, int64(1*time.Millisecond), gw) +} + +// TestSQLCPUHandleNoWorkQueue verifies that when no WorkQueue is +// attached (CTT AC disabled), CPU is still reported. +func TestSQLCPUHandleNoWorkQueue(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, nil) + + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 1*time.Millisecond, false)) + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 2*time.Millisecond, true)) + + gw, _ := provider.GetCumulativeSQLCPUNanos() + require.Equal(t, int64(3*time.Millisecond), gw) + + h.Close() + require.True(t, h.isClosed()) +} + +// TestSQLCPUHandleConcurrentFastPath exercises the CAS-based fast path +// under contention from multiple goroutines. All goroutines deduct from +// the same reservation. The total deducted must be exact, and +// reservation must never go negative. +func TestSQLCPUHandleConcurrentFastPath(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + q, tg, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + // Seed reservation via slow path. + // heuristic(50ms) = 50ms + min(50ms, 1ms) = 51ms. + // Reservation = 51ms - 50ms = 1ms. + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 50*time.Millisecond, false)) + require.Equal(t, int64(1*time.Millisecond), h.mu.reservation.Load()) + + _ = tg.buf.stringAndReset() + + // Launch goroutines that each deduct a small amount via fast path. + // Total = 20 * 10us = 200us, well within the 1ms reservation. + const numGoroutines = 20 + const perGoroutine = 10 * time.Microsecond + var wg sync.WaitGroup + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, perGoroutine, false)) + }() + } + wg.Wait() + + // Reservation should be exactly 1ms - 200us = 800us. + expected := int64(1*time.Millisecond) - int64(numGoroutines)*int64(perGoroutine) + require.Equal(t, expected, h.mu.reservation.Load(), + "CAS deductions should be exact under contention") + + // No Admit calls should have been made. + output := tg.buf.stringAndReset() + require.Empty(t, output, "all deductions should use CAS fast path") +} + +// TestSQLCPUHandleConcurrentSlowPath exercises the slow path under +// contention. When reservation is exhausted, goroutines serialize on +// admitTurn and only one calls Admit while others may find reservation +// refilled by the winner. +func TestSQLCPUHandleConcurrentSlowPath(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + q, _, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + // No reservation seed — all goroutines hit the slow path. + const numGoroutines = 10 + const perGoroutine = 1 * time.Millisecond + var wg sync.WaitGroup + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, perGoroutine, false)) + require.GreaterOrEqual(t, h.mu.reservation.Load(), int64(0)) + }() + } + wg.Wait() + + // All CPU should be reported. + gw, _ := provider.GetCumulativeSQLCPUNanos() + require.Equal(t, int64(numGoroutines)*int64(perGoroutine), gw) +} + +// TestSQLCPUHandleConcurrentCloseAndAdmit verifies that Close and +// reportAndAcquireConsumedCPU can run concurrently without races, +// panics, or token leaks. After Close, reservation is 0. +func TestSQLCPUHandleConcurrentCloseAndAdmit(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + q, _, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + // Seed reservation: heuristic(5ms) = 5ms + min(5ms, 1ms) = 6ms. + // Reservation = 6ms - 5ms = 1ms. + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 5*time.Millisecond, false)) + require.Equal(t, int64(1*time.Millisecond), h.mu.reservation.Load()) + + var wg sync.WaitGroup + + // Goroutines with noWait=false. They hit the slow path (admitTurn + // + Admit) once the 1ms reservation is exhausted, but Admit + // returns immediately since testGranter.tryGet always succeeds. + const numBlocking = 10 + wg.Add(numBlocking) + for i := 0; i < numBlocking; i++ { + go func() { + defer wg.Done() + for j := 0; j < 50; j++ { + _ = h.reportAndAcquireConsumedCPU(ctx, 10*time.Microsecond, false) + } + }() + } + + // Goroutines calling the noWait path. + const numNoWait = 10 + wg.Add(numNoWait) + for i := 0; i < numNoWait; i++ { + go func() { + defer wg.Done() + _ = h.reportAndAcquireConsumedCPU(ctx, 100*time.Microsecond, true) + }() + } + + // Close concurrently. + wg.Add(1) + go func() { + defer wg.Done() + h.Close() + }() + + wg.Wait() + + require.True(t, h.isClosed()) + // INVARIANT: closed == true => reservation == 0. + require.Equal(t, int64(0), h.mu.reservation.Load()) +} + +// TestSQLCPUHandleConcurrentCASAndSwap races tryDeductReservation +// (CAS decrement) against Close's Swap(0) (drain) on the same +// reservation atomic. Asserts CAS'd + returned == initial — no +// tokens lost regardless of interleaving. +func TestSQLCPUHandleConcurrentCASAndSwap(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + q, _, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + for iter := 0; iter < 100; iter++ { + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + // Seed reservation: heuristic(10ms) = 10ms + min(10ms, 1ms) = 11ms. + // Reservation = 11ms - 10ms = 1ms. + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 10*time.Millisecond, false)) + initialReservation := h.mu.reservation.Load() + + var wg sync.WaitGroup + var casDeducted atomic.Int64 + + // CAS goroutines: each tries to deduct 500us from + // reservation. Depending on timing, a CAS may grab the + // full 500us, a partial amount, or nothing (if Close + // already drained reservation). + const numGoroutines = 5 + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + amount := int64(500 * time.Microsecond) + grabbed := h.tryDeductReservation(amount) + casDeducted.Add(grabbed) + }() + } + + // Close concurrently: sets closed under mu, then Swap(0) + // drains whatever CAS hasn't grabbed, returning it via + // AdmittedSQLWorkDone. + wg.Add(1) + go func() { + defer wg.Done() + h.Close() + }() + + wg.Wait() + + // After Close + all CAS goroutines finish, reservation + // must be zero (no slow-path goroutines to Add tokens). + require.Zero(t, h.mu.reservation.Load(), + "iter %d: reservation should be zero after Close", iter) + + // Token conservation: reservation is zero and the only + // two drains are CAS decrements and Close's Swap(0), so + // CAS'd + Swap'd == initialReservation. CAS can't grab + // more than what was available. + require.LessOrEqual(t, casDeducted.Load(), initialReservation, + "iter %d: CAS deducted more than initial reservation(%d)", + iter, initialReservation) + } +} + +// TestSQLCPUHandleAdmitVsCloseTokenConservation runs a stress test +// verifying the token conservation invariant: all tokens obtained from +// Admit are either consumed, held in reservation, or returned via +// AdmittedSQLWorkDone. This exercises the Admit-vs-Close race where +// the commit step checks closed under mu. +func TestSQLCPUHandleAdmitVsCloseTokenConservation(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + q, _, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + provider := &sqlCPUProviderImpl{} + + // Run many iterations to exercise the race window between + // Admit's commit step and Close's Swap(0). + for iter := 0; iter < 200; iter++ { + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + var wg sync.WaitGroup + + // Multiple goroutines call reportAndAcquireConsumedCPU + // concurrently. + const numWorkers = 5 + wg.Add(numWorkers) + for i := 0; i < numWorkers; i++ { + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + _ = h.reportAndAcquireConsumedCPU(ctx, 50*time.Microsecond, false) + } + }() + } + + // Close races with the workers. + wg.Add(1) + go func() { + defer wg.Done() + h.Close() + }() + + wg.Wait() + + // After Close, both invariants must hold. + require.True(t, h.isClosed()) + require.Equal(t, int64(0), h.mu.reservation.Load(), + "iter %d: closed == true => reservation == 0", iter) + } +} + +// TestSQLCPUHandleContextCancellation verifies that when a goroutine's +// context is canceled while waiting for admitTurn, it falls through to +// the BypassAdmission path, accounts the deficit without blocking, and +// returns ctx.Err(). +func TestSQLCPUHandleContextCancellation(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + tenantID := roachpb.MustMakeTenantID(1) + q, _, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + // Hold admitTurn so the next goroutine blocks on it. + h.admitTurn <- struct{}{} + defer func() { <-h.admitTurn }() + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + go func() { + // This will try to send to admitTurn (blocked) and fall + // through to ctx.Done(). + errCh <- h.reportAndAcquireConsumedCPU(ctx, 1*time.Millisecond, false) + }() + + // Cancel the context — the goroutine should return ctx.Err(). + cancel() + err := <-errCh + require.ErrorIs(t, err, context.Canceled) + + // CPU should still be reported despite the cancellation. + gw, _ := provider.GetCumulativeSQLCPUNanos() + require.Equal(t, int64(1*time.Millisecond), gw) + + h.Close() + + // Close should have nothing to return — the bypass path + // already accounted for the consumed CPU. + require.Zero(t, h.mu.reservation.Load()) +} + +// TestSQLCPUHandleCloseDoesNotBlockOnAdmitTurn verifies that Close +// returns immediately even when admitTurn is held, and that a +// goroutine acquiring the turn after Close sees closed=true and +// falls back to BypassAdmission without refilling reservation. +func TestSQLCPUHandleCloseDoesNotBlockOnAdmitTurn(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + q, _, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + // Seed reservation: heuristic(1ms) = 1ms + min(1ms, 1ms) = 2ms. + // Reservation = 2ms - 1ms = 1ms. + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 1*time.Millisecond, false)) + require.Equal(t, int64(1*time.Millisecond), h.mu.reservation.Load()) + + // Hold admitTurn so the next slow-path goroutine blocks on it. + h.admitTurn <- struct{}{} + + // Start a goroutine that needs more than the reservation, + // forcing the slow path. It blocks waiting for admitTurn. + errCh := make(chan error, 1) + go func() { + errCh <- h.reportAndAcquireConsumedCPU(ctx, 2*time.Millisecond, false) + }() + + // Close doesn't touch admitTurn — it returns immediately. + // If it blocked, the test would time out. + h.Close() + require.True(t, h.isClosed()) + require.Zero(t, h.mu.reservation.Load()) + + // Release admitTurn — the blocked goroutine acquires the turn, + // sees closed=true, and accounts the deficit via BypassAdmission. + <-h.admitTurn + + require.NoError(t, <-errCh) + require.Zero(t, h.mu.reservation.Load()) +} + +// TestSQLCPUHandleSecondDeductionAfterTurn verifies the second +// deductFromReservation after acquiring admitTurn. When the previous +// turn-holder refills reservation before releasing the turn, the +// next turn-holder's second deduction can cover the shortfall and +// skip Admit entirely. +func TestSQLCPUHandleSecondDeductionAfterTurn(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + tenantID := roachpb.MustMakeTenantID(1) + q, tg, cleanup := makeCPUTimeTokenWorkQueue(t) + defer cleanup() + + provider := &sqlCPUProviderImpl{} + h := newSQLCPUAdmissionHandle( + WorkInfo{TenantID: tenantID}, true, provider, q) + + // Case 1: second deduction finds nothing, must call Admit. + // Reservation starts at 0, so both deductions get nothing. + _ = tg.buf.stringAndReset() + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 500*time.Microsecond, false)) + require.Contains(t, tg.buf.stringAndReset(), "tryGet", + "should have called Admit") + // heuristic(500us) = 500us + min(500us, 1ms) = 1ms. + // Reservation = 1ms - 500us = 500us. + require.Equal(t, int64(500*time.Microsecond), h.mu.reservation.Load()) + + // Case 2: second deduction covers the shortfall, skips Admit. + // Hold the turn, start a goroutine that blocks on it, then + // inject tokens into reservation (simulating what the previous + // turn-holder would leave) before releasing the turn. + + // Consume 400us via fast path, leaving 100us. + require.NoError(t, h.reportAndAcquireConsumedCPU(ctx, 400*time.Microsecond, false)) + require.Equal(t, int64(100*time.Microsecond), h.mu.reservation.Load()) + + // Hold admitTurn. + h.admitTurn <- struct{}{} + + // Request 200us: first CAS grabs 100us, remaining=100us, + // goes to slow path, blocks waiting for admitTurn. + errCh := make(chan error, 1) + go func() { + errCh <- h.reportAndAcquireConsumedCPU(ctx, 200*time.Microsecond, false) + }() + + // Wait until the goroutine has done its first CAS (reservation + // drops from 100us to 0) before injecting tokens. This ensures + // the goroutine is blocked on admitTurn, not still in the fast path. + testutils.SucceedsSoon(t, func() error { + if h.mu.reservation.Load() != 0 { + return errors.New("waiting for goroutine to CAS reservation to 0") + } + return nil + }) + + // Inject 100us into reservation, simulating the previous + // turn-holder's Admit refilling it. + h.mu.reservation.Add(int64(100 * time.Microsecond)) + + // Release the turn. The goroutine's second deduction finds + // the 100us and covers the shortfall — no Admit needed. + _ = tg.buf.stringAndReset() + <-h.admitTurn + require.NoError(t, <-errCh) + require.NotContains(t, tg.buf.stringAndReset(), "tryGet", + "second deduction should have covered shortfall, skipping Admit") + + h.Close() +} diff --git a/pkg/util/admission/work_queue.go b/pkg/util/admission/work_queue.go index 91ed40ad5497..41346f488cbf 100644 --- a/pkg/util/admission/work_queue.go +++ b/pkg/util/admission/work_queue.go @@ -25,6 +25,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/settings" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/util/admission/admissionpb" + "github.com/cockroachdb/cockroach/pkg/util/buildutil" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/metric" "github.com/cockroachdb/cockroach/pkg/util/metric/aggmetric" @@ -1287,6 +1288,26 @@ func (q *WorkQueue) adjustTenantUsedLocked(tenant *tenantInfo, delta int64) { } } +// AdmittedSQLWorkDone returns unused reservation tokens to the granter +// when a SQL statement closes. remaining is always non-negative since +// the CAS-based deduction in SQLCPUHandle never drives reservation +// below zero. +func (q *WorkQueue) AdmittedSQLWorkDone(tenantID roachpb.TenantID, remaining int64) { + if remaining == 0 { + return + } + if remaining < 0 && buildutil.CrdbTestBuild { + log.Dev.Fatalf(q.ambientCtx, "AdmittedSQLWorkDone: remaining %d is negative", remaining) + } + q.adjustTenantUsed(tenantID, -remaining) + if remaining < 0 { + // Should never happen, but account for it defensively. + q.granter.tookWithoutPermission(-remaining) + } else { + q.granter.returnGrant(remaining) + } +} + // refillBurstBuckets adds tokens to all tenant burst buckets and updates // their capacity. This is called by cpuTimeTokenAllocator periodically (every // 1ms). If a tenant's burst qualification changes as a result of the refill,