Skip to content

Commit 5c74fbd

Browse files
committed
Add test for pg and migrator
Signed-off-by: Bryan Frimin <bryan@getprobo.com>
1 parent 91dc646 commit 5c74fbd

File tree

2 files changed

+1280
-0
lines changed

2 files changed

+1280
-0
lines changed

migrator/migrator_test.go

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
package migrator_test
2+
3+
import (
4+
"context"
5+
"io"
6+
"testing"
7+
"testing/fstest"
8+
"time"
9+
10+
"github.com/prometheus/client_golang/prometheus"
11+
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
"go.gearno.de/kit/log"
14+
"go.gearno.de/kit/migrator"
15+
"go.gearno.de/kit/pg"
16+
)
17+
18+
func newTestPGClient(t *testing.T) *pg.Client {
19+
t.Helper()
20+
21+
client, err := pg.NewClient(
22+
pg.WithAddr("localhost:5432"),
23+
pg.WithUser("kit"),
24+
pg.WithPassword("kit"),
25+
pg.WithDatabase("kit_test"),
26+
pg.WithLogger(log.NewLogger(log.WithOutput(io.Discard))),
27+
pg.WithRegisterer(prometheus.NewRegistry()),
28+
)
29+
if err != nil {
30+
t.Skipf("skipping: cannot create PostgreSQL client: %v", err)
31+
}
32+
33+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
34+
defer cancel()
35+
36+
err = client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
37+
_, err := conn.Exec(ctx, "SELECT 1")
38+
return err
39+
})
40+
if err != nil {
41+
client.Close()
42+
t.Skipf("skipping: cannot connect to PostgreSQL: %v", err)
43+
}
44+
45+
t.Cleanup(client.Close)
46+
47+
return client
48+
}
49+
50+
func dropTables(t *testing.T, client *pg.Client, tables ...string) {
51+
t.Helper()
52+
ctx := context.Background()
53+
_ = client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
54+
for _, tbl := range tables {
55+
_, _ = conn.Exec(ctx, "DROP TABLE IF EXISTS "+tbl+" CASCADE")
56+
}
57+
return nil
58+
})
59+
}
60+
61+
// ---------------------------------------------------------------------------
62+
// Unit tests (no database required)
63+
// ---------------------------------------------------------------------------
64+
65+
func TestMigration_LoadFromFile(t *testing.T) {
66+
disk := fstest.MapFS{
67+
"migrations/001_create_users.sql": &fstest.MapFile{
68+
Data: []byte("CREATE TABLE users (id serial PRIMARY KEY);"),
69+
},
70+
}
71+
72+
var m migrator.Migration
73+
err := m.LoadFromFile(disk, "migrations/001_create_users.sql")
74+
require.NoError(t, err)
75+
assert.Equal(t, "001_create_users", m.Version)
76+
assert.Equal(t, "CREATE TABLE users (id serial PRIMARY KEY);", m.SQL)
77+
}
78+
79+
func TestMigrations_LoadFromDir(t *testing.T) {
80+
disk := fstest.MapFS{
81+
"migrations/002_add_email.sql": &fstest.MapFile{
82+
Data: []byte("ALTER TABLE users ADD COLUMN email TEXT;"),
83+
},
84+
"migrations/001_create_users.sql": &fstest.MapFile{
85+
Data: []byte("CREATE TABLE users (id serial PRIMARY KEY);"),
86+
},
87+
"migrations/README.md": &fstest.MapFile{
88+
Data: []byte("Not a migration"),
89+
},
90+
}
91+
92+
var ms migrator.Migrations
93+
err := ms.LoadFromDir(disk, "migrations")
94+
require.NoError(t, err)
95+
require.Len(t, ms, 2)
96+
97+
versions := map[string]bool{}
98+
for _, m := range ms {
99+
versions[m.Version] = true
100+
}
101+
assert.True(t, versions["001_create_users"])
102+
assert.True(t, versions["002_add_email"])
103+
}
104+
105+
func TestMigrations_Sort(t *testing.T) {
106+
ms := migrator.Migrations{
107+
{Version: "003_add_index", SQL: "CREATE INDEX ...;"},
108+
{Version: "001_create_users", SQL: "CREATE TABLE ...;"},
109+
{Version: "002_add_email", SQL: "ALTER TABLE ...;"},
110+
}
111+
112+
ms.Sort()
113+
114+
require.Len(t, ms, 3)
115+
assert.Equal(t, "001_create_users", ms[0].Version)
116+
assert.Equal(t, "002_add_email", ms[1].Version)
117+
assert.Equal(t, "003_add_index", ms[2].Version)
118+
}
119+
120+
// ---------------------------------------------------------------------------
121+
// Integration tests (require a running PostgreSQL instance)
122+
// ---------------------------------------------------------------------------
123+
124+
func TestMigrator_Run(t *testing.T) {
125+
client := newTestPGClient(t)
126+
ctx := context.Background()
127+
logger := log.NewLogger(log.WithOutput(io.Discard))
128+
129+
t.Run("applies migrations", func(t *testing.T) {
130+
dropTables(t, client, "schema_versions", "test_mig_users")
131+
t.Cleanup(func() { dropTables(t, client, "schema_versions", "test_mig_users") })
132+
133+
disk := fstest.MapFS{
134+
"migrations/001_create_users.sql": &fstest.MapFile{
135+
Data: []byte("CREATE TABLE test_mig_users (id serial PRIMARY KEY, name text NOT NULL);"),
136+
},
137+
}
138+
139+
m := migrator.NewMigrator(client, disk, logger)
140+
require.NoError(t, m.Run(ctx, "migrations"))
141+
142+
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
143+
var exists bool
144+
err := conn.QueryRow(ctx,
145+
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name = 'test_mig_users')").
146+
Scan(&exists)
147+
require.NoError(t, err)
148+
assert.True(t, exists)
149+
150+
var version string
151+
err = conn.QueryRow(ctx, "SELECT version FROM schema_versions").Scan(&version)
152+
require.NoError(t, err)
153+
assert.Equal(t, "001_create_users", version)
154+
155+
return nil
156+
})
157+
require.NoError(t, err)
158+
})
159+
160+
t.Run("idempotent", func(t *testing.T) {
161+
dropTables(t, client, "schema_versions", "test_mig_idem")
162+
t.Cleanup(func() { dropTables(t, client, "schema_versions", "test_mig_idem") })
163+
164+
disk := fstest.MapFS{
165+
"migrations/001_create_table.sql": &fstest.MapFile{
166+
Data: []byte("CREATE TABLE test_mig_idem (id serial PRIMARY KEY);"),
167+
},
168+
}
169+
170+
m := migrator.NewMigrator(client, disk, logger)
171+
require.NoError(t, m.Run(ctx, "migrations"))
172+
require.NoError(t, m.Run(ctx, "migrations"))
173+
174+
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
175+
var count int
176+
err := conn.QueryRow(ctx, "SELECT count(*) FROM schema_versions").Scan(&count)
177+
require.NoError(t, err)
178+
assert.Equal(t, 1, count)
179+
return nil
180+
})
181+
require.NoError(t, err)
182+
})
183+
184+
t.Run("applies multiple migrations in order", func(t *testing.T) {
185+
dropTables(t, client, "schema_versions", "test_mig_ordered")
186+
t.Cleanup(func() { dropTables(t, client, "schema_versions", "test_mig_ordered") })
187+
188+
disk := fstest.MapFS{
189+
"migrations/002_add_column.sql": &fstest.MapFile{
190+
Data: []byte("ALTER TABLE test_mig_ordered ADD COLUMN email TEXT;"),
191+
},
192+
"migrations/001_create_table.sql": &fstest.MapFile{
193+
Data: []byte("CREATE TABLE test_mig_ordered (id serial PRIMARY KEY, name text NOT NULL);"),
194+
},
195+
}
196+
197+
m := migrator.NewMigrator(client, disk, logger)
198+
require.NoError(t, m.Run(ctx, "migrations"))
199+
200+
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
201+
var count int
202+
err := conn.QueryRow(ctx, "SELECT count(*) FROM schema_versions").Scan(&count)
203+
require.NoError(t, err)
204+
assert.Equal(t, 2, count)
205+
206+
var exists bool
207+
err = conn.QueryRow(ctx, `
208+
SELECT EXISTS(
209+
SELECT 1 FROM information_schema.columns
210+
WHERE table_name = 'test_mig_ordered' AND column_name = 'email'
211+
)
212+
`).Scan(&exists)
213+
require.NoError(t, err)
214+
assert.True(t, exists, "email column should exist after migration 002")
215+
216+
return nil
217+
})
218+
require.NoError(t, err)
219+
})
220+
221+
t.Run("no migrations to apply", func(t *testing.T) {
222+
disk := fstest.MapFS{
223+
"migrations/readme.txt": &fstest.MapFile{
224+
Data: []byte("no sql files here"),
225+
},
226+
}
227+
228+
m := migrator.NewMigrator(client, disk, logger)
229+
require.NoError(t, m.Run(ctx, "migrations"))
230+
})
231+
232+
t.Run("incremental run only applies new migrations", func(t *testing.T) {
233+
dropTables(t, client, "schema_versions", "test_mig_incr")
234+
t.Cleanup(func() { dropTables(t, client, "schema_versions", "test_mig_incr") })
235+
236+
disk1 := fstest.MapFS{
237+
"migrations/001_create_table.sql": &fstest.MapFile{
238+
Data: []byte("CREATE TABLE test_mig_incr (id serial PRIMARY KEY);"),
239+
},
240+
}
241+
242+
m1 := migrator.NewMigrator(client, disk1, logger)
243+
require.NoError(t, m1.Run(ctx, "migrations"))
244+
245+
disk2 := fstest.MapFS{
246+
"migrations/001_create_table.sql": &fstest.MapFile{
247+
Data: []byte("CREATE TABLE test_mig_incr (id serial PRIMARY KEY);"),
248+
},
249+
"migrations/002_add_name.sql": &fstest.MapFile{
250+
Data: []byte("ALTER TABLE test_mig_incr ADD COLUMN name TEXT;"),
251+
},
252+
}
253+
254+
m2 := migrator.NewMigrator(client, disk2, logger)
255+
require.NoError(t, m2.Run(ctx, "migrations"))
256+
257+
err := client.WithConn(ctx, func(ctx context.Context, conn pg.Conn) error {
258+
var count int
259+
err := conn.QueryRow(ctx, "SELECT count(*) FROM schema_versions").Scan(&count)
260+
require.NoError(t, err)
261+
assert.Equal(t, 2, count)
262+
return nil
263+
})
264+
require.NoError(t, err)
265+
})
266+
267+
t.Run("returns error on invalid SQL", func(t *testing.T) {
268+
dropTables(t, client, "schema_versions")
269+
t.Cleanup(func() { dropTables(t, client, "schema_versions") })
270+
271+
disk := fstest.MapFS{
272+
"migrations/001_bad.sql": &fstest.MapFile{
273+
Data: []byte("THIS IS NOT VALID SQL;"),
274+
},
275+
}
276+
277+
m := migrator.NewMigrator(client, disk, logger)
278+
err := m.Run(ctx, "migrations")
279+
require.Error(t, err)
280+
assert.Contains(t, err.Error(), "cannot apply migration")
281+
})
282+
}
283+
284+
func TestMigrations_LoadFromDir_InvalidDir(t *testing.T) {
285+
disk := fstest.MapFS{}
286+
287+
var ms migrator.Migrations
288+
err := ms.LoadFromDir(disk, "nonexistent")
289+
require.Error(t, err)
290+
assert.Contains(t, err.Error(), "cannot read directory")
291+
}
292+
293+
func TestMigration_LoadFromFile_NotFound(t *testing.T) {
294+
disk := fstest.MapFS{}
295+
296+
var m migrator.Migration
297+
err := m.LoadFromFile(disk, "nonexistent.sql")
298+
require.Error(t, err)
299+
}

0 commit comments

Comments
 (0)