Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions nexus-broker/internal/repository/instrumented/instrumented.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository"
)

type txRunner interface {
InTx(ctx context.Context, fn func(context.Context) error) error
}

var (
dbOpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "nexus_db_operation_duration_seconds",
Expand Down Expand Up @@ -65,13 +69,13 @@ func (r *ConnectionRepository) UpdateStatus(ctx context.Context, id uuid.UUID, s
}

func (r *ConnectionRepository) CountByStatus(ctx context.Context) (map[string]int64, error) {
defer observe("connection", "CountByStatus", time.Now())
return r.inner.CountByStatus(ctx)
defer observe("connection", "CountByStatus", time.Now())
return r.inner.CountByStatus(ctx)
}

func (r *ConnectionRepository) GetActiveByWorkspaceAndProvider(ctx context.Context, workspaceID, providerName string) (*domain.ConnectionWithProvider, error) {
defer observe("connection", "GetActiveByWorkspaceAndProvider", time.Now())
return r.inner.GetActiveByWorkspaceAndProvider(ctx, workspaceID, providerName)
defer observe("connection", "GetActiveByWorkspaceAndProvider", time.Now())
return r.inner.GetActiveByWorkspaceAndProvider(ctx, workspaceID, providerName)
}

func (r *ConnectionRepository) GetForHealthCheck(ctx context.Context, limit int) ([]*domain.ConnectionWithProvider, error) {
Expand All @@ -89,6 +93,15 @@ func (r *ConnectionRepository) ListByWorkspace(ctx context.Context, workspaceID
return r.inner.ListByWorkspace(ctx, workspaceID)
}

// InTx forwards transactional execution when the wrapped repository supports it.
func (r *ConnectionRepository) InTx(ctx context.Context, fn func(context.Context) error) error {
if runner, ok := r.inner.(txRunner); ok {
defer observe("connection", "InTx", time.Now())
return runner.InTx(ctx, fn)
}
return fn(ctx)
}
Comment on lines +96 to +103

// --- TokenRepository decorator ---

// TokenRepository wraps repository.TokenRepository with latency instrumentation.
Expand Down
25 changes: 20 additions & 5 deletions nexus-broker/internal/repository/postgres/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package postgres
import (
"context"

"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain"
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain"
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository"
)

type connectionRepository struct {
Expand All @@ -19,8 +19,23 @@ func NewConnectionRepository(db *sqlx.DB) repository.ConnectionRepository {
return &connectionRepository{db: db}
}

// InTx executes fn inside a database transaction. The transaction is stored in
// context so token/connection repository writes in the same request are atomic.
func (r *connectionRepository) InTx(ctx context.Context, fn func(context.Context) error) error {
tx, err := r.db.BeginTxx(ctx, nil)
if err != nil {
return err
}
ctxWithTx := withTx(ctx, tx)
if err := fn(ctxWithTx); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
Comment on lines +22 to +35

func (r *connectionRepository) Create(ctx context.Context, conn *domain.Connection) error {
_, err := r.db.ExecContext(ctx, `
_, err := execerFromContext(ctx, r.db).ExecContext(ctx, `
INSERT INTO connections (id, workspace_id, provider_id, code_verifier, scopes, return_url, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)`,
conn.ID, conn.WorkspaceID, conn.ProviderID, conn.CodeVerifier, pq.Array(conn.Scopes), conn.ReturnURL, conn.ExpiresAt)
Expand Down Expand Up @@ -85,7 +100,7 @@ func (r *connectionRepository) GetReturnURL(ctx context.Context, id uuid.UUID) (
}

func (r *connectionRepository) UpdateStatus(ctx context.Context, id uuid.UUID, status string) error {
_, err := r.db.ExecContext(ctx, "UPDATE connections SET status = $1, updated_at = NOW() WHERE id = $2", status, id)
_, err := execerFromContext(ctx, r.db).ExecContext(ctx, "UPDATE connections SET status = $1, updated_at = NOW() WHERE id = $2", status, id)
return err
}

Expand Down Expand Up @@ -156,7 +171,7 @@ func (r *connectionRepository) GetForHealthCheck(ctx context.Context, limit int)
}

func (r *connectionRepository) UpdateHealthStatus(ctx context.Context, id uuid.UUID, status string) error {
_, err := r.db.ExecContext(ctx, `
_, err := execerFromContext(ctx, r.db).ExecContext(ctx, `
UPDATE connections
SET health_status = $1, last_health_check_at = NOW(), updated_at = NOW()
WHERE id = $2`, status, id)
Expand Down
6 changes: 3 additions & 3 deletions nexus-broker/internal/repository/postgres/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package postgres
import (
"context"

"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain"
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
)

type tokenRepository struct {
Expand All @@ -19,7 +19,7 @@ func NewTokenRepository(db *sqlx.DB) repository.TokenRepository {
}

func (r *tokenRepository) Upsert(ctx context.Context, token *domain.Token) error {
_, err := r.db.ExecContext(ctx, `
_, err := execerFromContext(ctx, r.db).ExecContext(ctx, `
INSERT INTO tokens (connection_id, encrypted_data, expires_at)
VALUES ($1, $2, $3)
ON CONFLICT (connection_id)
Expand Down
25 changes: 25 additions & 0 deletions nexus-broker/internal/repository/postgres/tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package postgres

import (
"context"
"database/sql"

"github.com/jmoiron/sqlx"
)

type txContextKey struct{}

type execer interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}

func withTx(ctx context.Context, tx *sqlx.Tx) context.Context {
return context.WithValue(ctx, txContextKey{}, tx)
}

func execerFromContext(ctx context.Context, db *sqlx.DB) execer {
if tx, ok := ctx.Value(txContextKey{}).(*sqlx.Tx); ok && tx != nil {
return tx
}
return db
}
109 changes: 109 additions & 0 deletions nexus-broker/internal/repository/postgres/tx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package postgres

import (
"context"
"errors"
"regexp"
"testing"
"time"

"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"gopkg.in/DATA-DOG/go-sqlmock.v1"

"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain"
)

func TestInTx_Commit(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("sqlmock.New: %v", err)
}
defer db.Close()

sqlxDB := sqlx.NewDb(db, "sqlmock")
connRepo := NewConnectionRepository(sqlxDB)
tokenRepo := NewTokenRepository(sqlxDB)

connID := uuid.New()
mock.ExpectBegin()
mock.ExpectExec(regexp.QuoteMeta("UPDATE connections SET status = $1, updated_at = NOW() WHERE id = $2")).
WithArgs("active", connID).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec(regexp.QuoteMeta(`
INSERT INTO tokens (connection_id, encrypted_data, expires_at)
VALUES ($1, $2, $3)
ON CONFLICT (connection_id)
DO UPDATE SET
encrypted_data = EXCLUDED.encrypted_data,
expires_at = EXCLUDED.expires_at,
created_at = NOW()`)).
Comment on lines +33 to +40
WithArgs(connID, "enc", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectCommit()

runner, ok := connRepo.(interface {
InTx(ctx context.Context, fn func(context.Context) error) error
})
if !ok {
t.Fatal("connection repo does not implement InTx")
}

err = runner.InTx(context.Background(), func(ctx context.Context) error {
if err := connRepo.UpdateStatus(ctx, connID, "active"); err != nil {
return err
}
expires := time.Now().Add(time.Hour)
return tokenRepo.Upsert(ctx, &domain.Token{
ConnectionID: connID,
EncryptedData: "enc",
ExpiresAt: &expires,
})
})
if err != nil {
t.Fatalf("InTx commit flow failed: %v", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Fatalf("unmet sql expectations: %v", err)
}
}

func TestInTx_Rollback(t *testing.T) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("sqlmock.New: %v", err)
}
defer db.Close()

sqlxDB := sqlx.NewDb(db, "sqlmock")
connRepo := NewConnectionRepository(sqlxDB)
connID := uuid.New()

mock.ExpectBegin()
mock.ExpectExec(regexp.QuoteMeta("UPDATE connections SET status = $1, updated_at = NOW() WHERE id = $2")).
WithArgs("active", connID).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectRollback()

runner, ok := connRepo.(interface {
InTx(ctx context.Context, fn func(context.Context) error) error
})
if !ok {
t.Fatal("connection repo does not implement InTx")
}

err = runner.InTx(context.Background(), func(ctx context.Context) error {
if err := connRepo.UpdateStatus(ctx, connID, "active"); err != nil {
return err
}
return errors.New("force rollback")
})
if err == nil {
t.Fatal("expected rollback error, got nil")
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Fatalf("unmet sql expectations: %v", err)
}
}
36 changes: 26 additions & 10 deletions nexus-broker/internal/service/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"strings"
"time"

"github.com/google/uuid"
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/audit"
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain"
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository"
Expand All @@ -21,6 +20,7 @@ import (
"github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/provider"
"github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/server"
"github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/vault"
"github.com/google/uuid"
)

type ConnectionService interface {
Expand Down Expand Up @@ -48,6 +48,10 @@ type connectionService struct {
allowedReturnDomains []string
}

type txRunner interface {
InTx(ctx context.Context, fn func(context.Context) error) error
}

Comment on lines +51 to +54
type CreateConsentRequest struct {
WorkspaceID string `json:"workspace_id"`
ProviderID string `json:"provider_id"`
Expand Down Expand Up @@ -305,17 +309,22 @@ func (s *connectionService) ExchangeCodeForTokens(ctx context.Context, state, co
expiresAt = &expiry
}

err = s.tokenRepo.Upsert(ctx, &domain.Token{
ConnectionID: connID,
EncryptedData: encryptedData,
ExpiresAt: expiresAt,
})
if err != nil {
return "", false, ErrInternalWithErr(err, "token_store_failed", "Failed to store tokens")
if err := s.inTx(ctx, func(txCtx context.Context) error {
if err := s.tokenRepo.Upsert(txCtx, &domain.Token{
ConnectionID: connID,
EncryptedData: encryptedData,
ExpiresAt: expiresAt,
}); err != nil {
return ErrInternalWithErr(err, "token_store_failed", "Failed to store tokens")
}
if err := s.connRepo.UpdateStatus(txCtx, connID, "active"); err != nil {
return ErrInternalWithErr(err, "status_update_failed", "Failed to update status")
}
return nil
}); err != nil {
return "", false, err
}

s.connRepo.UpdateStatus(ctx, connID, "active")

if !server.IsReturnURLAllowed(conn.ReturnURL, s.enforceReturnURL, s.allowedReturnDomains) {
return "", false, ErrBadRequest("return_url_not_allowed", "return_url not allowed")
}
Expand All @@ -334,6 +343,13 @@ func (s *connectionService) ExchangeCodeForTokens(ctx context.Context, state, co
return returnURL.String(), hasIDToken, nil
}

func (s *connectionService) inTx(ctx context.Context, fn func(context.Context) error) error {
if runner, ok := s.connRepo.(txRunner); ok {
return runner.InTx(ctx, fn)
}
return fn(ctx)
}

func (s *connectionService) GetTokenByWorkspaceAndProvider(ctx context.Context, workspaceID, providerName string) (map[string]interface{}, string, error) {
conn, err := s.connRepo.GetActiveByWorkspaceAndProvider(ctx, workspaceID, providerName)
if err != nil {
Expand Down
25 changes: 14 additions & 11 deletions nexus-broker/internal/service/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import (
"strings"
"time"

"github.com/google/uuid"
"github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain"
"github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/auth"
"github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/vault"
"github.com/google/uuid"
)

type RefreshResponse struct {
Expand Down Expand Up @@ -84,16 +84,19 @@ func (s *connectionService) SaveCredential(ctx context.Context, state string, cr
return "", ErrInternalWithErr(err, "encryption_failed", "Failed to encrypt credentials")
}

err = s.tokenRepo.Upsert(ctx, &domain.Token{
ConnectionID: connID,
EncryptedData: encryptedData,
})
if err != nil {
return "", ErrInternalWithErr(err, "credential_store_failed", "Failed to store credentials")
}

if err := s.connRepo.UpdateStatus(ctx, connID, "active"); err != nil {
return "", ErrInternalWithErr(err, "status_update_failed", "Failed to update status")
if err := s.inTx(ctx, func(txCtx context.Context) error {
if err := s.tokenRepo.Upsert(txCtx, &domain.Token{
ConnectionID: connID,
EncryptedData: encryptedData,
}); err != nil {
return ErrInternalWithErr(err, "credential_store_failed", "Failed to store credentials")
}
if err := s.connRepo.UpdateStatus(txCtx, connID, "active"); err != nil {
return ErrInternalWithErr(err, "status_update_failed", "Failed to update status")
}
return nil
}); err != nil {
return "", err
}

returnURL, err := url.Parse(conn.ReturnURL)
Expand Down
2 changes: 2 additions & 0 deletions nexus-broker/pkg/handlers/soc2_compliance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,15 @@ func TestSOC2_CC61_EncryptionAtRest(t *testing.T) {
WillReturnRows(sqlmock.NewRows([]string{"id", "provider_id", "status", "scopes", "return_url", "name", "auth_type", "auth_header", "api_base_url", "user_info_endpoint", "params", "health_status"}).
AddRow(connID.String(), providerID.String(), "active", "{}", "http://localhost/return", "TestProvider", "api_key", "", "", "", nil, "unknown"))

mock.ExpectBegin()
mock.ExpectExec("INSERT INTO tokens").
WithArgs(connID, sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))

mock.ExpectExec("UPDATE connections SET status").
WithArgs("active", connID).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()

// 3. Fire the request
creds := map[string]interface{}{"api_key": plainTextKey}
Expand Down
Loading