Skip to content

Commit bf6f920

Browse files
committed
Added dynamic port allocation, GetPort helper, and test
1 parent 0c7d03d commit bf6f920

2 files changed

Lines changed: 84 additions & 0 deletions

File tree

embedded_postgres.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ func (ep *EmbeddedPostgres) Start() error {
7171
return ErrServerAlreadyStarted
7272
}
7373

74+
if ep.config.port == 0 {
75+
port, err := getFreePort()
76+
if err != nil {
77+
return err
78+
}
79+
80+
ep.config.port = port
81+
}
82+
7483
if err := ensurePortAvailable(ep.config.port); err != nil {
7584
return err
7685
}
@@ -147,6 +156,10 @@ func (ep *EmbeddedPostgres) Start() error {
147156
return nil
148157
}
149158

159+
func (ep *EmbeddedPostgres) GetPort() uint32 {
160+
return ep.config.port
161+
}
162+
150163
func (ep *EmbeddedPostgres) downloadAndExtractBinary(cacheExists bool, cacheLocation string) error {
151164
// lock to prevent collisions with duplicate downloads
152165
mu.Lock()
@@ -242,6 +255,10 @@ func stopPostgres(ep *EmbeddedPostgres) error {
242255
}
243256

244257
func ensurePortAvailable(port uint32) error {
258+
if port == 0 {
259+
260+
}
261+
245262
conn, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
246263
if err != nil {
247264
return fmt.Errorf("process already listening on port %d", port)
@@ -254,6 +271,21 @@ func ensurePortAvailable(port uint32) error {
254271
return nil
255272
}
256273

274+
func getFreePort() (uint32, error) {
275+
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
276+
if err != nil {
277+
return 0, fmt.Errorf("failed to resolve TCP address: %w", err)
278+
}
279+
280+
l, err := net.ListenTCP("tcp", addr)
281+
if err != nil {
282+
return 0, fmt.Errorf("failed to listen on TCP: %w", err)
283+
}
284+
defer l.Close()
285+
286+
return uint32(l.Addr().(*net.TCPAddr).Port), nil
287+
}
288+
257289
func dataDirIsValid(dataDir string, version PostgresVersion) bool {
258290
pgVersion := filepath.Join(dataDir, "PG_VERSION")
259291

embedded_postgres_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,3 +833,55 @@ func Test_RunningInParallel(t *testing.T) {
833833

834834
waitGroup.Wait()
835835
}
836+
837+
func Test_DynamicallyAllocatingPort(t *testing.T) {
838+
tempDir, err := os.MkdirTemp("", "embedded_postgres_test")
839+
if err != nil {
840+
panic(err)
841+
}
842+
843+
defer func() {
844+
if err := os.RemoveAll(tempDir); err != nil {
845+
panic(err)
846+
}
847+
}()
848+
849+
database := NewDatabase(DefaultConfig().
850+
Username("gin").
851+
Password("wine").
852+
Database("beer").
853+
Version(V15).
854+
RuntimePath(tempDir).
855+
Port(0).
856+
StartTimeout(10 * time.Second).
857+
Locale("C").
858+
Encoding("UTF8").
859+
Logger(nil))
860+
861+
if err := database.Start(); err != nil {
862+
shutdownDBAndFail(t, err, database)
863+
}
864+
865+
port := database.GetPort()
866+
867+
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d user=gin password=wine dbname=beer sslmode=disable", port))
868+
if err != nil {
869+
shutdownDBAndFail(t, err, database)
870+
}
871+
872+
if !strings.Contains(database.config.GetConnectionURL(), fmt.Sprint(port)) {
873+
shutdownDBAndFail(t, errors.New("wrong port in connection url"), database)
874+
}
875+
876+
if err = db.Ping(); err != nil {
877+
shutdownDBAndFail(t, err, database)
878+
}
879+
880+
if err := db.Close(); err != nil {
881+
shutdownDBAndFail(t, err, database)
882+
}
883+
884+
if err := database.Stop(); err != nil {
885+
shutdownDBAndFail(t, err, database)
886+
}
887+
}

0 commit comments

Comments
 (0)