Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgconn/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func normalizeTimeoutError(ctx context.Context, err error) error {
} else if ctx.Err() == context.DeadlineExceeded {
return &errTimeout{err: ctx.Err()}
} else {
return &errTimeout{err: netErr}
return &errTimeout{err: err}
}
}
return err
Expand Down
8 changes: 7 additions & 1 deletion pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,13 @@ func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*
}

if fallbackConnectOneConfig != nil {
pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true)
fallbackCtx := octx
if config.ConnectTimeout != 0 {
var cancel context.CancelFunc
fallbackCtx, cancel = context.WithTimeout(octx, config.ConnectTimeout)
defer cancel()
}
pgConn, err := connectOne(fallbackCtx, config, fallbackConnectOneConfig, true)
if err == nil {
return pgConn, nil
}
Expand Down
44 changes: 44 additions & 0 deletions pgconn/pgconn_private_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package pgconn

import (
"context"
"errors"
"fmt"
"net"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCommandTag(t *testing.T) {
Expand Down Expand Up @@ -39,3 +44,42 @@ func TestCommandTag(t *testing.T) {
assert.Equalf(t, tt.isSelect, ct.Select(), "%d. %v", i, tt.commandTag)
}
}

// timeoutNetError is a net.Error that always reports Timeout() == true.
type timeoutNetError struct{ msg string }

func (e *timeoutNetError) Error() string { return e.msg }
func (e *timeoutNetError) Timeout() bool { return true }
func (e *timeoutNetError) Temporary() bool { return false }

// wrappedDialError simulates a DialError wrapping a net.Error, like the error
// produced when a dial deadline is exceeded on a custom dialer.
type wrappedDialError struct {
addr string
err error
}

func (e *wrappedDialError) Error() string { return fmt.Sprintf("dial %s: %s", e.addr, e.err) }
func (e *wrappedDialError) Unwrap() error { return e.err }

func TestNormalizeTimeoutError_PreservesErrorChain(t *testing.T) {
t.Parallel()

inner := &timeoutNetError{msg: "i/o timeout"}
outer := &wrappedDialError{addr: "192.0.2.1:5432", err: inner}

// Sanity check: errors.As finds the net.Error through the wrapper.
var netErr net.Error
require.True(t, errors.As(outer, &netErr))

result := normalizeTimeoutError(context.Background(), outer)

// The result should be an errTimeout and should still unwrap to the
// original wrappedDialError so callers can inspect dial context (address, etc.).
var te *errTimeout
require.True(t, errors.As(result, &te))

var dial *wrappedDialError
assert.True(t, errors.As(result, &dial), "original dial error should be preserved in the chain")
assert.Equal(t, "192.0.2.1:5432", dial.addr)
}