From 8d34fccb9d3e6123cdaf782c3d99a5110d20375e Mon Sep 17 00:00:00 2001 From: ashioyajotham Date: Wed, 27 May 2026 18:14:57 +0300 Subject: [PATCH] refactor(broker): enforce explicit tx boundaries for multi-table writes (#46) Wrap multi-table write flows in explicit transactions so token upserts and connection status updates are atomic in both OAuth callback exchange and static credential capture paths. Add repository-level InTx support and tests for commit/rollback behavior to prevent partial writes during failures. --- .../repository/instrumented/instrumented.go | 21 +++- .../repository/postgres/connection.go | 25 +++- .../internal/repository/postgres/token.go | 6 +- .../internal/repository/postgres/tx.go | 25 ++++ .../internal/repository/postgres/tx_test.go | 109 ++++++++++++++++++ nexus-broker/internal/service/connection.go | 36 ++++-- nexus-broker/internal/service/credential.go | 25 ++-- .../pkg/handlers/soc2_compliance_test.go | 2 + 8 files changed, 216 insertions(+), 33 deletions(-) create mode 100644 nexus-broker/internal/repository/postgres/tx.go create mode 100644 nexus-broker/internal/repository/postgres/tx_test.go diff --git a/nexus-broker/internal/repository/instrumented/instrumented.go b/nexus-broker/internal/repository/instrumented/instrumented.go index b539e8e..3187afb 100644 --- a/nexus-broker/internal/repository/instrumented/instrumented.go +++ b/nexus-broker/internal/repository/instrumented/instrumented.go @@ -11,6 +11,10 @@ import ( "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository" ) +type txRunner interface { + InTx(ctx context.Context, fn func(context.Context) error) error +} + var ( dbOpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Name: "nexus_db_operation_duration_seconds", @@ -65,13 +69,13 @@ func (r *ConnectionRepository) UpdateStatus(ctx context.Context, id uuid.UUID, s } func (r *ConnectionRepository) CountByStatus(ctx context.Context) (map[string]int64, error) { - defer observe("connection", "CountByStatus", time.Now()) - return r.inner.CountByStatus(ctx) + defer observe("connection", "CountByStatus", time.Now()) + return r.inner.CountByStatus(ctx) } func (r *ConnectionRepository) GetActiveByWorkspaceAndProvider(ctx context.Context, workspaceID, providerName string) (*domain.ConnectionWithProvider, error) { - defer observe("connection", "GetActiveByWorkspaceAndProvider", time.Now()) - return r.inner.GetActiveByWorkspaceAndProvider(ctx, workspaceID, providerName) + defer observe("connection", "GetActiveByWorkspaceAndProvider", time.Now()) + return r.inner.GetActiveByWorkspaceAndProvider(ctx, workspaceID, providerName) } func (r *ConnectionRepository) GetForHealthCheck(ctx context.Context, limit int) ([]*domain.ConnectionWithProvider, error) { @@ -89,6 +93,15 @@ func (r *ConnectionRepository) ListByWorkspace(ctx context.Context, workspaceID return r.inner.ListByWorkspace(ctx, workspaceID) } +// InTx forwards transactional execution when the wrapped repository supports it. +func (r *ConnectionRepository) InTx(ctx context.Context, fn func(context.Context) error) error { + if runner, ok := r.inner.(txRunner); ok { + defer observe("connection", "InTx", time.Now()) + return runner.InTx(ctx, fn) + } + return fn(ctx) +} + // --- TokenRepository decorator --- // TokenRepository wraps repository.TokenRepository with latency instrumentation. diff --git a/nexus-broker/internal/repository/postgres/connection.go b/nexus-broker/internal/repository/postgres/connection.go index a7326ee..8e8fe86 100644 --- a/nexus-broker/internal/repository/postgres/connection.go +++ b/nexus-broker/internal/repository/postgres/connection.go @@ -3,11 +3,11 @@ package postgres import ( "context" + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository" "github.com/google/uuid" "github.com/jmoiron/sqlx" "github.com/lib/pq" - "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" - "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository" ) type connectionRepository struct { @@ -19,8 +19,23 @@ func NewConnectionRepository(db *sqlx.DB) repository.ConnectionRepository { return &connectionRepository{db: db} } +// InTx executes fn inside a database transaction. The transaction is stored in +// context so token/connection repository writes in the same request are atomic. +func (r *connectionRepository) InTx(ctx context.Context, fn func(context.Context) error) error { + tx, err := r.db.BeginTxx(ctx, nil) + if err != nil { + return err + } + ctxWithTx := withTx(ctx, tx) + if err := fn(ctxWithTx); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} + func (r *connectionRepository) Create(ctx context.Context, conn *domain.Connection) error { - _, err := r.db.ExecContext(ctx, ` + _, err := execerFromContext(ctx, r.db).ExecContext(ctx, ` INSERT INTO connections (id, workspace_id, provider_id, code_verifier, scopes, return_url, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)`, conn.ID, conn.WorkspaceID, conn.ProviderID, conn.CodeVerifier, pq.Array(conn.Scopes), conn.ReturnURL, conn.ExpiresAt) @@ -85,7 +100,7 @@ func (r *connectionRepository) GetReturnURL(ctx context.Context, id uuid.UUID) ( } func (r *connectionRepository) UpdateStatus(ctx context.Context, id uuid.UUID, status string) error { - _, err := r.db.ExecContext(ctx, "UPDATE connections SET status = $1, updated_at = NOW() WHERE id = $2", status, id) + _, err := execerFromContext(ctx, r.db).ExecContext(ctx, "UPDATE connections SET status = $1, updated_at = NOW() WHERE id = $2", status, id) return err } @@ -156,7 +171,7 @@ func (r *connectionRepository) GetForHealthCheck(ctx context.Context, limit int) } func (r *connectionRepository) UpdateHealthStatus(ctx context.Context, id uuid.UUID, status string) error { - _, err := r.db.ExecContext(ctx, ` + _, err := execerFromContext(ctx, r.db).ExecContext(ctx, ` UPDATE connections SET health_status = $1, last_health_check_at = NOW(), updated_at = NOW() WHERE id = $2`, status, id) diff --git a/nexus-broker/internal/repository/postgres/token.go b/nexus-broker/internal/repository/postgres/token.go index 9d6c448..e45563d 100644 --- a/nexus-broker/internal/repository/postgres/token.go +++ b/nexus-broker/internal/repository/postgres/token.go @@ -3,10 +3,10 @@ package postgres import ( "context" - "github.com/google/uuid" - "github.com/jmoiron/sqlx" "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository" + "github.com/google/uuid" + "github.com/jmoiron/sqlx" ) type tokenRepository struct { @@ -19,7 +19,7 @@ func NewTokenRepository(db *sqlx.DB) repository.TokenRepository { } func (r *tokenRepository) Upsert(ctx context.Context, token *domain.Token) error { - _, err := r.db.ExecContext(ctx, ` + _, err := execerFromContext(ctx, r.db).ExecContext(ctx, ` INSERT INTO tokens (connection_id, encrypted_data, expires_at) VALUES ($1, $2, $3) ON CONFLICT (connection_id) diff --git a/nexus-broker/internal/repository/postgres/tx.go b/nexus-broker/internal/repository/postgres/tx.go new file mode 100644 index 0000000..fd9c028 --- /dev/null +++ b/nexus-broker/internal/repository/postgres/tx.go @@ -0,0 +1,25 @@ +package postgres + +import ( + "context" + "database/sql" + + "github.com/jmoiron/sqlx" +) + +type txContextKey struct{} + +type execer interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +func withTx(ctx context.Context, tx *sqlx.Tx) context.Context { + return context.WithValue(ctx, txContextKey{}, tx) +} + +func execerFromContext(ctx context.Context, db *sqlx.DB) execer { + if tx, ok := ctx.Value(txContextKey{}).(*sqlx.Tx); ok && tx != nil { + return tx + } + return db +} diff --git a/nexus-broker/internal/repository/postgres/tx_test.go b/nexus-broker/internal/repository/postgres/tx_test.go new file mode 100644 index 0000000..bca4b97 --- /dev/null +++ b/nexus-broker/internal/repository/postgres/tx_test.go @@ -0,0 +1,109 @@ +package postgres + +import ( + "context" + "errors" + "regexp" + "testing" + "time" + + "github.com/google/uuid" + "github.com/jmoiron/sqlx" + "gopkg.in/DATA-DOG/go-sqlmock.v1" + + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" +) + +func TestInTx_Commit(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock.New: %v", err) + } + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + connRepo := NewConnectionRepository(sqlxDB) + tokenRepo := NewTokenRepository(sqlxDB) + + connID := uuid.New() + mock.ExpectBegin() + mock.ExpectExec(regexp.QuoteMeta("UPDATE connections SET status = $1, updated_at = NOW() WHERE id = $2")). + WithArgs("active", connID). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(regexp.QuoteMeta(` + INSERT INTO tokens (connection_id, encrypted_data, expires_at) + VALUES ($1, $2, $3) + ON CONFLICT (connection_id) + DO UPDATE SET + encrypted_data = EXCLUDED.encrypted_data, + expires_at = EXCLUDED.expires_at, + created_at = NOW()`)). + WithArgs(connID, "enc", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + runner, ok := connRepo.(interface { + InTx(ctx context.Context, fn func(context.Context) error) error + }) + if !ok { + t.Fatal("connection repo does not implement InTx") + } + + err = runner.InTx(context.Background(), func(ctx context.Context) error { + if err := connRepo.UpdateStatus(ctx, connID, "active"); err != nil { + return err + } + expires := time.Now().Add(time.Hour) + return tokenRepo.Upsert(ctx, &domain.Token{ + ConnectionID: connID, + EncryptedData: "enc", + ExpiresAt: &expires, + }) + }) + if err != nil { + t.Fatalf("InTx commit flow failed: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet sql expectations: %v", err) + } +} + +func TestInTx_Rollback(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("sqlmock.New: %v", err) + } + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + connRepo := NewConnectionRepository(sqlxDB) + connID := uuid.New() + + mock.ExpectBegin() + mock.ExpectExec(regexp.QuoteMeta("UPDATE connections SET status = $1, updated_at = NOW() WHERE id = $2")). + WithArgs("active", connID). + WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectRollback() + + runner, ok := connRepo.(interface { + InTx(ctx context.Context, fn func(context.Context) error) error + }) + if !ok { + t.Fatal("connection repo does not implement InTx") + } + + err = runner.InTx(context.Background(), func(ctx context.Context) error { + if err := connRepo.UpdateStatus(ctx, connID, "active"); err != nil { + return err + } + return errors.New("force rollback") + }) + if err == nil { + t.Fatal("expected rollback error, got nil") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet sql expectations: %v", err) + } +} diff --git a/nexus-broker/internal/service/connection.go b/nexus-broker/internal/service/connection.go index 7c7fc34..9981221 100644 --- a/nexus-broker/internal/service/connection.go +++ b/nexus-broker/internal/service/connection.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/audit" "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository" @@ -21,6 +20,7 @@ import ( "github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/provider" "github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/server" "github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/vault" + "github.com/google/uuid" ) type ConnectionService interface { @@ -48,6 +48,10 @@ type connectionService struct { allowedReturnDomains []string } +type txRunner interface { + InTx(ctx context.Context, fn func(context.Context) error) error +} + type CreateConsentRequest struct { WorkspaceID string `json:"workspace_id"` ProviderID string `json:"provider_id"` @@ -305,17 +309,22 @@ func (s *connectionService) ExchangeCodeForTokens(ctx context.Context, state, co expiresAt = &expiry } - err = s.tokenRepo.Upsert(ctx, &domain.Token{ - ConnectionID: connID, - EncryptedData: encryptedData, - ExpiresAt: expiresAt, - }) - if err != nil { - return "", false, ErrInternalWithErr(err, "token_store_failed", "Failed to store tokens") + if err := s.inTx(ctx, func(txCtx context.Context) error { + if err := s.tokenRepo.Upsert(txCtx, &domain.Token{ + ConnectionID: connID, + EncryptedData: encryptedData, + ExpiresAt: expiresAt, + }); err != nil { + return ErrInternalWithErr(err, "token_store_failed", "Failed to store tokens") + } + if err := s.connRepo.UpdateStatus(txCtx, connID, "active"); err != nil { + return ErrInternalWithErr(err, "status_update_failed", "Failed to update status") + } + return nil + }); err != nil { + return "", false, err } - s.connRepo.UpdateStatus(ctx, connID, "active") - if !server.IsReturnURLAllowed(conn.ReturnURL, s.enforceReturnURL, s.allowedReturnDomains) { return "", false, ErrBadRequest("return_url_not_allowed", "return_url not allowed") } @@ -334,6 +343,13 @@ func (s *connectionService) ExchangeCodeForTokens(ctx context.Context, state, co return returnURL.String(), hasIDToken, nil } +func (s *connectionService) inTx(ctx context.Context, fn func(context.Context) error) error { + if runner, ok := s.connRepo.(txRunner); ok { + return runner.InTx(ctx, fn) + } + return fn(ctx) +} + func (s *connectionService) GetTokenByWorkspaceAndProvider(ctx context.Context, workspaceID, providerName string) (map[string]interface{}, string, error) { conn, err := s.connRepo.GetActiveByWorkspaceAndProvider(ctx, workspaceID, providerName) if err != nil { diff --git a/nexus-broker/internal/service/credential.go b/nexus-broker/internal/service/credential.go index 63f58d1..db059da 100644 --- a/nexus-broker/internal/service/credential.go +++ b/nexus-broker/internal/service/credential.go @@ -11,10 +11,10 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" "github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/auth" "github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/vault" + "github.com/google/uuid" ) type RefreshResponse struct { @@ -84,16 +84,19 @@ func (s *connectionService) SaveCredential(ctx context.Context, state string, cr return "", ErrInternalWithErr(err, "encryption_failed", "Failed to encrypt credentials") } - err = s.tokenRepo.Upsert(ctx, &domain.Token{ - ConnectionID: connID, - EncryptedData: encryptedData, - }) - if err != nil { - return "", ErrInternalWithErr(err, "credential_store_failed", "Failed to store credentials") - } - - if err := s.connRepo.UpdateStatus(ctx, connID, "active"); err != nil { - return "", ErrInternalWithErr(err, "status_update_failed", "Failed to update status") + if err := s.inTx(ctx, func(txCtx context.Context) error { + if err := s.tokenRepo.Upsert(txCtx, &domain.Token{ + ConnectionID: connID, + EncryptedData: encryptedData, + }); err != nil { + return ErrInternalWithErr(err, "credential_store_failed", "Failed to store credentials") + } + if err := s.connRepo.UpdateStatus(txCtx, connID, "active"); err != nil { + return ErrInternalWithErr(err, "status_update_failed", "Failed to update status") + } + return nil + }); err != nil { + return "", err } returnURL, err := url.Parse(conn.ReturnURL) diff --git a/nexus-broker/pkg/handlers/soc2_compliance_test.go b/nexus-broker/pkg/handlers/soc2_compliance_test.go index 10406dc..aaca941 100644 --- a/nexus-broker/pkg/handlers/soc2_compliance_test.go +++ b/nexus-broker/pkg/handlers/soc2_compliance_test.go @@ -96,6 +96,7 @@ func TestSOC2_CC61_EncryptionAtRest(t *testing.T) { WillReturnRows(sqlmock.NewRows([]string{"id", "provider_id", "status", "scopes", "return_url", "name", "auth_type", "auth_header", "api_base_url", "user_info_endpoint", "params", "health_status"}). AddRow(connID.String(), providerID.String(), "active", "{}", "http://localhost/return", "TestProvider", "api_key", "", "", "", nil, "unknown")) + mock.ExpectBegin() mock.ExpectExec("INSERT INTO tokens"). WithArgs(connID, sqlmock.AnyArg(), sqlmock.AnyArg()). WillReturnResult(sqlmock.NewResult(1, 1)) @@ -103,6 +104,7 @@ func TestSOC2_CC61_EncryptionAtRest(t *testing.T) { mock.ExpectExec("UPDATE connections SET status"). WithArgs("active", connID). WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() // 3. Fire the request creds := map[string]interface{}{"api_key": plainTextKey}