diff --git a/pgconn/errors.go b/pgconn/errors.go index bc1e31e31..5d90ad2ba 100644 --- a/pgconn/errors.go +++ b/pgconn/errors.go @@ -144,7 +144,7 @@ func normalizeTimeoutError(ctx context.Context, err error) error { } else if ctx.Err() == context.DeadlineExceeded { return &errTimeout{err: ctx.Err()} } else { - return &errTimeout{err: netErr} + return &errTimeout{err: err} } } return err diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index abf6c9d8d..45b6362da 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -292,7 +292,13 @@ func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []* } if fallbackConnectOneConfig != nil { - pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true) + fallbackCtx := octx + if config.ConnectTimeout != 0 { + var cancel context.CancelFunc + fallbackCtx, cancel = context.WithTimeout(octx, config.ConnectTimeout) + defer cancel() + } + pgConn, err := connectOne(fallbackCtx, config, fallbackConnectOneConfig, true) if err == nil { return pgConn, nil } diff --git a/pgconn/pgconn_private_test.go b/pgconn/pgconn_private_test.go index a0c15c27a..218071630 100644 --- a/pgconn/pgconn_private_test.go +++ b/pgconn/pgconn_private_test.go @@ -1,9 +1,14 @@ package pgconn import ( + "context" + "errors" + "fmt" + "net" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCommandTag(t *testing.T) { @@ -39,3 +44,42 @@ func TestCommandTag(t *testing.T) { assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag) } } + +// timeoutNetError is a net.Error that always reports Timeout() == true. +type timeoutNetError struct{ msg string } + +func (e *timeoutNetError) Error() string { return e.msg } +func (e *timeoutNetError) Timeout() bool { return true } +func (e *timeoutNetError) Temporary() bool { return false } + +// wrappedDialError simulates a DialError wrapping a net.Error, like the error +// produced when a dial deadline is exceeded on a custom dialer. +type wrappedDialError struct { + addr string + err error +} + +func (e *wrappedDialError) Error() string { return fmt.Sprintf("dial %s: %s", e.addr, e.err) } +func (e *wrappedDialError) Unwrap() error { return e.err } + +func TestNormalizeTimeoutError_PreservesErrorChain(t *testing.T) { + t.Parallel() + + inner := &timeoutNetError{msg: "i/o timeout"} + outer := &wrappedDialError{addr: "192.0.2.1:5432", err: inner} + + // Sanity check: errors.As finds the net.Error through the wrapper. + var netErr net.Error + require.True(t, errors.As(outer, &netErr)) + + result := normalizeTimeoutError(context.Background(), outer) + + // The result should be an errTimeout and should still unwrap to the + // original wrappedDialError so callers can inspect dial context (address, etc.). + var te *errTimeout + require.True(t, errors.As(result, &te)) + + var dial *wrappedDialError + assert.True(t, errors.As(result, &dial), "original dial error should be preserved in the chain") + assert.Equal(t, "192.0.2.1:5432", dial.addr) +}