From 88fdff77b2557914448ef78b55d4a663e3c96685 Mon Sep 17 00:00:00 2001 From: Sebastiaan van Stijn Date: Wed, 15 Apr 2026 23:34:56 +0200 Subject: [PATCH] chore!: wait.ForSQL: change url callback to accept network.Network Commit 3e56fb95cf7337bf1f262eaea23db7857a2c7894 updated this code to use the moby modules, and changed signatures to accept a string for ports ("[/]"), handling conversion to strong types (`network.Port`) internally. In this case, however, the url callback function gets passed the host and port that's already parsed internally, so downgrading the parsed `network.Port` to a string means that a custom callback would have to parse the string again. Change it to a strong-typed `network.Port`, which keeps all options open; callbacks that want to use the port in its string-format can call `port.String()`, but if only the port or proto is needed, they can use the `port.Port()`, `port.Proto()` or other methods. Signed-off-by: Sebastiaan van Stijn --- modules/cockroachdb/cockroachdb.go | 14 +++++--------- modules/postgres/postgres_test.go | 3 +-- wait/sql.go | 6 +++--- wait/sql_test.go | 18 +++++++++--------- 4 files changed, 18 insertions(+), 23 deletions(-) diff --git a/modules/cockroachdb/cockroachdb.go b/modules/cockroachdb/cockroachdb.go index a9e0fdce97..86d6269cd2 100644 --- a/modules/cockroachdb/cockroachdb.go +++ b/modules/cockroachdb/cockroachdb.go @@ -120,7 +120,7 @@ func (c *CockroachDBContainer) ConnectionConfig(ctx context.Context) (*pgx.ConnC return nil, fmt.Errorf("host: %w", err) } - return c.connConfig(host, port.String()) + return c.connConfig(host, port) } // TLSConfig returns config necessary to connect to CockroachDB over TLS. @@ -199,7 +199,7 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom certsDir+"/client."+defaultUser+".crt", certsDir+"/client."+defaultUser+".key", ).WithRootCAs(fileCACert).WithServerName("127.0.0.1"), - wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port string) string { + wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port network.Port) string { connStr, err := ctr.connString(host, port) if err != nil { panic(err) @@ -225,7 +225,7 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom } // connString returns a connection string for the given host, port and options. -func (c *CockroachDBContainer) connString(host string, port string) (string, error) { +func (c *CockroachDBContainer) connString(host string, port network.Port) (string, error) { cfg, err := c.connConfig(host, port) if err != nil { return "", fmt.Errorf("connection config: %w", err) @@ -235,11 +235,7 @@ func (c *CockroachDBContainer) connString(host string, port string) (string, err } // connConfig returns a [pgx.ConnConfig] for the given host, port and options. -func (c *CockroachDBContainer) connConfig(host string, port string) (*pgx.ConnConfig, error) { - p, err := network.ParsePort(port) - if err != nil { - return nil, err - } +func (c *CockroachDBContainer) connConfig(host string, port network.Port) (*pgx.ConnConfig, error) { var user *url.Userinfo if c.password != "" { user = url.UserPassword(c.user, c.password) @@ -264,7 +260,7 @@ func (c *CockroachDBContainer) connConfig(host string, port string) (*pgx.ConnCo u := url.URL{ Scheme: "postgres", User: user, - Host: net.JoinHostPort(host, p.Port()), + Host: net.JoinHostPort(host, port.Port()), Path: c.database, RawQuery: params.Encode(), } diff --git a/modules/postgres/postgres_test.go b/modules/postgres/postgres_test.go index 4bce0bb07e..c58fe5013f 100644 --- a/modules/postgres/postgres_test.go +++ b/modules/postgres/postgres_test.go @@ -143,8 +143,7 @@ func TestContainerWithWaitForSQL(t *testing.T) { ctx := context.Background() port := "5432/tcp" - dbURL := func(host string, port string) string { - p := network.MustParsePort(port) + dbURL := func(host string, p network.Port) string { return fmt.Sprintf("postgres://postgres:password@%s:%s/%s?sslmode=disable", host, p.Port(), dbname) } diff --git a/wait/sql.go b/wait/sql.go index 5d0228a075..e328be2885 100644 --- a/wait/sql.go +++ b/wait/sql.go @@ -17,7 +17,7 @@ var ( const defaultForSQLQuery = "SELECT 1" // ForSQL constructs a new waitForSql strategy for the given driver -func ForSQL(port string, driver string, url func(host string, port string) string) *waitForSQL { +func ForSQL(port string, driver string, url func(host string, port network.Port) string) *waitForSQL { return &waitForSQL{ Port: port, URL: url, @@ -31,7 +31,7 @@ func ForSQL(port string, driver string, url func(host string, port string) strin type waitForSQL struct { timeout *time.Duration - URL func(host string, port string) string + URL func(host string, port network.Port) string Driver string Port string startupTimeout time.Duration @@ -114,7 +114,7 @@ func (w *waitForSQL) WaitUntilReady(ctx context.Context, target StrategyTarget) } } - db, err := sql.Open(w.Driver, w.URL(host, port.String())) + db, err := sql.Open(w.Driver, w.URL(host, port)) if err != nil { return fmt.Errorf("sql.Open: %w", err) } diff --git a/wait/sql_test.go b/wait/sql_test.go index 9cba02630a..651cc64e4f 100644 --- a/wait/sql_test.go +++ b/wait/sql_test.go @@ -14,7 +14,7 @@ import ( func Test_waitForSql_WithQuery(t *testing.T) { t.Run("default query", func(t *testing.T) { - w := ForSQL("5432/tcp", "postgres", func(_ string, _ string) string { + w := ForSQL("5432/tcp", "postgres", func(_ string, _ network.Port) string { return "fake-url" }) @@ -23,7 +23,7 @@ func Test_waitForSql_WithQuery(t *testing.T) { t.Run("custom query", func(t *testing.T) { const q = "SELECT 100;" - w := ForSQL("5432/tcp", "postgres", func(_ string, _ string) string { + w := ForSQL("5432/tcp", "postgres", func(_ string, _ network.Port) string { return "fake-url" }).WithQuery(q) @@ -95,7 +95,7 @@ func TestWaitForSQLSucceeds(t *testing.T) { }, } - wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }). + wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }). WithStartupTimeout(500 * time.Millisecond). WithPollInterval(100 * time.Millisecond) @@ -123,7 +123,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToOOMKilledContainer(t *testing.T) { }, } - wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }). + wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }). WithStartupTimeout(500 * time.Millisecond). WithPollInterval(100 * time.Millisecond) @@ -154,7 +154,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToExitedContainer(t *testing.T) { }, } - wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }). + wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }). WithStartupTimeout(500 * time.Millisecond). WithPollInterval(100 * time.Millisecond) @@ -184,7 +184,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToUnexpectedContainerStatus(t *testin }, } - wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }). + wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }). WithStartupTimeout(500 * time.Millisecond). WithPollInterval(100 * time.Millisecond) @@ -209,7 +209,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToOOMKilledContainer(t *testing.T) }, } - wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }). + wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }). WithStartupTimeout(500 * time.Millisecond). WithPollInterval(100 * time.Millisecond) @@ -235,7 +235,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToExitedContainer(t *testing.T) { }, } - wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }). + wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }). WithStartupTimeout(500 * time.Millisecond). WithPollInterval(100 * time.Millisecond) @@ -260,7 +260,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToUnexpectedContainerStatus(t *tes }, } - wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }). + wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }). WithStartupTimeout(500 * time.Millisecond). WithPollInterval(100 * time.Millisecond)