|
37 | 37 | CloudSQLPrefix = "cloudsql+" |
38 | 38 | ) |
39 | 39 |
|
| 40 | +func handleRDSMySQLIAMAuth(conn string) (string, time.Time, error) { |
| 41 | + dsn := strings.TrimPrefix(conn, "rds-mysql://") |
| 42 | + config, err := mysql.ParseDSN(dsn) |
| 43 | + if err != nil { |
| 44 | + return "", time.Time{}, fmt.Errorf("failed to parse MySQL DSN: %v", err) |
| 45 | + } |
| 46 | + |
| 47 | + sess := session.Must(session.NewSessionWithOptions(session.Options{ |
| 48 | + SharedConfigState: session.SharedConfigEnable, |
| 49 | + })) |
| 50 | + |
| 51 | + token, err := rdsutils.BuildAuthToken(config.Addr, os.Getenv("AWS_REGION"), config.User, sess.Config.Credentials) |
| 52 | + if err != nil { |
| 53 | + return "", time.Time{}, fmt.Errorf("failed to build RDS auth token: %v", err) |
| 54 | + } |
| 55 | + |
| 56 | + expirationTime := time.Now().Add(14 * time.Minute) |
| 57 | + |
| 58 | + return token, expirationTime, nil |
| 59 | +} |
| 60 | + |
40 | 61 | // Init will initialize the metric descriptors |
41 | 62 | func (j *Job) Init(logger log.Logger, queries map[string]string) error { |
42 | 63 | j.log = log.With(logger, "job", j.Name) |
@@ -207,23 +228,53 @@ func (j *Job) updateConnections() { |
207 | 228 | continue |
208 | 229 | } |
209 | 230 |
|
210 | | - // MySQL DSNs do not parse cleanly as URLs as of Go 1.12.8+ |
211 | | - if strings.HasPrefix(conn, "mysql://") { |
212 | | - config, err := mysql.ParseDSN(strings.TrimPrefix(conn, "mysql://")) |
| 231 | + // Handle both RDS MySQL and regular MySQL connections |
| 232 | + if strings.HasPrefix(conn, "rds-mysql://") || strings.HasPrefix(conn, "mysql://") { |
| 233 | + isRDS := strings.HasPrefix(conn, "rds-mysql://") |
| 234 | + var dsn string |
| 235 | + var expirationTime time.Time |
| 236 | + |
| 237 | + trimmedConn := conn |
| 238 | + if isRDS { |
| 239 | + trimmedConn = strings.TrimPrefix(conn, "rds-mysql://") |
| 240 | + } else { |
| 241 | + trimmedConn = strings.TrimPrefix(conn, "mysql://") |
| 242 | + } |
| 243 | + |
| 244 | + config, err := mysql.ParseDSN(trimmedConn) |
213 | 245 | if err != nil { |
214 | 246 | level.Error(j.log).Log("msg", "Failed to parse MySQL DSN", "url", conn, "err", err) |
| 247 | + continue |
| 248 | + } |
| 249 | + |
| 250 | + if isRDS { |
| 251 | + authToken, tokenExpiration, err := handleRDSMySQLIAMAuth(conn) |
| 252 | + if err != nil { |
| 253 | + level.Error(j.log).Log("msg", "Failed to build RDS auth token", "url", conn, "err", err) |
| 254 | + continue |
| 255 | + } |
| 256 | + config.Passwd = authToken |
| 257 | + config.AllowCleartextPasswords = true |
| 258 | + expirationTime = tokenExpiration |
| 259 | + } |
| 260 | + |
| 261 | + dsn = config.FormatDSN() |
| 262 | + if isRDS { |
| 263 | + dsn = "rds-mysql://" + dsn |
215 | 264 | } |
216 | 265 |
|
217 | 266 | j.conns = append(j.conns, &connection{ |
218 | | - conn: nil, |
219 | | - url: conn, |
220 | | - driver: "mysql", |
221 | | - host: config.Addr, |
222 | | - database: config.DBName, |
223 | | - user: config.User, |
| 267 | + conn: nil, |
| 268 | + url: dsn, |
| 269 | + driver: "mysql", |
| 270 | + host: config.Addr, |
| 271 | + database: config.DBName, |
| 272 | + user: config.User, |
| 273 | + tokenExpirationTime: expirationTime, |
224 | 274 | }) |
225 | 275 | continue |
226 | 276 | } |
| 277 | + |
227 | 278 | if strings.HasPrefix(conn, "rds-postgres://") { |
228 | 279 | // Reuse Postgres driver by stripping "rds-" from connection URL after building the RDS authentication token |
229 | 280 | conn = strings.TrimPrefix(conn, "rds-") |
@@ -438,12 +489,45 @@ func (j *Job) runOnce() error { |
438 | 489 | func (c *connection) connect(job *Job) error { |
439 | 490 | // already connected |
440 | 491 | if c.conn != nil { |
| 492 | + if strings.HasPrefix(c.url, "rds-mysql://") && time.Now().After(c.tokenExpirationTime) { |
| 493 | + level.Warn(job.log).Log("msg", "Connection token expired, reconnecting") |
| 494 | + |
| 495 | + authToken, expirationTime, err := handleRDSMySQLIAMAuth(c.url) |
| 496 | + if err != nil { |
| 497 | + return fmt.Errorf("failed to refresh RDS MySQL IAM Auth token: %w", err) |
| 498 | + } |
| 499 | + |
| 500 | + config, err := mysql.ParseDSN(strings.TrimPrefix(c.url, "rds-mysql://")) |
| 501 | + if err != nil { |
| 502 | + return fmt.Errorf("failed to parse MySQL DSN: %w", err) |
| 503 | + } |
| 504 | + |
| 505 | + config.Passwd = authToken |
| 506 | + dsn := "rds-mysql://" + config.FormatDSN() |
| 507 | + |
| 508 | + // Close the existing connection |
| 509 | + c.conn.Close() |
| 510 | + c.conn = nil |
| 511 | + |
| 512 | + // Update the connection details |
| 513 | + c.tokenExpirationTime = expirationTime |
| 514 | + c.url = dsn |
| 515 | + |
| 516 | + // Connect to the database with the new token |
| 517 | + conn, err := sqlx.Connect(c.driver, strings.TrimPrefix(dsn, "rds-mysql://")) |
| 518 | + if err != nil { |
| 519 | + return fmt.Errorf("failed to connect to the database: %w", err) |
| 520 | + } |
| 521 | + c.conn = conn |
| 522 | + return nil |
| 523 | + } |
441 | 524 | return nil |
442 | 525 | } |
443 | 526 | dsn := c.url |
444 | 527 | switch c.driver { |
445 | 528 | case "mysql": |
446 | 529 | dsn = strings.TrimPrefix(dsn, "mysql://") |
| 530 | + dsn = strings.TrimPrefix(dsn, "rds-mysql://") |
447 | 531 | case "clickhouse+tcp", "clickhouse+http": // Support both http and tcp connections |
448 | 532 | dsn = strings.TrimPrefix(dsn, "clickhouse+") |
449 | 533 | c.driver = "clickhouse" |
|
0 commit comments