diff --git a/.mockery.yaml b/.mockery.yaml index f2ae8be1b93..d76cc4b6c1d 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -27,7 +27,7 @@ packages: DonSubscriber: github.com/smartcontractkit/chainlink/v2/core/capabilities/vault: interfaces: - RequestAuthorizer: + Authorizer: github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/vaulttypes: interfaces: SecretsService: @@ -434,4 +434,3 @@ packages: github.com/smartcontractkit/chainlink/v2/core/services/workflows/metering: interfaces: BillingClient: - diff --git a/core/capabilities/vault/allow_list_based_auth.go b/core/capabilities/vault/allow_list_based_auth.go new file mode 100644 index 00000000000..bedf02269c4 --- /dev/null +++ b/core/capabilities/vault/allow_list_based_auth.go @@ -0,0 +1,139 @@ +package vault + +import ( + "context" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "time" + + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-evm/gethwrappers/workflow/generated/workflow_registry_wrapper_v2" + workflowsyncerv2 "github.com/smartcontractkit/chainlink/v2/core/services/workflows/syncer/v2" +) + +const ( + allowListBasedAuthRetryCount = 3 + allowListBasedAuthRetryInterval = 3 * time.Second +) + +type allowListBasedAuth struct { + workflowRegistrySyncer workflowsyncerv2.WorkflowRegistrySyncer + lggr logger.Logger + retryCount int + retryInterval time.Duration +} + +// AuthorizeRequest authorizes a request using AllowListBasedAuth. +// It does NOT check if the request method is allowed. +func (r *allowListBasedAuth) AuthorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (*AuthResult, error) { + r.lggr.Debugw("AllowListBasedAuth authorizing request", "method", req.Method, "requestID", req.ID) + requestDigest, err := req.Digest() + if err != nil { + r.lggr.Debugw("AllowListBasedAuth failed to create digest", "method", req.Method, "requestID", req.ID, "error", err) + return nil, err + } + requestDigestBytes, err := hex.DecodeString(requestDigest) + if err != nil { + r.lggr.Debugw("AllowListBasedAuth failed to decode digest", "method", req.Method, "requestID", req.ID, "requestDigest", requestDigest, "error", err) + return nil, err + } + requestDigestBytes32 := [32]byte(requestDigestBytes) + if r.workflowRegistrySyncer == nil { + r.lggr.Errorw("AllowListBasedAuth workflowRegistrySyncer is nil", "method", req.Method, "requestID", req.ID) + return nil, errors.New("internal error: workflowRegistrySyncer is nil") + } + allowlistedRequest, allowedRequestsStrs, err := r.findAllowlistedItemWithRetry(ctx, req, requestDigest, requestDigestBytes32) + if err != nil { + return nil, err + } + if allowlistedRequest == nil { + r.lggr.Debugw("AllowListBasedAuth request digest not allowlisted", + "method", req.Method, + "requestID", req.ID, + "digestHexStr", requestDigest, + "allowedRequestsStrs", allowedRequestsStrs) + return nil, errors.New("request not allowlisted") + } + + if time.Now().UTC().Unix() > int64(allowlistedRequest.ExpiryTimestamp) { + authorizedRequestStr := string(allowlistedRequest.RequestDigest[:]) + r.lggr.Debugw("AllowListBasedAuth authorization expired", "method", req.Method, "requestID", req.ID, "authorizedRequestStr", authorizedRequestStr, "expiryTimestamp", allowlistedRequest.ExpiryTimestamp) + return nil, errors.New("request authorization expired") + } + + digestKey := string(allowlistedRequest.RequestDigest[:]) + r.lggr.Debugw("AllowListBasedAuth authorization succeeded", "method", req.Method, "requestID", req.ID, "authorizedRequestStr", digestKey, "owner", allowlistedRequest.Owner.Hex(), "expiryTimestamp", allowlistedRequest.ExpiryTimestamp) + return &AuthResult{ + workflowOwner: allowlistedRequest.Owner.Hex(), + digest: digestKey, + expiresAt: int64(allowlistedRequest.ExpiryTimestamp), + }, nil +} + +func (r *allowListBasedAuth) findAllowlistedItemWithRetry(ctx context.Context, req jsonrpc.Request[json.RawMessage], requestDigest string, requestDigestBytes32 [32]byte) (*workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest, []string, error) { + for attempt := 0; attempt <= r.retryCount; attempt++ { + allowedRequests := r.workflowRegistrySyncer.GetAllowlistedRequests(ctx) + allowedRequestsStrs := make([]string, 0, len(allowedRequests)) + for _, rr := range allowedRequests { + allowedReqStr := fmt.Sprintf("AuthorizedOwner: %s, RequestDigest: %s, ExpiryTimestamp: %d", rr.Owner.Hex(), hex.EncodeToString(rr.RequestDigest[:]), rr.ExpiryTimestamp) + allowedRequestsStrs = append(allowedRequestsStrs, allowedReqStr) + } + r.lggr.Debugw("AllowListBasedAuth loaded allowlisted requests", "method", req.Method, "requestID", req.ID, "attempt", attempt+1, "allowedRequests", allowedRequestsStrs) + + allowlistedRequest := r.fetchAllowlistedItem(allowedRequests, requestDigestBytes32) + if allowlistedRequest != nil { + return allowlistedRequest, allowedRequestsStrs, nil + } + if attempt == r.retryCount { + return nil, allowedRequestsStrs, nil + } + + r.lggr.Debugw("AllowListBasedAuth request digest not yet allowlisted, retrying", + "method", req.Method, + "requestID", req.ID, + "digestHexStr", requestDigest, + "attempt", attempt+1, + "maxAttempts", r.retryCount+1, + "retryInterval", r.retryInterval) + if err := sleepWithContext(ctx, r.retryInterval); err != nil { + r.lggr.Debugw("AllowListBasedAuth retry canceled", "method", req.Method, "requestID", req.ID, "error", err) + return nil, nil, err + } + } + + return nil, nil, nil // unreachable: loop always returns +} + +func (r *allowListBasedAuth) fetchAllowlistedItem(allowListedRequests []workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest, digest [32]byte) *workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest { + for _, item := range allowListedRequests { + if item.RequestDigest == digest { + return &item + } + } + return nil +} + +// NewAllowListBasedAuth creates the allowlist-backed Vault auth mechanism. +func NewAllowListBasedAuth(lggr logger.Logger, workflowRegistrySyncer workflowsyncerv2.WorkflowRegistrySyncer) *allowListBasedAuth { + return &allowListBasedAuth{ + workflowRegistrySyncer: workflowRegistrySyncer, + lggr: logger.Named(lggr, "VaultAllowListBasedAuth"), + retryCount: allowListBasedAuthRetryCount, + retryInterval: allowListBasedAuthRetryInterval, + } +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/core/capabilities/vault/request_authorizer_test.go b/core/capabilities/vault/allow_list_based_auth_test.go similarity index 76% rename from core/capabilities/vault/request_authorizer_test.go rename to core/capabilities/vault/allow_list_based_auth_test.go index dd82717c9d3..89ba2494e0a 100644 --- a/core/capabilities/vault/request_authorizer_test.go +++ b/core/capabilities/vault/allow_list_based_auth_test.go @@ -19,7 +19,7 @@ import ( syncerv2mocks "github.com/smartcontractkit/chainlink/v2/core/services/workflows/syncer/v2/mocks" ) -func TestRequestAuthorizer_CreateSecrets(t *testing.T) { +func TestAllowListBasedAuth_CreateSecrets(t *testing.T) { params, err := json.Marshal(vaultcommon.CreateSecretsRequest{ EncryptedSecrets: []*vaultcommon.EncryptedSecret{ { @@ -59,7 +59,7 @@ func TestRequestAuthorizer_CreateSecrets(t *testing.T) { testAuthForRequests(t, allowListedReq, notAllowListedReq) } -func TestRequestAuthorizer_UpdateSecrets(t *testing.T) { +func TestAllowListBasedAuth_UpdateSecrets(t *testing.T) { params, err := json.Marshal(vaultcommon.UpdateSecretsRequest{ EncryptedSecrets: []*vaultcommon.EncryptedSecret{ { @@ -98,7 +98,7 @@ func TestRequestAuthorizer_UpdateSecrets(t *testing.T) { testAuthForRequests(t, allowListedReq, notAllowListedReq) } -func TestRequestAuthorizer_DeleteSecrets(t *testing.T) { +func TestAllowListBasedAuth_DeleteSecrets(t *testing.T) { params, err := json.Marshal(vaultcommon.DeleteSecretsRequest{ Ids: []*vaultcommon.SecretIdentifier{ { @@ -131,7 +131,7 @@ func TestRequestAuthorizer_DeleteSecrets(t *testing.T) { testAuthForRequests(t, allowListedReq, notAllowListedReq) } -func TestRequestAuthorizer_ListSecrets(t *testing.T) { +func TestAllowListBasedAuth_ListSecrets(t *testing.T) { params, err := json.Marshal(vaultcommon.ListSecretIdentifiersRequest{ Namespace: "b", }) @@ -159,14 +159,16 @@ func testAuthForRequests(t *testing.T, allowlistedRequest, notAllowlistedRequest owner := common.Address{1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} mockSyncer := syncerv2mocks.NewWorkflowRegistrySyncer(t) - auth := NewRequestAuthorizer(lggr, mockSyncer) + auth := NewAllowListBasedAuth(lggr, mockSyncer) + auth.retryCount = 0 + auth.retryInterval = time.Millisecond // Happy path digest, err := allowlistedRequest.Digest() require.NoError(t, err) digestBytes, err := hex.DecodeString(digest) require.NoError(t, err) - expiry := uint64(time.Now().UTC().Unix() + 100) //nolint:gosec // it is a safe conversion + expiry := time.Now().UTC().Unix() + 100 allowlisted := []workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{ { RequestDigest: [32]byte(digestBytes), @@ -175,15 +177,16 @@ func testAuthForRequests(t *testing.T, allowlistedRequest, notAllowlistedRequest }, } mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return(allowlisted) - isAuthorized, gotOwner, err := auth.AuthorizeRequest(t.Context(), allowlistedRequest) - require.True(t, isAuthorized, err) - require.Equal(t, owner.Hex(), gotOwner) + authResult, err := auth.AuthorizeRequest(t.Context(), allowlistedRequest) require.NoError(t, err) + require.Equal(t, owner.Hex(), authResult.AuthorizedOwner()) + require.Equal(t, expiry, authResult.GetExpiresAt()) + require.NotEmpty(t, authResult.GetDigest()) - // Already authorized - isAuthorized, _, err = auth.AuthorizeRequest(t.Context(), allowlistedRequest) - require.False(t, isAuthorized) - require.ErrorContains(t, err, "already authorized previously") + // Same request is still authorized here; replay protection lives in the generic Authorizer. + authResult, err = auth.AuthorizeRequest(t.Context(), allowlistedRequest) + require.NoError(t, err) + require.Equal(t, owner.Hex(), authResult.AuthorizedOwner()) // Expired request allowlistedReqCopy := allowlistedRequest @@ -195,16 +198,16 @@ func testAuthForRequests(t *testing.T, allowlistedRequest, notAllowlistedRequest allowlisted[0].RequestDigest = [32]byte(allowlistedReqCopyDigestBytes) allowlisted[0].ExpiryTimestamp = uint32(time.Now().UTC().Unix() - 1) //nolint:gosec // it is a safe conversion mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return(allowlisted) - isAuthorized, _, err = auth.AuthorizeRequest(t.Context(), allowlistedReqCopy) - require.False(t, isAuthorized) + authResult, err = auth.AuthorizeRequest(t.Context(), allowlistedReqCopy) + require.Nil(t, authResult) require.ErrorContains(t, err, "authorization expired") - isAuthorized, _, err = auth.AuthorizeRequest(t.Context(), notAllowlistedRequest) - require.False(t, isAuthorized) + authResult, err = auth.AuthorizeRequest(t.Context(), notAllowlistedRequest) + require.Nil(t, authResult) require.ErrorContains(t, err, "not allowlisted") } -func TestRequestAuthorizer_RetriesAllowlistReadsUntilDigestAppears(t *testing.T) { +func TestAllowListBasedAuth_RetriesUntilRequestIsAllowlisted(t *testing.T) { lggr := logger.TestLogger(t) owner := common.Address{1, 2, 3} req := makeListSecretsRequest(t, "123", "b") @@ -213,55 +216,47 @@ func TestRequestAuthorizer_RetriesAllowlistReadsUntilDigestAppears(t *testing.T) require.NoError(t, err) digestBytes, err := hex.DecodeString(digest) require.NoError(t, err) - + expiry := time.Now().UTC().Unix() + 100 allowlisted := []workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{ { RequestDigest: [32]byte(digestBytes), Owner: owner, - ExpiryTimestamp: uint32(time.Now().UTC().Unix() + 100), //nolint:gosec // test fixture expiry is bounded and safe here + ExpiryTimestamp: uint32(expiry), //nolint:gosec // it is a safe conversion }, } mockSyncer := syncerv2mocks.NewWorkflowRegistrySyncer(t) + auth := NewAllowListBasedAuth(lggr, mockSyncer) + auth.retryCount = 2 + auth.retryInterval = time.Millisecond + mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return([]workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{}).Once() mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return([]workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{}).Once() mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return(allowlisted).Once() - auth := NewRequestAuthorizer(lggr, mockSyncer) - sleepCalls := 0 - auth.sleep = func(d time.Duration) { - require.Equal(t, allowlistReadRetryInterval, d) - sleepCalls++ - } - - isAuthorized, gotOwner, err := auth.AuthorizeRequest(t.Context(), req) - require.True(t, isAuthorized, err) + authResult, err := auth.AuthorizeRequest(t.Context(), req) require.NoError(t, err) - require.Equal(t, owner.Hex(), gotOwner) - require.Equal(t, 2, sleepCalls) + require.Equal(t, owner.Hex(), authResult.AuthorizedOwner()) + require.Equal(t, expiry, authResult.GetExpiresAt()) } -func TestRequestAuthorizer_FailsAfterAllowlistReadRetries(t *testing.T) { +func TestAllowListBasedAuth_FailsAfterAllowlistReadRetries(t *testing.T) { lggr := logger.TestLogger(t) req := makeListSecretsRequest(t, "123", "b") mockSyncer := syncerv2mocks.NewWorkflowRegistrySyncer(t) - mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return([]workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{}).Times(allowlistReadRetryCount + 1) + mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return([]workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{}).Times(3) - auth := NewRequestAuthorizer(lggr, mockSyncer) - sleepCalls := 0 - auth.sleep = func(d time.Duration) { - require.Equal(t, allowlistReadRetryInterval, d) - sleepCalls++ - } + auth := NewAllowListBasedAuth(lggr, mockSyncer) + auth.retryCount = 2 + auth.retryInterval = time.Millisecond - isAuthorized, _, err := auth.AuthorizeRequest(t.Context(), req) - require.False(t, isAuthorized) + authResult, err := auth.AuthorizeRequest(t.Context(), req) + require.Nil(t, authResult) require.ErrorContains(t, err, "not allowlisted") - require.Equal(t, allowlistReadRetryCount, sleepCalls) } -func TestRequestAuthorizer_StopsRetriesWhenContextCanceled(t *testing.T) { +func TestAllowListBasedAuth_StopsRetriesWhenContextCanceled(t *testing.T) { lggr := logger.TestLogger(t) req := makeListSecretsRequest(t, "123", "b") @@ -271,16 +266,13 @@ func TestRequestAuthorizer_StopsRetriesWhenContextCanceled(t *testing.T) { mockSyncer := syncerv2mocks.NewWorkflowRegistrySyncer(t) mockSyncer.On("GetAllowlistedRequests", mock.Anything).Return([]workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest{}).Once() - auth := NewRequestAuthorizer(lggr, mockSyncer) - sleepCalls := 0 - auth.sleep = func(time.Duration) { - sleepCalls++ - } + auth := NewAllowListBasedAuth(lggr, mockSyncer) + auth.retryCount = 2 + auth.retryInterval = time.Second - isAuthorized, _, err := auth.AuthorizeRequest(ctx, req) - require.False(t, isAuthorized) - require.ErrorContains(t, err, "not allowlisted") - require.Zero(t, sleepCalls) + authResult, err := auth.AuthorizeRequest(ctx, req) + require.Nil(t, authResult) + require.ErrorIs(t, err, context.Canceled) } func makeListSecretsRequest(t *testing.T, id, namespace string) jsonrpc.Request[json.RawMessage] { diff --git a/core/capabilities/vault/authorizer.go b/core/capabilities/vault/authorizer.go new file mode 100644 index 00000000000..718747f6cad --- /dev/null +++ b/core/capabilities/vault/authorizer.go @@ -0,0 +1,135 @@ +package vault + +import ( + "context" + "encoding/json" + "errors" + + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// AuthResult is the normalized authorization output shared by +// AllowListBasedAuth and JWTBasedAuth. +type AuthResult struct { + orgID string + workflowOwner string + digest string + expiresAt int64 +} + +// NewAuthResult remains exported for cross-package tests that cannot construct +// AuthResult directly because its fields are intentionally private. +func NewAuthResult(orgID, workflowOwner, digest string, expiresAt int64) *AuthResult { + return &AuthResult{ + orgID: orgID, + workflowOwner: workflowOwner, + digest: digest, + expiresAt: expiresAt, + } +} + +// AuthorizedOwner returns the canonical owner to use for request scoping. +func (a *AuthResult) AuthorizedOwner() string { + if a == nil { + return "" + } + if a.orgID != "" { + return a.orgID + } + return a.workflowOwner +} + +// GetDigest returns the request digest used for replay protection. +func (a *AuthResult) GetDigest() string { + if a == nil { + return "" + } + return a.digest +} + +// GetExpiresAt returns the unix timestamp (UTC) after which this +// authorization is no longer valid. +func (a *AuthResult) GetExpiresAt() int64 { + if a == nil { + return 0 + } + return a.expiresAt +} + +// GetUntrustedWorkflowOwner returns the workflow owner only for JWTBasedAuth results. +func (a *AuthResult) GetUntrustedWorkflowOwner() (string, error) { + if a == nil { + return "", errors.New("auth result is nil") + } + if a.orgID == "" { + return "", errors.New("untrusted workflow owner only applies to JWTBasedAuth results") + } + return a.workflowOwner, nil +} + +// Authorizer selects the applicable auth mechanism for a Vault request. +type Authorizer interface { + AuthorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (*AuthResult, error) +} + +type authorizer struct { + allowListBasedAuth Authorizer + jwtBasedAuth Authorizer + replayGuard *RequestReplayGuard + lggr logger.Logger +} + +func NewAuthorizer(allowListBasedAuth Authorizer, jwtBasedAuth Authorizer, lggr logger.Logger) Authorizer { + return &authorizer{ + allowListBasedAuth: allowListBasedAuth, + jwtBasedAuth: jwtBasedAuth, + replayGuard: NewRequestReplayGuard(), + lggr: logger.Named(lggr, "VaultAuthorizer"), + } +} + +func (a *authorizer) AuthorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (*AuthResult, error) { + authResult, err := a.authorizeRequest(ctx, req) + if err != nil { + return nil, err + } + if authResult == nil { + err = errors.New("auth mechanism returned nil auth result") + a.lggr.Errorw("auth mechanism returned nil auth result", "method", req.Method, "requestID", req.ID, "hasAuth", req.Auth != "") + return nil, err + } + if err := a.replayGuard.CheckAndRecord(authResult.GetDigest(), authResult.GetExpiresAt()); err != nil { + a.lggr.Debugw("replay guard rejected request", "method", req.Method, "requestID", req.ID, "owner", authResult.AuthorizedOwner(), "digest", authResult.GetDigest(), "expiresAt", authResult.GetExpiresAt(), "hasAuth", req.Auth != "", "error", err) + return nil, err + } + a.lggr.Debugw("request authorized", "method", req.Method, "requestID", req.ID, "owner", authResult.AuthorizedOwner(), "digest", authResult.GetDigest(), "expiresAt", authResult.GetExpiresAt(), "hasAuth", req.Auth != "") + return authResult, nil +} + +func (a *authorizer) authorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (*AuthResult, error) { + // Requests without req.Auth continue using the allowlist-based path for backwards compatibility. + // Existing clients do not populate the auth field yet, so treating an empty value as JWT would break them. + if req.Auth == "" { + return a.authorizeAllowListBasedAuth(ctx, req) + } + return a.authorizeJWTBasedAuth(ctx, req) +} + +func (a *authorizer) authorizeAllowListBasedAuth(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (*AuthResult, error) { + if a.allowListBasedAuth == nil { + err := errors.New("AllowListBasedAuth authorizer is nil") + a.lggr.Errorw("AllowListBasedAuth unavailable", "method", req.Method, "requestID", req.ID, "error", err) + return nil, err + } + return a.allowListBasedAuth.AuthorizeRequest(ctx, req) +} + +func (a *authorizer) authorizeJWTBasedAuth(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (*AuthResult, error) { + if a.jwtBasedAuth == nil { + err := errors.New("JWTBasedAuth is nil") + a.lggr.Errorw("JWTBasedAuth unavailable", "method", req.Method, "requestID", req.ID, "error", err) + return nil, err + } + return a.jwtBasedAuth.AuthorizeRequest(ctx, req) +} diff --git a/core/capabilities/vault/authorizer_test.go b/core/capabilities/vault/authorizer_test.go new file mode 100644 index 00000000000..3e02e75708d --- /dev/null +++ b/core/capabilities/vault/authorizer_test.go @@ -0,0 +1,162 @@ +package vault_test + +import ( + "encoding/json" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + vaultcommon "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/vault" + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" + vault "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault" + vaultmocks "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/mocks" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/vaulttypes" + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +func testLimitsFactory() limits.Factory { + return limits.Factory{Settings: cresettings.DefaultGetter} +} + +func TestAuthorizer_RejectsJWTBasedAuthWhenDisabled(t *testing.T) { + params, err := json.Marshal(vaultcommon.CreateSecretsRequest{}) + require.NoError(t, err) + + allowListBasedAuth := vaultmocks.NewAuthorizer(t) + allowListBasedAuth.EXPECT().AuthorizeRequest(mock.Anything, mock.Anything).Maybe() + + jwtBasedAuth, err := vault.NewJWTBasedAuth(vault.JWTBasedAuthConfig{}, testLimitsFactory(), logger.TestLogger(t), vault.WithDisabledJWTBasedAuth()) + require.NoError(t, err) + + a := vault.NewAuthorizer(allowListBasedAuth, jwtBasedAuth, logger.TestLogger(t)) + + authResult, err := a.AuthorizeRequest(t.Context(), jsonrpc.Request[json.RawMessage]{ + ID: "1", + Method: vaulttypes.MethodSecretsCreate, + Params: (*json.RawMessage)(¶ms), + Auth: "jwt-token", + }) + require.Nil(t, authResult) + require.ErrorContains(t, err, "JWTBasedAuth is disabled") + allowListBasedAuth.AssertNotCalled(t, "AuthorizeRequest", mock.Anything, mock.Anything) +} + +func TestAuthorizer_UsesJWTWhenGateEnabled(t *testing.T) { + params, err := json.Marshal(vaultcommon.CreateSecretsRequest{}) + require.NoError(t, err) + + req := jsonrpc.Request[json.RawMessage]{ + ID: "1", + Method: vaulttypes.MethodSecretsCreate, + Params: (*json.RawMessage)(¶ms), + Auth: "jwt-token", + } + digest, err := req.Digest() + require.NoError(t, err) + + jwtBasedAuth := vaultmocks.NewAuthorizer(t) + jwtBasedAuth.EXPECT().AuthorizeRequest(mock.Anything, req).Return(vault.NewAuthResult("org-1", "0xworkflow", digest, time.Now().Add(time.Minute).Unix()), nil).Once() + + a := vault.NewAuthorizer(nil, jwtBasedAuth, logger.TestLogger(t)) + + authResult, err := a.AuthorizeRequest(t.Context(), req) + require.NoError(t, err) + require.Equal(t, "org-1", authResult.AuthorizedOwner()) + + untrustedWorkflowOwner, err := authResult.GetUntrustedWorkflowOwner() + require.NoError(t, err) + require.Equal(t, "0xworkflow", untrustedWorkflowOwner) +} + +func TestAuthorizer_DelegatesDigestVerificationToJWTAuth(t *testing.T) { + params, err := json.Marshal(vaultcommon.CreateSecretsRequest{}) + require.NoError(t, err) + + req := jsonrpc.Request[json.RawMessage]{ + ID: "1", + Method: vaulttypes.MethodSecretsCreate, + Params: (*json.RawMessage)(¶ms), + Auth: "jwt-token", + } + + jwtBasedAuth := vaultmocks.NewAuthorizer(t) + jwtBasedAuth.EXPECT().AuthorizeRequest(mock.Anything, req).Return(vault.NewAuthResult("org-1", "", "wrong-digest", time.Now().Add(time.Minute).Unix()), nil).Once() + + a := vault.NewAuthorizer(nil, jwtBasedAuth, logger.TestLogger(t)) + + authResult, err := a.AuthorizeRequest(t.Context(), req) + require.NoError(t, err) + require.Equal(t, "org-1", authResult.AuthorizedOwner()) +} + +func TestAuthorizer_RejectsJWTReplay(t *testing.T) { + params, err := json.Marshal(vaultcommon.CreateSecretsRequest{}) + require.NoError(t, err) + + req := jsonrpc.Request[json.RawMessage]{ + ID: "1", + Method: vaulttypes.MethodSecretsCreate, + Params: (*json.RawMessage)(¶ms), + Auth: "jwt-token", + } + digest, err := req.Digest() + require.NoError(t, err) + + jwtBasedAuth := vaultmocks.NewAuthorizer(t) + jwtBasedAuth.EXPECT().AuthorizeRequest(mock.Anything, req).Return(vault.NewAuthResult("org-1", "", digest, time.Now().Add(time.Minute).Unix()), nil).Twice() + + a := vault.NewAuthorizer(nil, jwtBasedAuth, logger.TestLogger(t)) + + authResult, err := a.AuthorizeRequest(t.Context(), req) + require.NoError(t, err) + require.Equal(t, "org-1", authResult.AuthorizedOwner()) + + authResult, err = a.AuthorizeRequest(t.Context(), req) + require.Nil(t, authResult) + require.ErrorIs(t, err, vault.ErrRequestAlreadySeen) +} + +func TestAuthorizer_RejectsAllowListBasedAuthReplay(t *testing.T) { + allowListBasedAuth := vaultmocks.NewAuthorizer(t) + req := jsonrpc.Request[json.RawMessage]{ID: "1", Method: vaulttypes.MethodSecretsCreate} + allowListBasedAuth.EXPECT().AuthorizeRequest(mock.Anything, req).Return(vault.NewAuthResult("", "0xabc", "digest-1", time.Now().Add(time.Minute).Unix()), nil).Twice() + + jwtBasedAuth, err := vault.NewJWTBasedAuth(vault.JWTBasedAuthConfig{}, testLimitsFactory(), logger.TestLogger(t), vault.WithDisabledJWTBasedAuth()) + require.NoError(t, err) + + a := vault.NewAuthorizer(allowListBasedAuth, jwtBasedAuth, logger.TestLogger(t)) + + authResult, err := a.AuthorizeRequest(t.Context(), req) + require.NoError(t, err) + require.Equal(t, "0xabc", authResult.AuthorizedOwner()) + + authResult, err = a.AuthorizeRequest(t.Context(), req) + require.Nil(t, authResult) + require.ErrorIs(t, err, vault.ErrRequestAlreadySeen) +} + +func TestAuthorizer_PropagatesJWTValidationErrors(t *testing.T) { + params, err := json.Marshal(vaultcommon.CreateSecretsRequest{}) + require.NoError(t, err) + + req := jsonrpc.Request[json.RawMessage]{ + ID: "1", + Method: vaulttypes.MethodSecretsCreate, + Params: (*json.RawMessage)(¶ms), + Auth: "jwt-token", + } + + jwtBasedAuth := vaultmocks.NewAuthorizer(t) + jwtBasedAuth.EXPECT().AuthorizeRequest(mock.Anything, req).Return(nil, errors.New("bad token")).Once() + + a := vault.NewAuthorizer(nil, jwtBasedAuth, logger.TestLogger(t)) + + authResult, err := a.AuthorizeRequest(t.Context(), req) + require.Nil(t, authResult) + require.ErrorContains(t, err, "bad token") +} diff --git a/core/capabilities/vault/gw_handler.go b/core/capabilities/vault/gw_handler.go index 94b1563142e..ec2b2a093d0 100644 --- a/core/capabilities/vault/gw_handler.go +++ b/core/capabilities/vault/gw_handler.go @@ -14,11 +14,13 @@ import ( vaultcommon "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/vault" jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" "github.com/smartcontractkit/chainlink-common/pkg/types/core" "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/vaulttypes" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/connector" + workflowsyncerv2 "github.com/smartcontractkit/chainlink/v2/core/services/workflows/syncer/v2" ) var ( @@ -56,29 +58,64 @@ type gatewayConnector interface { RemoveHandler(ctx context.Context, methods []string) error } +type gatewayHandlerConfig struct { + authorizer Authorizer +} + +// GatewayHandlerOption customizes GatewayHandler construction for tests and future auth extensions. +type GatewayHandlerOption func(*gatewayHandlerConfig) + +// WithAuthorizer overrides the default Vault request authorizer. +func WithAuthorizer(authorizer Authorizer) GatewayHandlerOption { + return func(cfg *gatewayHandlerConfig) { + cfg.authorizer = authorizer + } +} + +// GatewayHandler serves Vault requests received from the gateway on the node side. type GatewayHandler struct { services.Service eng *services.Engine - secretsService vaulttypes.SecretsService - gatewayConnector gatewayConnector - requestAuthorizer RequestAuthorizer - lggr logger.Logger - metrics *metrics + secretsService vaulttypes.SecretsService + gatewayConnector gatewayConnector + authorizer Authorizer + jwtAuthService services.Service + lggr logger.Logger + metrics *metrics } -func NewGatewayHandler(secretsService vaulttypes.SecretsService, connector gatewayConnector, requestAuthorizer RequestAuthorizer, lggr logger.Logger) (*GatewayHandler, error) { +// NewGatewayHandler creates a Vault gateway connector handler with internal auth wiring. +func NewGatewayHandler(secretsService vaulttypes.SecretsService, connector gatewayConnector, workflowRegistrySyncer workflowsyncerv2.WorkflowRegistrySyncer, lggr logger.Logger, limitsFactory limits.Factory, opts ...GatewayHandlerOption) (*GatewayHandler, error) { + cfg := gatewayHandlerConfig{} + for _, opt := range opts { + opt(&cfg) + } + if cfg.authorizer == nil { + allowListBasedAuth := NewAllowListBasedAuth(lggr, workflowRegistrySyncer) + jwtBasedAuth, err := NewJWTBasedAuth(JWTBasedAuthConfig{}, limitsFactory, lggr, WithDisabledJWTBasedAuth()) + if err != nil { + return nil, fmt.Errorf("failed to create JWTBasedAuth: %w", err) + } + cfg.authorizer = NewAuthorizer(allowListBasedAuth, jwtBasedAuth, lggr) + return newGatewayHandlerWithAuthorizer(secretsService, connector, cfg.authorizer, jwtBasedAuth, lggr) + } + return newGatewayHandlerWithAuthorizer(secretsService, connector, cfg.authorizer, nil, lggr) +} + +func newGatewayHandlerWithAuthorizer(secretsService vaulttypes.SecretsService, connector gatewayConnector, authorizer Authorizer, jwtAuthService services.Service, lggr logger.Logger) (*GatewayHandler, error) { metrics, err := newMetrics() if err != nil { return nil, fmt.Errorf("failed to create metrics: %w", err) } gh := &GatewayHandler{ - secretsService: secretsService, - gatewayConnector: connector, - requestAuthorizer: requestAuthorizer, - lggr: lggr.Named(HandlerName), - metrics: metrics, + secretsService: secretsService, + gatewayConnector: connector, + authorizer: authorizer, + jwtAuthService: jwtAuthService, + lggr: lggr.Named(HandlerName), + metrics: metrics, } gh.Service, gh.eng = services.Config{ Name: "GatewayHandler", @@ -89,6 +126,11 @@ func NewGatewayHandler(secretsService vaulttypes.SecretsService, connector gatew } func (h *GatewayHandler) start(ctx context.Context) error { + if h.jwtAuthService != nil { + if err := h.jwtAuthService.Start(ctx); err != nil { + return fmt.Errorf("failed to start JWTBasedAuth: %w", err) + } + } if gwerr := h.gatewayConnector.AddHandler(ctx, h.Methods(), h); gwerr != nil { return fmt.Errorf("failed to add vault handler to connector: %w", gwerr) } @@ -96,10 +138,14 @@ func (h *GatewayHandler) start(ctx context.Context) error { } func (h *GatewayHandler) close() error { + var jwtAuthErr error + if h.jwtAuthService != nil { + jwtAuthErr = h.jwtAuthService.Close() + } if gwerr := h.gatewayConnector.RemoveHandler(context.Background(), h.Methods()); gwerr != nil { - return fmt.Errorf("failed to remove vault handler from connector: %w", gwerr) + return errors.Join(fmt.Errorf("failed to remove vault handler from connector: %w", gwerr), jwtAuthErr) } - return nil + return jwtAuthErr } func (h *GatewayHandler) ID(ctx context.Context) (string, error) { @@ -118,28 +164,28 @@ func (h *GatewayHandler) HandleGatewayMessage(ctx context.Context, gatewayID str case vaulttypes.MethodSecretsCreate: owner, authErr := h.authorizeAndPrefixRequest(ctx, req) if authErr != nil { - response = h.errorResponse(ctx, gatewayID, req, api.FatalError, authErr) + response = h.errorResponse(ctx, gatewayID, req, api.HandlerError, authErr) break } response = h.handleSecretsCreate(ctx, gatewayID, req, owner) case vaulttypes.MethodSecretsUpdate: owner, authErr := h.authorizeAndPrefixRequest(ctx, req) if authErr != nil { - response = h.errorResponse(ctx, gatewayID, req, api.FatalError, authErr) + response = h.errorResponse(ctx, gatewayID, req, api.HandlerError, authErr) break } response = h.handleSecretsUpdate(ctx, gatewayID, req, owner) case vaulttypes.MethodSecretsDelete: owner, authErr := h.authorizeAndPrefixRequest(ctx, req) if authErr != nil { - response = h.errorResponse(ctx, gatewayID, req, api.FatalError, authErr) + response = h.errorResponse(ctx, gatewayID, req, api.HandlerError, authErr) break } response = h.handleSecretsDelete(ctx, gatewayID, req, owner) case vaulttypes.MethodSecretsList: owner, authErr := h.authorizeAndPrefixRequest(ctx, req) if authErr != nil { - response = h.errorResponse(ctx, gatewayID, req, api.FatalError, authErr) + response = h.errorResponse(ctx, gatewayID, req, api.HandlerError, authErr) break } response = h.handleSecretsList(ctx, gatewayID, req, owner) @@ -162,8 +208,8 @@ func (h *GatewayHandler) HandleGatewayMessage(ctx context.Context, gatewayID str } func (h *GatewayHandler) authorizeAndPrefixRequest(ctx context.Context, req *jsonrpc.Request[json.RawMessage]) (string, error) { - if h.requestAuthorizer == nil { - err := errors.New("request authorizer is nil") + if h.authorizer == nil { + err := errors.New("authorizer is nil") h.lggr.Errorw("failed to authorize gateway request", "method", req.Method, "requestID", req.ID, "error", err) return "", err } @@ -183,21 +229,22 @@ func (h *GatewayHandler) authorizeAndPrefixRequest(ctx context.Context, req *jso } h.lggr.Debugw("authorizing gateway request", "method", req.Method, "requestID", originalRequestID) - isAuthorized, owner, err := h.requestAuthorizer.AuthorizeRequest(ctx, authReq) - if !isAuthorized { + authResult, err := h.authorizer.AuthorizeRequest(ctx, authReq) + if err != nil { authErr := fmt.Errorf("request not authorized: %w", err) - h.lggr.Errorw("gateway request authorization failed", "method", req.Method, "requestID", originalRequestID, "owner", owner, "error", authErr) + h.lggr.Errorw("gateway request authorization failed", "method", req.Method, "requestID", originalRequestID, "hasAuth", req.Auth != "", "incomingOwner", incomingOwner, "error", authErr) return "", authErr } - if incomingOwner != "" && normalizeOwner(incomingOwner) != normalizeOwner(owner) { - prefixErr := fmt.Errorf("request owner prefix %q does not match authorized owner %q", incomingOwner, owner) - h.lggr.Errorw("gateway request owner prefix mismatch", "method", req.Method, "requestID", originalRequestID, "incomingOwner", incomingOwner, "authorizedOwner", owner, "error", prefixErr) + authorizedOwner := authResult.AuthorizedOwner() + if incomingOwner != "" && normalizeOwner(incomingOwner) != normalizeOwner(authorizedOwner) { + prefixErr := fmt.Errorf("request owner prefix %q does not match authorized owner %q", incomingOwner, authorizedOwner) + h.lggr.Errorw("gateway request owner prefix mismatch", "method", req.Method, "requestID", originalRequestID, "incomingOwner", incomingOwner, "authorizedOwner", authorizedOwner, "error", prefixErr) return "", prefixErr } - req.ID = owner + vaulttypes.RequestIDSeparator + originalRequestID - h.lggr.Debugw("authorized gateway request", "method", req.Method, "requestID", req.ID, "owner", owner) - return owner, nil + req.ID = authorizedOwner + vaulttypes.RequestIDSeparator + originalRequestID + h.lggr.Debugw("authorized gateway request", "method", req.Method, "requestID", req.ID, "owner", authorizedOwner) + return authorizedOwner, nil } func stripPrefixedRequestIDFromParams(req *jsonrpc.Request[json.RawMessage], originalRequestID string) error { diff --git a/core/capabilities/vault/gw_handler_test.go b/core/capabilities/vault/gw_handler_test.go index 293544e9d0a..43eb9763bb1 100644 --- a/core/capabilities/vault/gw_handler_test.go +++ b/core/capabilities/vault/gw_handler_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -11,6 +12,8 @@ import ( vaultcommon "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/vault" jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" vaultcap "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault" vaultcapmocks "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/mocks" "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/vaulttypes" @@ -23,19 +26,22 @@ import ( func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { lggr := logger.TestLogger(t) ctx := t.Context() + authResult := func(owner string) *vaultcap.AuthResult { + return vaultcap.NewAuthResult("", owner, "digest-"+owner, time.Now().Add(time.Minute).Unix()) + } tests := []struct { name string - setupMocks func(*vaulttypesmocks.SecretsService, *connector_mocks.GatewayConnector, *vaultcapmocks.RequestAuthorizer) + setupMocks func(*vaulttypesmocks.SecretsService, *connector_mocks.GatewayConnector, *vaultcapmocks.Authorizer) request *jsonrpc.Request[json.RawMessage] expectedError bool }{ { name: "success - create secrets", - setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.RequestAuthorizer) { + setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.Authorizer) { ra.EXPECT().AuthorizeRequest(mock.Anything, mock.MatchedBy(func(req jsonrpc.Request[json.RawMessage]) bool { return req.Method == vaulttypes.MethodSecretsCreate && req.ID == "1" - })).Return(true, "0xabc", nil) + })).Return(authResult("0xabc"), nil) ss.EXPECT().CreateSecrets(mock.Anything, mock.MatchedBy(func(req *vaultcommon.CreateSecretsRequest) bool { return len(req.EncryptedSecrets) == 1 && req.EncryptedSecrets[0].Id.Key == "test-secret" && @@ -71,8 +77,8 @@ func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { }, { name: "failure - service error", - setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.RequestAuthorizer) { - ra.EXPECT().AuthorizeRequest(mock.Anything, mock.Anything).Return(true, "0xabc", nil) + setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.Authorizer) { + ra.EXPECT().AuthorizeRequest(mock.Anything, mock.Anything).Return(authResult("0xabc"), nil) ss.EXPECT().CreateSecrets(mock.Anything, mock.Anything).Return(nil, errors.New("service error")) gc.On("SendToGateway", mock.Anything, "gateway-1", mock.MatchedBy(func(resp *jsonrpc.Response[json.RawMessage]) bool { @@ -104,7 +110,7 @@ func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { }, { name: "failure - invalid method", - setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.RequestAuthorizer) { + setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.Authorizer) { gc.On("SendToGateway", mock.Anything, "gateway-1", mock.MatchedBy(func(resp *jsonrpc.Response[json.RawMessage]) bool { return resp.Error != nil && resp.Error.Code == api.ToJSONRPCErrorCode(api.UnsupportedMethodError) @@ -118,10 +124,10 @@ func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { }, { name: "failure - invalid request params", - setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.RequestAuthorizer) { + setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.Authorizer) { gc.On("SendToGateway", mock.Anything, "gateway-1", mock.MatchedBy(func(resp *jsonrpc.Response[json.RawMessage]) bool { return resp.Error != nil && - resp.Error.Code == api.ToJSONRPCErrorCode(api.FatalError) + resp.Error.Code == api.ToJSONRPCErrorCode(api.HandlerError) })).Return(nil) }, request: &jsonrpc.Request[json.RawMessage]{ @@ -136,10 +142,10 @@ func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { }, { name: "success - delete secrets", - setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.RequestAuthorizer) { + setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.Authorizer) { ra.EXPECT().AuthorizeRequest(mock.Anything, mock.MatchedBy(func(req jsonrpc.Request[json.RawMessage]) bool { return req.Method == vaulttypes.MethodSecretsDelete && req.ID == "1" - })).Return(true, "0xabc", nil) + })).Return(authResult("0xabc"), nil) ss.EXPECT().DeleteSecrets(mock.Anything, mock.MatchedBy(func(req *vaultcommon.DeleteSecretsRequest) bool { return len(req.Ids) == 1 && req.Ids[0].Key == "Foo" && @@ -174,11 +180,11 @@ func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { }, { name: "failure - unauthorized request", - setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.RequestAuthorizer) { - ra.EXPECT().AuthorizeRequest(mock.Anything, mock.Anything).Return(false, "", errors.New("not allowlisted")) + setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.Authorizer) { + ra.EXPECT().AuthorizeRequest(mock.Anything, mock.Anything).Return(nil, errors.New("not allowlisted")) gc.On("SendToGateway", mock.Anything, "gateway-1", mock.MatchedBy(func(resp *jsonrpc.Response[json.RawMessage]) bool { return resp.Error != nil && - resp.Error.Code == api.ToJSONRPCErrorCode(api.FatalError) && + resp.Error.Code == api.ToJSONRPCErrorCode(api.HandlerError) && resp.Error.Message == "request not authorized: not allowlisted" })).Return(nil) }, @@ -205,7 +211,7 @@ func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { }, { name: "success - strips owner prefix from forwarded request before authorization", - setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.RequestAuthorizer) { + setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.Authorizer) { ra.EXPECT().AuthorizeRequest(mock.Anything, mock.MatchedBy(func(req jsonrpc.Request[json.RawMessage]) bool { if req.Method != vaulttypes.MethodSecretsCreate || req.ID != "1" || req.Params == nil { return false @@ -220,7 +226,7 @@ func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { len(parsed.EncryptedSecrets) == 1 && parsed.EncryptedSecrets[0].Id != nil && parsed.EncryptedSecrets[0].Id.Owner == "0xAbC" - })).Return(true, "0xabc", nil) + })).Return(authResult("0xabc"), nil) ss.EXPECT().CreateSecrets(mock.Anything, mock.MatchedBy(func(req *vaultcommon.CreateSecretsRequest) bool { return req.RequestId == "0xabc"+vaulttypes.RequestIDSeparator+"1" })).Return(&vaulttypes.Response{ID: "test-secret"}, nil) @@ -253,8 +259,8 @@ func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { }, { name: "failure - owner mismatch against authorized owner", - setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.RequestAuthorizer) { - ra.EXPECT().AuthorizeRequest(mock.Anything, mock.Anything).Return(true, "0xdef", nil) + setupMocks: func(ss *vaulttypesmocks.SecretsService, gc *connector_mocks.GatewayConnector, ra *vaultcapmocks.Authorizer) { + ra.EXPECT().AuthorizeRequest(mock.Anything, mock.Anything).Return(authResult("0xdef"), nil) gc.On("SendToGateway", mock.Anything, "gateway-1", mock.MatchedBy(func(resp *jsonrpc.Response[json.RawMessage]) bool { return resp.Error != nil && resp.Error.Code == api.ToJSONRPCErrorCode(api.FatalError) && @@ -288,11 +294,21 @@ func TestGatewayHandler_HandleGatewayMessage(t *testing.T) { t.Run(tt.name, func(t *testing.T) { secretsService := vaulttypesmocks.NewSecretsService(t) gwConnector := connector_mocks.NewGatewayConnector(t) - requestAuthorizer := vaultcapmocks.NewRequestAuthorizer(t) + allowListBasedAuth := vaultcapmocks.NewAuthorizer(t) + limitsFactory := limits.Factory{Settings: cresettings.DefaultGetter} + jwtBasedAuth, err := vaultcap.NewJWTBasedAuth(vaultcap.JWTBasedAuthConfig{}, limitsFactory, lggr, vaultcap.WithDisabledJWTBasedAuth()) + require.NoError(t, err) - tt.setupMocks(secretsService, gwConnector, requestAuthorizer) + tt.setupMocks(secretsService, gwConnector, allowListBasedAuth) - handler, err := vaultcap.NewGatewayHandler(secretsService, gwConnector, requestAuthorizer, lggr) + handler, err := vaultcap.NewGatewayHandler( + secretsService, + gwConnector, + nil, + lggr, + limitsFactory, + vaultcap.WithAuthorizer(vaultcap.NewAuthorizer(allowListBasedAuth, jwtBasedAuth, lggr)), + ) require.NoError(t, err) err = handler.HandleGatewayMessage(ctx, "gateway-1", tt.request) @@ -312,9 +328,19 @@ func TestGatewayHandler_Lifecycle(t *testing.T) { secretsService := vaulttypesmocks.NewSecretsService(t) gwConnector := connector_mocks.NewGatewayConnector(t) - requestAuthorizer := vaultcapmocks.NewRequestAuthorizer(t) + allowListBasedAuth := vaultcapmocks.NewAuthorizer(t) + limitsFactory := limits.Factory{Settings: cresettings.DefaultGetter} + jwtBasedAuth, err := vaultcap.NewJWTBasedAuth(vaultcap.JWTBasedAuthConfig{}, limitsFactory, lggr, vaultcap.WithDisabledJWTBasedAuth()) + require.NoError(t, err) - handler, err := vaultcap.NewGatewayHandler(secretsService, gwConnector, requestAuthorizer, lggr) + handler, err := vaultcap.NewGatewayHandler( + secretsService, + gwConnector, + nil, + lggr, + limitsFactory, + vaultcap.WithAuthorizer(vaultcap.NewAuthorizer(allowListBasedAuth, jwtBasedAuth, lggr)), + ) require.NoError(t, err) t.Run("start", func(t *testing.T) { diff --git a/core/capabilities/vault/jwt_based_auth.go b/core/capabilities/vault/jwt_based_auth.go new file mode 100644 index 00000000000..69edcc8e4cc --- /dev/null +++ b/core/capabilities/vault/jwt_based_auth.go @@ -0,0 +1,458 @@ +package vault + +import ( + "context" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" + + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" +) + +var ( + ErrMissingToken = errors.New("missing JWT token") + ErrInvalidToken = errors.New("invalid JWT token") + ErrMissingOrgID = errors.New("missing org_id claim") + ErrMissingRequestDigest = errors.New("missing request_digest in authorization_details") + ErrJWKSFetchFailed = errors.New("failed to fetch JWKS") + ErrJWKSKeyNotFound = errors.New("signing key not found in JWKS") +) + +const ( + defaultJWKSRefreshInterval = 15 * time.Minute + defaultHTTPTimeout = 5 * time.Second +) + +// JWTBasedAuthConfig holds the configuration for JWTBasedAuth validation. +type JWTBasedAuthConfig struct { + IssuerURL string + Audience string + JWKSRefreshInterval time.Duration // minimum interval between JWKS fetches; 0 uses default (30s) + HTTPClient *http.Client // nil uses a default client with 5s timeout +} + +// JWTClaims contains the validated claims extracted from an Auth0 JWT +// relevant to Vault request authorization. +type JWTClaims struct { + OrgID string + WorkflowOwner string // from authorization_details; may be empty for new JWT-only clients + RequestDigest string // from authorization_details + ExpiresAt time.Time +} + +type jsonWebKey struct { + Kid string `json:"kid"` + Alg string `json:"alg"` + Kty string `json:"kty"` + Use string `json:"use"` + N string `json:"n"` + E string `json:"e"` +} + +type jsonWebKeySet struct { + Keys []jsonWebKey `json:"keys"` +} + +// JWTBasedAuth verifies Auth0-issued RS256 JWTs using the provider's +// public JWKS endpoint and extracts Vault-specific claims (org_id, +// workflow_owner, request_digest). It is safe for concurrent use. +// +// JWKS keys are fetched lazily on the first token validation and refreshed +// on key-ID misses, rate-limited to at most once per JWKSRefreshInterval. +// +// Reference: cre-platform-graphql/internal/auth/jwt_auth0.go +type jwtBasedAuth struct { + services.Service + eng *services.Engine + + issuerURL string + audience string + jwksURL string + refreshInterval time.Duration + authEnabledGate limits.GateLimiter + refreshEnabled bool + + mu sync.RWMutex + keySet *jsonWebKeySet + lastRefreshed time.Time + + refreshMu sync.Mutex // serializes JWKS refresh attempts + + httpClient *http.Client + lggr logger.Logger +} + +type jwtBasedAuthOptions struct { + authEnabledGate limits.GateLimiter + skipConfigChecks bool +} + +// JWTBasedAuthOption customizes JWTBasedAuth construction without multiplying constructors. +type JWTBasedAuthOption func(*jwtBasedAuthOptions) + +// WithJWTBasedAuthGateLimiter overrides the gate limiter that decides whether JWT-based auth is enabled. +func WithJWTBasedAuthGateLimiter(gateLimiter limits.GateLimiter) JWTBasedAuthOption { + return func(opts *jwtBasedAuthOptions) { + opts.authEnabledGate = gateLimiter + } +} + +// WithDisabledJWTBasedAuth makes the constructed JWTBasedAuth fail closed without requiring issuer config. +func WithDisabledJWTBasedAuth() JWTBasedAuthOption { + return func(opts *jwtBasedAuthOptions) { + opts.authEnabledGate = limits.NewGateLimiter(false) + opts.skipConfigChecks = true + } +} + +// NewJWTBasedAuth creates a JWTBasedAuth authorizer that verifies Auth0-issued JWTs +// against the provider's JWKS endpoint. The JWKS is fetched lazily on first +// use and refreshed on key-ID cache misses (rate-limited). +func NewJWTBasedAuth(cfg JWTBasedAuthConfig, limitsFactory limits.Factory, lggr logger.Logger, opts ...JWTBasedAuthOption) (*jwtBasedAuth, error) { + options := jwtBasedAuthOptions{} + for _, opt := range opts { + opt(&options) + } + if options.authEnabledGate == nil { + options.authEnabledGate = newVaultJWTAuthEnabledGateLimiter(limitsFactory, lggr) + } + if !options.skipConfigChecks && cfg.IssuerURL == "" { + return nil, errors.New("issuer URL is required") + } + if !options.skipConfigChecks && cfg.Audience == "" { + return nil, errors.New("audience is required") + } + + trimmedIssuer := strings.TrimSuffix(cfg.IssuerURL, "/") + jwksURL := trimmedIssuer + "/.well-known/jwks.json" + + refreshInterval := cfg.JWKSRefreshInterval + if refreshInterval == 0 { + refreshInterval = defaultJWKSRefreshInterval + } + + httpClient := cfg.HTTPClient + if httpClient == nil { + httpClient = &http.Client{Timeout: defaultHTTPTimeout} + } + + v := &jwtBasedAuth{ + issuerURL: cfg.IssuerURL, + audience: cfg.Audience, + jwksURL: jwksURL, + refreshInterval: refreshInterval, + authEnabledGate: options.authEnabledGate, + refreshEnabled: !options.skipConfigChecks, + httpClient: httpClient, + lggr: logger.Named(lggr, "VaultJWTBasedAuth"), + } + v.Service, v.eng = services.Config{ + Name: "VaultJWTBasedAuth", + Start: v.start, + Close: v.close, + }.NewServiceEngine(v.lggr) + + return v, nil +} + +func newVaultJWTAuthEnabledGateLimiter(limitsFactory limits.Factory, lggr logger.Logger) limits.GateLimiter { + limiter, err := limits.MakeGateLimiter(limitsFactory, cresettings.Default.VaultJWTAuthEnabled) + if err != nil { + logger.Named(lggr, "VaultJWTBasedAuth").Errorw("failed to create VaultJWTAuthEnabled limiter", "error", err) + return limits.NewGateLimiter(false) + } + + return limiter +} + +func (v *jwtBasedAuth) start(context.Context) error { + if !v.refreshEnabled { + v.lggr.Debug("JWTBasedAuth periodic JWKS refresh disabled") + return nil + } + + v.eng.GoTick(services.NewTicker(v.refreshInterval), func(ctx context.Context) { + if err := v.refreshJWKS(ctx); err != nil { + v.lggr.Warnw("periodic JWKS refresh failed", "error", err) + } + }) + return nil +} + +func (v *jwtBasedAuth) close() error { + return v.authEnabledGate.Close() +} + +// AuthorizeRequest verifies JWTBasedAuth state and token claims, and returns a common AuthResult. +func (v *jwtBasedAuth) AuthorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (*AuthResult, error) { + isEnabled, err := v.authEnabledGate.Limit(ctx) + if err != nil { + v.lggr.Errorw("failed to resolve JWTBasedAuth gate", "method", req.Method, "requestID", req.ID, "error", err) + return nil, fmt.Errorf("failed to resolve JWTBasedAuth gate: %w", err) + } + if !isEnabled { + v.lggr.Debugw("JWTBasedAuth rejected request because it is disabled", "method", req.Method, "requestID", req.ID) + return nil, errors.New("JWTBasedAuth is disabled") + } + + requestDigest, err := req.Digest() + if err != nil { + v.lggr.Debugw("JWTBasedAuth failed to compute request digest", "method", req.Method, "requestID", req.ID, "error", err) + return nil, fmt.Errorf("failed to compute request digest: %w", err) + } + + claims, err := v.validateToken(ctx, req.Auth) + if err != nil { + v.lggr.Debugw("JWTBasedAuth token validation failed", "method", req.Method, "requestID", req.ID, "error", err) + return nil, fmt.Errorf("invalid JWT auth token: %w", err) + } + + if !strings.EqualFold(requestDigest, claims.RequestDigest) { + v.lggr.Debugw("JWTBasedAuth request digest mismatch", "method", req.Method, "requestID", req.ID, "orgID", claims.OrgID, "workflowOwner", claims.WorkflowOwner, "computedDigest", requestDigest, "claimedDigest", claims.RequestDigest) + return nil, errors.New("request digest mismatch") + } + + v.lggr.Debugw("JWTBasedAuth authorization succeeded", "method", req.Method, "requestID", req.ID, "orgID", claims.OrgID, "workflowOwner", claims.WorkflowOwner, "digest", requestDigest, "expiresAt", claims.ExpiresAt.UTC().Unix()) + return &AuthResult{ + orgID: claims.OrgID, + workflowOwner: claims.WorkflowOwner, + digest: requestDigest, + expiresAt: claims.ExpiresAt.UTC().Unix(), + }, nil +} + +// validateToken verifies the JWT signature via Auth0 JWKS, validates +// standard claims (iss, aud, exp), and extracts Vault-specific claims +// (org_id, workflow_owner, request_digest). +func (v *jwtBasedAuth) validateToken(ctx context.Context, tokenString string) (*JWTClaims, error) { + if tokenString == "" { + return nil, ErrMissingToken + } + + unverified, _, err := jwt.NewParser().ParseUnverified(tokenString, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrInvalidToken, err) + } + + kid, ok := unverified.Header["kid"].(string) + if !ok || kid == "" { + return nil, fmt.Errorf("%w: missing kid header", ErrInvalidToken) + } + + rsaKey, err := v.resolveSigningKey(ctx, kid) + if err != nil { + return nil, err + } + + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, methodOK := token.Method.(*jwt.SigningMethodRSA); !methodOK { + return nil, fmt.Errorf("%w: unsupported alg %v", ErrInvalidToken, token.Header["alg"]) + } + return rsaKey, nil + }, + jwt.WithIssuer(v.issuerURL), + jwt.WithAudience(v.audience), + jwt.WithExpirationRequired(), + jwt.WithIssuedAt(), + ) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrInvalidToken, err) + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, ErrInvalidToken + } + + return extractVaultClaims(claims) +} + +func extractVaultClaims(claims jwt.MapClaims) (*JWTClaims, error) { + orgID, _ := claims["org_id"].(string) + if orgID == "" { + return nil, ErrMissingOrgID + } + + exp, err := claims.GetExpirationTime() + if err != nil { + return nil, fmt.Errorf("%w: invalid exp claim", ErrInvalidToken) + } + + workflowOwner, requestDigest, err := extractAuthorizationDetails(claims) + if err != nil { + return nil, err + } + + return &JWTClaims{ + OrgID: orgID, + WorkflowOwner: workflowOwner, + RequestDigest: requestDigest, + ExpiresAt: exp.Time, + }, nil +} + +func extractAuthorizationDetails(claims jwt.MapClaims) (workflowOwner, requestDigest string, err error) { + rawDetails, ok := claims["authorization_details"] + if !ok { + return "", "", ErrMissingRequestDigest + } + + details, ok := rawDetails.([]interface{}) + if !ok { + return "", "", fmt.Errorf("%w: authorization_details must be an array", ErrInvalidToken) + } + + for _, rawDetail := range details { + detail, ok := rawDetail.(map[string]interface{}) + if !ok { + continue + } + authDetailType, _ := detail["type"].(string) + authDetailValue, _ := detail["value"].(string) + switch authDetailType { + case "request_digest": + requestDigest = authDetailValue + case "workflow_owner": + workflowOwner = authDetailValue + } + } + + if requestDigest == "" { + return "", "", ErrMissingRequestDigest + } + + return workflowOwner, requestDigest, nil +} + +// resolveSigningKey looks up the RSA public key for the given kid from the +// JWKS cache, refreshing the cache if necessary. +func (v *jwtBasedAuth) resolveSigningKey(ctx context.Context, kid string) (*rsa.PublicKey, error) { + key, err := v.findCachedKey(kid) + if err != nil { + return nil, err + } + if key != nil { + return key, nil + } + + if refreshErr := v.refreshJWKS(ctx); refreshErr != nil { + v.lggr.Warnw("JWKS refresh failed", "error", refreshErr, "kid", kid) + return nil, fmt.Errorf("%w: kid=%s", ErrJWKSKeyNotFound, kid) + } + + key, err = v.findCachedKey(kid) + if err != nil { + return nil, err + } + if key == nil { + return nil, fmt.Errorf("%w: kid=%s", ErrJWKSKeyNotFound, kid) + } + + return key, nil +} + +func (v *jwtBasedAuth) findCachedKey(kid string) (*rsa.PublicKey, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + if v.keySet == nil { + return nil, nil + } + + for _, key := range v.keySet.Keys { + if key.Kid == kid { + return parseRSAPublicKey(key) + } + } + + return nil, nil +} + +// refreshJWKS fetches the JWKS from Auth0. Concurrent callers are serialized +// via refreshMu; if a recent fetch already happened within refreshInterval +// the call is a no-op. +func (v *jwtBasedAuth) refreshJWKS(ctx context.Context) error { + v.refreshMu.Lock() + defer v.refreshMu.Unlock() + + v.mu.RLock() + if v.keySet != nil && time.Since(v.lastRefreshed) < v.refreshInterval { + v.mu.RUnlock() + return nil + } + v.mu.RUnlock() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, v.jwksURL, nil) + if err != nil { + return fmt.Errorf("%w: %w", ErrJWKSFetchFailed, err) + } + + resp, err := v.httpClient.Do(req) + if err != nil { + return fmt.Errorf("%w: %w", ErrJWKSFetchFailed, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("%w: HTTP %d", ErrJWKSFetchFailed, resp.StatusCode) + } + + const maxJWKSBodySize = 1 << 20 // 1 MB + body, err := io.ReadAll(io.LimitReader(resp.Body, maxJWKSBodySize)) + if err != nil { + return fmt.Errorf("%w: %w", ErrJWKSFetchFailed, err) + } + + var keySet jsonWebKeySet + if err := json.Unmarshal(body, &keySet); err != nil { + return fmt.Errorf("%w: invalid JWKS: %w", ErrJWKSFetchFailed, err) + } + + v.mu.Lock() + v.keySet = &keySet + v.lastRefreshed = time.Now() + v.mu.Unlock() + + v.lggr.Infow("Refreshed JWKS", "numKeys", len(keySet.Keys), "url", v.jwksURL) + return nil +} + +func parseRSAPublicKey(key jsonWebKey) (*rsa.PublicKey, error) { + if key.Kty != "RSA" { + return nil, fmt.Errorf("unsupported key type: %s", key.Kty) + } + + nBytes, err := base64.RawURLEncoding.DecodeString(key.N) + if err != nil { + return nil, fmt.Errorf("failed to decode RSA modulus: %w", err) + } + + eBytes, err := base64.RawURLEncoding.DecodeString(key.E) + if err != nil { + return nil, fmt.Errorf("failed to decode RSA exponent: %w", err) + } + + var eInt int + for _, b := range eBytes { + eInt = eInt<<8 + int(b) + } + + return &rsa.PublicKey{ + N: new(big.Int).SetBytes(nBytes), + E: eInt, + }, nil +} diff --git a/core/capabilities/vault/jwt_based_auth_test.go b/core/capabilities/vault/jwt_based_auth_test.go new file mode 100644 index 00000000000..a7e1c890379 --- /dev/null +++ b/core/capabilities/vault/jwt_based_auth_test.go @@ -0,0 +1,533 @@ +package vault + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/settings" + "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/vaulttypes" + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +// --- test helpers --- + +type testRSAKey struct { + kid string + privateKey *rsa.PrivateKey +} + +type testJWKSServer struct { + server *httptest.Server + mu sync.Mutex + keys []testRSAKey + hits chan struct{} +} + +func newTestJWKSServer(t *testing.T, keys ...testRSAKey) *testJWKSServer { + t.Helper() + s := &testJWKSServer{keys: keys, hits: make(chan struct{}, 32)} + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + currentKeys := s.keys + s.mu.Unlock() + select { + case s.hits <- struct{}{}: + default: + } + + ks := jsonWebKeySet{} + for _, k := range currentKeys { + ks.Keys = append(ks.Keys, testRSAKeyToJWK(k.kid, &k.privateKey.PublicKey)) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(ks) + }) + s.server = httptest.NewServer(mux) + t.Cleanup(s.server.Close) + return s +} + +func (s *testJWKSServer) URL() string { return s.server.URL } + +func (s *testJWKSServer) waitForHits(t *testing.T, count int, timeout time.Duration) { + t.Helper() + + deadline := time.NewTimer(timeout) + defer deadline.Stop() + + for range count { + select { + case <-s.hits: + case <-deadline.C: + t.Fatalf("timed out waiting for %d JWKS hits", count) + } + } +} + +func (s *testJWKSServer) setKeys(keys ...testRSAKey) { + s.mu.Lock() + defer s.mu.Unlock() + s.keys = keys +} + +func testRSAKeyToJWK(kid string, pub *rsa.PublicKey) jsonWebKey { + return jsonWebKey{ + Kid: kid, + Alg: "RS256", + Kty: "RSA", + Use: "sig", + N: base64.RawURLEncoding.EncodeToString(pub.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()), + } +} + +func generateTestRSAKey(t *testing.T, kid string) testRSAKey { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return testRSAKey{kid: kid, privateKey: key} +} + +func createTestJWT(t *testing.T, key testRSAKey, claims jwt.MapClaims) string { + t.Helper() + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = key.kid + tokenString, err := token.SignedString(key.privateKey) + require.NoError(t, err) + return tokenString +} + +func validTestClaims(issuer, audience string) jwt.MapClaims { + return jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "exp": jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + "iat": jwt.NewNumericDate(time.Now()), + "org_id": "org_test123", + "authorization_details": []interface{}{ + map[string]interface{}{ + "type": "request_digest", + "value": "abc123def456", + }, + map[string]interface{}{ + "type": "workflow_owner", + "value": "0xAbCdEf0123456789AbCdEf0123456789AbCdEf01", + }, + }, + } +} + +func newTestValidator(t *testing.T, issuer, audience string) *jwtBasedAuth { + t.Helper() + v, err := NewJWTBasedAuth(JWTBasedAuthConfig{ + IssuerURL: issuer, + Audience: audience, + JWKSRefreshInterval: time.Millisecond, + }, limits.Factory{Settings: cresettings.DefaultGetter}, logger.TestLogger(t), WithJWTBasedAuthGateLimiter(limits.NewGateLimiter(true))) + require.NoError(t, err) + return v +} + +// --- tests --- + +func TestJWTBasedAuth_ValidToken(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + tokenString := createTestJWT(t, rsaKey, validTestClaims(issuer, audience)) + + result, err := v.validateToken(context.Background(), tokenString) + require.NoError(t, err) + assert.Equal(t, "org_test123", result.OrgID) + assert.Equal(t, "0xAbCdEf0123456789AbCdEf0123456789AbCdEf01", result.WorkflowOwner) + assert.Equal(t, "abc123def456", result.RequestDigest) + assert.False(t, result.ExpiresAt.IsZero()) +} + +func TestJWTBasedAuth_ValidToken_NoWorkflowOwner(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "exp": jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + "iat": jwt.NewNumericDate(time.Now()), + "org_id": "org_no_wfowner", + "authorization_details": []interface{}{ + map[string]interface{}{ + "type": "request_digest", + "value": "digest456", + }, + }, + } + tokenString := createTestJWT(t, rsaKey, claims) + + result, err := v.validateToken(context.Background(), tokenString) + require.NoError(t, err) + assert.Equal(t, "org_no_wfowner", result.OrgID) + assert.Empty(t, result.WorkflowOwner) + assert.Equal(t, "digest456", result.RequestDigest) +} + +func TestJWTBasedAuth_ExpiredToken(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + claims := validTestClaims(issuer, audience) + claims["exp"] = jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)) + tokenString := createTestJWT(t, rsaKey, claims) + + _, err := v.validateToken(context.Background(), tokenString) + require.Error(t, err) + require.ErrorIs(t, err, ErrInvalidToken) + assert.Contains(t, err.Error(), "expired") +} + +func TestJWTBasedAuth_WrongIssuer(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + claims := validTestClaims("https://wrong-issuer.auth0.com/", audience) + tokenString := createTestJWT(t, rsaKey, claims) + + _, err := v.validateToken(context.Background(), tokenString) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidToken) +} + +func TestJWTBasedAuth_WrongAudience(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + claims := validTestClaims(issuer, "https://wrong-audience.com") + tokenString := createTestJWT(t, rsaKey, claims) + + _, err := v.validateToken(context.Background(), tokenString) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidToken) +} + +func TestJWTBasedAuth_MissingOrgID(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + claims := validTestClaims(issuer, audience) + delete(claims, "org_id") + tokenString := createTestJWT(t, rsaKey, claims) + + _, err := v.validateToken(context.Background(), tokenString) + require.Error(t, err) + assert.ErrorIs(t, err, ErrMissingOrgID) +} + +func TestJWTBasedAuth_MissingRequestDigest(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + claims := validTestClaims(issuer, audience) + claims["authorization_details"] = []interface{}{ + map[string]interface{}{ + "type": "workflow_owner", + "value": "0xAbCd", + }, + } + tokenString := createTestJWT(t, rsaKey, claims) + + _, err := v.validateToken(context.Background(), tokenString) + require.Error(t, err) + assert.ErrorIs(t, err, ErrMissingRequestDigest) +} + +func TestJWTBasedAuth_MissingAuthorizationDetails(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + claims := validTestClaims(issuer, audience) + delete(claims, "authorization_details") + tokenString := createTestJWT(t, rsaKey, claims) + + _, err := v.validateToken(context.Background(), tokenString) + require.Error(t, err) + assert.ErrorIs(t, err, ErrMissingRequestDigest) +} + +func TestJWTBasedAuth_InvalidSignature(t *testing.T) { + goodKey := generateTestRSAKey(t, "key-1") + badKey := generateTestRSAKey(t, "key-1") // same kid, different key material + jwksServer := newTestJWKSServer(t, goodKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + claims := validTestClaims(issuer, audience) + tokenString := createTestJWT(t, badKey, claims) // signed with wrong private key + + _, err := v.validateToken(context.Background(), tokenString) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidToken) +} + +func TestJWTBasedAuth_EmptyToken(t *testing.T) { + v, err := NewJWTBasedAuth(JWTBasedAuthConfig{ + IssuerURL: "https://example.auth0.com/", + Audience: "https://api.test.chain.link", + }, limits.Factory{Settings: cresettings.DefaultGetter}, logger.TestLogger(t), WithJWTBasedAuthGateLimiter(limits.NewGateLimiter(true))) + require.NoError(t, err) + + _, err = v.validateToken(context.Background(), "") + require.Error(t, err) + assert.ErrorIs(t, err, ErrMissingToken) +} + +func TestJWTBasedAuth_JWKSKeyRotation(t *testing.T) { + keyA := generateTestRSAKey(t, "key-A") + keyB := generateTestRSAKey(t, "key-B") + + jwksServer := newTestJWKSServer(t, keyA) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + // Token signed with key-A succeeds + claimsA := validTestClaims(issuer, audience) + tokenA := createTestJWT(t, keyA, claimsA) + resultA, err := v.validateToken(context.Background(), tokenA) + require.NoError(t, err) + assert.Equal(t, "org_test123", resultA.OrgID) + + // Simulate key rotation: JWKS now serves only key-B + jwksServer.setKeys(keyB) + + // Allow the refresh interval to elapse so the next miss triggers a fetch + time.Sleep(2 * time.Millisecond) + + // Token signed with key-B succeeds after JWKS refresh + claimsB := validTestClaims(issuer, audience) + claimsB["org_id"] = "org_after_rotation" + tokenB := createTestJWT(t, keyB, claimsB) + resultB, err := v.validateToken(context.Background(), tokenB) + require.NoError(t, err) + assert.Equal(t, "org_after_rotation", resultB.OrgID) +} + +func TestJWTBasedAuth_AuthorizationDetailsFromTypedArray(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + claims := jwt.MapClaims{ + "iss": issuer, + "aud": audience, + "exp": jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + "iat": jwt.NewNumericDate(time.Now()), + "org_id": "org_single", + "authorization_details": []interface{}{ + map[string]interface{}{"type": "request_digest", "value": "single_digest"}, + map[string]interface{}{"type": "workflow_owner", "value": "0x1111"}, + }, + } + tokenString := createTestJWT(t, rsaKey, claims) + + result, err := v.validateToken(context.Background(), tokenString) + require.NoError(t, err) + assert.Equal(t, "org_single", result.OrgID) + assert.Equal(t, "single_digest", result.RequestDigest) + assert.Equal(t, "0x1111", result.WorkflowOwner) +} + +func TestJWTBasedAuth_UnsupportedAlgorithm(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + issuer := jwksServer.URL() + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + // Create a token signed with HMAC instead of RSA + claims := validTestClaims(issuer, audience) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["kid"] = rsaKey.kid + tokenString, err := token.SignedString([]byte("hmac-secret")) + require.NoError(t, err) + + _, err = v.validateToken(context.Background(), tokenString) + require.Error(t, err) + assert.ErrorIs(t, err, ErrInvalidToken) +} + +func TestJWTBasedAuth_JWKSServerUnavailable(t *testing.T) { + // Start a server that always returns 500 + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + issuer := server.URL + "/" + audience := "https://api.test.chain.link" + v := newTestValidator(t, issuer, audience) + + rsaKey := generateTestRSAKey(t, "key-1") + claims := validTestClaims(issuer, audience) + tokenString := createTestJWT(t, rsaKey, claims) + + _, err := v.validateToken(context.Background(), tokenString) + require.Error(t, err) + assert.ErrorIs(t, err, ErrJWKSKeyNotFound) +} + +func TestJWTBasedAuth_StartRefreshesJWKSPeriodically(t *testing.T) { + rsaKey := generateTestRSAKey(t, "key-1") + jwksServer := newTestJWKSServer(t, rsaKey) + + v, err := NewJWTBasedAuth(JWTBasedAuthConfig{ + IssuerURL: jwksServer.URL() + "/", + Audience: "https://api.test.chain.link", + JWKSRefreshInterval: 10 * time.Millisecond, + }, limits.Factory{Settings: cresettings.DefaultGetter}, logger.TestLogger(t), WithJWTBasedAuthGateLimiter(limits.NewGateLimiter(true))) + require.NoError(t, err) + + require.NoError(t, v.Start(t.Context())) + jwksServer.waitForHits(t, 2, time.Second) + require.NoError(t, v.Close()) +} + +func TestJWTBasedAuth_DisabledStartSkipsPeriodicRefresh(t *testing.T) { + v, err := NewJWTBasedAuth( + JWTBasedAuthConfig{}, + limits.Factory{Settings: cresettings.DefaultGetter}, + logger.TestLogger(t), + WithDisabledJWTBasedAuth(), + ) + require.NoError(t, err) + require.NoError(t, v.Start(t.Context())) + require.NoError(t, v.Close()) +} + +func TestNewJWTBasedAuth_InvalidConfig(t *testing.T) { + lggr := logger.TestLogger(t) + + _, err := NewJWTBasedAuth(JWTBasedAuthConfig{ + IssuerURL: "", + Audience: "https://api.test.chain.link", + }, limits.Factory{Settings: cresettings.DefaultGetter}, lggr, WithJWTBasedAuthGateLimiter(limits.NewGateLimiter(true))) + require.Error(t, err) + assert.Contains(t, err.Error(), "issuer URL is required") + + _, err = NewJWTBasedAuth(JWTBasedAuthConfig{ + IssuerURL: "https://example.auth0.com/", + Audience: "", + }, limits.Factory{Settings: cresettings.DefaultGetter}, lggr, WithJWTBasedAuthGateLimiter(limits.NewGateLimiter(true))) + require.Error(t, err) + assert.Contains(t, err.Error(), "audience is required") +} + +func TestNewJWTBasedAuth_UsesVaultJWTAuthEnabledLimiter_Disabled(t *testing.T) { + setDefaultGetter(t, `{}`) + + v, err := NewJWTBasedAuth(JWTBasedAuthConfig{ + IssuerURL: "https://example.auth0.com/", + Audience: "https://api.test.chain.link", + }, limits.Factory{Settings: cresettings.DefaultGetter}, logger.TestLogger(t)) + require.NoError(t, err) + + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-1", + Method: vaulttypes.MethodSecretsList, + Auth: "token", + } + + _, err = v.AuthorizeRequest(t.Context(), req) + require.Error(t, err) + require.ErrorContains(t, err, "JWTBasedAuth is disabled") +} + +func TestNewJWTBasedAuth_UsesVaultJWTAuthEnabledLimiter_Enabled(t *testing.T) { + setDefaultGetter(t, `{"global":{"VaultJWTAuthEnabled":true}}`) + + v, err := NewJWTBasedAuth(JWTBasedAuthConfig{ + IssuerURL: "https://example.auth0.com/", + Audience: "https://api.test.chain.link", + }, limits.Factory{Settings: cresettings.DefaultGetter}, logger.TestLogger(t)) + require.NoError(t, err) + + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-1", + Method: vaulttypes.MethodSecretsList, + } + + _, err = v.AuthorizeRequest(t.Context(), req) + require.Error(t, err) + require.ErrorContains(t, err, "invalid JWT auth token") + require.ErrorContains(t, err, ErrMissingToken.Error()) +} + +func setDefaultGetter(t *testing.T, payload string) { + t.Helper() + + prev := cresettings.DefaultGetter + t.Cleanup(func() { + cresettings.DefaultGetter = prev + }) + + getter, err := settings.NewJSONGetter([]byte(payload)) + require.NoError(t, err) + cresettings.DefaultGetter = getter +} diff --git a/core/capabilities/vault/mocks/authorizer.go b/core/capabilities/vault/mocks/authorizer.go new file mode 100644 index 00000000000..fbb57897079 --- /dev/null +++ b/core/capabilities/vault/mocks/authorizer.go @@ -0,0 +1,99 @@ +// Code generated by mockery v2.53.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + json "encoding/json" + + jsonrpc2 "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + mock "github.com/stretchr/testify/mock" + + vault "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault" +) + +// Authorizer is an autogenerated mock type for the Authorizer type +type Authorizer struct { + mock.Mock +} + +type Authorizer_Expecter struct { + mock *mock.Mock +} + +func (_m *Authorizer) EXPECT() *Authorizer_Expecter { + return &Authorizer_Expecter{mock: &_m.Mock} +} + +// AuthorizeRequest provides a mock function with given fields: ctx, req +func (_m *Authorizer) AuthorizeRequest(ctx context.Context, req jsonrpc2.Request[json.RawMessage]) (*vault.AuthResult, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for AuthorizeRequest") + } + + var r0 *vault.AuthResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, jsonrpc2.Request[json.RawMessage]) (*vault.AuthResult, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, jsonrpc2.Request[json.RawMessage]) *vault.AuthResult); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*vault.AuthResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, jsonrpc2.Request[json.RawMessage]) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Authorizer_AuthorizeRequest_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AuthorizeRequest' +type Authorizer_AuthorizeRequest_Call struct { + *mock.Call +} + +// AuthorizeRequest is a helper method to define mock.On call +// - ctx context.Context +// - req jsonrpc2.Request[json.RawMessage] +func (_e *Authorizer_Expecter) AuthorizeRequest(ctx interface{}, req interface{}) *Authorizer_AuthorizeRequest_Call { + return &Authorizer_AuthorizeRequest_Call{Call: _e.mock.On("AuthorizeRequest", ctx, req)} +} + +func (_c *Authorizer_AuthorizeRequest_Call) Run(run func(ctx context.Context, req jsonrpc2.Request[json.RawMessage])) *Authorizer_AuthorizeRequest_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(jsonrpc2.Request[json.RawMessage])) + }) + return _c +} + +func (_c *Authorizer_AuthorizeRequest_Call) Return(_a0 *vault.AuthResult, _a1 error) *Authorizer_AuthorizeRequest_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Authorizer_AuthorizeRequest_Call) RunAndReturn(run func(context.Context, jsonrpc2.Request[json.RawMessage]) (*vault.AuthResult, error)) *Authorizer_AuthorizeRequest_Call { + _c.Call.Return(run) + return _c +} + +// NewAuthorizer creates a new instance of Authorizer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAuthorizer(t interface { + mock.TestingT + Cleanup(func()) +}) *Authorizer { + mock := &Authorizer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/capabilities/vault/mocks/request_authorizer.go b/core/capabilities/vault/mocks/request_authorizer.go deleted file mode 100644 index cead13f6b65..00000000000 --- a/core/capabilities/vault/mocks/request_authorizer.go +++ /dev/null @@ -1,102 +0,0 @@ -// Code generated by mockery v2.53.0. DO NOT EDIT. - -package mocks - -import ( - context "context" - json "encoding/json" - - jsonrpc2 "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" - mock "github.com/stretchr/testify/mock" -) - -// RequestAuthorizer is an autogenerated mock type for the RequestAuthorizer type -type RequestAuthorizer struct { - mock.Mock -} - -type RequestAuthorizer_Expecter struct { - mock *mock.Mock -} - -func (_m *RequestAuthorizer) EXPECT() *RequestAuthorizer_Expecter { - return &RequestAuthorizer_Expecter{mock: &_m.Mock} -} - -// AuthorizeRequest provides a mock function with given fields: ctx, req -func (_m *RequestAuthorizer) AuthorizeRequest(ctx context.Context, req jsonrpc2.Request[json.RawMessage]) (bool, string, error) { - ret := _m.Called(ctx, req) - - if len(ret) == 0 { - panic("no return value specified for AuthorizeRequest") - } - - var r0 bool - var r1 string - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, jsonrpc2.Request[json.RawMessage]) (bool, string, error)); ok { - return rf(ctx, req) - } - if rf, ok := ret.Get(0).(func(context.Context, jsonrpc2.Request[json.RawMessage]) bool); ok { - r0 = rf(ctx, req) - } else { - r0 = ret.Get(0).(bool) - } - - if rf, ok := ret.Get(1).(func(context.Context, jsonrpc2.Request[json.RawMessage]) string); ok { - r1 = rf(ctx, req) - } else { - r1 = ret.Get(1).(string) - } - - if rf, ok := ret.Get(2).(func(context.Context, jsonrpc2.Request[json.RawMessage]) error); ok { - r2 = rf(ctx, req) - } else { - r2 = ret.Error(2) - } - - return r0, r1, r2 -} - -// RequestAuthorizer_AuthorizeRequest_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AuthorizeRequest' -type RequestAuthorizer_AuthorizeRequest_Call struct { - *mock.Call -} - -// AuthorizeRequest is a helper method to define mock.On call -// - ctx context.Context -// - req jsonrpc2.Request[json.RawMessage] -func (_e *RequestAuthorizer_Expecter) AuthorizeRequest(ctx interface{}, req interface{}) *RequestAuthorizer_AuthorizeRequest_Call { - return &RequestAuthorizer_AuthorizeRequest_Call{Call: _e.mock.On("AuthorizeRequest", ctx, req)} -} - -func (_c *RequestAuthorizer_AuthorizeRequest_Call) Run(run func(ctx context.Context, req jsonrpc2.Request[json.RawMessage])) *RequestAuthorizer_AuthorizeRequest_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(jsonrpc2.Request[json.RawMessage])) - }) - return _c -} - -func (_c *RequestAuthorizer_AuthorizeRequest_Call) Return(isAuthorized bool, owner string, err error) *RequestAuthorizer_AuthorizeRequest_Call { - _c.Call.Return(isAuthorized, owner, err) - return _c -} - -func (_c *RequestAuthorizer_AuthorizeRequest_Call) RunAndReturn(run func(context.Context, jsonrpc2.Request[json.RawMessage]) (bool, string, error)) *RequestAuthorizer_AuthorizeRequest_Call { - _c.Call.Return(run) - return _c -} - -// NewRequestAuthorizer creates a new instance of RequestAuthorizer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewRequestAuthorizer(t interface { - mock.TestingT - Cleanup(func()) -}) *RequestAuthorizer { - mock := &RequestAuthorizer{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/core/capabilities/vault/request_authorizer.go b/core/capabilities/vault/request_authorizer.go deleted file mode 100644 index 96e22396831..00000000000 --- a/core/capabilities/vault/request_authorizer.go +++ /dev/null @@ -1,139 +0,0 @@ -package vault - -import ( - "context" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "time" - - jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-evm/gethwrappers/workflow/generated/workflow_registry_wrapper_v2" - workflowsyncerv2 "github.com/smartcontractkit/chainlink/v2/core/services/workflows/syncer/v2" -) - -type RequestAuthorizer interface { - AuthorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (isAuthorized bool, owner string, err error) -} -type requestAuthorizer struct { - workflowRegistrySyncer workflowsyncerv2.WorkflowRegistrySyncer - replayGuard *DigestReplayGuard - lggr logger.Logger - sleep func(time.Duration) -} - -const ( - allowlistReadRetryCount = 3 - allowlistReadRetryInterval = 3 * time.Second -) - -// AuthorizeRequest authorizes a request based on the request digest and the allowlisted requests. -// It does NOT check if the request method is allowed. -func (r *requestAuthorizer) AuthorizeRequest(ctx context.Context, req jsonrpc.Request[json.RawMessage]) (isAuthorized bool, owner string, err error) { - defer r.replayGuard.ClearExpired() - r.lggr.Infow("AuthorizeRequest", "method", req.Method, "requestID", req.ID) - requestDigest, err := req.Digest() - if err != nil { - r.lggr.Infow("AuthorizeRequest failed to create digest", "method", req.Method, "requestID", req.ID) - return false, "", err - } - requestDigestBytes, err := hex.DecodeString(requestDigest) - if err != nil { - r.lggr.Infow("AuthorizeRequest failed to decode digest", "method", req.Method, "requestID", req.ID) - return false, "", err - } - requestDigestBytes32 := [32]byte(requestDigestBytes) - if r.workflowRegistrySyncer == nil { - r.lggr.Errorw("AuthorizeRequest workflowRegistrySyncer is nil", "method", req.Method, "requestID", req.ID) - return false, "", errors.New("internal error: workflowRegistrySyncer is nil") - } - allowlistedRequest, _ := r.fetchAllowlistedItemWithRetry(ctx, req.Method, req.ID, requestDigest, requestDigestBytes32) - if allowlistedRequest == nil { - return false, "", errors.New("request not allowlisted") - } - - if time.Now().UTC().Unix() > int64(allowlistedRequest.ExpiryTimestamp) { - authorizedRequestStr := string(allowlistedRequest.RequestDigest[:]) - r.lggr.Infow("AuthorizeRequest expired authorization", "method", req.Method, "requestID", req.ID, "authorizedRequestStr", authorizedRequestStr) - return false, "", errors.New("request authorization expired") - } - - digestKey := string(allowlistedRequest.RequestDigest[:]) - if err := r.replayGuard.CheckAndRecord(digestKey, int64(allowlistedRequest.ExpiryTimestamp)); err != nil { - r.lggr.Infow("AuthorizeRequest already authorized previously", "method", req.Method, "requestID", req.ID, "authorizedRequestStr", digestKey) - return false, "", err - } - - r.lggr.Infow("AuthorizeRequest success in auth", "method", req.Method, "requestID", req.ID, "authorizedRequestStr", digestKey) - return true, allowlistedRequest.Owner.Hex(), nil -} - -func (r *requestAuthorizer) fetchAllowlistedItemWithRetry(ctx context.Context, method string, requestID interface{}, requestDigest string, digest [32]byte) (*workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest, []string) { - var allowedRequestsStrs []string - for attempt := 0; attempt <= allowlistReadRetryCount; attempt++ { - allowedRequests := r.workflowRegistrySyncer.GetAllowlistedRequests(ctx) - allowedRequestsStrs = make([]string, 0, len(allowedRequests)) - for _, rr := range allowedRequests { - allowedReqStr := fmt.Sprintf("Owner: %s, RequestDigest: %s, ExpiryTimestamp: %d", rr.Owner.Hex(), hex.EncodeToString(rr.RequestDigest[:]), rr.ExpiryTimestamp) - allowedRequestsStrs = append(allowedRequestsStrs, allowedReqStr) - } - r.lggr.Infow("AuthorizeRequest GetAllowlistedRequests", "method", method, "requestID", requestID, "attempt", attempt+1, "allowedRequests", allowedRequestsStrs) - - allowlistedRequest := r.fetchAllowlistedItem(allowedRequests, digest) - if allowlistedRequest != nil { - return allowlistedRequest, allowedRequestsStrs - } - - if attempt == allowlistReadRetryCount { - break - } - - r.lggr.Warnw("AuthorizeRequest request not found in allowlist, retrying", - "method", method, - "requestID", requestID, - "digestHexStr", requestDigest, - "attempt", attempt+1, - "retryInterval", allowlistReadRetryInterval, - "allowedRequestsStrs", allowedRequestsStrs) - - select { - case <-ctx.Done(): - r.lggr.Warnw("AuthorizeRequest allowlist retry canceled", - "method", method, - "requestID", requestID, - "digestHexStr", requestDigest, - "attempt", attempt+1) - return nil, allowedRequestsStrs - default: - } - - r.sleep(allowlistReadRetryInterval) - } - - r.lggr.Infow("AuthorizeRequest fetchAllowlistedItem request not allowlisted", - "method", method, - "requestID", requestID, - "digestHexStr", requestDigest, - "allowedRequestsStrs", allowedRequestsStrs) - return nil, allowedRequestsStrs -} - -func (r *requestAuthorizer) fetchAllowlistedItem(allowListedRequests []workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest, digest [32]byte) *workflow_registry_wrapper_v2.WorkflowRegistryOwnerAllowlistedRequest { - for _, item := range allowListedRequests { - if item.RequestDigest == digest { - return &item - } - } - return nil -} - -func NewRequestAuthorizer(lggr logger.Logger, workflowRegistrySyncer workflowsyncerv2.WorkflowRegistrySyncer) *requestAuthorizer { - return &requestAuthorizer{ - workflowRegistrySyncer: workflowRegistrySyncer, - lggr: logger.Named(lggr, "VaultRequestAuthorizer"), - replayGuard: NewDigestReplayGuard(), - sleep: time.Sleep, - } -} diff --git a/core/capabilities/vault/digest_replay_guard.go b/core/capabilities/vault/request_replay_guard.go similarity index 59% rename from core/capabilities/vault/digest_replay_guard.go rename to core/capabilities/vault/request_replay_guard.go index 2ef15221a6f..7ce5dfa1643 100644 --- a/core/capabilities/vault/digest_replay_guard.go +++ b/core/capabilities/vault/request_replay_guard.go @@ -6,39 +6,40 @@ import ( "time" ) -var ErrDigestAlreadySeen = errors.New("request already authorized previously") +var ErrRequestAlreadySeen = errors.New("request was already authorized previously") -// DigestReplayGuard prevents replay of already-processed requests by tracking +// RequestReplayGuard prevents replay of already-processed requests by tracking // request digests with expiry timestamps. It is safe for concurrent use. // -// Used by both the on-chain allowlist flow and the JWT auth flow to ensure +// Used by both the AllowListBasedAuth flow and the JWTBasedAuth flow to ensure // that a given request digest is only accepted once. -type DigestReplayGuard struct { +type RequestReplayGuard struct { mu sync.Mutex seen map[string]int64 // digest → unix expiry timestamp nowFunc func() time.Time // injectable for testing } -func NewDigestReplayGuard() *DigestReplayGuard { - return &DigestReplayGuard{ +// NewRequestReplayGuard creates a replay guard for authorized Vault requests. +func NewRequestReplayGuard() *RequestReplayGuard { + return &RequestReplayGuard{ seen: make(map[string]int64), nowFunc: time.Now, } } -// CheckAndRecord returns ErrDigestAlreadySeen if the digest was previously +// CheckAndRecord returns ErrRequestAlreadySeen if the digest was previously // recorded and has not yet expired. Otherwise it records the digest with // the given expiry timestamp (unix seconds, UTC). // // Expired entries are cleaned up on every call. -func (g *DigestReplayGuard) CheckAndRecord(digest string, expiresAtUnix int64) error { +func (g *RequestReplayGuard) CheckAndRecord(digest string, expiresAtUnix int64) error { g.mu.Lock() defer g.mu.Unlock() g.clearExpiredLocked() if _, exists := g.seen[digest]; exists { - return ErrDigestAlreadySeen + return ErrRequestAlreadySeen } g.seen[digest] = expiresAtUnix @@ -47,13 +48,13 @@ func (g *DigestReplayGuard) CheckAndRecord(digest string, expiresAtUnix int64) e // ClearExpired removes all entries whose expiry timestamp is in the past. // Call this to eagerly reclaim memory even when CheckAndRecord is not invoked. -func (g *DigestReplayGuard) ClearExpired() { +func (g *RequestReplayGuard) ClearExpired() { g.mu.Lock() defer g.mu.Unlock() g.clearExpiredLocked() } -func (g *DigestReplayGuard) clearExpiredLocked() { +func (g *RequestReplayGuard) clearExpiredLocked() { now := g.nowFunc().UTC().Unix() for digest, expiry := range g.seen { if now > expiry { diff --git a/core/capabilities/vault/digest_replay_guard_test.go b/core/capabilities/vault/request_replay_guard_test.go similarity index 77% rename from core/capabilities/vault/digest_replay_guard_test.go rename to core/capabilities/vault/request_replay_guard_test.go index 5d9c9cf64c8..07e4b1ee6a5 100644 --- a/core/capabilities/vault/digest_replay_guard_test.go +++ b/core/capabilities/vault/request_replay_guard_test.go @@ -9,39 +9,39 @@ import ( "github.com/stretchr/testify/require" ) -func TestDigestReplayGuard_FirstCallSucceeds(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_FirstCallSucceeds(t *testing.T) { + guard := NewRequestReplayGuard() futureExpiry := time.Now().UTC().Unix() + 100 err := guard.CheckAndRecord("digest-1", futureExpiry) require.NoError(t, err) } -func TestDigestReplayGuard_DuplicateRejected(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_DuplicateRejected(t *testing.T) { + guard := NewRequestReplayGuard() futureExpiry := time.Now().UTC().Unix() + 100 err := guard.CheckAndRecord("digest-1", futureExpiry) require.NoError(t, err) err = guard.CheckAndRecord("digest-1", futureExpiry) - require.ErrorIs(t, err, ErrDigestAlreadySeen) + require.ErrorIs(t, err, ErrRequestAlreadySeen) } -func TestDigestReplayGuard_DifferentDigestsIndependent(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_DifferentDigestsIndependent(t *testing.T) { + guard := NewRequestReplayGuard() futureExpiry := time.Now().UTC().Unix() + 100 require.NoError(t, guard.CheckAndRecord("digest-1", futureExpiry)) require.NoError(t, guard.CheckAndRecord("digest-2", futureExpiry)) require.NoError(t, guard.CheckAndRecord("digest-3", futureExpiry)) - require.ErrorIs(t, guard.CheckAndRecord("digest-1", futureExpiry), ErrDigestAlreadySeen) - require.ErrorIs(t, guard.CheckAndRecord("digest-2", futureExpiry), ErrDigestAlreadySeen) + require.ErrorIs(t, guard.CheckAndRecord("digest-1", futureExpiry), ErrRequestAlreadySeen) + require.ErrorIs(t, guard.CheckAndRecord("digest-2", futureExpiry), ErrRequestAlreadySeen) } -func TestDigestReplayGuard_ExpiredEntryCleaned(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_ExpiredEntryCleaned(t *testing.T) { + guard := NewRequestReplayGuard() now := time.Now() guard.nowFunc = func() time.Time { return now } @@ -57,8 +57,8 @@ func TestDigestReplayGuard_ExpiredEntryCleaned(t *testing.T) { require.NoError(t, err) } -func TestDigestReplayGuard_NonExpiredEntryRetained(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_NonExpiredEntryRetained(t *testing.T) { + guard := NewRequestReplayGuard() now := time.Now() guard.nowFunc = func() time.Time { return now } @@ -69,11 +69,11 @@ func TestDigestReplayGuard_NonExpiredEntryRetained(t *testing.T) { guard.nowFunc = func() time.Time { return now.Add(50 * time.Second) } err := guard.CheckAndRecord("digest-1", futureExpiry) - require.ErrorIs(t, err, ErrDigestAlreadySeen) + require.ErrorIs(t, err, ErrRequestAlreadySeen) } -func TestDigestReplayGuard_MixedExpiryCleanup(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_MixedExpiryCleanup(t *testing.T) { + guard := NewRequestReplayGuard() now := time.Now() guard.nowFunc = func() time.Time { return now } @@ -90,11 +90,11 @@ func TestDigestReplayGuard_MixedExpiryCleanup(t *testing.T) { require.NoError(t, guard.CheckAndRecord("short-lived", now.Add(50*time.Second).UTC().Unix()+100)) // Long-lived should still be rejected - require.ErrorIs(t, guard.CheckAndRecord("long-lived", longExpiry), ErrDigestAlreadySeen) + require.ErrorIs(t, guard.CheckAndRecord("long-lived", longExpiry), ErrRequestAlreadySeen) } -func TestDigestReplayGuard_ConcurrentAccess(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_ConcurrentAccess(t *testing.T) { + guard := NewRequestReplayGuard() futureExpiry := time.Now().UTC().Unix() + 100 const goroutines = 100 @@ -116,7 +116,7 @@ func TestDigestReplayGuard_ConcurrentAccess(t *testing.T) { if err == nil { successCount++ } else { - require.ErrorIs(t, err, ErrDigestAlreadySeen) + require.ErrorIs(t, err, ErrRequestAlreadySeen) duplicateCount++ } } @@ -125,8 +125,8 @@ func TestDigestReplayGuard_ConcurrentAccess(t *testing.T) { assert.Equal(t, goroutines-1, duplicateCount, "all others should be rejected as duplicates") } -func TestDigestReplayGuard_ConcurrentDifferentDigests(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_ConcurrentDifferentDigests(t *testing.T) { + guard := NewRequestReplayGuard() futureExpiry := time.Now().UTC().Unix() + 100 const goroutines = 50 @@ -148,8 +148,8 @@ func TestDigestReplayGuard_ConcurrentDifferentDigests(t *testing.T) { } } -func TestDigestReplayGuard_ClearExpiredIndependently(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_ClearExpiredIndependently(t *testing.T) { + guard := NewRequestReplayGuard() now := time.Now() guard.nowFunc = func() time.Time { return now } @@ -174,10 +174,10 @@ func TestDigestReplayGuard_ClearExpiredIndependently(t *testing.T) { assert.True(t, durablePresent, "non-expired entry should remain") } -func TestDigestReplayGuard_EmptyDigest(t *testing.T) { - guard := NewDigestReplayGuard() +func TestRequestReplayGuard_EmptyDigest(t *testing.T) { + guard := NewRequestReplayGuard() futureExpiry := time.Now().UTC().Unix() + 100 require.NoError(t, guard.CheckAndRecord("", futureExpiry)) - require.ErrorIs(t, guard.CheckAndRecord("", futureExpiry), ErrDigestAlreadySeen) + require.ErrorIs(t, guard.CheckAndRecord("", futureExpiry), ErrRequestAlreadySeen) } diff --git a/core/services/gateway/handler_factory.go b/core/services/gateway/handler_factory.go index 76172b3dc9b..784c3497924 100644 --- a/core/services/gateway/handler_factory.go +++ b/core/services/gateway/handler_factory.go @@ -13,7 +13,6 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types/core" "github.com/smartcontractkit/chainlink-evm/pkg/chains/legacyevm" - vaultcap "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" @@ -85,8 +84,7 @@ func (hf *handlerFactory) NewHandler( case HTTPCapabilityType: return v2.NewGatewayHandler(handlerConfig, donConfig, don, hf.httpClient, hf.lggr, hf.lf) case VaultHandlerType: - requestAuthorizer := vaultcap.NewRequestAuthorizer(hf.lggr, hf.workflowRegistrySyncer) - return vault.NewHandler(handlerConfig, donConfig, don, hf.capabilitiesRegistry, requestAuthorizer, hf.lggr, clockwork.NewRealClock(), hf.lf) + return vault.NewHandler(handlerConfig, donConfig, don, hf.capabilitiesRegistry, hf.workflowRegistrySyncer, hf.lggr, clockwork.NewRealClock(), hf.lf) default: return nil, fmt.Errorf("unsupported handler type %s", handlerType) } diff --git a/core/services/gateway/handlers/vault/handler.go b/core/services/gateway/handlers/vault/handler.go index ceb95f92a49..a1335bfb0e7 100644 --- a/core/services/gateway/handlers/vault/handler.go +++ b/core/services/gateway/handlers/vault/handler.go @@ -34,6 +34,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" gwhandlers "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" handlerscommon "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" + workflowsyncerv2 "github.com/smartcontractkit/chainlink/v2/core/services/workflows/syncer/v2" ) const ( @@ -129,14 +130,15 @@ type aggregator interface { type handler struct { services.StateMachine - methodConfig Config - donConfig *config.DONConfig - don gwhandlers.DON - lggr logger.Logger - codec api.JsonRPCCodec - mu sync.RWMutex - stopCh services.StopChan - requestAuthorizer vaultcap.RequestAuthorizer + methodConfig Config + donConfig *config.DONConfig + don gwhandlers.DON + lggr logger.Logger + codec api.JsonRPCCodec + mu sync.RWMutex + stopCh services.StopChan + authorizer vaultcap.Authorizer + jwtAuth services.Service *vaultcap.RequestValidator nodeRateLimiter *ratelimit.RateLimiter @@ -162,18 +164,35 @@ func (h *handler) Name() string { return h.lggr.Name() } +// SecretEntry is the user-facing shape returned by list operations. type SecretEntry struct { ID string `json:"id"` Value string `json:"value"` CreatedAt int64 `json:"created_at"` } +// Config configures the gateway-side Vault handler. type Config struct { NodeRateLimiter ratelimit.RateLimiterConfig `json:"nodeRateLimiter"` RequestTimeoutSec int `json:"requestTimeoutSec"` } -func NewHandler(methodConfig json.RawMessage, donConfig *config.DONConfig, don gwhandlers.DON, capabilitiesRegistry capabilitiesRegistry, requestAuthorizer vaultcap.RequestAuthorizer, lggr logger.Logger, clock clockwork.Clock, limitsFactory limits.Factory) (*handler, error) { +// NewHandler creates the gateway-side Vault handler with internal auth wiring. +func NewHandler(methodConfig json.RawMessage, donConfig *config.DONConfig, don gwhandlers.DON, capabilitiesRegistry capabilitiesRegistry, workflowRegistrySyncer workflowsyncerv2.WorkflowRegistrySyncer, lggr logger.Logger, clock clockwork.Clock, limitsFactory limits.Factory) (*handler, error) { + allowListBasedAuth := vaultcap.NewAllowListBasedAuth(lggr, workflowRegistrySyncer) + jwtBasedAuth, err := vaultcap.NewJWTBasedAuth(vaultcap.JWTBasedAuthConfig{}, limitsFactory, lggr, vaultcap.WithDisabledJWTBasedAuth()) + if err != nil { + return nil, fmt.Errorf("failed to create JWTBasedAuth: %w", err) + } + authorizer := vaultcap.NewAuthorizer(allowListBasedAuth, jwtBasedAuth, lggr) + return newHandlerWithJWTAuth(methodConfig, donConfig, don, capabilitiesRegistry, authorizer, jwtBasedAuth, lggr, clock, limitsFactory) +} + +func newHandlerWithAuthorizer(methodConfig json.RawMessage, donConfig *config.DONConfig, don gwhandlers.DON, capabilitiesRegistry capabilitiesRegistry, authorizer vaultcap.Authorizer, lggr logger.Logger, clock clockwork.Clock, limitsFactory limits.Factory) (*handler, error) { + return newHandlerWithJWTAuth(methodConfig, donConfig, don, capabilitiesRegistry, authorizer, nil, lggr, clock, limitsFactory) +} + +func newHandlerWithJWTAuth(methodConfig json.RawMessage, donConfig *config.DONConfig, don gwhandlers.DON, capabilitiesRegistry capabilitiesRegistry, authorizer vaultcap.Authorizer, jwtAuth services.Service, lggr logger.Logger, clock clockwork.Clock, limitsFactory limits.Factory) (*handler, error) { var cfg Config if err := json.Unmarshal(methodConfig, &cfg); err != nil { return nil, fmt.Errorf("failed to unmarshal method config: %w", err) @@ -193,7 +212,7 @@ func NewHandler(methodConfig json.RawMessage, donConfig *config.DONConfig, don g return nil, fmt.Errorf("failed to create metrics: %w", err) } - limiter, err := limits.MakeBoundLimiter(limitsFactory, cresettings.Default.VaultRequestBatchSizeLimit) + limiter, err := limits.MakeUpperBoundLimiter(limitsFactory, cresettings.Default.VaultRequestBatchSizeLimit) if err != nil { return nil, fmt.Errorf("could not create request batch size limiter: %w", err) } @@ -217,7 +236,8 @@ func NewHandler(methodConfig json.RawMessage, donConfig *config.DONConfig, don g writeMethodsEnabled: writeMethodsEnabled, activeRequests: make(map[string]*activeRequest), mu: sync.RWMutex{}, - requestAuthorizer: requestAuthorizer, + authorizer: authorizer, + jwtAuth: jwtAuth, stopCh: make(services.StopChan), metrics: metrics, aggregator: &baseAggregator{capabilitiesRegistry: capabilitiesRegistry}, @@ -229,6 +249,11 @@ func NewHandler(methodConfig json.RawMessage, donConfig *config.DONConfig, don g func (h *handler) Start(_ context.Context) error { return h.StartOnce("VaultHandler", func() error { h.lggr.Debug("starting vault handler") + if h.jwtAuth != nil { + if err := h.jwtAuth.Start(context.Background()); err != nil { + return fmt.Errorf("failed to start JWTBasedAuth: %w", err) + } + } go func() { ctx, cancel := h.stopCh.NewCtx() defer cancel() @@ -256,7 +281,12 @@ func (h *handler) Close() error { return h.StopOnce("VaultHandler", func() error { h.lggr.Debug("closing vault handler") close(h.stopCh) + var jwtAuthErr error + if h.jwtAuth != nil { + jwtAuthErr = h.jwtAuth.Close() + } return errors.Join( + jwtAuthErr, h.writeMethodsEnabled.Close(), h.MaxRequestBatchSizeLimiter.Close(), ) @@ -318,11 +348,13 @@ func (h *handler) removeExpiredRequests(ctx context.Context) { h.mu.RUnlock() for _, er := range expiredRequests { - var nodeResponses string - for nodeKey, nodeResponse := range er.responses { - nodeResponses += fmt.Sprintf("%s ---::: %v ", nodeKey, nodeResponse) + responses := er.copiedResponses() + var nodeResponses strings.Builder + for nodeKey, nodeResponse := range responses { + _, _ = fmt.Fprintf(&nodeResponses, "%s ---::: %v ", nodeKey, nodeResponse) } - err := h.sendResponse(ctx, er, h.errorResponse(er.req, api.RequestTimeoutError, errors.New("request expired without getting quorum of responses from nodes. Available responses: "+nodeResponses), []byte(nodeResponses))) + nodeResponsesStr := nodeResponses.String() + err := h.sendResponse(ctx, er, h.errorResponse(er.req, api.RequestTimeoutError, errors.New("request expired without getting quorum of responses from nodes. Available responses: "+nodeResponsesStr), []byte(nodeResponsesStr))) if err != nil { h.lggr.Errorw("error sending response to user", "requestID", er.req.ID, "error", err) } @@ -347,8 +379,7 @@ func (h *handler) HandleJSONRPCUserMessage(ctx context.Context, req jsonrpc.Requ } h.lggr.Debugw("handling vault request", "method", req.Method, "requestID", req.ID, "request", req) - switch req.Method { - case vaulttypes.MethodPublicKeyGet: + if req.Method == vaulttypes.MethodPublicKeyGet { // Public key requests don't require authorization, // Let's process this request right away. // Note we cache this value quite aggressively so don't need to worry about DoS. @@ -364,20 +395,20 @@ func (h *handler) HandleJSONRPCUserMessage(ctx context.Context, req jsonrpc.Requ } h.lggr.Debugw("returning cached public key response") return h.handlePublicKeyGetSynchronously(ctx, req, publicKeyResponseBytes, callback) - } - isAuthorized, owner, err := h.requestAuthorizer.AuthorizeRequest(ctx, req) - if !isAuthorized { - h.lggr.Errorw("request not authorized", "requestID", req.ID, "owner", owner, "reason:", err) + authResult, err := h.authorizer.AuthorizeRequest(ctx, req) + if err != nil { + h.lggr.Errorw("request not authorized", "method", req.Method, "requestID", req.ID, "hasAuth", req.Auth != "", "error", err) return errors.New("request not authorized: " + err.Error()) } + authorizedOwner := authResult.AuthorizedOwner() // Generate a unique ID for the request. - // Prefix request id with owner, to ensure uniqueness across different owners + // Prefix request id with authorizedOwner, to ensure uniqueness across different owners // We do this ourselves to ensure the ID is unique and can't be tampered with by the user. - req.ID = owner + vaulttypes.RequestIDSeparator + req.ID + req.ID = authorizedOwner + vaulttypes.RequestIDSeparator + req.ID - h.lggr.Debugw("handling authorized vault request", "method", req.Method, "requestID", req.ID, "owner", owner) + h.lggr.Debugw("handling authorized vault request", "method", req.Method, "requestID", req.ID, "authorizedOwner", authorizedOwner) ar, err := h.newActiveRequest(req, callback) if err != nil { return err @@ -450,7 +481,7 @@ func (h *handler) HandleNodeMessage(ctx context.Context, resp *jsonrpc.Response[ l.Debugw("aggregating responses, waiting for other nodes...", "error", err) return nil case err != nil: - l.Error("quorum unobtainable, returning response to user...", "error", err, "responses", maps.Values(ar.responses)) + l.Error("quorum unobtainable, returning response to user...", "error", err, "responses", maps.Values(copiedResponses)) return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.FatalError, err, nil)) } @@ -541,8 +572,8 @@ func (h *handler) handleSecretsCreate(ctx context.Context, ar *activeRequest) er } createSecretsRequest := &vaultcommon.CreateSecretsRequest{} - if err := json.Unmarshal(*ar.req.Params, &createSecretsRequest); err != nil { - return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.UserMessageParseError, err, nil)) + if unmarshalErr := json.Unmarshal(*ar.req.Params, &createSecretsRequest); unmarshalErr != nil { + return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.UserMessageParseError, unmarshalErr, nil)) } createSecretsRequest.RequestId = ar.req.ID for _, secretItem := range createSecretsRequest.EncryptedSecrets { @@ -581,8 +612,8 @@ func (h *handler) handleSecretsUpdate(ctx context.Context, ar *activeRequest) er } updateSecretsRequest := &vaultcommon.UpdateSecretsRequest{} - if err := json.Unmarshal(*ar.req.Params, updateSecretsRequest); err != nil { - return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.UserMessageParseError, err, nil)) + if unmarshalErr := json.Unmarshal(*ar.req.Params, updateSecretsRequest); unmarshalErr != nil { + return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.UserMessageParseError, unmarshalErr, nil)) } updateSecretsRequest.RequestId = ar.req.ID @@ -621,8 +652,8 @@ func (h *handler) handleSecretsDelete(ctx context.Context, ar *activeRequest) er } deleteSecretsRequest := &vaultcommon.DeleteSecretsRequest{} - if err := json.Unmarshal(*ar.req.Params, deleteSecretsRequest); err != nil { - return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.UserMessageParseError, err, nil)) + if unmarshalErr := json.Unmarshal(*ar.req.Params, deleteSecretsRequest); unmarshalErr != nil { + return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.UserMessageParseError, unmarshalErr, nil)) } deleteSecretsRequest.RequestId = ar.req.ID @@ -772,7 +803,9 @@ func (h *handler) errorResponse( err = errors.New("user message parse error: " + err.Error()) case api.NoError: case api.UnsupportedDONIdError: + case api.ConflictError: case api.HandlerError: + case api.LimitExceededError: case api.RequestTimeoutError: case api.StaleNodeResponseError: // Unused in this handler @@ -803,6 +836,8 @@ func (h *handler) sendResponse(ctx context.Context, userRequest *activeRequest, case api.NodeReponseEncodingError: case api.RequestTimeoutError: case api.HandlerError: + case api.ConflictError: + case api.LimitExceededError: h.metrics.requestInternalError.Add(ctx, 1, metric.WithAttributes( attribute.String("don_id", h.donConfig.DonId), attribute.String("error", resp.ErrorCode.String()), diff --git a/core/services/gateway/handlers/vault/handler_test.go b/core/services/gateway/handlers/vault/handler_test.go index 353e739b3b1..91a3df576ee 100644 --- a/core/services/gateway/handlers/vault/handler_test.go +++ b/core/services/gateway/handlers/vault/handler_test.go @@ -25,7 +25,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/ratelimit" "github.com/smartcontractkit/chainlink-common/pkg/settings/cresettings" "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" - vaultcapmocks "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/mocks" + vaultcap "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault" "github.com/smartcontractkit/chainlink/v2/core/capabilities/vault/vaulttypes" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" @@ -59,17 +59,26 @@ func setupHandler(t *testing.T) (handlers.Handler, *common.Callback, *mocks.DON, methodConfig, err := json.Marshal(handlerConfig) require.NoError(t, err) - requestAuthorizer := vaultcapmocks.NewRequestAuthorizer(t) - requestAuthorizer.On("AuthorizeRequest", mock.Anything, mock.Anything).Return(true, owner, nil).Maybe() clock := clockwork.NewFakeClock() limitsFactory := limits.Factory{Settings: cresettings.DefaultGetter} - handler, err := NewHandler(methodConfig, donConfig, don, nil, requestAuthorizer, lggr, clock, limitsFactory) + jwtBasedAuth, err := vaultcap.NewJWTBasedAuth(vaultcap.JWTBasedAuthConfig{}, limitsFactory, lggr, vaultcap.WithDisabledJWTBasedAuth()) + require.NoError(t, err) + authorizer := vaultcap.NewAuthorizer(&stubAllowListBasedAuth{clock: clock}, jwtBasedAuth, lggr) + handler, err := newHandlerWithAuthorizer(methodConfig, donConfig, don, nil, authorizer, lggr, clock, limitsFactory) require.NoError(t, err) handler.aggregator = &mockAggregator{} cb := common.NewCallback() return handler, cb, don, clock } +type stubAllowListBasedAuth struct { + clock clockwork.Clock +} + +func (s *stubAllowListBasedAuth) AuthorizeRequest(_ context.Context, req jsonrpc.Request[json.RawMessage]) (*vaultcap.AuthResult, error) { + return vaultcap.NewAuthResult("", owner, "digest-"+req.ID, s.clock.Now().Add(time.Minute).Unix()), nil +} + type mockAggregator struct { err error } @@ -416,7 +425,6 @@ func TestVaultHandler_HandleJSONRPCUserMessage(t *testing.T) { }) t.Run("happy path - list secret identifiers", func(t *testing.T) { - var wg sync.WaitGroup h, callback, don, _ := setupHandler(t) don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -454,28 +462,22 @@ func TestVaultHandler_HandleJSONRPCUserMessage(t *testing.T) { resultBytes, err = json.Marshal(responseData) require.NoError(t, err) - wg.Add(1) - go func() { - defer wg.Done() - resp, err2 := callback.Wait(t.Context()) - assert.NoError(t, err2) - var secretsResponse jsonrpc.Response[vaultcommon.ListSecretIdentifiersResponse] - err2 = json.Unmarshal(resp.RawResponse, &secretsResponse) - assert.NoError(t, err2) - assert.Equal(t, validJSONRequest.ID, secretsResponse.ID, "Request ID should match") - assert.True(t, proto.Equal(secretsResponse.Result, responseData), "Response data should match") - }() - err = h.HandleJSONRPCUserMessage(t.Context(), validJSONRequest, callback) require.NoError(t, err) err = h.HandleNodeMessage(t.Context(), &response, NodeOne.Address) require.NoError(t, err) - wg.Wait() + + resp, err := callback.Wait(t.Context()) + require.NoError(t, err) + var secretsResponse jsonrpc.Response[vaultcommon.ListSecretIdentifiersResponse] + err = json.Unmarshal(resp.RawResponse, &secretsResponse) + require.NoError(t, err) + assert.Equal(t, validJSONRequest.ID, secretsResponse.ID, "Request ID should match") + assert.True(t, proto.Equal(secretsResponse.Result, responseData), "Response data should match") }) t.Run("unhappy path - duplicate requestId", func(t *testing.T) { - var wg sync.WaitGroup h, callback, don, _ := setupHandler(t) don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) @@ -513,28 +515,23 @@ func TestVaultHandler_HandleJSONRPCUserMessage(t *testing.T) { resultBytes, err = json.Marshal(responseData) require.NoError(t, err) - wg.Add(1) - go func() { - defer wg.Done() - resp, err2 := callback.Wait(t.Context()) - assert.NoError(t, err2) - var secretsResponse jsonrpc.Response[vaultcommon.ListSecretIdentifiersResponse] - err2 = json.Unmarshal(resp.RawResponse, &secretsResponse) - assert.NoError(t, err2) - assert.Equal(t, validJSONRequest.ID, secretsResponse.ID, "Request ID should match") - assert.True(t, proto.Equal(secretsResponse.Result, responseData), "Response data should match") - }() - err = h.HandleJSONRPCUserMessage(t.Context(), validJSONRequest, callback) require.NoError(t, err) // send duplicate request err = h.HandleJSONRPCUserMessage(t.Context(), validJSONRequest, callback) - require.ErrorContains(t, err, "request ID already exists") + require.ErrorContains(t, err, "request was already authorized previously") err = h.HandleNodeMessage(t.Context(), &response, NodeOne.Address) require.NoError(t, err) - wg.Wait() + + resp, err := callback.Wait(t.Context()) + require.NoError(t, err) + var secretsResponse jsonrpc.Response[vaultcommon.ListSecretIdentifiersResponse] + err = json.Unmarshal(resp.RawResponse, &secretsResponse) + require.NoError(t, err) + assert.Equal(t, validJSONRequest.ID, secretsResponse.ID, "Request ID should match") + assert.True(t, proto.Equal(secretsResponse.Result, responseData), "Response data should match") }) t.Run("unhappy path - quorum unobtainable", func(t *testing.T) { diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index ca1c7f226ba..bdeb8869bd7 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -7,6 +7,7 @@ import ( stderrors "errors" "fmt" "log" + "math" "os" "path/filepath" "strconv" @@ -78,7 +79,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocr/capregconfig" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipcommit" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/ccipexec" - "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" ccipconfig "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/config" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/generic" @@ -110,26 +110,26 @@ const ( dontimeCapabilityID = "dontime@1.0.0" ) -type ErrJobSpecNoRelayer struct { +type JobSpecNoRelayerError struct { PluginName string Err error } -func (e ErrJobSpecNoRelayer) Unwrap() error { return e.Err } +func (e JobSpecNoRelayerError) Unwrap() error { return e.Err } -func (e ErrJobSpecNoRelayer) Error() string { +func (e JobSpecNoRelayerError) Error() string { return fmt.Sprintf("%s services: OCR2 job spec could not get relayer ID: %s", e.PluginName, e.Err) } -type ErrRelayNotEnabled struct { +type RelayNotEnabledError struct { PluginName string Relay string Err error } -func (e ErrRelayNotEnabled) Unwrap() error { return e.Err } +func (e RelayNotEnabledError) Unwrap() error { return e.Err } -func (e ErrRelayNotEnabled) Error() string { +func (e RelayNotEnabledError) Error() string { return fmt.Sprintf("%s services: failed to get relay %s, is it enabled? %s", e.PluginName, e.Relay, e.Err) } @@ -354,7 +354,7 @@ func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job) error { rid, err := spec.RelayID() if err != nil { - d.lggr.Errorw("DeleteJob", "err", ErrJobSpecNoRelayer{Err: err, PluginName: string(spec.PluginType)}) + d.lggr.Errorw("DeleteJob", "err", JobSpecNoRelayerError{Err: err, PluginName: string(spec.PluginType)}) return nil } // we only have clean to do for the EVM @@ -516,7 +516,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: string(spec.PluginType)} + return nil, JobSpecNoRelayerError{Err: err, PluginName: string(spec.PluginType)} } if rid.Network == relay.NetworkEVM { @@ -601,11 +601,11 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi case types.Functions: const ( _ int32 = iota - thresholdPluginId - s4PluginId + thresholdPluginID + s4PluginID ) - thresholdPluginDB := NewDB(d.ds, spec.ID, thresholdPluginId, lggr) - s4PluginDB := NewDB(d.ds, spec.ID, s4PluginId, lggr) + thresholdPluginDB := NewDB(d.ds, spec.ID, thresholdPluginID, lggr) + s4PluginDB := NewDB(d.ds, spec.ID, s4PluginID, lggr) return d.newServicesOCR2Functions(ctx, lggr, jb, bootstrapPeers, kb, ocrDB, thresholdPluginDB, s4PluginDB, lc) case types.GenericPlugin: @@ -726,8 +726,7 @@ func (d *Delegate) newServicesVaultPlugin( } srvs = append(srvs, vaultCapability) - requestAuthorizer := vaultcap.NewRequestAuthorizer(lggr, syncer) - handler, err := vaultcap.NewGatewayHandler(vaultCapability, gwconnector, requestAuthorizer, d.lggr) + handler, err := vaultcap.NewGatewayHandler(vaultCapability, gwconnector, syncer, d.lggr, limitsFactory) if err != nil { return nil, fmt.Errorf("failed to instantiate vault plugin: failed to create vault handler: %w", err) } @@ -735,12 +734,12 @@ func (d *Delegate) newServicesVaultPlugin( rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{PluginName: string(types.VaultPlugin), Err: err} + return nil, JobSpecNoRelayerError{PluginName: string(types.VaultPlugin), Err: err} } relayer, err := d.Get(rid) if err != nil { - return nil, ErrRelayNotEnabled{Err: err, Relay: spec.Relay, PluginName: string(types.VaultPlugin)} + return nil, RelayNotEnabledError{Err: err, Relay: spec.Relay, PluginName: string(types.VaultPlugin)} } provider, err := relayer.NewPluginProvider(ctx, types.RelayArgs{ @@ -940,12 +939,12 @@ func (d *Delegate) newDonTimePlugin( rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{PluginName: "dontime", Err: err} + return nil, JobSpecNoRelayerError{PluginName: "dontime", Err: err} } relayer, err := d.Get(rid) if err != nil { - return nil, ErrRelayNotEnabled{Err: err, Relay: spec.Relay, PluginName: "dontime"} + return nil, RelayNotEnabledError{Err: err, Relay: spec.Relay, PluginName: "dontime"} } provider, err := relayer.NewPluginProvider(ctx, types.RelayArgs{ @@ -1067,12 +1066,12 @@ func (d *Delegate) newServicesRing( rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{PluginName: "ring", Err: err} + return nil, JobSpecNoRelayerError{PluginName: "ring", Err: err} } relayer, err := d.Get(rid) if err != nil { - return nil, ErrRelayNotEnabled{Err: err, Relay: spec.Relay, PluginName: "ring"} + return nil, RelayNotEnabledError{Err: err, Relay: spec.Relay, PluginName: "ring"} } provider, err := relayer.NewPluginProvider(ctx, types.RelayArgs{ @@ -1259,7 +1258,7 @@ func (d *Delegate) newServicesGenericPlugin( rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{PluginName: pCfg.PluginName, Err: err} + return nil, JobSpecNoRelayerError{PluginName: pCfg.PluginName, Err: err} } relayerSet, err := generic.NewRelayerSet(d.RelayGetter, jb.ExternalJobID, jb.ID, d.isNewlyCreatedJob) @@ -1269,7 +1268,7 @@ func (d *Delegate) newServicesGenericPlugin( relayer, err := d.Get(rid) if err != nil { - return nil, ErrRelayNotEnabled{Err: err, Relay: spec.Relay, PluginName: pCfg.PluginName} + return nil, RelayNotEnabledError{Err: err, Relay: spec.Relay, PluginName: pCfg.PluginName} } provider, err := relayer.NewPluginProvider(ctx, types.RelayArgs{ @@ -1485,14 +1484,14 @@ func (d *Delegate) newServicesMercury( rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: "mercury"} + return nil, JobSpecNoRelayerError{Err: err, PluginName: "mercury"} } if rid.Network != relay.NetworkEVM { return nil, fmt.Errorf("mercury services: expected EVM relayer got %q", rid.Network) } relayer, err := d.Get(rid) if err != nil { - return nil, ErrRelayNotEnabled{Err: err, Relay: spec.Relay, PluginName: "mercury"} + return nil, RelayNotEnabledError{Err: err, Relay: spec.Relay, PluginName: "mercury"} } provider, err2 := relayer.NewPluginProvider(ctx, @@ -1591,11 +1590,11 @@ func (d *Delegate) newServicesLLO( rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: "streams"} + return nil, JobSpecNoRelayerError{Err: err, PluginName: "streams"} } relayer, err := d.Get(rid) if err != nil { - return nil, ErrRelayNotEnabled{Err: err, Relay: spec.Relay, PluginName: "streams"} + return nil, RelayNotEnabledError{Err: err, Relay: spec.Relay, PluginName: "streams"} } provider, err2 := relayer.NewLLOProvider(ctx, @@ -1735,7 +1734,7 @@ func (d *Delegate) newServicesMedian( rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: "median"} + return nil, JobSpecNoRelayerError{Err: err, PluginName: "median"} } ocrLogger := ocrcommon.NewOCRWrapper(lggr, d.cfg.OCR2().TraceLogging(), func(ctx context.Context, msg string) { @@ -1764,7 +1763,7 @@ func (d *Delegate) newServicesMedian( relayer, err := d.Get(rid) if err != nil { - return nil, ErrRelayNotEnabled{Err: err, PluginName: "median", Relay: spec.Relay} + return nil, RelayNotEnabledError{Err: err, PluginName: "median", Relay: spec.Relay} } medianServices, err2 := median.NewMedianServices(ctx, jb, d.isNewlyCreatedJob, relayer, kvStore, d.pipelineRunner, lggr, oracleArgsNoPlugin, mConfig, enhancedTelemChan, errorLog) @@ -1832,7 +1831,7 @@ func (d *Delegate) newServicesOCR2Keepers21( mc := d.cfg.Mercury().Credentials(credName) rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: "keeper2"} + return nil, JobSpecNoRelayerError{Err: err, PluginName: "keeper2"} } if rid.Network != relay.NetworkEVM { return nil, fmt.Errorf("keeper2 services: expected EVM relayer got %q", rid.Network) @@ -1841,7 +1840,7 @@ func (d *Delegate) newServicesOCR2Keepers21( transmitterID := spec.TransmitterID.String relayer, err := d.Get(rid) if err != nil { - return nil, ErrRelayNotEnabled{Err: err, Relay: spec.Relay, PluginName: "ocr2keepers"} + return nil, RelayNotEnabledError{Err: err, Relay: spec.Relay, PluginName: "ocr2keepers"} } provider, err := relayer.NewPluginProvider(ctx, @@ -1977,7 +1976,7 @@ func (d *Delegate) newServicesOCR2Keepers20( ) ([]job.ServiceCtx, error) { rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: "keepers2.0"} + return nil, JobSpecNoRelayerError{Err: err, PluginName: "keepers2.0"} } if rid.Network != relay.NetworkEVM { return nil, fmt.Errorf("keepers2.0 services: expected EVM relayer got %q", rid.Network) @@ -2111,7 +2110,7 @@ func (d *Delegate) newServicesOCR2Functions( rid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: "functions"} + return nil, JobSpecNoRelayerError{Err: err, PluginName: "functions"} } if rid.Network != relay.NetworkEVM { return nil, fmt.Errorf("functions services: expected EVM relayer got %q", rid.Network) @@ -2264,7 +2263,7 @@ func (d *Delegate) newServicesCCIPCommit(ctx context.Context, lggr logger.Sugare } dstRid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: string(spec.PluginType)} + return nil, JobSpecNoRelayerError{Err: err, PluginName: string(spec.PluginType)} } logError := func(msg string) { @@ -2332,8 +2331,8 @@ func (d *Delegate) newServicesCCIPCommit(ctx context.Context, lggr logger.Sugare ) } -func newCCIPCommitPluginBytes(isSourceProvider bool, sourceStartBlock uint64, destStartBlock uint64) config.CommitPluginConfig { - return config.CommitPluginConfig{ +func newCCIPCommitPluginBytes(isSourceProvider bool, sourceStartBlock uint64, destStartBlock uint64) ccipconfig.CommitPluginConfig { + return ccipconfig.CommitPluginConfig{ IsSourceProvider: isSourceProvider, SourceStartBlock: sourceStartBlock, DestStartBlock: destStartBlock, @@ -2348,7 +2347,7 @@ func (d *Delegate) ccipCommitGetDstProvider(ctx context.Context, jb job.Job, plu dstRid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: string(spec.PluginType)} + return nil, JobSpecNoRelayerError{Err: err, PluginName: string(spec.PluginType)} } // Write PluginConfig bytes to send source/dest relayer provider + info outside of top level rargs/pargs over the wire @@ -2448,7 +2447,7 @@ func (d *Delegate) newServicesCCIPExecution(ctx context.Context, lggr logger.Sug dstRid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: string(spec.PluginType)} + return nil, JobSpecNoRelayerError{Err: err, PluginName: string(spec.PluginType)} } logError := func(msg string) { @@ -2498,6 +2497,10 @@ func (d *Delegate) newServicesCCIPExecution(ctx context.Context, lggr logger.Sug MetricsRegisterer: prometheus.WrapRegistererWith(map[string]string{"job_name": jb.Name.ValueOrZero()}, prometheus.DefaultRegisterer), } + if srcChainID > math.MaxInt64 { + return nil, fmt.Errorf("source chain ID %d overflows int64", srcChainID) + } + return ccipexec.NewExecServices(ctx, lggr, jb, srcProvider, dstProvider, int64(srcChainID), dstChainID, d.isNewlyCreatedJob, oracleArgsNoPlugin2, logError) } @@ -2509,7 +2512,7 @@ func (d *Delegate) ccipExecGetDstProvider(ctx context.Context, jb job.Job, plugi dstRid, err := spec.RelayID() if err != nil { - return nil, ErrJobSpecNoRelayer{Err: err, PluginName: string(spec.PluginType)} + return nil, JobSpecNoRelayerError{Err: err, PluginName: string(spec.PluginType)} } // PROVIDER BASED ARG CONSTRUCTION @@ -2596,8 +2599,8 @@ func (d *Delegate) ccipExecGetSrcProvider(ctx context.Context, jb job.Job, plugi return } -func newExecPluginConfig(isSourceProvider bool, srcStartBlock uint64, dstStartBlock uint64, usdcConfig ccipconfig.USDCConfig, lbtcConfigs []ccipconfig.LBTCConfig, jobID string) config.ExecPluginConfig { - return config.ExecPluginConfig{ +func newExecPluginConfig(isSourceProvider bool, srcStartBlock uint64, dstStartBlock uint64, usdcConfig ccipconfig.USDCConfig, lbtcConfigs []ccipconfig.LBTCConfig, jobID string) ccipconfig.ExecPluginConfig { + return ccipconfig.ExecPluginConfig{ IsSourceProvider: isSourceProvider, SourceStartBlock: srcStartBlock, DestStartBlock: dstStartBlock,