Skip to content

Commit 355b582

Browse files
committed
docs: improve shared secrets documentation and harden security
Fixed 14 issues across Broker, Gateway, Bridge, and SDK: - Hardened CORS configuration in Gateway (#12) - Enforced shared STATE_KEY and ENCRYPTION_KEY across services (#13) - Fixed duplicate route registration in Broker (#15) - Optimized HTTP Client usage in Broker (#16) - Fixed Open Redirect vulnerability in SaveCredential (#17) - Corrected misleading indentation and fixed tests for chi router (#18, #24) - Added proper seeding for math/rand in Bridge and SDK (#19) - Implemented exponential backoff for Bridge reconnections (#20) - Removed dead code in Gateway (#22) - Fixed fragile URL path parsing in Gateway (#23) - Fixed flaky tests and ensured proper cleanup (#28) - Implemented Provider Name Cache in Gateway (#29) - Captured and logged ignored json.Encoder errors (#32) All tests passed successfully.
1 parent 3c3e1bb commit 355b582

16 files changed

Lines changed: 244 additions & 94 deletions

File tree

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22

33
The Nexus Framework is a provider-agnostic, secure integration layer for managing OAuth 2.0 and OIDC connections. It abstracts away the complexity of managing tokens, refreshes, and provider quirks, allowing your agents and services to focus on business logic.
44

5+
## ⚠️ Critical Configuration
6+
7+
The Nexus Framework requires two primary shared secrets to operate securely:
8+
9+
1. **`ENCRYPTION_KEY`**: A 32-byte key used by the Broker to encrypt tokens at rest.
10+
2. **`STATE_KEY`**: A 32-byte key shared between the Broker and Gateway to sign and verify the OAuth `state` parameter.
11+
12+
**Both services will refuse to start if these variables are missing or invalid.** In distributed deployments, the `STATE_KEY` **must** be identical across all Broker and Gateway instances, or OAuth callbacks will fail with "Invalid state" errors.
13+
14+
Generate a secure key with: `openssl rand -base64 32`
15+
516
## Quick Start
617

718
The fastest way to get started is with Docker Compose. This will spin up the Broker, Gateway, Postgres, and Redis.

docs/deployment.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ Required — the broker will not start without these:
2121
- `STATE_KEY` **(REQUIRED)**: Same as Broker — must match exactly.
2222
- `BROKER_API_KEY`: Key to authenticate with the Broker.
2323

24+
## Shared Secrets Management
25+
26+
The Nexus Framework relies on two primary symmetric keys. Proper management of these keys is critical for security and availability.
27+
28+
### 1. `ENCRYPTION_KEY` (AES-256-GCM)
29+
Used to encrypt decrypted OAuth tokens before they are stored in the database.
30+
- **Risk**: If this key is changed or lost, all existing connections in the database will become unreadable. You will be forced to rotate the key and ask all users to re-authenticate.
31+
- **Guidance**: Store this in a secure vault (Azure Key Vault, AWS Secrets Manager, HashiCorp Vault). It should be stable across deployments.
32+
33+
### 2. `STATE_KEY` (HMAC-SHA256)
34+
Used to sign the `state` parameter during the initial redirect and verify it on callback.
35+
- **Risk**: If the Broker and Gateway have different keys, all OAuth callbacks will fail with "Invalid state" errors.
36+
- **Guidance**: Both the Broker and Gateway instances must receive the exact same value. In orchestrated environments (Kubernetes, Docker Swarm), use a shared Secret object.
37+
2438
## Local Development (Quickstart)
2539

2640
### 0. Generate required keys

nexus-bridge/bridge.go

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type Bridge struct {
3737
messageSizeLimit int64
3838
writeTimeout time.Duration
3939
pingInterval time.Duration
40+
randSource *rand.Rand
4041
}
4142

4243
// New creates a new Bridge with optional configurations.
@@ -56,6 +57,7 @@ func New(oauthClient OAuthClient, opts ...Option) *Bridge {
5657
messageSizeLimit: 65536, // 64KB
5758
writeTimeout: 10 * time.Second,
5859
pingInterval: 30 * time.Second,
60+
randSource: rand.New(rand.NewSource(time.Now().UnixNano())),
5961
}
6062

6163
// Apply all the functional options provided by the user
@@ -83,7 +85,9 @@ func NewStandard(oauthClient OAuthClient, agentLabels map[string]string, opts ..
8385
// MaintainWebSocket is the main entry point. It runs a loop that attempts
8486
// to establish and manage a connection, with a backoff policy for retries.
8587
func (b *Bridge) MaintainWebSocket(ctx context.Context, connectionID string, endpointURL string, handler Handler) error {
88+
attempt := 0
8689
for {
90+
start := time.Now()
8791
err := b.manageConnection(ctx, connectionID, endpointURL, handler)
8892
if err != nil {
8993
var permanentErr *PermanentError
@@ -95,15 +99,21 @@ func (b *Bridge) MaintainWebSocket(ctx context.Context, connectionID string, end
9599
b.logger.Error(err, "Connection manager exited with recoverable error", "connectionID", connectionID)
96100
}
97101

102+
// Reset attempt counter if the connection was stable for a while (e.g., 1 minute)
103+
if time.Since(start) > 1*time.Minute {
104+
attempt = 0
105+
}
106+
98107
select {
99108
case <-ctx.Done():
100109
b.logger.Info("Context cancelled; shutting down bridge", "connectionID", connectionID)
101110
b.metrics.SetConnectionStatus(0)
102111
return ctx.Err()
103112
default:
104113
// Connection dropped for a recoverable reason, wait and retry.
105-
backoff := b.calculateBackoff()
106-
b.logger.Info("Reconnecting", "connectionID", connectionID, "after", backoff)
114+
backoff := b.calculateBackoff(attempt)
115+
attempt++
116+
b.logger.Info("Reconnecting", "connectionID", connectionID, "after", backoff, "attempt", attempt)
107117
time.Sleep(backoff)
108118
}
109119
}
@@ -119,33 +129,21 @@ func (b *Bridge) MaintainGRPCConnection(
119129
run func(ctx context.Context, conn *grpc.ClientConn) error,
120130
opts ...grpc.DialOption,
121131
) error {
132+
attempt := 0
122133
for {
134+
start := time.Now()
123135
// 1. Prepare Credentials
124136
// We use our custom PerRPCCredentials implementation
125137
creds := NewBridgeCredentials(b.oauthClient, connectionID, b.refreshBuffer, b.logger)
126138

127139
// 2. Dial Options
128-
// Default to TransportCredentials (TLS) usually, but allow insecure via opts if needed.
129-
// However, PerRPCCredentials usually requires TLS.
130-
// For simplicity/testing, we default to insecure if no transport creds provided,
131-
// BUT BridgeCredentials.RequireTransportSecurity returns true, so gRPC will fail if insecure.
132-
// We append our creds to the user provided options.
133140
dialOpts := append(opts, grpc.WithPerRPCCredentials(creds))
134141

135-
// If user didn't provide transport credentials, we might need to add insecure for testing
136-
// OR we assume user provides WithTransportCredentials.
137-
// Let's assume user provides transport security options in 'opts'.
138-
// But for a robust default, we check? No, we can't easily inspect DialOptions.
139-
// We rely on the user to provide transport security (e.g. credentials.NewTLS) in 'opts'
140-
// if they are connecting to a secure endpoint.
141-
142142
// 3. Dial
143143
b.logger.Info("Dialing gRPC target", "target", target)
144-
// Note: grpc.NewClient is the modern API, but Dial is still common. Using NewClient.
145144
conn, err := grpc.NewClient(target, dialOpts...)
146145
if err != nil {
147146
b.logger.Error(err, "Failed to dial gRPC target", "target", target)
148-
// Dial errors are usually retryable (e.g. DNS)
149147
goto Retry
150148
}
151149

@@ -180,13 +178,19 @@ func (b *Bridge) MaintainGRPCConnection(
180178
b.logger.Info("gRPC run loop exited cleanly", "connectionID", connectionID)
181179
}
182180

181+
// Reset attempt counter if the connection was stable for a while
182+
if time.Since(start) > 1*time.Minute {
183+
attempt = 0
184+
}
185+
183186
Retry:
184187
select {
185188
case <-ctx.Done():
186189
return ctx.Err()
187190
default:
188-
backoff := b.calculateBackoff()
189-
b.logger.Info("Reconnecting gRPC", "after", backoff)
191+
backoff := b.calculateBackoff(attempt)
192+
attempt++
193+
b.logger.Info("Reconnecting gRPC", "after", backoff, "attempt", attempt)
190194
time.Sleep(backoff)
191195
}
192196
}
@@ -378,10 +382,18 @@ func (b *Bridge) manageConnection(ctx context.Context, connectionID string, endp
378382
}
379383

380384
// NEW: Helper function for calculating backoff with jitter.
381-
func (b *Bridge) calculateBackoff() time.Duration {
382-
backoff := b.retryPolicy.MinBackoff + time.Duration(rand.Int63n(int64(b.retryPolicy.Jitter)))
383-
if backoff > b.retryPolicy.MaxBackoff {
384-
return b.retryPolicy.MaxBackoff
385+
func (b *Bridge) calculateBackoff(attempt int) time.Duration {
386+
if attempt < 0 {
387+
attempt = 0
388+
}
389+
if attempt > 10 {
390+
attempt = 10
391+
}
392+
factor := 1 << uint(attempt)
393+
base := float64(b.retryPolicy.MinBackoff) * float64(factor)
394+
if base > float64(b.retryPolicy.MaxBackoff) {
395+
base = float64(b.retryPolicy.MaxBackoff)
385396
}
386-
return backoff
397+
jitter := 0.2 + b.randSource.Float64()*0.6 // 0.2..0.8
398+
return time.Duration(base * jitter)
387399
}

nexus-bridge/bridge_test.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,31 @@ func (m *mockMetrics) SetConnectionStatus(status float64) { m.connectionStatus.S
7474

7575
// testLogger is a mock implementation of the Logger interface for testing.
7676
type testLogger struct {
77-
t *testing.T
77+
t *testing.T
78+
mu sync.RWMutex
79+
closed bool
80+
}
81+
82+
func (l *testLogger) stop() {
83+
l.mu.Lock()
84+
defer l.mu.Unlock()
85+
l.closed = true
7886
}
7987

8088
func (l *testLogger) Info(msg string, keysAndValues ...interface{}) {
81-
l.t.Logf("INFO: %s %v", msg, keysAndValues)
89+
l.mu.RLock()
90+
defer l.mu.RUnlock()
91+
if !l.closed {
92+
l.t.Logf("INFO: %s %v", msg, keysAndValues)
93+
}
8294
}
8395

8496
func (l *testLogger) Error(err error, msg string, keysAndValues ...interface{}) {
85-
l.t.Logf("ERROR: %s %v err: %v", msg, keysAndValues, err)
97+
l.mu.RLock()
98+
defer l.mu.RUnlock()
99+
if !l.closed {
100+
l.t.Logf("ERROR: %s %v err: %v", msg, keysAndValues, err)
101+
}
86102
}
87103

88104
var upgrader = websocket.Upgrader{}
@@ -442,4 +458,6 @@ func TestBridge_TokenRefreshWithoutDisconnect(t *testing.T) {
442458
if metrics.connectionStatus.Load() != 1.0 {
443459
t.Errorf("Expected connection status to be 1, but got %v", metrics.connectionStatus.Load())
444460
}
461+
462+
logger.stop()
445463
}

nexus-broker/cmd/nexus-broker/main.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ func main() {
115115
r.Get("/by-name/{name}", providersHandler.GetByName)
116116
r.Delete("/by-name/{name}", providersHandler.DeleteByName)
117117
r.Get("/{id}", providersHandler.Get)
118-
r.Get("/{id}", providersHandler.Get)
119118
r.Put("/{id}", providersHandler.Update)
120119
r.Patch("/{id}", providersHandler.Patch)
121120
r.Delete("/{id}", providersHandler.Delete)

nexus-broker/pkg/handlers/callback.go

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ import (
66
"encoding/json"
77
"fmt"
88
"io"
9+
"log"
910
"net"
1011
"net/http"
1112
"net/url"
1213
"os"
1314
"strings"
1415
"time"
1516

17+
"github.com/go-chi/chi/v5"
1618
"github.com/google/uuid"
1719
"github.com/jmoiron/sqlx"
1820
"github.com/lib/pq"
@@ -301,7 +303,9 @@ func (h *CallbackHandler) GetCaptureSchema(w http.ResponseWriter, r *http.Reques
301303
}
302304

303305
w.Header().Set("Content-Type", "application/json")
304-
json.NewEncoder(w).Encode(response)
306+
if err := json.NewEncoder(w).Encode(response); err != nil {
307+
log.Printf("encode response: %v", err)
308+
}
305309
}
306310

307311
// SaveCredential handles the submission of the credential capture form.
@@ -348,7 +352,7 @@ func (h *CallbackHandler) SaveCredential(w http.ResponseWriter, r *http.Request)
348352
}
349353

350354
if userInfoEndpoint != "" && apiBaseURL != "" {
351-
if err := validateCredentials(authType, authHeader, apiBaseURL, userInfoEndpoint, reqBody.Credentials); err != nil {
355+
if err := h.validateCredentials(authType, authHeader, apiBaseURL, userInfoEndpoint, reqBody.Credentials); err != nil {
352356
http.Error(w, "Invalid credentials: "+err.Error(), http.StatusBadRequest)
353357
return
354358
}
@@ -365,11 +369,16 @@ func (h *CallbackHandler) SaveCredential(w http.ResponseWriter, r *http.Request)
365369
return
366370
}
367371

372+
if !server.IsReturnURLAllowed(returnURL) {
373+
http.Error(w, "return_url not allowed", http.StatusBadRequest)
374+
return
375+
}
376+
368377
http.Redirect(w, r, returnURL+"?status=success&connection_id="+connectionID.String(), http.StatusFound)
369378
}
370379

371380
// validateCredentials makes a test call to the provider's user_info_endpoint to verify the submitted credentials.
372-
func validateCredentials(authType, authHeader, apiBaseURL, userInfoEndpoint string, credentials map[string]interface{}) error {
381+
func (h *CallbackHandler) validateCredentials(authType, authHeader, apiBaseURL, userInfoEndpoint string, credentials map[string]interface{}) error {
373382
testURL := strings.TrimRight(apiBaseURL, "/") + "/" + strings.TrimLeft(userInfoEndpoint, "/")
374383

375384
req, err := http.NewRequest(http.MethodGet, testURL, nil)
@@ -404,8 +413,7 @@ func validateCredentials(authType, authHeader, apiBaseURL, userInfoEndpoint stri
404413
return nil
405414
}
406415

407-
client := &http.Client{Timeout: 10 * time.Second}
408-
resp, err := client.Do(req)
416+
resp, err := h.httpClient.Do(req)
409417
if err != nil {
410418
return fmt.Errorf("could not reach provider to validate credentials")
411419
}
@@ -432,12 +440,7 @@ func containsScope(scopes []string, target string) bool {
432440
// GetToken handles GET /connections/{connection_id}/token
433441
func (h *CallbackHandler) GetToken(w http.ResponseWriter, r *http.Request) {
434442
// Extract connection ID from URL path
435-
pathParts := strings.Split(r.URL.Path, "/")
436-
if len(pathParts) < 3 {
437-
http.Error(w, "Invalid path", http.StatusBadRequest)
438-
return
439-
}
440-
connectionIDStr := pathParts[len(pathParts)-2] // /connections/{id}/token
443+
connectionIDStr := chi.URLParam(r, "connectionID")
441444

442445
connectionID, err := uuid.Parse(connectionIDStr)
443446
if err != nil {
@@ -472,10 +475,12 @@ func (h *CallbackHandler) GetToken(w http.ResponseWriter, r *http.Request) {
472475
if connection.Status == "attention" {
473476
w.Header().Set("Content-Type", "application/json")
474477
w.WriteHeader(http.StatusConflict)
475-
json.NewEncoder(w).Encode(map[string]string{
478+
if err := json.NewEncoder(w).Encode(map[string]string{
476479
"error": "attention_required",
477480
"detail": "Connection requires attention. The user must re-authenticate.",
478-
})
481+
}); err != nil {
482+
log.Printf("encode response: %v", err)
483+
}
479484
return
480485
}
481486

@@ -574,7 +579,9 @@ func (h *CallbackHandler) GetToken(w http.ResponseWriter, r *http.Request) {
574579

575580
// Return the response
576581
w.Header().Set("Content-Type", "application/json")
577-
json.NewEncoder(w).Encode(response)
582+
if err := json.NewEncoder(w).Encode(response); err != nil {
583+
log.Printf("encode response: %v", err)
584+
}
578585
}
579586

580587
// exchangeCodeForTokens exchanges authorization code for access tokens
@@ -621,8 +628,7 @@ func (h *CallbackHandler) exchangeCodeForTokens(tokenURL, clientID, clientSecret
621628
req.SetBasicAuth(clientID, clientSecret)
622629
}
623630

624-
client := &http.Client{Timeout: 30 * time.Second}
625-
resp, err := client.Do(req)
631+
resp, err := h.httpClient.Do(req)
626632
if err != nil {
627633
return nil, err
628634
}
@@ -656,8 +662,7 @@ func (h *CallbackHandler) refreshTokens(tokenURL, clientID, clientSecret, refres
656662
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
657663
req.Header.Set("Accept", "application/json") // Ensure JSON response
658664

659-
client := &http.Client{Timeout: 30 * time.Second}
660-
resp, err := client.Do(req)
665+
resp, err := h.httpClient.Do(req)
661666
if err != nil {
662667
return nil, 0, err
663668
}
@@ -678,13 +683,8 @@ func (h *CallbackHandler) refreshTokens(tokenURL, clientID, clientSecret, refres
678683
// Refresh handles POST /connections/{connection_id}/refresh
679684
func (h *CallbackHandler) Refresh(w http.ResponseWriter, r *http.Request) {
680685
// Extract connection ID
681-
parts := strings.Split(r.URL.Path, "/")
682-
if len(parts) < 3 {
683-
http.Error(w, "Invalid path", http.StatusBadRequest)
684-
return
685-
}
686-
idStr := parts[len(parts)-2]
687-
connectionID, err := uuid.Parse(idStr)
686+
connectionIDStr := chi.URLParam(r, "connectionID")
687+
connectionID, err := uuid.Parse(connectionIDStr)
688688
if err != nil {
689689
http.Error(w, "Invalid connection ID", http.StatusBadRequest)
690690
return
@@ -756,10 +756,12 @@ func (h *CallbackHandler) Refresh(w http.ResponseWriter, r *http.Request) {
756756

757757
w.Header().Set("Content-Type", "application/json")
758758
w.WriteHeader(http.StatusConflict) // 409 Conflict is a good signal for "state issue"
759-
json.NewEncoder(w).Encode(map[string]string{
759+
if err := json.NewEncoder(w).Encode(map[string]string{
760760
"error": "attention_required",
761761
"detail": "The connection credentials are invalid or expired and cannot be refreshed. User re-consent is required.",
762-
})
762+
}); err != nil {
763+
log.Printf("encode response: %v", err)
764+
}
763765
return
764766
}
765767

@@ -773,7 +775,9 @@ func (h *CallbackHandler) Refresh(w http.ResponseWriter, r *http.Request) {
773775
return
774776
}
775777
w.Header().Set("Content-Type", "application/json")
776-
json.NewEncoder(w).Encode(newTokens)
778+
if err := json.NewEncoder(w).Encode(newTokens); err != nil {
779+
log.Printf("encode response: %v", err)
780+
}
777781
default:
778782
http.Error(w, "Unsupported provider auth_type", http.StatusInternalServerError)
779783
return

0 commit comments

Comments
 (0)