Skip to content

Commit a979eb2

Browse files
committed
chore(wait)!: ForSQL: change url callback to accept network.Network
Commit 3e56fb9 updated this code to use the moby modules, and changed signatures to accept a string for ports ("<port>[/<protocol>]"), handling conversion to strong types (`network.Port`) internally. In this case, however, the url callback function gets passed the host and port that's already parsed internally, so downgrading the parsed `network.Port` to a string means that a custom callback would have to parse the string again. Change it to a strong-typed `network.Port`, which keeps all options open; callbacks that want to use the port in its string-format can call `port.String()`, but if only the port or proto is needed, they can use the `port.Port()`, `port.Proto()` or other methods. Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
1 parent c251392 commit a979eb2

4 files changed

Lines changed: 26 additions & 31 deletions

File tree

modules/cockroachdb/cockroachdb.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom
199199
certsDir+"/client."+defaultUser+".crt",
200200
certsDir+"/client."+defaultUser+".key",
201201
).WithRootCAs(fileCACert).WithServerName("127.0.0.1"),
202-
wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port string) string {
202+
wait.ForSQL(defaultSQLPort, "pgx/v5", func(host string, port network.Port) string {
203203
connStr, err := ctr.connString(host, port)
204204
if err != nil {
205205
panic(err)
@@ -225,7 +225,7 @@ func Run(ctx context.Context, img string, opts ...testcontainers.ContainerCustom
225225
}
226226

227227
// connString returns a connection string for the given host, port and options.
228-
func (c *CockroachDBContainer) connString(host string, port string) (string, error) {
228+
func (c *CockroachDBContainer) connString(host string, port network.Port) (string, error) {
229229
cfg, err := c.connConfig(host, port)
230230
if err != nil {
231231
return "", fmt.Errorf("connection config: %w", err)
@@ -235,11 +235,7 @@ func (c *CockroachDBContainer) connString(host string, port string) (string, err
235235
}
236236

237237
// connConfig returns a [pgx.ConnConfig] for the given host, port and options.
238-
func (c *CockroachDBContainer) connConfig(host string, port string) (*pgx.ConnConfig, error) {
239-
p, err := network.ParsePort(port)
240-
if err != nil {
241-
return nil, err
242-
}
238+
func (c *CockroachDBContainer) connConfig(host string, port network.Port) (*pgx.ConnConfig, error) {
243239
var user *url.Userinfo
244240
if c.password != "" {
245241
user = url.UserPassword(c.user, c.password)
@@ -264,7 +260,7 @@ func (c *CockroachDBContainer) connConfig(host string, port string) (*pgx.ConnCo
264260
u := url.URL{
265261
Scheme: "postgres",
266262
User: user,
267-
Host: net.JoinHostPort(host, p.Port()),
263+
Host: net.JoinHostPort(host, port.Port()),
268264
Path: c.database,
269265
RawQuery: params.Encode(),
270266
}

modules/postgres/postgres_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,7 @@ func TestContainerWithWaitForSQL(t *testing.T) {
143143
ctx := context.Background()
144144

145145
port := "5432/tcp"
146-
dbURL := func(host string, port string) string {
147-
p := network.MustParsePort(port)
146+
dbURL := func(host string, p network.Port) string {
148147
return fmt.Sprintf("postgres://postgres:password@%s:%s/%s?sslmode=disable", host, p.Port(), dbname)
149148
}
150149

wait/sql.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ var (
1717
const defaultForSQLQuery = "SELECT 1"
1818

1919
// ForSQL constructs a new waitForSql strategy for the given driver
20-
func ForSQL(port string, driver string, url func(host string, port string) string) *waitForSQL {
20+
func ForSQL(port string, driver string, url func(host string, port network.Port) string) *waitForSQL {
2121
return &waitForSQL{
2222
Port: port,
2323
URL: url,
@@ -31,7 +31,7 @@ func ForSQL(port string, driver string, url func(host string, port string) strin
3131
type waitForSQL struct {
3232
timeout *time.Duration
3333

34-
URL func(host string, port string) string
34+
URL func(host string, port network.Port) string
3535
Driver string
3636
Port string
3737
startupTimeout time.Duration
@@ -114,7 +114,7 @@ func (w *waitForSQL) WaitUntilReady(ctx context.Context, target StrategyTarget)
114114
}
115115
}
116116

117-
db, err := sql.Open(w.Driver, w.URL(host, port.String()))
117+
db, err := sql.Open(w.Driver, w.URL(host, port))
118118
if err != nil {
119119
return fmt.Errorf("sql.Open: %w", err)
120120
}

wait/sql_test.go

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414

1515
func Test_waitForSql_WithQuery(t *testing.T) {
1616
t.Run("default query", func(t *testing.T) {
17-
w := ForSQL("5432/tcp", "postgres", func(_ string, _ string) string {
17+
w := ForSQL("5432/tcp", "postgres", func(_ string, _ network.Port) string {
1818
return "fake-url"
1919
})
2020

@@ -23,7 +23,7 @@ func Test_waitForSql_WithQuery(t *testing.T) {
2323
t.Run("custom query", func(t *testing.T) {
2424
const q = "SELECT 100;"
2525

26-
w := ForSQL("5432/tcp", "postgres", func(_ string, _ string) string {
26+
w := ForSQL("5432/tcp", "postgres", func(_ string, _ network.Port) string {
2727
return "fake-url"
2828
}).WithQuery(q)
2929

@@ -39,7 +39,7 @@ type mockSQLDriver struct {
3939
driver.Driver
4040
}
4141

42-
func (sd *mockSQLDriver) Open(_ string) (driver.Conn, error) {
42+
func (sd *mockSQLDriver) Open(_ network.Port) (driver.Conn, error) {
4343
return &mockSQLConn{}, nil
4444
}
4545

@@ -53,7 +53,7 @@ func (sc *mockSQLConn) Close() error {
5353
return nil
5454
}
5555

56-
func (sc *mockSQLConn) PrepareContext(_ context.Context, _ string) (driver.Stmt, error) {
56+
func (sc *mockSQLConn) PrepareContext(_ context.Context, _ network.Port) (driver.Stmt, error) {
5757
return &mockSQLStmt{}, nil
5858
}
5959

@@ -81,7 +81,7 @@ func TestWaitForSQLSucceeds(t *testing.T) {
8181
HostImpl: func(_ context.Context) (string, error) {
8282
return "localhost", nil
8383
},
84-
MappedPortImpl: func(_ context.Context, _ string) (network.Port, error) {
84+
MappedPortImpl: func(_ context.Context, _ network.Port) (network.Port, error) {
8585
defer func() { mappedPortCount++ }()
8686
if mappedPortCount == 0 {
8787
return network.Port{}, ErrPortNotFound
@@ -95,7 +95,7 @@ func TestWaitForSQLSucceeds(t *testing.T) {
9595
},
9696
}
9797

98-
wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
98+
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
9999
WithStartupTimeout(500 * time.Millisecond).
100100
WithPollInterval(100 * time.Millisecond)
101101

@@ -109,7 +109,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToOOMKilledContainer(t *testing.T) {
109109
HostImpl: func(_ context.Context) (string, error) {
110110
return "localhost", nil
111111
},
112-
MappedPortImpl: func(_ context.Context, _ string) (network.Port, error) {
112+
MappedPortImpl: func(_ context.Context, _ network.Port) (network.Port, error) {
113113
defer func() { mappedPortCount++ }()
114114
if mappedPortCount == 0 {
115115
return network.Port{}, ErrPortNotFound
@@ -123,7 +123,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToOOMKilledContainer(t *testing.T) {
123123
},
124124
}
125125

126-
wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
126+
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
127127
WithStartupTimeout(500 * time.Millisecond).
128128
WithPollInterval(100 * time.Millisecond)
129129

@@ -139,7 +139,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToExitedContainer(t *testing.T) {
139139
HostImpl: func(_ context.Context) (string, error) {
140140
return "localhost", nil
141141
},
142-
MappedPortImpl: func(_ context.Context, _ string) (network.Port, error) {
142+
MappedPortImpl: func(_ context.Context, _ network.Port) (network.Port, error) {
143143
defer func() { mappedPortCount++ }()
144144
if mappedPortCount == 0 {
145145
return network.Port{}, ErrPortNotFound
@@ -154,7 +154,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToExitedContainer(t *testing.T) {
154154
},
155155
}
156156

157-
wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
157+
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
158158
WithStartupTimeout(500 * time.Millisecond).
159159
WithPollInterval(100 * time.Millisecond)
160160

@@ -170,7 +170,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToUnexpectedContainerStatus(t *testin
170170
HostImpl: func(_ context.Context) (string, error) {
171171
return "localhost", nil
172172
},
173-
MappedPortImpl: func(_ context.Context, _ string) (network.Port, error) {
173+
MappedPortImpl: func(_ context.Context, _ network.Port) (network.Port, error) {
174174
defer func() { mappedPortCount++ }()
175175
if mappedPortCount == 0 {
176176
return network.Port{}, ErrPortNotFound
@@ -184,7 +184,7 @@ func TestWaitForSQLFailsWhileGettingPortDueToUnexpectedContainerStatus(t *testin
184184
},
185185
}
186186

187-
wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
187+
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
188188
WithStartupTimeout(500 * time.Millisecond).
189189
WithPollInterval(100 * time.Millisecond)
190190

@@ -199,7 +199,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToOOMKilledContainer(t *testing.T)
199199
HostImpl: func(_ context.Context) (string, error) {
200200
return "localhost", nil
201201
},
202-
MappedPortImpl: func(_ context.Context, _ string) (network.Port, error) {
202+
MappedPortImpl: func(_ context.Context, _ network.Port) (network.Port, error) {
203203
return network.MustParsePort("49152"), nil
204204
},
205205
StateImpl: func(_ context.Context) (*container.State, error) {
@@ -209,7 +209,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToOOMKilledContainer(t *testing.T)
209209
},
210210
}
211211

212-
wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
212+
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
213213
WithStartupTimeout(500 * time.Millisecond).
214214
WithPollInterval(100 * time.Millisecond)
215215

@@ -224,7 +224,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToExitedContainer(t *testing.T) {
224224
HostImpl: func(_ context.Context) (string, error) {
225225
return "localhost", nil
226226
},
227-
MappedPortImpl: func(_ context.Context, _ string) (network.Port, error) {
227+
MappedPortImpl: func(_ context.Context, _ network.Port) (network.Port, error) {
228228
return network.MustParsePort("49152"), nil
229229
},
230230
StateImpl: func(_ context.Context) (*container.State, error) {
@@ -235,7 +235,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToExitedContainer(t *testing.T) {
235235
},
236236
}
237237

238-
wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
238+
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
239239
WithStartupTimeout(500 * time.Millisecond).
240240
WithPollInterval(100 * time.Millisecond)
241241

@@ -250,7 +250,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToUnexpectedContainerStatus(t *tes
250250
HostImpl: func(_ context.Context) (string, error) {
251251
return "localhost", nil
252252
},
253-
MappedPortImpl: func(_ context.Context, _ string) (network.Port, error) {
253+
MappedPortImpl: func(_ context.Context, _ network.Port) (network.Port, error) {
254254
return network.MustParsePort("49152"), nil
255255
},
256256
StateImpl: func(_ context.Context) (*container.State, error) {
@@ -260,7 +260,7 @@ func TestWaitForSQLFailsWhileQueryExecutingDueToUnexpectedContainerStatus(t *tes
260260
},
261261
}
262262

263-
wg := ForSQL("3306", "mock", func(_ string, _ string) string { return "" }).
263+
wg := ForSQL("3306", "mock", func(_ string, _ network.Port) string { return "" }).
264264
WithStartupTimeout(500 * time.Millisecond).
265265
WithPollInterval(100 * time.Millisecond)
266266

0 commit comments

Comments
 (0)