Skip to content

Commit 1a9519e

Browse files
committed
Refactor: address remaining technical debt items
Fixes #30 Fixes #27 Fixes #21
1 parent 4f71ec6 commit 1a9519e

8 files changed

Lines changed: 61 additions & 53 deletions

File tree

nexus-bridge/bridge_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,15 @@ func TestBridge_ContextCancellation(t *testing.T) {
236236
errChan <- bridge.MaintainWebSocket(ctx, "conn-123", "ws"+server.URL[4:], handler)
237237
}()
238238

239-
time.Sleep(100 * time.Millisecond)
239+
time.Sleep(500 * time.Millisecond)
240240
cancel()
241241

242242
select {
243243
case err := <-errChan:
244244
if !errors.Is(err, context.Canceled) {
245245
t.Errorf("Expected context.Canceled error, but got %v", err)
246246
}
247-
case <-time.After(1 * time.Second):
247+
case <-time.After(5 * time.Second):
248248
t.Fatal("Bridge did not exit after context cancellation")
249249
}
250250
if metrics.connectionStatus.Load() != 0.0 {

nexus-broker/cmd/nexus-broker/main.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,13 @@ func main() {
9494
store := provider.NewStore(db)
9595

9696
// Setup handlers
97+
redirectPath := os.Getenv("REDIRECT_PATH")
98+
if redirectPath == "" {
99+
redirectPath = "/auth/callback"
100+
}
97101
providersHandler := handlers.NewProvidersHandler(store)
98-
consentHandler := handlers.NewConsentHandler(db, baseURL, stateKey, cachingClient)
99-
callbackHandler := handlers.NewCallbackHandler(db, encryptionKey, stateKey, cachingClient)
102+
consentHandler := handlers.NewConsentHandler(db, baseURL, redirectPath, stateKey, cachingClient)
103+
callbackHandler := handlers.NewCallbackHandler(db, baseURL, redirectPath, encryptionKey, stateKey, cachingClient)
100104

101105
// Setup routes
102106
router := srv.Router()

nexus-broker/pkg/handlers/callback.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"net"
1010
"net/http"
1111
"net/url"
12-
"os"
1312
"strings"
1413
"time"
1514

@@ -29,6 +28,8 @@ import (
2928
// CallbackHandler handles OAuth callback and token exchange
3029
type CallbackHandler struct {
3130
db *sqlx.DB
31+
baseURL string
32+
redirectPath string
3233
encryptionKey []byte
3334
stateKey []byte
3435
httpClient *http.Client
@@ -40,7 +41,7 @@ type CallbackHandler struct {
4041
}
4142

4243
// NewCallbackHandler creates a new callback handler
43-
func NewCallbackHandler(db *sqlx.DB, encryptionKey, stateKey []byte, httpClient *http.Client) *CallbackHandler {
44+
func NewCallbackHandler(db *sqlx.DB, baseURL, redirectPath string, encryptionKey, stateKey []byte, httpClient *http.Client) *CallbackHandler {
4445
success := prometheus.NewCounter(prometheus.CounterOpts{
4546
Name: "oauth_token_exchanges_total",
4647
Help: "Total OAuth token exchanges",
@@ -76,6 +77,8 @@ func NewCallbackHandler(db *sqlx.DB, encryptionKey, stateKey []byte, httpClient
7677

7778
return &CallbackHandler{
7879
db: db,
80+
baseURL: baseURL,
81+
redirectPath: redirectPath,
7982
encryptionKey: encryptionKey,
8083
stateKey: stateKey,
8184
httpClient: httpClient,
@@ -161,11 +164,8 @@ func (h *CallbackHandler) Handle(w http.ResponseWriter, r *http.Request) {
161164
}
162165

163166
// Compute redirect_uri to match the auth request
164-
redirectPath := os.Getenv("REDIRECT_PATH")
165-
if redirectPath == "" {
166-
redirectPath = "/auth/callback"
167-
}
168-
base := strings.TrimSuffix(os.Getenv("BASE_URL"), "/")
167+
redirectPath := h.redirectPath
168+
base := strings.TrimSuffix(h.baseURL, "/")
169169
redirectURI := base + redirectPath
170170

171171
// Check if provider wants to skip scope on token exchange (e.g., Salesforce rejects it)

nexus-broker/pkg/handlers/callback_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestRefresh_StaticKeyProvider(t *testing.T) {
2626
defer db.Close()
2727

2828
sqlxDB := sqlx.NewDb(db, "sqlmock")
29-
handler := NewCallbackHandler(sqlxDB, []byte("test-key"), []byte("test-key"), http.DefaultClient)
29+
handler := NewCallbackHandler(sqlxDB, "http://localhost:8080", "/auth/callback", []byte("test-key"), []byte("test-key"), http.DefaultClient)
3030

3131
// Mock the initial query to find the connection
3232

@@ -66,7 +66,7 @@ func TestRefresh_OAuth2Provider(t *testing.T) {
6666
}))
6767
defer mockProviderServer.Close()
6868

69-
handler := NewCallbackHandler(sqlxDB, []byte("01234567890123456789012345678901"), []byte("01234567890123456789012345678901"), mockProviderServer.Client())
69+
handler := NewCallbackHandler(sqlxDB, "http://localhost:8080", "/auth/callback", []byte("01234567890123456789012345678901"), []byte("01234567890123456789012345678901"), mockProviderServer.Client())
7070

7171
// Mock the initial query to find the connection
7272

@@ -117,7 +117,7 @@ func TestGetCaptureSchema(t *testing.T) {
117117
sqlxDB := sqlx.NewDb(db, "sqlmock")
118118
// Use a real key for signing/verifying state
119119
stateKey := []byte("01234567890123456789012345678901")
120-
handler := NewCallbackHandler(sqlxDB, nil, stateKey, http.DefaultClient)
120+
handler := NewCallbackHandler(sqlxDB, "http://localhost:8080", "/auth/callback", nil, stateKey, http.DefaultClient)
121121

122122
providerID := uuid.New()
123123
stateData := auth.StateData{
@@ -170,7 +170,7 @@ func TestSaveCredential_ValidState(t *testing.T) {
170170
sqlxDB := sqlx.NewDb(db, "sqlmock")
171171
stateKey := []byte("01234567890123456789012345678901")
172172
encryptionKey := []byte("01234567890123456789012345678901")
173-
handler := NewCallbackHandler(sqlxDB, encryptionKey, stateKey, http.DefaultClient)
173+
handler := NewCallbackHandler(sqlxDB, "http://localhost:8080", "/auth/callback", encryptionKey, stateKey, http.DefaultClient)
174174

175175
connectionID := uuid.New()
176176
stateData := auth.StateData{
@@ -224,7 +224,7 @@ func TestSaveCredential_ValidState(t *testing.T) {
224224
}
225225

226226
func TestSaveCredential_InvalidState(t *testing.T) {
227-
handler := NewCallbackHandler(nil, nil, []byte("test-key"), http.DefaultClient)
227+
handler := NewCallbackHandler(nil, "http://localhost:8080", "/auth/callback", nil, []byte("test-key"), http.DefaultClient)
228228

229229
creds := map[string]interface{}{"api_key": "test-key"}
230230
body := map[string]interface{}{
@@ -245,7 +245,7 @@ func TestSaveCredential_InvalidState(t *testing.T) {
245245
}
246246

247247
func TestSaveCredential_InvalidJSON(t *testing.T) {
248-
handler := NewCallbackHandler(nil, nil, nil, http.DefaultClient)
248+
handler := NewCallbackHandler(nil, "http://localhost:8080", "/auth/callback", nil, nil, http.DefaultClient)
249249

250250
req, err := http.NewRequest("POST", "/auth/capture-credential", bytes.NewBuffer([]byte("not-json")))
251251
assert.NoError(t, err)

nexus-broker/pkg/handlers/consent.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"log"
88
"net/http"
99
"net/url"
10-
"os"
1110
"strings"
1211
"time"
1312

@@ -34,14 +33,15 @@ type ConsentSpec struct {
3433
type ConsentHandler struct {
3534
db *sqlx.DB
3635
baseURL string
36+
redirectPath string
3737
stateKey []byte
3838
httpClient *http.Client
3939
consentsMetric prometheus.Counter
4040
consentsOpenID prometheus.Counter
4141
}
4242

4343
// NewConsentHandler creates a new consent handler
44-
func NewConsentHandler(db *sqlx.DB, baseURL string, stateKey []byte, httpClient *http.Client) *ConsentHandler {
44+
func NewConsentHandler(db *sqlx.DB, baseURL, redirectPath string, stateKey []byte, httpClient *http.Client) *ConsentHandler {
4545
metric := prometheus.NewCounter(prometheus.CounterOpts{
4646
Name: "oauth_consents_created_total",
4747
Help: "Total OAuth consents created",
@@ -63,6 +63,7 @@ func NewConsentHandler(db *sqlx.DB, baseURL string, stateKey []byte, httpClient
6363
return &ConsentHandler{
6464
db: db,
6565
baseURL: baseURL,
66+
redirectPath: redirectPath,
6667
stateKey: stateKey,
6768
httpClient: httpClient,
6869
consentsMetric: metric,
@@ -246,10 +247,7 @@ func (h *ConsentHandler) GetSpec(w http.ResponseWriter, r *http.Request) {
246247
// buildAuthURL constructs the OAuth authorization URL
247248
func (h *ConsentHandler) buildAuthURL(providerAuthURL, clientID, state, codeChallenge string, scopes []string, providerParams *json.RawMessage) (string, error) {
248249
baseURL := strings.TrimSuffix(h.baseURL, "/")
249-
redirectPath := os.Getenv("REDIRECT_PATH")
250-
if redirectPath == "" {
251-
redirectPath = "/auth/callback"
252-
}
250+
redirectPath := h.redirectPath
253251

254252
if providerAuthURL == "" {
255253
return "", fmt.Errorf("provider auth_url is required for OAuth2")

nexus-broker/pkg/handlers/consent_test.go

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,13 @@ func TestGetSpec_OAuth2(t *testing.T) {
3232
defer mockProviderServer.Close()
3333

3434
// Pass the test server's client to the handler
35-
handler := NewConsentHandler(sqlxDB, "http://localhost:8080", []byte("test-key"), mockProviderServer.Client())
35+
handler := NewConsentHandler(sqlxDB, "http://localhost:8080", "/auth/callback", []byte("test-key"), mockProviderServer.Client())
3636

3737
paramsJSON := []byte(`{"access_type": "offline", "prompt": "consent"}`)
3838

39-
40-
rows := sqlmock.NewRows([]string{"id", "name", "auth_type", "auth_url", "client_id", "scopes", "params"}).
39+
rows := sqlmock.NewRows([]string{"id", "name", "auth_type", "auth_url", "client_id", "scopes", "params"}).
4140
AddRow("a0a0a0a0-a0a0-a0a0-a0a0-a0a0a0a0a0a0", "Test OAuth2 Provider", "oauth2", "http://provider.com/auth", "test-client-id", "{openid}", paramsJSON)
42-
mock.ExpectQuery("SELECT id, name, auth_type, auth_url, client_id, scopes, params FROM provider_profiles WHERE id = \\$1").
41+
mock.ExpectQuery("SELECT id, name, auth_type, auth_url, client_id, scopes, params FROM provider_profiles WHERE id = \\$1").
4342
WithArgs("a0a0a0a0-a0a0-a0a0-a0a0-a0a0a0a0a0a0").
4443
WillReturnRows(rows)
4544

@@ -82,12 +81,11 @@ func TestGetSpec_StaticKey(t *testing.T) {
8281

8382
sqlxDB := sqlx.NewDb(db, "sqlmock")
8483
// For static key tests, we can pass a default client as no external calls are made.
85-
handler := NewConsentHandler(sqlxDB, "http://localhost:8080", []byte("test-key"), http.DefaultClient)
86-
84+
handler := NewConsentHandler(sqlxDB, "http://localhost:8080", "/auth/callback", []byte("test-key"), http.DefaultClient)
8785

88-
rows := sqlmock.NewRows([]string{"id", "name", "auth_type", "auth_url", "client_id", "scopes", "params"}).
86+
rows := sqlmock.NewRows([]string{"id", "name", "auth_type", "auth_url", "client_id", "scopes", "params"}).
8987
AddRow("b1b1b1b1-b1b1-b1b1-b1b1-b1b1b1b1b1b1", "Test API", "api_key", nil, nil, "{}", []byte("{}"))
90-
mock.ExpectQuery("SELECT id, name, auth_type, auth_url, client_id, scopes, params FROM provider_profiles WHERE id = \\$1").
88+
mock.ExpectQuery("SELECT id, name, auth_type, auth_url, client_id, scopes, params FROM provider_profiles WHERE id = \\$1").
9189
WithArgs("b1b1b1b1-b1b1-b1b1-b1b1-b1b1b1b1b1b1").
9290
WillReturnRows(rows)
9391

@@ -134,7 +132,7 @@ func TestGetSpec_MixedOAuth2_Discovery(t *testing.T) {
134132
"authorization_endpoint": "http://%s/openid/connect/authorize",
135133
"jwks_uri": "http://%s/jwks"
136134
}`,
137-
r.Host, r.Host, r.Host)
135+
r.Host, r.Host, r.Host)
138136
w.Write([]byte(oidcConfig))
139137
return
140138
}
@@ -143,16 +141,16 @@ func TestGetSpec_MixedOAuth2_Discovery(t *testing.T) {
143141
defer ts.Close()
144142

145143
// Handler under test
146-
handler := NewConsentHandler(sqlxDB, "http://localhost:8080", []byte("test-key"), ts.Client())
144+
handler := NewConsentHandler(sqlxDB, "http://localhost:8080", "/auth/callback", []byte("test-key"), ts.Client())
147145

148146
// Define the configured (legacy) auth URL
149147
configuredAuthURL := ts.URL + "/oauth/v2/authorize"
150148

151149
// 1. Mock DB Provider Query
152150

153-
rows := sqlmock.NewRows([]string{"id", "name", "auth_type", "auth_url", "client_id", "scopes", "params"}).
151+
rows := sqlmock.NewRows([]string{"id", "name", "auth_type", "auth_url", "client_id", "scopes", "params"}).
154152
AddRow("00000000-0000-0000-0000-000000000000", "Slack", "oauth2", configuredAuthURL, "slack-client", "{chat:write}", []byte("{}"))
155-
153+
156154
// Use regex to avoid strict string matching issues with sqlmock
157155
mock.ExpectQuery("SELECT .* FROM provider_profiles WHERE id = .*").
158156
WithArgs("00000000-0000-0000-0000-000000000000").
@@ -192,4 +190,4 @@ rows := sqlmock.NewRows([]string{"id", "name", "auth_type", "auth_url", "client_
192190
if !strings.HasPrefix(response.AuthURL, configuredAuthURL) {
193191
t.Errorf("Expected AuthURL to start with configured URL %s, but got %s", configuredAuthURL, response.AuthURL)
194192
}
195-
}
193+
}

nexus-gateway/pkg/grpc/server_grpc.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ func NewService(handler *usecase.Handler) *Service {
3131
return &Service{usecaseHandler: handler}
3232
}
3333

34+
func mapUsecaseError(err error, msg string) error {
35+
switch {
36+
case errors.Is(err, usecase.ErrProviderNotFound):
37+
return status.Errorf(codes.NotFound, "%s: %v", msg, err)
38+
case errors.Is(err, usecase.ErrInvalidState):
39+
return status.Errorf(codes.InvalidArgument, "%s: %v", msg, err)
40+
case errors.Is(err, usecase.ErrProviderAmbiguous):
41+
return status.Errorf(codes.FailedPrecondition, "%s: %v", msg, err)
42+
case errors.Is(err, usecase.ErrBrokerUnavailable):
43+
return status.Errorf(codes.Unavailable, "%s: %v", msg, err)
44+
default:
45+
return status.Errorf(codes.Internal, "%s: %v", msg, err)
46+
}
47+
}
48+
3449
// RequestConnection implements NexusServiceServer.RequestConnection.
3550
func (s *Service) RequestConnection(ctx context.Context, req *nexuspb.RequestConnectionRequest) (*nexuspb.RequestConnectionResponse, error) {
3651
if req == nil {
@@ -45,7 +60,7 @@ func (s *Service) RequestConnection(ctx context.Context, req *nexuspb.RequestCon
4560
Action: req.GetAction(),
4661
})
4762
if err != nil {
48-
return nil, status.Errorf(codes.Internal, "request connection failed: %v", err)
63+
return nil, mapUsecaseError(err, "request connection failed")
4964
}
5065
return &nexuspb.RequestConnectionResponse{
5166
AuthUrl: out.AuthURL,
@@ -64,7 +79,7 @@ func (s *Service) CheckConnection(ctx context.Context, req *nexuspb.CheckConnect
6479
}
6580
statusStr, err := s.usecaseHandler.CheckConnectionCore(ctx, req.GetConnectionId())
6681
if err != nil {
67-
return nil, status.Errorf(codes.Internal, "check connection failed: %v", err)
82+
return nil, mapUsecaseError(err, "check connection failed")
6883
}
6984
return &nexuspb.CheckConnectionResponse{Status: statusStr}, nil
7085
}
@@ -77,7 +92,7 @@ func (s *Service) GetToken(ctx context.Context, req *nexuspb.GetTokenRequest) (*
7792
data, code, err := s.usecaseHandler.GetTokenCore(ctx, req.GetConnectionId())
7893
if err != nil {
7994
_ = code // keep the HTTP status for potential mapping if needed later
80-
return nil, status.Errorf(codes.Internal, "get token failed: %v", err)
95+
return nil, mapUsecaseError(err, "get token failed")
8196
}
8297
st, err := structpb.NewStruct(data)
8398
if err != nil {
@@ -94,7 +109,7 @@ func (s *Service) RefreshConnection(ctx context.Context, req *nexuspb.RefreshCon
94109
data, code, err := s.usecaseHandler.RefreshConnectionCore(ctx, req.GetConnectionId())
95110
if err != nil {
96111
_ = code // unused
97-
return nil, status.Errorf(codes.Internal, "refresh connection failed: %v", err)
112+
return nil, mapUsecaseError(err, "refresh connection failed")
98113
}
99114
st, err := structpb.NewStruct(data)
100115
if err != nil {

nexus-sdk/client.go

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,7 @@ func (c *Client) GetToken(ctx context.Context, connectionID string) (*TokenRespo
149149
defer resp.Body.Close()
150150
var raw map[string]any
151151
if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil { return nil, err }
152-
tr := &TokenResponse{Raw: raw}
153-
if v, ok := raw["access_token"].(string); ok { tr.AccessToken = v }
154-
if v, ok := raw["token_type"].(string); ok { tr.TokenType = &v }
155-
if v, ok := raw["expires_in"].(float64); ok { vv := int64(v); tr.ExpiresIn = &vv }
156-
if v, ok := raw["expires_at"]; ok { tr.ExpiresAt = v }
157-
if v, ok := raw["scope"].(string); ok { tr.Scope = &v }
158-
if v, ok := raw["id_token"].(string); ok { tr.IDToken = &v }
159-
if v, ok := raw["refresh_token"].(string); ok { tr.RefreshToken = &v }
160-
if v, ok := raw["provider"].(string); ok { tr.Provider = &v }
161-
if v, ok := raw["strategy"].(map[string]interface{}); ok { tr.Strategy = v }
162-
if v, ok := raw["credentials"].(map[string]interface{}); ok { tr.Credentials = v }
163-
return tr, nil
152+
return parseTokenResponse(raw), nil
164153
}
165154

166155
// RefreshConnection calls the Gateway to force a token refresh.
@@ -171,6 +160,10 @@ func (c *Client) RefreshConnection(ctx context.Context, connectionID string) (*T
171160
defer resp.Body.Close()
172161
var raw map[string]any
173162
if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil { return nil, err }
163+
return parseTokenResponse(raw), nil
164+
}
165+
166+
func parseTokenResponse(raw map[string]any) *TokenResponse {
174167
tr := &TokenResponse{Raw: raw}
175168
if v, ok := raw["access_token"].(string); ok { tr.AccessToken = v }
176169
if v, ok := raw["token_type"].(string); ok { tr.TokenType = &v }
@@ -182,7 +175,7 @@ func (c *Client) RefreshConnection(ctx context.Context, connectionID string) (*T
182175
if v, ok := raw["provider"].(string); ok { tr.Provider = &v }
183176
if v, ok := raw["strategy"].(map[string]interface{}); ok { tr.Strategy = v }
184177
if v, ok := raw["credentials"].(map[string]interface{}); ok { tr.Credentials = v }
185-
return tr, nil
178+
return tr
186179
}
187180

188181
// RefreshViaBroker calls RefreshConnection (Gateway Proxy).

0 commit comments

Comments
 (0)