Skip to content

Commit 57d67c7

Browse files
committed
ssl, schema, etc
1 parent 63da979 commit 57d67c7

6 files changed

Lines changed: 270 additions & 26 deletions

File tree

.idea/sqldialects.xml

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pg.go

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
)
1818

1919
type migrationStatus = string
20+
type SslMode = string
2021

2122
const (
2223
EnvDatabaseAddress = "DB_ADDRESS"
@@ -31,7 +32,26 @@ const (
3132
EnvDatabaseName = "DB_NAME"
3233
EnvDatabaseNameDefault = "postgres"
3334

34-
EnvMigrationsEnabled = "DB_MIGRATIONS_ENABLED"
35+
EnvDatabaseSchema = "DB_SCHEMA"
36+
EnvDatabaseSchemaDefault = ""
37+
38+
EnvDatabaseMigrationSchema = "DB_MIGRATION_SCHEMA"
39+
EnvDatabaseMigrationSchemaDefault = ""
40+
41+
EnvDatabaseSslMode = "DB_SSL_MODE"
42+
EnvDatabaseSslModeDefault = SslModeDisable
43+
44+
EnvDatabaseSslRootCert = "DB_SSL_ROOT_CERT"
45+
EnvDatabaseSslRootCertDefault = ""
46+
47+
EnvDatabaseSslCert = "DB_SSL_CERT"
48+
EnvDatabaseSslCertDefault = ""
49+
50+
EnvDatabaseSslKey = "DB_SSL_KEY"
51+
EnvDatabaseSslKeyDefault = ""
52+
53+
EnvMigrationsEnabled = "DB_MIGRATIONS_ENABLED"
54+
EnvMigrationsEnabledDefault = true
3555

3656
EnvChangelogSchema = "DB_CHANGELOG_SCHEMA"
3757
EnvChangelogSchemaDefault = "public"
@@ -45,13 +65,26 @@ const (
4565
statusCompleted migrationStatus = "COMPLETED"
4666
statusError migrationStatus = "ERROR"
4767
statusNew migrationStatus = "NEW"
68+
69+
SslModeDisable SslMode = "disable"
70+
SslModeRequire SslMode = "require"
71+
SslModeVerifyFull SslMode = "verify-full"
72+
SslModeVerifyCA SslMode = "verify-ca"
73+
SslModePrefer SslMode = "prefer"
74+
SslModeAllow SslMode = "allow"
4875
)
4976

5077
type Configuration struct {
51-
Address string
52-
Username string
53-
Password string
54-
Name string
78+
Address string
79+
Username string
80+
Password string
81+
Name string
82+
Schema string
83+
MigrationSchema string
84+
SslMode SslMode
85+
SslRootCert string
86+
SslCert string
87+
SslKey string
5588

5689
MigrationsEnabled bool
5790
ChangelogSchema string
@@ -76,10 +109,34 @@ func CreateConfigurationFromEnv() Configuration {
76109
if name == "" {
77110
name = EnvDatabaseNameDefault
78111
}
112+
schema := os.Getenv(EnvDatabaseSchema)
113+
if schema == "" {
114+
schema = EnvDatabaseSchemaDefault
115+
}
116+
migrationSchema := os.Getenv(EnvDatabaseMigrationSchema)
117+
if migrationSchema == "" {
118+
migrationSchema = EnvDatabaseMigrationSchemaDefault
119+
}
120+
sslMode := os.Getenv(EnvDatabaseSslMode)
121+
if sslMode == "" {
122+
sslMode = EnvDatabaseSslModeDefault
123+
}
124+
sslRootCert := os.Getenv(EnvDatabaseSslRootCert)
125+
if sslRootCert == "" {
126+
sslRootCert = EnvDatabaseSslRootCertDefault
127+
}
128+
sslCert := os.Getenv(EnvDatabaseSslCert)
129+
if sslCert == "" {
130+
sslCert = EnvDatabaseSslCertDefault
131+
}
132+
sslKey := os.Getenv(EnvDatabaseSslKey)
133+
if sslKey == "" {
134+
sslKey = EnvDatabaseSslKeyDefault
135+
}
79136

80137
migrationsEnabled, err := strconv.ParseBool(os.Getenv(EnvMigrationsEnabled))
81138
if err != nil {
82-
migrationsEnabled = false
139+
migrationsEnabled = EnvMigrationsEnabledDefault
83140
}
84141

85142
changelogSchema := os.Getenv(EnvChangelogSchema)
@@ -99,6 +156,12 @@ func CreateConfigurationFromEnv() Configuration {
99156
Username: username,
100157
Password: password,
101158
Name: name,
159+
Schema: schema,
160+
MigrationSchema: migrationSchema,
161+
SslMode: sslMode,
162+
SslRootCert: sslRootCert,
163+
SslCert: sslCert,
164+
SslKey: sslKey,
102165
MigrationsEnabled: migrationsEnabled,
103166
ChangelogSchema: changelogSchema,
104167
ChangelogTable: changelogTable,
@@ -107,6 +170,9 @@ func CreateConfigurationFromEnv() Configuration {
107170
}
108171

109172
func (c Configuration) schemaTable() string {
173+
if c.ChangelogSchema == "" {
174+
return c.ChangelogTable
175+
}
110176
return c.ChangelogSchema + "." + c.ChangelogTable
111177
}
112178

@@ -116,7 +182,16 @@ func Connect() (*pgxpool.Pool, error) {
116182
}
117183

118184
func ConnectWithConfig(c Configuration) (*pgxpool.Pool, error) {
119-
url := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable", c.Username, c.Password, c.Address, c.Name)
185+
url := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=%s", c.Username, c.Password, c.Address, c.Name, c.SslMode)
186+
if c.SslRootCert != "" {
187+
url += "&sslrootcert=" + c.SslRootCert
188+
}
189+
if c.SslCert != "" {
190+
url += "&sslcert=" + c.SslCert
191+
}
192+
if c.SslKey != "" {
193+
url += "&sslkey=" + c.SslKey
194+
}
120195
config, err := pgxpool.ParseConfig(url)
121196
if err != nil {
122197
return nil, err
@@ -132,6 +207,12 @@ func ConnectWithConfig(c Configuration) (*pgxpool.Pool, error) {
132207
return nil, err
133208
}
134209
}
210+
if c.Schema != "" {
211+
_, err := pool.Exec(context.Background(), "SET search_path TO "+c.Schema)
212+
if err != nil {
213+
return nil, err
214+
}
215+
}
135216
return pool, nil
136217
}
137218

@@ -176,6 +257,22 @@ func (dbm *databaseMigrator) Migrate() error {
176257
if err != nil {
177258
return err
178259
}
260+
if dbm.Configuration.MigrationSchema != "" {
261+
exists, err := dbm.schemaExists(dbm.Configuration.MigrationSchema)
262+
if err != nil {
263+
return err
264+
}
265+
if !exists {
266+
err = dbm.createSchema(dbm.Configuration.MigrationSchema)
267+
if err != nil {
268+
return err
269+
}
270+
}
271+
_, err = tx.Exec(context.Background(), "SET search_path TO "+dbm.Configuration.MigrationSchema)
272+
if err != nil {
273+
return err
274+
}
275+
}
179276
for _, migration := range migrations {
180277
err = dbm.applyMigration(migration, tx)
181278
if err != nil {
@@ -336,6 +433,17 @@ func (dbm *databaseMigrator) initChangelogTable() error {
336433
return nil
337434
}
338435

436+
func (dbm *databaseMigrator) schemaExists(schema string) (bool, error) {
437+
querySql := "SELECT EXISTS (SELECT FROM information_schema.schemata WHERE schemata.schema_name = $1)"
438+
row := dbm.PgxPool.QueryRow(context.Background(), querySql, schema)
439+
var exists bool
440+
err := row.Scan(&exists)
441+
if err != nil {
442+
return false, err
443+
}
444+
return exists, nil
445+
}
446+
339447
func (dbm *databaseMigrator) tableExists(schema string, table string) (bool, error) {
340448
//goland:noinspection SqlResolve
341449
querySql := "SELECT EXISTS (SELECT FROM pg_tables WHERE schemaname = $1 AND tablename = $2)"
@@ -348,6 +456,14 @@ func (dbm *databaseMigrator) tableExists(schema string, table string) (bool, err
348456
return exists, nil
349457
}
350458

459+
func (dbm *databaseMigrator) createSchema(schema string) error {
460+
_, err := dbm.PgxPool.Exec(context.Background(), "CREATE SCHEMA IF NOT EXISTS "+schema)
461+
if err != nil {
462+
return err
463+
}
464+
return nil
465+
}
466+
351467
func (dbm *databaseMigrator) createChangelogTable() error {
352468
tx, err := dbm.PgxPool.Begin(context.Background())
353469
if err != nil {

pg_test.go

Lines changed: 124 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,37 +9,108 @@ import (
99
)
1010

1111
func TestConnect(t *testing.T) {
12-
containerRequest := testcontainers.ContainerRequest{
13-
Image: "postgres:17",
14-
ExposedPorts: []string{"5432/tcp"},
15-
Env: map[string]string{
16-
"POSTGRES_USER": "test_user",
17-
"POSTGRES_PASSWORD": "test_password",
18-
"POSTGRES_DB": "test_db",
19-
},
20-
WaitingFor: wait.ForListeningPort("5432/tcp").WithPollInterval(time.Second),
12+
postgres := createContainer(t)
13+
defer func(postgres testcontainers.Container, ctx context.Context) {
14+
_ = postgres.Terminate(ctx)
15+
}(postgres, context.Background())
16+
17+
t.Setenv(EnvMigrationsDirectory, "testdb")
18+
t.Setenv(EnvChangelogSchema, "testschema")
19+
pool, err := Connect()
20+
if err != nil {
21+
t.Error(err)
2122
}
22-
postgres, err := testcontainers.GenericContainer(context.Background(), testcontainers.GenericContainerRequest{
23-
ContainerRequest: containerRequest,
24-
Started: true,
25-
})
23+
24+
rows, err := pool.Query(context.Background(), "SELECT id, name, description FROM testtable")
2625
if err != nil {
2726
t.Error(err)
2827
}
28+
data := make([]map[string]interface{}, 0)
29+
for rows.Next() {
30+
var id int
31+
var name string
32+
var description string
33+
err = rows.Scan(&id, &name, &description)
34+
if err != nil {
35+
t.Error(err)
36+
}
37+
data = append(data, map[string]interface{}{
38+
"id": id,
39+
"name": name,
40+
"description": description,
41+
})
42+
}
43+
if len(data) != 1 {
44+
t.Error("data len is not 1")
45+
}
46+
if data[0]["id"] != 1 {
47+
t.Error("id should be 1")
48+
}
49+
if data[0]["name"] != "name1" {
50+
t.Error("name should be name1")
51+
}
52+
if data[0]["description"] != "name1" {
53+
t.Error("description should be name1")
54+
}
55+
}
56+
57+
func TestConnectOtherSchema(t *testing.T) {
58+
postgres := createContainer(t)
2959
defer func(postgres testcontainers.Container, ctx context.Context) {
3060
_ = postgres.Terminate(ctx)
3161
}(postgres, context.Background())
3262

33-
port, err := postgres.MappedPort(context.Background(), "5432")
63+
t.Setenv(EnvMigrationsDirectory, "testdb_schema")
64+
t.Setenv(EnvChangelogSchema, "testschema")
65+
t.Setenv(EnvDatabaseSchema, "dbschema")
66+
pool, err := Connect()
3467
if err != nil {
3568
t.Error(err)
3669
}
37-
t.Setenv(EnvDatabaseAddress, "localhost:"+port.Port())
38-
t.Setenv(EnvDatabaseUsername, "test_user")
39-
t.Setenv(EnvDatabasePassword, "test_password")
40-
t.Setenv(EnvDatabaseName, "test_db")
70+
71+
rows, err := pool.Query(context.Background(), "SELECT id, name, description FROM testtable")
72+
if err != nil {
73+
t.Error(err)
74+
}
75+
data := make([]map[string]interface{}, 0)
76+
for rows.Next() {
77+
var id int
78+
var name string
79+
var description string
80+
err = rows.Scan(&id, &name, &description)
81+
if err != nil {
82+
t.Error(err)
83+
}
84+
data = append(data, map[string]interface{}{
85+
"id": id,
86+
"name": name,
87+
"description": description,
88+
})
89+
}
90+
if len(data) != 1 {
91+
t.Error("data len is not 1")
92+
}
93+
if data[0]["id"] != 1 {
94+
t.Error("id should be 1")
95+
}
96+
if data[0]["name"] != "name1" {
97+
t.Error("name should be name1")
98+
}
99+
if data[0]["description"] != "name1" {
100+
t.Error("description should be name1")
101+
}
102+
}
103+
104+
func TestConnectSchemaInMigration(t *testing.T) {
105+
postgres := createContainer(t)
106+
defer func(postgres testcontainers.Container, ctx context.Context) {
107+
_ = postgres.Terminate(ctx)
108+
}(postgres, context.Background())
109+
41110
t.Setenv(EnvMigrationsDirectory, "testdb")
42-
t.Setenv(EnvMigrationsEnabled, "true")
111+
t.Setenv(EnvChangelogSchema, "testschema")
112+
t.Setenv(EnvDatabaseSchema, "dbschema")
113+
t.Setenv(EnvDatabaseMigrationSchema, "dbschema")
43114
pool, err := Connect()
44115
if err != nil {
45116
t.Error(err)
@@ -77,3 +148,37 @@ func TestConnect(t *testing.T) {
77148
t.Error("description should be name1")
78149
}
79150
}
151+
152+
func createContainer(t *testing.T) testcontainers.Container {
153+
containerRequest := testcontainers.ContainerRequest{
154+
Image: "postgres:17",
155+
ExposedPorts: []string{"5432/tcp"},
156+
Env: map[string]string{
157+
"POSTGRES_USER": "test_user",
158+
"POSTGRES_PASSWORD": "test_password",
159+
"POSTGRES_DB": "test_db",
160+
},
161+
WaitingFor: wait.ForListeningPort("5432/tcp").WithPollInterval(time.Second),
162+
}
163+
postgres, err := testcontainers.GenericContainer(context.Background(), testcontainers.GenericContainerRequest{
164+
ContainerRequest: containerRequest,
165+
Started: true,
166+
})
167+
if err != nil {
168+
t.Error(err)
169+
}
170+
configure(t, postgres)
171+
return postgres
172+
}
173+
174+
func configure(t *testing.T, postgres testcontainers.Container) {
175+
port, err := postgres.MappedPort(context.Background(), "5432")
176+
if err != nil {
177+
t.Error(err)
178+
}
179+
t.Setenv(EnvDatabaseAddress, "localhost:"+port.Port())
180+
t.Setenv(EnvDatabaseUsername, "test_user")
181+
t.Setenv(EnvDatabasePassword, "test_password")
182+
t.Setenv(EnvDatabaseName, "test_db")
183+
t.Setenv(EnvMigrationsEnabled, "true")
184+
}

testdb_schema/0_1_init_data.sql

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
3+
INSERT INTO dbschema.testtable (id, name) VALUES (1, 'name1');

0 commit comments

Comments
 (0)