Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ func removeAuthEnvVar(t *testing.T) (*oauth2.Token, string, func()) {
if err != nil {
t.Errorf("failed to get token: %v", err)
}
if *ipType != "public" {
return tok, "", func() {}
}
path, ok := os.LookupEnv("GOOGLE_APPLICATION_CREDENTIALS")
if !ok {
t.Fatalf("GOOGLE_APPLICATION_CREDENTIALS was not set in the environment")
Expand Down
3 changes: 3 additions & 0 deletions tests/fuse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import (
)

func TestPostgresFUSEConnect(t *testing.T) {
if v := os.Getenv("IP_TYPE"); v == "private" || v == "psc" {
t.Skipf("skipping test because IP_TYPE is set to %v", v)
}
if testing.Short() {
t.Skip("skipping Postgres integration tests")
}
Expand Down
110 changes: 79 additions & 31 deletions tests/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ var (
mysqlUser = flag.String("mysql_user", os.Getenv("MYSQL_USER"), "Name of database user.")
mysqlPass = flag.String("mysql_pass", os.Getenv("MYSQL_PASS"), "Password for the database user; be careful when entering a password on the command line (it may go into your terminal's history).")
mysqlDB = flag.String("mysql_db", os.Getenv("MYSQL_DB"), "Name of the database to connect to.")
ipType = flag.String("ip_type", func() string {
if v := os.Getenv("IP_TYPE"); v != "" {
return v
}
return "public"
}(), "IP type of the instance to connect to, can be public, private or psc")
)

func requireMySQLVars(t *testing.T) {
Expand Down Expand Up @@ -56,12 +62,29 @@ func mysqlDSN() string {
return cfg.FormatDSN()
}

// AddIPTypeFlag appends the correct flag based on the ipType variable.
func AddIPTypeFlag(args []string) []string {
switch *ipType {
case "private":
return append(args, "--private-ip")
case "psc":
return append(args, "--psc")
default:
return args
}
}

func TestMySQLTCP(t *testing.T) {
if testing.Short() {
t.Skip("skipping MySQL integration tests")
}
requireMySQLVars(t)
proxyConnTest(t, []string{*mysqlConnName}, "mysql", mysqlDSN())
// Prepare the initial arguments
args := []string{*mysqlConnName}
// Add the IP type flag using the helper
args = AddIPTypeFlag(args)
// Run the test
proxyConnTest(t, args, "mysql", mysqlDSN())
}

func TestMySQLUnix(t *testing.T) {
Expand All @@ -82,8 +105,12 @@ func TestMySQLUnix(t *testing.T) {
Addr: proxy.UnixAddress(tmpDir, *mysqlConnName),
Net: "unix",
}
proxyConnTest(t,
[]string{"--unix-socket", tmpDir, *mysqlConnName}, "mysql", cfg.FormatDSN())
// Prepare the initial arguments
args := []string{"--unix-socket", tmpDir, *mysqlConnName}
// Add the IP type flag using the helper
args = AddIPTypeFlag(args)
// Run the test
proxyConnTest(t, args, "mysql", cfg.FormatDSN())
}

func TestMySQLImpersonation(t *testing.T) {
Expand All @@ -92,10 +119,15 @@ func TestMySQLImpersonation(t *testing.T) {
}
requireMySQLVars(t)

proxyConnTest(t, []string{
// Prepare the initial arguments
args := []string{
"--impersonate-service-account", *impersonatedUser,
*mysqlConnName},
"mysql", mysqlDSN())
*mysqlConnName,
}
// Add the IP type flag using the helper
args = AddIPTypeFlag(args)
// Run the test
proxyConnTest(t, args, "mysql", mysqlDSN())
}

func TestMySQLAuthentication(t *testing.T) {
Expand All @@ -104,7 +136,10 @@ func TestMySQLAuthentication(t *testing.T) {
}
requireMySQLVars(t)

creds := keyfile(t)
var creds string
if *ipType == "public" {
creds = keyfile(t)
}
tok, path, cleanup := removeAuthEnvVar(t)
defer cleanup()

Expand All @@ -123,32 +158,42 @@ func TestMySQLAuthentication(t *testing.T) {
"--impersonate-service-account", *impersonatedUser,
*mysqlConnName},
},
{
desc: "with credentials file",
args: []string{"--credentials-file", path, *mysqlConnName},
},
{
desc: "with credentials file and impersonation",
args: []string{
"--credentials-file", path,
"--impersonate-service-account", *impersonatedUser,
*mysqlConnName},
},
{
desc: "with credentials JSON",
args: []string{"--json-credentials", string(creds), *mysqlConnName},
},
{
desc: "with credentials JSON and impersonation",
args: []string{
"--json-credentials", string(creds),
"--impersonate-service-account", *impersonatedUser,
*mysqlConnName},
},
}
if *ipType == "public" {
additionaTcs := []struct {
desc string
args []string
}{
{
desc: "with credentials file",
args: []string{"--credentials-file", path, *mysqlConnName},
},
{
desc: "with credentials file and impersonation",
args: []string{
"--credentials-file", path,
"--impersonate-service-account", *impersonatedUser,
*mysqlConnName,
},
},
{
desc: "with credentials JSON",
args: []string{"--json-credentials", string(creds), *mysqlConnName},
},
{
desc: "with credentials JSON and impersonation",
args: []string{
"--json-credentials", string(creds),
"--impersonate-service-account", *impersonatedUser,
*mysqlConnName,
},
},
}
tcs = append(tcs, additionaTcs...)
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
proxyConnTest(t, tc.args, "mysql", mysqlDSN())
proxyConnTest(t, AddIPTypeFlag(tc.args), "mysql", mysqlDSN())
})
}
}
Expand All @@ -157,6 +202,9 @@ func TestMySQLGcloudAuth(t *testing.T) {
if testing.Short() {
t.Skip("skipping MySQL integration tests")
}
if v := os.Getenv("IP_TYPE"); v == "private" || v == "psc" {
t.Skipf("skipping test because IP_TYPE is set to %v", v)
}
requireMySQLVars(t)

tcs := []struct {
Expand All @@ -177,7 +225,7 @@ func TestMySQLGcloudAuth(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
proxyConnTest(t, tc.args, "mysql", mysqlDSN())
proxyConnTest(t, AddIPTypeFlag(tc.args), "mysql", mysqlDSN())
})
}
}
Expand Down
98 changes: 63 additions & 35 deletions tests/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ func TestPostgresTCP(t *testing.T) {
t.Skip("skipping Postgres integration tests")
}
requirePostgresVars(t)

proxyConnTest(t, []string{*postgresConnName}, "pgx", postgresDSN())
// Prepare the initial arguments
args := []string{*postgresConnName}
// Add the IP type flag using the helper
args = AddIPTypeFlag(args)
// Run the test
proxyConnTest(t, args, "pgx", postgresDSN())
}

func TestPostgresUnix(t *testing.T) {
Expand All @@ -78,8 +82,12 @@ func TestPostgresUnix(t *testing.T) {
proxy.UnixAddress(tmpDir, *postgresConnName),
*postgresUser, *postgresPass, *postgresDB)

proxyConnTest(t,
[]string{"--unix-socket", tmpDir, *postgresConnName}, "pgx", dsn)
// Prepare the initial arguments
args := []string{"--unix-socket", tmpDir, *postgresConnName}
// Add the IP type flag using the helper
args = AddIPTypeFlag(args)
// Run the test
proxyConnTest(t, args, "pgx", dsn)
}

func createTempDir(t *testing.T) (string, func()) {
Expand All @@ -99,11 +107,15 @@ func TestPostgresImpersonation(t *testing.T) {
t.Skip("skipping Postgres integration tests")
}
requirePostgresVars(t)

proxyConnTest(t, []string{
// Prepare the initial arguments
args := []string{
"--impersonate-service-account", *impersonatedUser,
*postgresConnName},
"pgx", postgresDSN())
*postgresConnName,
}
// Add the IP type flag using the helper
args = AddIPTypeFlag(args)
// Run the test
proxyConnTest(t, args, "pgx", postgresDSN())
}

func TestPostgresAuthentication(t *testing.T) {
Expand All @@ -112,7 +124,10 @@ func TestPostgresAuthentication(t *testing.T) {
}
requirePostgresVars(t)

creds := keyfile(t)
var creds string
if *ipType == "public" {
creds = keyfile(t)
}
tok, path, cleanup := removeAuthEnvVar(t)
defer cleanup()

Expand All @@ -131,32 +146,42 @@ func TestPostgresAuthentication(t *testing.T) {
"--impersonate-service-account", *impersonatedUser,
*postgresConnName},
},
{
desc: "with credentials file",
args: []string{"--credentials-file", path, *postgresConnName},
},
{
desc: "with credentials file and impersonation",
args: []string{
"--credentials-file", path,
"--impersonate-service-account", *impersonatedUser,
*postgresConnName},
},
{
desc: "with credentials JSON",
args: []string{"--json-credentials", string(creds), *postgresConnName},
},
{
desc: "with credentials JSON and impersonation",
args: []string{
"--json-credentials", string(creds),
"--impersonate-service-account", *impersonatedUser,
*postgresConnName},
},
}
if *ipType == "public" {
additionalTcs := []struct {
desc string
args []string
}{
{
desc: "with credentials file",
args: []string{"--credentials-file", path, *postgresConnName},
},
{
desc: "with credentials file and impersonation",
args: []string{
"--credentials-file", path,
"--impersonate-service-account", *impersonatedUser,
*postgresConnName,
},
},
{
desc: "with credentials JSON",
args: []string{"--json-credentials", string(creds), *postgresConnName},
},
{
desc: "with credentials JSON and impersonation",
args: []string{
"--json-credentials", string(creds),
"--impersonate-service-account", *impersonatedUser,
*postgresConnName,
},
},
}
tcs = append(tcs, additionalTcs...)
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
proxyConnTest(t, tc.args, "pgx", postgresDSN())
proxyConnTest(t, AddIPTypeFlag(tc.args), "pgx", postgresDSN())
})
}
}
Expand All @@ -165,6 +190,9 @@ func TestPostgresGcloudAuth(t *testing.T) {
if testing.Short() {
t.Skip("skipping Postgres integration tests")
}
if v := os.Getenv("IP_TYPE"); v == "private" || v == "psc" {
t.Skipf("skipping test because IP_TYPE is set to %v", v)
}
requirePostgresVars(t)

tcs := []struct {
Expand All @@ -185,7 +213,7 @@ func TestPostgresGcloudAuth(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
proxyConnTest(t, tc.args, "pgx", postgresDSN())
proxyConnTest(t, AddIPTypeFlag(tc.args), "pgx", postgresDSN())
})
}

Expand Down Expand Up @@ -231,7 +259,7 @@ func TestPostgresIAMDBAuthn(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
proxyConnTest(t, tc.args, "pgx", tc.dsn)
proxyConnTest(t, AddIPTypeFlag(tc.args), "pgx", tc.dsn)
})
}
}
Expand Down Expand Up @@ -272,7 +300,7 @@ func TestPostgresCustomerCAS(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
proxyConnTest(t, tc.args, "pgx", tc.dsn)
proxyConnTest(t, AddIPTypeFlag(tc.args), "pgx", tc.dsn)
})
}
}
Expand Down
Loading
Loading