Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions modules/cockroachdb/cockroachdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func (c *CockroachDBContainer) ConnectionConfig(ctx context.Context) (*pgx.ConnC
return nil, fmt.Errorf("host: %w", err)
}

return c.connConfig(host, port.String())
return c.connConfig(host, port)
}

// TLSConfig returns config necessary to connect to CockroachDB over TLS.
Expand Down Expand Up @@ -199,7 +199,7 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom
certsDir+"/client."+defaultUser+".crt",
certsDir+"/client."+defaultUser+".key",
).WithRootCAs(fileCACert).WithServerName("127.0.0.1"),
wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port string) string {
wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port network.Port) string {
connStr, err := ctr.connString(host, port)
if err != nil {
panic(err)
Expand All @@ -225,7 +225,7 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom
}

// connString returns a connection string for the given host, port and options.
func (c *CockroachDBContainer) connString(host string, port string) (string, error) {
func (c *CockroachDBContainer) connString(host string, port network.Port) (string, error) {
cfg, err := c.connConfig(host, port)
if err != nil {
return "", fmt.Errorf("connection config: %w", err)
Expand All @@ -235,11 +235,7 @@ func (c *CockroachDBContainer) connString(host string, port string) (string, err
}

// connConfig returns a [pgx.ConnConfig] for the given host, port and options.
func (c *CockroachDBContainer) connConfig(host string, port string) (*pgx.ConnConfig, error) {
p, err := network.ParsePort(port)
if err != nil {
return nil, err
}
func (c *CockroachDBContainer) connConfig(host string, port network.Port) (*pgx.ConnConfig, error) {
var user *url.Userinfo
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if c.password != "" {
user = url.UserPassword(c.user, c.password)
Expand All @@ -264,7 +260,7 @@ func (c *CockroachDBContainer) connConfig(host string, port string) (*pgx.ConnCo
u := url.URL{
Scheme: "postgres",
User: user,
Host: net.JoinHostPort(host, p.Port()),
Host: net.JoinHostPort(host, port.Port()),
Path: c.database,
RawQuery: params.Encode(),
}
Expand Down
3 changes: 1 addition & 2 deletions modules/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ func TestContainerWithWaitForSQL(t *testing.T) {
ctx := context.Background()

port := "5432/tcp"
dbURL := func(host string, port string) string {
p := network.MustParsePort(port)
dbURL := func(host string, p network.Port) string {
return fmt.Sprintf("postgres://postgres:password@%s:%s/%s?sslmode=disable", host, p.Port(), dbname)
}

Expand Down
6 changes: 3 additions & 3 deletions wait/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var (
const defaultForSQLQuery = "SELECT 1"

// ForSQL constructs a new waitForSql strategy for the given driver
func ForSQL(port string, driver string, url func(host string, port string) string) *waitForSQL {
func ForSQL(port string, driver string, url func(host string, port network.Port) string) *waitForSQL {
return &waitForSQL{
Port: port,
URL: url,
Expand All @@ -31,7 +31,7 @@ func ForSQL(port string, driver string, url func(host string, port string) strin
type waitForSQL struct {
timeout *time.Duration

URL func(host string, port string) string
URL func(host string, port network.Port) string
Driver string
Port string
startupTimeout time.Duration
Expand Down Expand Up @@ -114,7 +114,7 @@ func (w *waitForSQL) WaitUntilReady(ctx context.Context, target StrategyTarget)
}
}

db, err := sql.Open(w.Driver, w.URL(host, port.String()))
db, err := sql.Open(w.Driver, w.URL(host, port))
if err != nil {
return fmt.Errorf("sql.Open: %w", err)
}
Expand Down
18 changes: 9 additions & 9 deletions wait/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func Test_waitForSql_WithQuery(t *testing.T) {
t.Run("default query", func(t *testing.T) {
w := ForSQL("5432/tcp", "postgres", func(_ string, _ string) string {
w := ForSQL("5432/tcp", "postgres", func(_ string, _ network.Port) string {
return "fake-url"
})

Expand All @@ -23,7 +23,7 @@ func Test_waitForSql_WithQuery(t *testing.T) {
t.Run("custom query", func(t *testing.T) {
const q = "SELECT 100;"

w := ForSQL("5432/tcp", "postgres", func(_ string, _ string) string {
w := ForSQL("5432/tcp", "postgres", func(_ string, _ network.Port) string {
return "fake-url"
}).WithQuery(q)

Expand Down Expand Up @@ -95,7 +95,7 @@ func TestWaitForSQLSucceeds(t *testing.T) {
},
}

wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
WithStartupTimeout(500 * time.Millisecond).
WithPollInterval(100 * time.Millisecond)

Expand Down Expand Up @@ -123,7 +123,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToOOMKilledContainer(t *testing.T) {
},
}

wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
WithStartupTimeout(500 * time.Millisecond).
WithPollInterval(100 * time.Millisecond)

Expand Down Expand Up @@ -154,7 +154,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToExitedContainer(t *testing.T) {
},
}

wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
WithStartupTimeout(500 * time.Millisecond).
WithPollInterval(100 * time.Millisecond)

Expand Down Expand Up @@ -184,7 +184,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToUnexpectedContainerStatus(t *testin
},
}

wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
WithStartupTimeout(500 * time.Millisecond).
WithPollInterval(100 * time.Millisecond)

Expand All @@ -209,7 +209,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToOOMKilledContainer(t *testing.T)
},
}

wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
WithStartupTimeout(500 * time.Millisecond).
WithPollInterval(100 * time.Millisecond)

Expand All @@ -235,7 +235,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToExitedContainer(t *testing.T) {
},
}

wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
WithStartupTimeout(500 * time.Millisecond).
WithPollInterval(100 * time.Millisecond)

Expand All @@ -260,7 +260,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToUnexpectedContainerStatus(t *tes
},
}

wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
WithStartupTimeout(500 * time.Millisecond).
WithPollInterval(100 * time.Millisecond)

Expand Down
Loading