diff --git a/go/mysql/sqlerror/constants.go b/go/mysql/sqlerror/constants.go index 45a051a4939..4edf033eebf 100644 --- a/go/mysql/sqlerror/constants.go +++ b/go/mysql/sqlerror/constants.go @@ -636,10 +636,18 @@ func IsEphemeralError(err error) bool { } // IsTooManyConnectionsErr returns true if the error is due to too many connections. +// This covers three cases: +// - Global max_connections exceeded: MySQL rejects during handshake with CRServerHandshakeErr +// and a "Too many connections" message. +// - Per-user max_user_connections exceeded (errno 1203 ERTooManyUserConnections). +// - Per-user resource limit reached (errno 1226 ERUserLimitReached) for max_user_connections. func IsTooManyConnectionsErr(err error) bool { if sqlErr, ok := err.(*SQLError); ok { - if sqlErr.Number() == CRServerHandshakeErr && strings.Contains(sqlErr.Message, "Too many connections") { + switch sqlErr.Number() { + case ERTooManyUserConnections, ERUserLimitReached: return true + case CRServerHandshakeErr: + return strings.Contains(sqlErr.Message, "Too many connections") } } return false diff --git a/go/mysql/sqlerror/sql_error_test.go b/go/mysql/sqlerror/sql_error_test.go index 9b88803f29a..ec9cab0e4a8 100644 --- a/go/mysql/sqlerror/sql_error_test.go +++ b/go/mysql/sqlerror/sql_error_test.go @@ -221,3 +221,52 @@ func TestNewSQLErrorFromError(t *testing.T) { }) } } + +func TestIsTooManyConnectionsErr(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "global max_connections exceeded", + err: NewSQLError(CRServerHandshakeErr, "", "Too many connections"), + want: true, + }, + { + name: "per-user max_user_connections exceeded (errno 1203)", + err: NewSQLError(ERTooManyUserConnections, SSUnknownSQLState, "Too many connections for user"), + want: true, + }, + { + name: "per-user resource limit reached (errno 1226)", + err: NewSQLError(ERUserLimitReached, SSUnknownSQLState, "User 'vt_app' has exceeded the 'max_user_connections' resource (current value: 1000)"), + want: true, + }, + { + name: "handshake error without too many connections message", + err: NewSQLError(CRServerHandshakeErr, "", "SSL connection error"), + want: false, + }, + { + name: "access denied is not a connection limit error", + err: NewSQLError(ERAccessDeniedError, SSAccessDeniedError, "access denied"), + want: false, + }, + { + name: "non-SQL error", + err: errors.New("some random error"), + want: false, + }, + { + name: "nil error", + err: nil, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsTooManyConnectionsErr(tt.err)) + }) + } +} diff --git a/go/vt/vttablet/tabletserver/query_engine.go b/go/vt/vttablet/tabletserver/query_engine.go index 3896a0986d2..17a1bd15175 100644 --- a/go/vt/vttablet/tabletserver/query_engine.go +++ b/go/vt/vttablet/tabletserver/query_engine.go @@ -538,17 +538,54 @@ func (qe *QueryEngine) ForEachPlan(each func(plan *TabletPlan) bool) { // IsMySQLReachable returns an error if it cannot connect to MySQL. // This can be called before opening the QueryEngine. func (qe *QueryEngine) IsMySQLReachable() error { - conn, err := dbconnpool.NewDBConnection(context.TODO(), qe.env.Config().DB.AppWithDB()) - if err != nil { - if sqlerror.IsTooManyConnectionsErr(err) { + return isMySQLReachable(func() error { + conn, err := dbconnpool.NewDBConnection(context.TODO(), qe.env.Config().DB.DbaWithDB()) + if err != nil { + return err + } + conn.Close() + return nil + }) +} + +func isMySQLReachable(connect func() error) error { + var lastErr error + for attempt := range healthCheckMaxRetries { + lastErr = connect() + if lastErr == nil { return nil } - return err + if sqlerror.IsTooManyConnectionsErr(lastErr) { + return nil + } + if !isTransientConnErr(lastErr) { + return lastErr + } + if attempt < healthCheckMaxRetries-1 { + time.Sleep(healthCheckRetryBaseDelay << attempt) + } + } + return lastErr +} + +func isTransientConnErr(err error) bool { + sqlErr, ok := err.(*sqlerror.SQLError) + if !ok { + return false + } + switch sqlErr.Number() { + case sqlerror.CRConnectionError, sqlerror.CRConnHostError: + return true + default: + return false } - conn.Close() - return nil } +var ( + healthCheckMaxRetries = 3 + healthCheckRetryBaseDelay = 100 * time.Millisecond +) + func (qe *QueryEngine) schemaChanged(tables map[string]*schema.Table, created, altered, dropped []*schema.Table, _ bool) { qe.schemaMu.Lock() defer qe.schemaMu.Unlock() diff --git a/go/vt/vttablet/tabletserver/query_engine_test.go b/go/vt/vttablet/tabletserver/query_engine_test.go index c2ed1b0ea3d..c566b935038 100644 --- a/go/vt/vttablet/tabletserver/query_engine_test.go +++ b/go/vt/vttablet/tabletserver/query_engine_test.go @@ -43,6 +43,7 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/fakesqldb" + "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/streamlog" "vitess.io/vitess/go/vt/dbconfigs" @@ -55,6 +56,121 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" ) +func TestIsMySQLReachable_SucceedsOnFirstAttempt(t *testing.T) { + calls := 0 + err := isMySQLReachable(func() error { + calls++ + return nil + }) + assert.NoError(t, err) + assert.Equal(t, 1, calls) +} + +func TestIsMySQLReachable_RetriesTransientConnError(t *testing.T) { + defer func(d time.Duration) { healthCheckRetryBaseDelay = d }(healthCheckRetryBaseDelay) + healthCheckRetryBaseDelay = time.Millisecond + + calls := 0 + err := isMySQLReachable(func() error { + calls++ + if calls < 3 { + return sqlerror.NewSQLError(sqlerror.CRConnectionError, "", "socket backlog full") + } + return nil + }) + assert.NoError(t, err) + assert.Equal(t, 3, calls) +} + +func TestIsMySQLReachable_RetriesTCPConnError(t *testing.T) { + defer func(d time.Duration) { healthCheckRetryBaseDelay = d }(healthCheckRetryBaseDelay) + healthCheckRetryBaseDelay = time.Millisecond + + calls := 0 + err := isMySQLReachable(func() error { + calls++ + if calls < 2 { + return sqlerror.NewSQLError(sqlerror.CRConnHostError, "", "connection refused") + } + return nil + }) + assert.NoError(t, err) + assert.Equal(t, 2, calls) +} + +func TestIsMySQLReachable_FailsAfterAllRetries(t *testing.T) { + defer func(d time.Duration) { healthCheckRetryBaseDelay = d }(healthCheckRetryBaseDelay) + healthCheckRetryBaseDelay = time.Millisecond + + calls := 0 + err := isMySQLReachable(func() error { + calls++ + return sqlerror.NewSQLError(sqlerror.CRConnectionError, "", "socket backlog full") + }) + assert.Error(t, err) + assert.Equal(t, healthCheckMaxRetries, calls) + assert.Contains(t, err.Error(), "socket backlog full") +} + +func TestIsMySQLReachable_NoRetryOnNonTransientError(t *testing.T) { + calls := 0 + err := isMySQLReachable(func() error { + calls++ + return sqlerror.NewSQLError(sqlerror.ERAccessDeniedError, "", "access denied") + }) + assert.Error(t, err) + assert.Equal(t, 1, calls) +} + +func TestIsMySQLReachable_TooManyConnectionsTreatedAsReachable(t *testing.T) { + calls := 0 + err := isMySQLReachable(func() error { + calls++ + return sqlerror.NewSQLError(sqlerror.CRServerHandshakeErr, "", "Too many connections") + }) + assert.NoError(t, err) + assert.Equal(t, 1, calls) +} + +func TestIsMySQLReachable_MaxUserConnectionsTreatedAsReachable(t *testing.T) { + calls := 0 + err := isMySQLReachable(func() error { + calls++ + return sqlerror.NewSQLError(sqlerror.ERUserLimitReached, sqlerror.SSUnknownSQLState, "User 'vt_app' has exceeded the 'max_user_connections' resource (current value: 1000)") + }) + assert.NoError(t, err) + assert.Equal(t, 1, calls) +} + +func TestIsMySQLReachable_TooManyUserConnectionsTreatedAsReachable(t *testing.T) { + calls := 0 + err := isMySQLReachable(func() error { + calls++ + return sqlerror.NewSQLError(sqlerror.ERTooManyUserConnections, sqlerror.SSUnknownSQLState, "Too many connections for user") + }) + assert.NoError(t, err) + assert.Equal(t, 1, calls) +} + +func TestIsMySQLReachable_ExponentialBackoff(t *testing.T) { + defer func(d time.Duration) { healthCheckRetryBaseDelay = d }(healthCheckRetryBaseDelay) + healthCheckRetryBaseDelay = 50 * time.Millisecond + + var timestamps []time.Time + err := isMySQLReachable(func() error { + timestamps = append(timestamps, time.Now()) + return sqlerror.NewSQLError(sqlerror.CRConnectionError, "", "socket backlog full") + }) + assert.Error(t, err) + assert.Equal(t, healthCheckMaxRetries, len(timestamps)) + + gap1 := timestamps[1].Sub(timestamps[0]) + gap2 := timestamps[2].Sub(timestamps[1]) + assert.True(t, gap1 >= 40*time.Millisecond, "first retry delay too short: %v", gap1) + assert.True(t, gap2 >= 80*time.Millisecond, "second retry delay too short: %v", gap2) + assert.True(t, gap2 > gap1, "backoff should be exponential: gap1=%v gap2=%v", gap1, gap2) +} + func TestStrictMode(t *testing.T) { db := fakesqldb.New(t) defer db.Close()