Skip to content

Commit c4e5122

Browse files
author
Saransh Gupta
committed
feat: add tls verification to mysql dsn
- Update mysql.template.json with new TLS options - Handle custom TLS configuration - Add tests
1 parent 0d7d29c commit c4e5122

3 files changed

Lines changed: 127 additions & 11 deletions

File tree

mysql.template.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,9 @@
22
"username": "user",
33
"password": "password",
44
"server": "localhost",
5-
"database": "mydb"
5+
"database": "mydb",
6+
"tls": false,
7+
"caCertPath": "/path/to/CA/certificate",
8+
"clientCertPath": "/path/to/client/certificate",
9+
"clientKeyPath": "/path/to/client/key"
610
}

utils/utils.go

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,22 @@ package utils
22

33
import (
44
"bufio"
5+
"crypto/tls"
6+
"crypto/x509"
57
"database/sql"
68
"encoding/json"
79
"errors"
810
"fmt"
9-
"github.com/rs/zerolog"
10-
"golang.org/x/sys/unix"
1111
"io"
1212
"os"
1313
"path/filepath"
1414
"pbench/log"
1515
"reflect"
1616
"strings"
17+
18+
"github.com/go-sql-driver/mysql"
19+
"github.com/rs/zerolog"
20+
"golang.org/x/sys/unix"
1721
)
1822

1923
const (
@@ -64,6 +68,33 @@ func InitLogFile(logPath string) (finalizer func()) {
6468
}
6569
}
6670

71+
func createTLSConfig(caCertPath, clientCertPath, clientKeyPath string) (*tls.Config, error) {
72+
rootCertPool := x509.NewCertPool()
73+
pem, err := os.ReadFile(caCertPath)
74+
if err != nil {
75+
log.Error().Err(err).Msg("failed to read CA certificate")
76+
return nil, err
77+
}
78+
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
79+
log.Error().Msg("failed to append CA certificate")
80+
return nil, fmt.Errorf("failed to append CA certificate from PEM")
81+
}
82+
tlsConfig := &tls.Config{
83+
RootCAs: rootCertPool,
84+
}
85+
86+
// Check if client certificate and key are provided for mutual TLS
87+
if clientCertPath != "" && clientKeyPath != "" {
88+
certs, err := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
89+
if err != nil {
90+
log.Error().Err(err).Msg("failed to load client certificate or key")
91+
return nil, err
92+
}
93+
tlsConfig.Certificates = []tls.Certificate{certs}
94+
}
95+
return tlsConfig, nil
96+
}
97+
6798
func InitMySQLConnFromCfg(cfgPath string) *sql.DB {
6899
if cfgPath == "" {
69100
return nil
@@ -73,17 +104,33 @@ func InitMySQLConnFromCfg(cfgPath string) *sql.DB {
73104
return nil
74105
} else {
75106
mySQLCfg := &struct {
76-
Username string `json:"username"`
77-
Password string `json:"password"`
78-
Server string `json:"server"`
79-
Database string `json:"database"`
80-
}{}
107+
Username string `json:"username"`
108+
Password string `json:"password"`
109+
Server string `json:"server"`
110+
Database string `json:"database"`
111+
TLS bool `json:"tls"`
112+
CaCertPath string `json:"caCertPath"`
113+
ClientCertPath string `json:"clientCertPath"`
114+
ClientKeyPath string `json:"clientKeyPath"`
115+
}{
116+
TLS: false,
117+
}
81118
if err := json.Unmarshal(cfgBytes, mySQLCfg); err != nil {
82119
log.Error().Err(err).Msg("failed to unmarshal MySQL connection config for the run recorder")
83120
return nil
84121
}
85-
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true",
86-
mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database)); err != nil {
122+
tlsType := "false"
123+
if mySQLCfg.TLS {
124+
tlsType = "custom"
125+
tlsConfig, err := createTLSConfig(mySQLCfg.CaCertPath, mySQLCfg.ClientCertPath, mySQLCfg.ClientKeyPath)
126+
if err != nil {
127+
log.Error().Msg("TLS enabled but failed to load certificates")
128+
return nil
129+
}
130+
mysql.RegisterTLSConfig(tlsType, tlsConfig)
131+
}
132+
if db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?tls=%s&parseTime=true",
133+
mySQLCfg.Username, mySQLCfg.Password, mySQLCfg.Server, mySQLCfg.Database, tlsType)); err != nil {
87134
log.Error().Err(err).Msg("failed to initialize MySQL connection for the run recorder")
88135
return nil
89136
} else if err = db.Ping(); err != nil {

utils/utils_test.go

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package utils
22

33
import (
4-
"github.com/stretchr/testify/assert"
4+
"encoding/json"
55
"os"
66
"path/filepath"
77
"testing"
8+
9+
"github.com/stretchr/testify/assert"
810
)
911

1012
func TestExpandHomeDirectory(t *testing.T) {
@@ -25,3 +27,66 @@ func TestExpandHomeDirectory_JustTilde(t *testing.T) {
2527
ExpandHomeDirectory(&path)
2628
assert.Equal(t, os.Getenv("HOME"), path)
2729
}
30+
31+
func TestCreateTLSConfig_InvalidCAPath(t *testing.T) {
32+
// tests error handling when CA cert file doesn't exist
33+
tlsConfig, err := createTLSConfig("/nonexistent/ca.pem", "", "")
34+
35+
assert.Error(t, err, "should return error for non-existent CA certificate")
36+
assert.Nil(t, tlsConfig, "should return nil config on error")
37+
}
38+
39+
func TestCreateTLSConfig_InvalidCAPEM(t *testing.T) {
40+
// tests error handling when CA cert has invalid PEM content
41+
tmpDir := t.TempDir()
42+
caPath := filepath.Join(tmpDir, "invalid-ca.pem")
43+
err := os.WriteFile(caPath, []byte("invalid pem content"), 0644)
44+
assert.NoError(t, err)
45+
46+
tlsConfig, err := createTLSConfig(caPath, "", "")
47+
48+
assert.Error(t, err, "should return error for invalid PEM content")
49+
assert.Nil(t, tlsConfig, "should return nil config on error")
50+
}
51+
52+
func TestCreateTLSConfig_InvalidClientCert(t *testing.T) {
53+
// tests error handling with invalid client certificates
54+
tmpDir := t.TempDir()
55+
clientCertPath := filepath.Join(tmpDir, "client.pem")
56+
clientKeyPath := filepath.Join(tmpDir, "client-key.pem")
57+
58+
err := os.WriteFile(clientCertPath, []byte("dummy cert"), 0644)
59+
assert.NoError(t, err)
60+
err = os.WriteFile(clientKeyPath, []byte("dummy key"), 0644)
61+
assert.NoError(t, err)
62+
63+
// Should fail because CA cert doesn't exist
64+
tlsConfig, err := createTLSConfig("/nonexistent/ca.pem", clientCertPath, clientKeyPath)
65+
66+
assert.Error(t, err, "should fail with invalid certificates")
67+
assert.Nil(t, tlsConfig, "should return nil on error")
68+
}
69+
70+
func TestInitMySQLConnFromCfg_TLSEnabledInvalidCerts(t *testing.T) {
71+
// When TLS is enabled but certificates are invalid, function should return nil early
72+
config := map[string]interface{}{
73+
"username": "testuser",
74+
"password": "testpass",
75+
"server": "localhost:3306",
76+
"database": "testdb",
77+
"tls": true,
78+
"caCertPath": "/nonexistent/ca.pem",
79+
"clientCertPath": "/nonexistent/client.pem",
80+
"clientKeyPath": "/nonexistent/client-key.pem",
81+
}
82+
83+
tmpDir := t.TempDir()
84+
cfgPath := filepath.Join(tmpDir, "config.json")
85+
configJSON, err := json.Marshal(config)
86+
assert.NoError(t, err)
87+
err = os.WriteFile(cfgPath, configJSON, 0644)
88+
assert.NoError(t, err)
89+
90+
db := InitMySQLConnFromCfg(cfgPath)
91+
assert.Nil(t, db, "should return nil when TLS config fails")
92+
}

0 commit comments

Comments
 (0)