diff --git a/.mockery.yaml b/.mockery.yaml index 212c96cd32..1d438df01c 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -45,6 +45,9 @@ packages: ExecutionHelper: config: mockname: "Mock{{.InterfaceName}}" + ExecutionHelperWithRawSecrets: + config: + mockname: "Mock{{.InterfaceName}}" github.com/smartcontractkit/chainlink-common/pkg/custmsg: interfaces: MessageEmitter: diff --git a/go.mod b/go.mod index 7a36bf6c3b..cb3990e1e0 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.100 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.11-0.20260528204832-58c7145c53f8 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260618082634-432eb85805e7 + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260622152157-c8e129347b8b github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index d94a28f55c..f7e3962cf2 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.11-0.202605282 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.11-0.20260528204832-58c7145c53f8/go.mod h1:HmUyH2oD9m+GRpKq7q3vuRnm1F2Uczf/Nd1v3ipMSK8= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260618082634-432eb85805e7 h1:iRFmfMFQtcnhGDjCuARQG4MPbcmbbJDDw7MUH3GcGy8= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260618082634-432eb85805e7/go.mod h1:vTFHTCbLui4Vn8fTmAadfE3rdnvfrDwOmMujmW857D0= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260622152157-c8e129347b8b h1:VDgJWDipihV9f7M5+d21d1RzSsg5rEv+iI12oN1VQbo= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260622152157-c8e129347b8b/go.mod h1:vTFHTCbLui4Vn8fTmAadfE3rdnvfrDwOmMujmW857D0= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go index d017e47f94..570f6c486e 100644 --- a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go +++ b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go @@ -114,8 +114,11 @@ type WorkflowExecution struct { // the other). Consumers that want the typed message read this; legacy // consumers continue to unmarshal execute_request. SdkExecuteRequest *sdk.ExecuteRequest `protobuf:"bytes,9,opt,name=sdk_execute_request,json=sdkExecuteRequest,proto3" json:"sdk_execute_request,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // restrictions on the capabilities and the secrets.bool + // This is sent to avoid overhead when a TEE is not compromised, the DON will verify the restrictions on its end as well. + Restrictions *sdk.Restrictions `protobuf:"bytes,10,opt,name=restrictions,proto3" json:"restrictions,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *WorkflowExecution) Reset() { @@ -211,6 +214,13 @@ func (x *WorkflowExecution) GetSdkExecuteRequest() *sdk.ExecuteRequest { return nil } +func (x *WorkflowExecution) GetRestrictions() *sdk.Restrictions { + if x != nil { + return x.Restrictions + } + return nil +} + // ConfidentialWorkflowRequest is the input provided to the confidential workflows capability. // It combines a WorkflowExecution with secrets from VaultDON. type ConfidentialWorkflowRequest struct { @@ -390,7 +400,7 @@ const file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDes "\x03key\x18\x01 \x01(\tR\x03key\x12!\n" + "\tnamespace\x18\x02 \x01(\tH\x00R\tnamespace\x88\x01\x01B\f\n" + "\n" + - "_namespace\"\xf9\x02\n" + + "_namespace\"\xb8\x03\n" + "\x11WorkflowExecution\x12\x1f\n" + "\vworkflow_id\x18\x01 \x01(\tR\n" + "workflowId\x12\x1d\n" + @@ -403,7 +413,9 @@ const file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDes "\fexecution_id\x18\x06 \x01(\tR\vexecutionId\x12\x15\n" + "\x06org_id\x18\a \x01(\tR\x05orgId\x12=\n" + "\frequirements\x18\b \x01(\v2\x19.sdk.v1alpha.RequirementsR\frequirements\x12K\n" + - "\x13sdk_execute_request\x18\t \x01(\v2\x1b.sdk.v1alpha.ExecuteRequestR\x11sdkExecuteRequest\"\x95\x02\n" + + "\x13sdk_execute_request\x18\t \x01(\v2\x1b.sdk.v1alpha.ExecuteRequestR\x11sdkExecuteRequest\x12=\n" + + "\frestrictions\x18\n" + + " \x01(\v2\x19.sdk.v1alpha.RestrictionsR\frestrictions\"\x95\x02\n" + "\x1bConfidentialWorkflowRequest\x12o\n" + "\x11vault_don_secrets\x18\x01 \x03(\v2C.capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifierR\x0fvaultDonSecrets\x12b\n" + "\texecution\x18\x02 \x01(\v2D.capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecutionR\texecution\x12!\n" + @@ -439,26 +451,28 @@ var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_goTypes (*ProvidedTeesResponse)(nil), // 4: capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse (*sdk.Requirements)(nil), // 5: sdk.v1alpha.Requirements (*sdk.ExecuteRequest)(nil), // 6: sdk.v1alpha.ExecuteRequest - (*sdk.ExecutionResult)(nil), // 7: sdk.v1alpha.ExecutionResult - (*sdk.TeeTypeAndRegions)(nil), // 8: sdk.v1alpha.TeeTypeAndRegions - (*emptypb.Empty)(nil), // 9: google.protobuf.Empty + (*sdk.Restrictions)(nil), // 7: sdk.v1alpha.Restrictions + (*sdk.ExecutionResult)(nil), // 8: sdk.v1alpha.ExecutionResult + (*sdk.TeeTypeAndRegions)(nil), // 9: sdk.v1alpha.TeeTypeAndRegions + (*emptypb.Empty)(nil), // 10: google.protobuf.Empty } var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_depIdxs = []int32{ - 5, // 0: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.requirements:type_name -> sdk.v1alpha.Requirements - 6, // 1: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.sdk_execute_request:type_name -> sdk.v1alpha.ExecuteRequest - 0, // 2: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.vault_don_secrets:type_name -> capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier - 1, // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.execution:type_name -> capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution - 7, // 4: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse.sdk_execution_result:type_name -> sdk.v1alpha.ExecutionResult - 8, // 5: capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse.tee:type_name -> sdk.v1alpha.TeeTypeAndRegions - 2, // 6: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:input_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest - 9, // 7: capabilities.compute.confidentialworkflow.v1alpha.Client.ProvidedTees:input_type -> google.protobuf.Empty - 3, // 8: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse - 4, // 9: capabilities.compute.confidentialworkflow.v1alpha.Client.ProvidedTees:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse - 8, // [8:10] is the sub-list for method output_type - 6, // [6:8] is the sub-list for method input_type - 6, // [6:6] is the sub-list for extension type_name - 6, // [6:6] is the sub-list for extension extendee - 0, // [0:6] is the sub-list for field type_name + 5, // 0: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.requirements:type_name -> sdk.v1alpha.Requirements + 6, // 1: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.sdk_execute_request:type_name -> sdk.v1alpha.ExecuteRequest + 7, // 2: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.restrictions:type_name -> sdk.v1alpha.Restrictions + 0, // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.vault_don_secrets:type_name -> capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier + 1, // 4: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.execution:type_name -> capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution + 8, // 5: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse.sdk_execution_result:type_name -> sdk.v1alpha.ExecutionResult + 9, // 6: capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse.tee:type_name -> sdk.v1alpha.TeeTypeAndRegions + 2, // 7: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:input_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest + 10, // 8: capabilities.compute.confidentialworkflow.v1alpha.Client.ProvidedTees:input_type -> google.protobuf.Empty + 3, // 9: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse + 4, // 10: capabilities.compute.confidentialworkflow.v1alpha.Client.ProvidedTees:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse + 9, // [9:11] is the sub-list for method output_type + 7, // [7:9] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name } func init() { file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_init() } diff --git a/pkg/workflows/host/encryption_key_fetcher.go b/pkg/workflows/host/encryption_key_fetcher.go new file mode 100644 index 0000000000..ee2e494cbc --- /dev/null +++ b/pkg/workflows/host/encryption_key_fetcher.go @@ -0,0 +1,7 @@ +package host + +import "context" + +type EncryptionKeyFetcher interface { + GetEncryptionKeys(ctx context.Context) ([]string, error) +} diff --git a/pkg/workflows/host/execution_restrictions.go b/pkg/workflows/host/execution_restrictions.go new file mode 100644 index 0000000000..04c36dc02f --- /dev/null +++ b/pkg/workflows/host/execution_restrictions.go @@ -0,0 +1,285 @@ +package host + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/vault" + caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/actions/confidentialhttp" + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type methodKey struct { + id string + method string +} + +type secretKey struct { + id string + namespace string +} + +type prefixRestriction struct { + prefix string + namespace string + maxCalls int32 +} + +// TODO refactor to instead be injected INTO the hepler +// this would allow raw secrets to call the same restriction check +// don't make it an execution helper itself +type executionRestrictions struct { + ExecutionHelper + mu sync.Mutex + + hasCaps bool + capType sdk.CapabilityRestrictionType + maxTotalCalls int32 + methods map[methodKey]int32 + + hasSecrets bool + maxSecrets int32 + exactSecrets map[secretKey]bool + prefixSecrets []prefixRestriction +} + +type executionRestrictionsWithRawSecrets struct { + *executionRestrictions +} + +func (e *executionRestrictionsWithRawSecrets) GetOwner() string { + return e.ExecutionHelper.(ExecutionHelperWithRawSecrets).GetOwner() +} + +func (e *executionRestrictionsWithRawSecrets) GetRawSecrets(ctx context.Context, request *sdk.GetSecretsRequest, fetcher EncryptionKeyFetcher) ([]*vault.SecretResponse, error) { + rawSecretsHelper := e.ExecutionHelper.(ExecutionHelperWithRawSecrets) + owner := rawSecretsHelper.GetOwner() + + e.mu.Lock() + var allowed []*sdk.SecretRequest + var responses []*vault.SecretResponse + for _, req := range request.Requests { + if e.reserveSecret(req) { + allowed = append(allowed, req) + } else { + responses = append(responses, &vault.SecretResponse{ + Id: &vault.SecretIdentifier{ + Key: req.Id, + Namespace: req.Namespace, + Owner: owner, + }, + Result: &vault.SecretResponse_Error{ + Error: fmt.Sprintf("secret %q in namespace %q denied by user pre-hook restrictions", req.Id, req.Namespace), + }, + }) + } + } + e.mu.Unlock() + + if len(allowed) == 0 { + return responses, nil + } + + inner, err := rawSecretsHelper.GetRawSecrets(ctx, &sdk.GetSecretsRequest{Requests: allowed}, fetcher) + if err != nil { + return nil, err + } + return append(responses, inner...), nil +} + +var _ ExecutionHelperWithRawSecrets = (*executionRestrictionsWithRawSecrets)(nil) + +// NewRestrictedExecutionHelper wraps ExecutionHelper with restriction enforcement derived from r. +// If inner implements ExecutionHelperWithRawSecrets, the returned value will as well. +// If r is nil, ExecutionHelper is returned unchanged. +func NewRestrictedExecutionHelper(inner ExecutionHelper, r *sdk.Restrictions) ExecutionHelper { + if r == nil { + return inner + } + + er := &executionRestrictions{ExecutionHelper: inner} + + if caps := r.Capabilities; caps != nil { + er.hasCaps = true + er.capType = caps.Type + er.maxTotalCalls = caps.MaxTotalCalls + er.methods = make(map[methodKey]int32) + for _, cr := range caps.Restrictions { + m, ok := cr.Restriction.(*sdk.CapabilityRestriction_Method) + if !ok || m.Method == nil { + continue + } + mr := m.Method + key := methodKey{id: mr.Id, method: mr.Method} + existing, found := er.methods[key] + if !found || (mr.MaxCalls >= 0 && (existing < 0 || mr.MaxCalls < existing)) { + er.methods[key] = mr.MaxCalls + } + } + } + + if secrets := r.Secrets; secrets != nil { + er.hasSecrets = true + er.maxSecrets = secrets.MaxSecrets + er.exactSecrets = make(map[secretKey]bool) + for _, sr := range secrets.Restrictions { + switch v := sr.Restriction.(type) { + case *sdk.SecretRestriction_ExactSecret: + s := v.ExactSecret + er.exactSecrets[secretKey{id: s.Id, namespace: s.Namespace}] = true + case *sdk.SecretRestriction_PrefixedSecret: + p := v.PrefixedSecret + er.prefixSecrets = append(er.prefixSecrets, prefixRestriction{ + prefix: p.Prefix, + namespace: p.Namespace, + maxCalls: p.MaxSecrets, + }) + } + } + } + + if _, ok := inner.(ExecutionHelperWithRawSecrets); ok { + return &executionRestrictionsWithRawSecrets{executionRestrictions: er} + } + return er +} + +var confHttpRequest = (&confidentialhttp.ConfidentialHTTPRequest{}).ProtoReflect().Descriptor().FullName() + +func (e *executionRestrictions) reserveCapabilityCall(request *sdk.CapabilityRequest) bool { + if e == nil || !e.hasCaps { + return true + } + + if e.maxTotalCalls == 0 { + return false + } + + if request.Payload != nil { + switch request.Payload.MessageName() { + case confHttpRequest: + conf := &confidentialhttp.ConfidentialHTTPRequest{} + if err := request.Payload.UnmarshalTo(conf); err != nil { + return false + } + + secrets := conf.GetVaultDonSecrets() + for _, secret := range secrets { + if !e.reserveSecret(&sdk.SecretRequest{ + Id: secret.Key, + Namespace: secret.Namespace, + }) { + return false + } + } + } + } + + key := methodKey{id: request.Id, method: request.Method} + remaining, found := e.methods[key] + + if !found { + if e.capType == sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED { + return false + } + if e.maxTotalCalls > 0 { + e.maxTotalCalls-- + } + return true + } + + if remaining == 0 { + return false + } + + if remaining > 0 { + e.methods[key] = remaining - 1 + } + if e.maxTotalCalls > 0 { + e.maxTotalCalls-- + } + return true +} + +func (e *executionRestrictions) reserveSecret(request *sdk.SecretRequest) bool { + if !e.hasSecrets { + return true + } + + if e.maxSecrets == 0 { + return false + } + + key := secretKey{id: request.Id, namespace: request.Namespace} + exactMatch := e.exactSecrets[key] + + var matchedPrefixes []*prefixRestriction + for i := range e.prefixSecrets { + p := &e.prefixSecrets[i] + if p.namespace == request.Namespace && strings.HasPrefix(request.Id, p.prefix) { + if p.maxCalls == 0 { + return false + } + matchedPrefixes = append(matchedPrefixes, p) + } + } + + if !exactMatch && len(matchedPrefixes) == 0 { + return false + } + + for _, p := range matchedPrefixes { + if p.maxCalls > 0 { + p.maxCalls-- + } + } + if e.maxSecrets > 0 { + e.maxSecrets-- + } + return true +} + +func (e *executionRestrictions) CallCapability(ctx context.Context, request *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error) { + e.mu.Lock() + allowed := e.reserveCapabilityCall(request) + e.mu.Unlock() + if !allowed { + return nil, caperrors.NewLimitExceededError("call denied by user pre-hook restrictions", fmt.Errorf("%s %s", request.Id, request.Method)) + } + return e.ExecutionHelper.CallCapability(ctx, request) +} + +func (e *executionRestrictions) GetSecrets(ctx context.Context, request *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error) { + e.mu.Lock() + var allowed []*sdk.SecretRequest + var responses []*sdk.SecretResponse + for _, req := range request.Requests { + if e.reserveSecret(req) { + allowed = append(allowed, req) + } else { + responses = append(responses, &sdk.SecretResponse{ + Response: &sdk.SecretResponse_Error{ + Error: &sdk.SecretError{ + Id: req.Id, + Namespace: req.Namespace, + Error: fmt.Sprintf("secret %q in namespace %q denied by user pre-hook restrictions", req.Id, req.Namespace), + }, + }, + }) + } + } + e.mu.Unlock() + + if len(allowed) == 0 { + return responses, nil + } + + inner, err := e.ExecutionHelper.GetSecrets(ctx, &sdk.GetSecretsRequest{Requests: allowed}) + if err != nil { + return nil, err + } + return append(responses, inner...), nil +} diff --git a/pkg/workflows/host/execution_restrictions_test.go b/pkg/workflows/host/execution_restrictions_test.go new file mode 100644 index 0000000000..63a303db50 --- /dev/null +++ b/pkg/workflows/host/execution_restrictions_test.go @@ -0,0 +1,962 @@ +package host_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "google.golang.org/protobuf/types/known/anypb" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/vault" + caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/actions/confidentialhttp" + "github.com/smartcontractkit/chainlink-common/pkg/utils/matches" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +// stubEncryptionKeyFetcher is a no-op EncryptionKeyFetcher used to verify the fetcher is +// delegated through to the inner helper. +type stubEncryptionKeyFetcher struct{} + +func (stubEncryptionKeyFetcher) GetEncryptionKeys(context.Context) ([]string, error) { + return nil, nil +} + +// capabilitySequence drives CallCapability through the public API, in order, against a +// single restricted helper. It returns, for each request, whether the call was allowed +// through to the inner helper (true) or denied by the restrictions (false). +func capabilitySequence(t *testing.T, r *sdk.Restrictions, reqs ...*sdk.CapabilityRequest) []bool { + t.Helper() + inner := mocks.NewMockExecutionHelper(t) + inner.EXPECT().CallCapability(matches.AnyContext, mock.Anything). + Return(&sdk.CapabilityResponse{}, nil).Maybe() + h := host.NewRestrictedExecutionHelper(inner, r) + + allowed := make([]bool, len(reqs)) + for i, req := range reqs { + _, err := h.CallCapability(t.Context(), req) + allowed[i] = err == nil + } + return allowed +} + +// secretSequence drives GetSecrets (one request per call) through the public API, in +// order, against a single restricted helper. It returns, for each request, whether the +// secret was allowed through to the inner helper (true) or denied by the restrictions +// (false). +func secretSequence(t *testing.T, r *sdk.Restrictions, reqs ...*sdk.SecretRequest) []bool { + t.Helper() + inner := mocks.NewMockExecutionHelper(t) + inner.EXPECT().GetSecrets(matches.AnyContext, mock.Anything). + Return([]*sdk.SecretResponse{{}}, nil).Maybe() + h := host.NewRestrictedExecutionHelper(inner, r) + + allowed := make([]bool, len(reqs)) + for i, req := range reqs { + resp, err := h.GetSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{req}, + }) + require.NoError(t, err) + require.Len(t, resp, 1) + // A denied secret is short-circuited into an error response; an allowed one is + // forwarded to the inner helper which returns a non-error response. + allowed[i] = resp[0].GetError() == nil + } + return allowed +} + +func TestRequirementSelectingModule_CallCapWithRestrictions(t *testing.T) { + restrictions := &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "allowed@1.0.0", Method: "Foo", MaxCalls: 5}, + }}, + }, + }, + } + + t.Run("denied call returns a limit-exceeded error without calling inner", func(t *testing.T) { + inner := mocks.NewMockExecutionHelper(t) // no expectations: inner must not be called + h := host.NewRestrictedExecutionHelper(inner, restrictions) + _, err := h.CallCapability(t.Context(), &sdk.CapabilityRequest{Id: "blocked@1.0.0", Method: "Bar"}) + var capErr caperrors.Error + require.True(t, errors.As(err, &capErr)) + assert.Contains(t, capErr.Error(), "denied by user pre-hook restrictions") + assert.Equal(t, caperrors.LimitExceeded, capErr.Code()) + }) + + t.Run("allowed call reaches inner and returns its response", func(t *testing.T) { + inner := mocks.NewMockExecutionHelper(t) + want := &sdk.CapabilityResponse{} + inner.EXPECT().CallCapability(matches.AnyContext, mock.Anything).Return(want, nil) + h := host.NewRestrictedExecutionHelper(inner, restrictions) + got, err := h.CallCapability(t.Context(), &sdk.CapabilityRequest{Id: "allowed@1.0.0", Method: "Foo"}) + require.NoError(t, err) + assert.Same(t, want, got) + }) + + t.Run("no restrictions allows everything", func(t *testing.T) { + got := capabilitySequence(t, nil, &sdk.CapabilityRequest{Id: "anything@1.0.0", Method: "Whatever"}) + assert.Equal(t, []bool{true}, got) + }) + + t.Run("allows when no capabilities restrictions set", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{}, &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"}) + assert.Equal(t, []bool{true}, got) + }) + + t.Run("closed denies unmatched capability", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 5}, + }}, + }, + }, + }, &sdk.CapabilityRequest{Id: "other-cap@1.0.0", Method: "Bar"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("closed allows matched capability until method limit is reached", func(t *testing.T) { + req := &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"} + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 2}, + }}, + }, + }, + }, req, req, req) + assert.Equal(t, []bool{true, true, false}, got) + }) + + t.Run("denies when max total calls is zero", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 0, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 5}, + }}, + }, + }, + }, &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("open allows unmatched capability", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_OPEN, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 2}, + }}, + }, + }, + }, &sdk.CapabilityRequest{Id: "other-cap@1.0.0", Method: "Bar"}) + assert.Equal(t, []bool{true}, got) + }) + + t.Run("denies when matched method has zero calls remaining", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_OPEN, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 0}, + }}, + }, + }, + }, &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("matches by both id and method", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 5}, + }}, + }, + }, + }, + &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Bar"}, + &sdk.CapabilityRequest{Id: "cap@2.0.0", Method: "Foo"}, + ) + assert.Equal(t, []bool{false, false}, got) + }) + + t.Run("multiple different methods match independently", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 1}, + }}, + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Bar", MaxCalls: 1}, + }}, + }, + }, + }, + &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"}, + &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"}, + &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Bar"}, + &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Bar"}, + ) + assert.Equal(t, []bool{true, false, true, false}, got) + }) + + t.Run("total calls limit reached before method limit", func(t *testing.T) { + req := &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"} + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 2, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 100}, + }}, + }, + }, + }, req, req, req) + assert.Equal(t, []bool{true, true, false}, got) + }) + + t.Run("negative max total calls means unlimited (method limit still applies)", func(t *testing.T) { + req := &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"} + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: -1, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 2}, + }}, + }, + }, + }, req, req, req) + assert.Equal(t, []bool{true, true, false}, got) + }) + + t.Run("negative max calls on method means unlimited (total limit still applies)", func(t *testing.T) { + req := &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"} + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 3, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: -1}, + }}, + }, + }, + }, req, req, req, req) + assert.Equal(t, []bool{true, true, true, false}, got) + }) + + t.Run("duplicate restrictions keep smallest non-negative value", func(t *testing.T) { + req := &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"} + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 5}, + }}, + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 2}, + }}, + }, + }, + }, req, req, req) + assert.Equal(t, []bool{true, true, false}, got) + }) + + t.Run("duplicate restrictions non-negative overrides negative", func(t *testing.T) { + req := &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"} + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: -1}, + }}, + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 3}, + }}, + }, + }, + }, req, req, req, req) + assert.Equal(t, []bool{true, true, true, false}, got) + }) + + t.Run("duplicate restrictions zero overrides positive", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 5}, + }}, + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "cap@1.0.0", Method: "Foo", MaxCalls: 0}, + }}, + }, + }, + }, &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("closed with no methods denies all", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: -1, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + }, + }, &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("open with no methods respects max total calls", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 2, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_OPEN, + }, + }, + &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"}, + &sdk.CapabilityRequest{Id: "cap@2.0.0", Method: "Bar"}, + &sdk.CapabilityRequest{Id: "cap@3.0.0", Method: "Baz"}, + ) + assert.Equal(t, []bool{true, true, false}, got) + }) + + t.Run("open with zero max total calls denies all", func(t *testing.T) { + got := capabilitySequence(t, &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 0, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_OPEN, + }, + }, &sdk.CapabilityRequest{Id: "cap@1.0.0", Method: "Foo"}) + assert.Equal(t, []bool{false}, got) + }) +} + +// confidentialHTTPRequest builds a CapabilityRequest whose payload is a +// ConfidentialHTTPRequest referencing the given vault DON secrets. The payload's +// type URL is what drives the secret-reservation branch in reserveCapabilityCall. +func confidentialHTTPRequest(t *testing.T, id, method string, secrets ...*confidentialhttp.SecretIdentifier) *sdk.CapabilityRequest { + t.Helper() + payload, err := anypb.New(&confidentialhttp.ConfidentialHTTPRequest{VaultDonSecrets: secrets}) + require.NoError(t, err) + return &sdk.CapabilityRequest{Id: id, Method: method, Payload: payload} +} + +func TestRequirementSelectingModule_ConfidentialHTTPWithRestrictions(t *testing.T) { + // Restrictions that allow the confidential HTTP capability method and a single + // exact secret. A confidential HTTP call only succeeds if both the capability + // method and every vault DON secret it references are permitted. + restrictions := func() *sdk.Restrictions { + return &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 10, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_CLOSED, + Restrictions: []*sdk.CapabilityRestriction{ + {Restriction: &sdk.CapabilityRestriction_Method{ + Method: &sdk.MethodRestriction{Id: "confhttp@1.0.0", Method: "Call", MaxCalls: 5}, + }}, + }, + }, + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "allowed-secret", Namespace: "ns"}, + }}, + }, + }, + } + } + + t.Run("allowed confidential http call reaches inner", func(t *testing.T) { + inner := mocks.NewMockExecutionHelper(t) + want := &sdk.CapabilityResponse{} + inner.EXPECT().CallCapability(matches.AnyContext, mock.Anything).Return(want, nil) + h := host.NewRestrictedExecutionHelper(inner, restrictions()) + + req := confidentialHTTPRequest(t, "confhttp@1.0.0", "Call", + &confidentialhttp.SecretIdentifier{Key: "allowed-secret", Namespace: "ns"}) + got, err := h.CallCapability(t.Context(), req) + require.NoError(t, err) + assert.Same(t, want, got) + }) + + t.Run("nil payload is not treated as confidential http and reaches inner", func(t *testing.T) { + // A capability call sharing the confidential-http method id but carrying no + // payload must skip the vault-secret reservation branch entirely (guarded by + // request.Payload != nil) and fall through to the normal method check. + inner := mocks.NewMockExecutionHelper(t) + want := &sdk.CapabilityResponse{} + inner.EXPECT().CallCapability(matches.AnyContext, mock.Anything).Return(want, nil) + h := host.NewRestrictedExecutionHelper(inner, restrictions()) + + got, err := h.CallCapability(t.Context(), &sdk.CapabilityRequest{Id: "confhttp@1.0.0", Method: "Call"}) + require.NoError(t, err) + assert.Same(t, want, got) + }) + + t.Run("disallowed confidential http call is denied without calling inner", func(t *testing.T) { + inner := mocks.NewMockExecutionHelper(t) // no expectations: inner must not be called + h := host.NewRestrictedExecutionHelper(inner, restrictions()) + + req := confidentialHTTPRequest(t, "confhttp@1.0.0", "Call", + &confidentialhttp.SecretIdentifier{Key: "blocked-secret", Namespace: "ns"}) + _, err := h.CallCapability(t.Context(), req) + var capErr caperrors.Error + require.True(t, errors.As(err, &capErr)) + assert.Contains(t, capErr.Error(), "denied by user pre-hook restrictions") + assert.Equal(t, caperrors.LimitExceeded, capErr.Code()) + }) +} + +func TestRequirementSelectingModule_GetSecretsWithRestrictions(t *testing.T) { + restrictions := &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "allowed-secret", Namespace: "ns"}, + }}, + }, + }, + } + + t.Run("blocked secret returns error response without calling inner", func(t *testing.T) { + inner := mocks.NewMockExecutionHelper(t) // no expectations: inner must not be called + h := host.NewRestrictedExecutionHelper(inner, restrictions) + resp, err := h.GetSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{{Id: "blocked-secret", Namespace: "ns"}}, + }) + require.NoError(t, err) + require.Len(t, resp, 1) + errResp := resp[0].GetError() + require.NotNil(t, errResp) + assert.Contains(t, errResp.Error, "denied by user pre-hook restrictions") + }) + + t.Run("allows permitted secret", func(t *testing.T) { + inner := mocks.NewMockExecutionHelper(t) + inner.EXPECT().GetSecrets(matches.AnyContext, mock.Anything).Return([]*sdk.SecretResponse{}, nil) + h := host.NewRestrictedExecutionHelper(inner, restrictions) + _, err := h.GetSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{{Id: "allowed-secret", Namespace: "ns"}}, + }) + require.NoError(t, err) + }) + + t.Run("mixed batch: blocked gets error response, allowed goes to inner", func(t *testing.T) { + inner := mocks.NewMockExecutionHelper(t) + inner.EXPECT().GetSecrets(matches.AnyContext, mock.MatchedBy(func(r *sdk.GetSecretsRequest) bool { + return len(r.Requests) == 1 && r.Requests[0].Id == "allowed-secret" + })).Return([]*sdk.SecretResponse{{}}, nil) + h := host.NewRestrictedExecutionHelper(inner, restrictions) + resp, err := h.GetSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{ + {Id: "allowed-secret", Namespace: "ns"}, + {Id: "blocked-secret", Namespace: "ns"}, + }, + }) + require.NoError(t, err) + require.Len(t, resp, 2) + }) + + t.Run("allows when nil restrictions", func(t *testing.T) { + got := secretSequence(t, nil, &sdk.SecretRequest{Id: "my-secret", Namespace: "ns"}) + assert.Equal(t, []bool{true}, got) + }) + + t.Run("allows when no secrets restrictions set", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{}, &sdk.SecretRequest{Id: "my-secret", Namespace: "ns"}) + assert.Equal(t, []bool{true}, got) + }) + + t.Run("denies when max secrets is zero", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 0, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "my-secret", Namespace: "ns"}, + }}, + }, + }, + }, &sdk.SecretRequest{Id: "my-secret", Namespace: "ns"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("exact match allows until max secrets is reached", func(t *testing.T) { + req := &sdk.SecretRequest{Id: "db-password", Namespace: "infra"} + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 2, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "db-password", Namespace: "infra"}, + }}, + }, + }, + }, req, req, req) + assert.Equal(t, []bool{true, true, false}, got) + }) + + t.Run("exact match requires both id and namespace", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "db-password", Namespace: "infra"}, + }}, + }, + }, + }, + &sdk.SecretRequest{Id: "db-password", Namespace: "other"}, + &sdk.SecretRequest{Id: "other", Namespace: "infra"}, + &sdk.SecretRequest{Id: "db-password", Namespace: "infra"}, + ) + assert.Equal(t, []bool{false, false, true}, got) + }) + + t.Run("prefix match allows until prefix limit is reached", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "db-", Namespace: "infra", MaxSecrets: 2, + }, + }}, + }, + }, + }, + &sdk.SecretRequest{Id: "db-password", Namespace: "infra"}, + &sdk.SecretRequest{Id: "db-host", Namespace: "infra"}, + &sdk.SecretRequest{Id: "db-port", Namespace: "infra"}, + ) + assert.Equal(t, []bool{true, true, false}, got) + }) + + t.Run("prefix match requires namespace match", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "db-", Namespace: "infra", MaxSecrets: 5, + }, + }}, + }, + }, + }, &sdk.SecretRequest{Id: "db-password", Namespace: "other"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("prefix match denied when global max secrets hits zero", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 1, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "db-", Namespace: "infra", MaxSecrets: 5, + }, + }}, + }, + }, + }, + &sdk.SecretRequest{Id: "db-password", Namespace: "infra"}, + &sdk.SecretRequest{Id: "db-host", Namespace: "infra"}, + ) + assert.Equal(t, []bool{true, false}, got) + }) + + t.Run("denies unmatched secret", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "db-password", Namespace: "infra"}, + }}, + }, + }, + }, &sdk.SecretRequest{Id: "api-key", Namespace: "external"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("multiple restrictions match independently", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "db-password", Namespace: "infra"}, + }}, + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "api-", Namespace: "external", MaxSecrets: 5, + }, + }}, + }, + }, + }, + &sdk.SecretRequest{Id: "db-password", Namespace: "infra"}, + &sdk.SecretRequest{Id: "api-key", Namespace: "external"}, + &sdk.SecretRequest{Id: "api-key", Namespace: "infra"}, + &sdk.SecretRequest{Id: "other", Namespace: "external"}, + ) + assert.Equal(t, []bool{true, true, false, false}, got) + }) + + t.Run("global max secrets reached before individual limit", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 1, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "secret-a", Namespace: "ns"}, + }}, + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "secret-b", Namespace: "ns"}, + }}, + }, + }, + }, + &sdk.SecretRequest{Id: "secret-a", Namespace: "ns"}, + &sdk.SecretRequest{Id: "secret-b", Namespace: "ns"}, + ) + assert.Equal(t, []bool{true, false}, got) + }) + + t.Run("negative max secrets means unlimited", func(t *testing.T) { + reqs := make([]*sdk.SecretRequest, 100) + for i := range reqs { + reqs[i] = &sdk.SecretRequest{Id: "db-password", Namespace: "infra"} + } + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: -1, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "db-password", Namespace: "infra"}, + }}, + }, + }, + }, reqs...) + for i, allowed := range got { + assert.Truef(t, allowed, "call %d should be allowed", i) + } + }) + + t.Run("negative prefix max secrets means unlimited for that prefix (global limit applies)", func(t *testing.T) { + req := &sdk.SecretRequest{Id: "db-password", Namespace: "infra"} + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 3, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "db-", Namespace: "infra", MaxSecrets: -1, + }, + }}, + }, + }, + }, req, req, req, req) + assert.Equal(t, []bool{true, true, true, false}, got) + }) + + t.Run("secrets configured with only max secrets denies unmatched", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + }, + }, &sdk.SecretRequest{Id: "any-secret", Namespace: "ns"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("secrets configured with zero max secrets denies even matched", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 0, + }, + }, &sdk.SecretRequest{Id: "any-secret", Namespace: "ns"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("exact match still respects and decrements covering prefix limits", func(t *testing.T) { + req := &sdk.SecretRequest{Id: "db-password", Namespace: "infra"} + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: -1, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "db-password", Namespace: "infra"}, + }}, + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "db-", Namespace: "infra", MaxSecrets: 2, + }, + }}, + }, + }, + }, req, req, req) + assert.Equal(t, []bool{true, true, false}, got) + }) + + t.Run("exact match denied when covering prefix has zero calls", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "db-password", Namespace: "infra"}, + }}, + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "db-", Namespace: "infra", MaxSecrets: 0, + }, + }}, + }, + }, + }, &sdk.SecretRequest{Id: "db-password", Namespace: "infra"}) + assert.Equal(t, []bool{false}, got) + }) + + t.Run("exact match without covering prefix still works", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "db-password", Namespace: "infra"}, + }}, + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "api-", Namespace: "external", MaxSecrets: 5, + }, + }}, + }, + }, + }, &sdk.SecretRequest{Id: "db-password", Namespace: "infra"}) + assert.Equal(t, []bool{true}, got) + }) + + t.Run("multiple overlapping prefixes all decrement on match", func(t *testing.T) { + got := secretSequence(t, &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: -1, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "db-", Namespace: "infra", MaxSecrets: 3, + }, + }}, + {Restriction: &sdk.SecretRestriction_PrefixedSecret{ + PrefixedSecret: &sdk.SecretPrefixRestriction{ + Prefix: "db-pass", Namespace: "infra", MaxSecrets: 1, + }, + }}, + }, + }, + }, + // First db-password matches both prefixes; the narrower db-pass prefix is then + // exhausted, so a second db-password is denied while a db-host (only the broader + // prefix) is still allowed. + &sdk.SecretRequest{Id: "db-password", Namespace: "infra"}, + &sdk.SecretRequest{Id: "db-password", Namespace: "infra"}, + &sdk.SecretRequest{Id: "db-host", Namespace: "infra"}, + ) + assert.Equal(t, []bool{true, false, true}, got) + }) +} + +func TestRequirementSelectingModule_GetRawSecretsWithRestrictions(t *testing.T) { + restrictions := &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "allowed-secret", Namespace: "ns"}, + }}, + }, + }, + } + + fetcher := &stubEncryptionKeyFetcher{} + + newHelper := func(t *testing.T) (*mocks.MockExecutionHelperWithRawSecrets, host.ExecutionHelperWithRawSecrets) { + inner := mocks.NewMockExecutionHelperWithRawSecrets(t) + h := host.NewRestrictedExecutionHelper(inner, restrictions).(host.ExecutionHelperWithRawSecrets) + return inner, h + } + + t.Run("blocked secret returns error response without calling inner", func(t *testing.T) { + inner, h := newHelper(t) + inner.EXPECT().GetOwner().Return("owner-1") + + resp, err := h.GetRawSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{{Id: "blocked-secret", Namespace: "ns"}}, + }, fetcher) + require.NoError(t, err) + require.Len(t, resp, 1) + assert.Contains(t, resp[0].GetError(), "denied by user pre-hook restrictions") + assert.Equal(t, "blocked-secret", resp[0].GetId().GetKey()) + assert.Equal(t, "ns", resp[0].GetId().GetNamespace()) + assert.Equal(t, "owner-1", resp[0].GetId().GetOwner()) + }) + + t.Run("allows permitted secret", func(t *testing.T) { + inner, h := newHelper(t) + inner.EXPECT().GetOwner().Return("owner-1") + inner.EXPECT().GetRawSecrets(matches.AnyContext, mock.MatchedBy(func(r *sdk.GetSecretsRequest) bool { + return len(r.Requests) == 1 && r.Requests[0].Id == "allowed-secret" + }), mock.Anything).Return([]*vault.SecretResponse{{}}, nil) + + resp, err := h.GetRawSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{{Id: "allowed-secret", Namespace: "ns"}}, + }, fetcher) + require.NoError(t, err) + require.Len(t, resp, 1) + }) + + t.Run("mixed batch: blocked gets error response, allowed goes to inner", func(t *testing.T) { + inner, h := newHelper(t) + inner.EXPECT().GetOwner().Return("owner-1") + inner.EXPECT().GetRawSecrets(matches.AnyContext, mock.MatchedBy(func(r *sdk.GetSecretsRequest) bool { + return len(r.Requests) == 1 && r.Requests[0].Id == "allowed-secret" + }), mock.Anything).Return([]*vault.SecretResponse{{}}, nil) + + resp, err := h.GetRawSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{ + {Id: "allowed-secret", Namespace: "ns"}, + {Id: "blocked-secret", Namespace: "ns"}, + }, + }, fetcher) + require.NoError(t, err) + require.Len(t, resp, 2) + }) + + t.Run("delegates the encryption key fetcher to inner", func(t *testing.T) { + inner, h := newHelper(t) + inner.EXPECT().GetOwner().Return("owner-1") + var gotFetcher host.EncryptionKeyFetcher + inner.EXPECT().GetRawSecrets(matches.AnyContext, mock.Anything, mock.Anything). + RunAndReturn(func(_ context.Context, _ *sdk.GetSecretsRequest, f host.EncryptionKeyFetcher) ([]*vault.SecretResponse, error) { + gotFetcher = f + return []*vault.SecretResponse{{}}, nil + }) + + _, err := h.GetRawSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{{Id: "allowed-secret", Namespace: "ns"}}, + }, fetcher) + require.NoError(t, err) + assert.Same(t, fetcher, gotFetcher, "the fetcher passed in must be delegated unchanged to the inner helper") + }) + + t.Run("all blocked does not call inner", func(t *testing.T) { + inner, h := newHelper(t) + inner.EXPECT().GetOwner().Return("owner-1") + + resp, err := h.GetRawSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{ + {Id: "blocked-a", Namespace: "ns"}, + {Id: "blocked-b", Namespace: "ns"}, + }, + }, fetcher) + require.NoError(t, err) + require.Len(t, resp, 2) + assert.Contains(t, resp[0].GetError(), "denied by user pre-hook restrictions") + assert.Contains(t, resp[1].GetError(), "denied by user pre-hook restrictions") + }) + + t.Run("inner error is propagated", func(t *testing.T) { + inner, h := newHelper(t) + inner.EXPECT().GetOwner().Return("owner-1") + inner.EXPECT().GetRawSecrets(matches.AnyContext, mock.Anything, mock.Anything).Return(nil, errors.New("boom")) + + resp, err := h.GetRawSecrets(t.Context(), &sdk.GetSecretsRequest{ + Requests: []*sdk.SecretRequest{{Id: "allowed-secret", Namespace: "ns"}}, + }, fetcher) + require.Error(t, err) + assert.Nil(t, resp) + }) +} + +func TestRequirementSelectingModule_GetOwner(t *testing.T) { + restrictions := &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + }, + } + + inner := mocks.NewMockExecutionHelperWithRawSecrets(t) + inner.EXPECT().GetOwner().Return("owner-123") + h := host.NewRestrictedExecutionHelper(inner, restrictions).(host.ExecutionHelperWithRawSecrets) + + owner := h.GetOwner() + assert.Equal(t, "owner-123", owner) + +} + +func TestRequirementSelectingModule_NewCreatesTheRightInterface(t *testing.T) { + restrictions := &sdk.Restrictions{ + Secrets: &sdk.SecretsRestritions{ + MaxSecrets: 10, + Restrictions: []*sdk.SecretRestriction{ + {Restriction: &sdk.SecretRestriction_ExactSecret{ + ExactSecret: &sdk.Secret{Id: "allowed-secret", Namespace: "ns"}, + }}, + }, + }, + } + + t.Run("normal ExecutionHelper doesn't return ExecutionHelperWithRawSecrets", func(t *testing.T) { + result := host.NewRestrictedExecutionHelper(mocks.NewMockExecutionHelper(t), restrictions) + assert.NotImplements(t, (*host.ExecutionHelperWithRawSecrets)(nil), result) + }) + + t.Run("ExecutionHelperWithRawSecrets returns ExecutionHelperWithRawSecrets", func(t *testing.T) { + result := host.NewRestrictedExecutionHelper(mocks.NewMockExecutionHelperWithRawSecrets(t), restrictions) + assert.Implements(t, (*host.ExecutionHelperWithRawSecrets)(nil), result) + }) +} diff --git a/pkg/workflows/host/mocks/execution_helper_with_raw_secrets.go b/pkg/workflows/host/mocks/execution_helper_with_raw_secrets.go new file mode 100644 index 0000000000..b822b2ef22 --- /dev/null +++ b/pkg/workflows/host/mocks/execution_helper_with_raw_secrets.go @@ -0,0 +1,502 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mocks + +import ( + context "context" + time "time" + + vault "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/vault" + host "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + v2 "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" + mock "github.com/stretchr/testify/mock" +) + +// MockExecutionHelperWithRawSecrets is an autogenerated mock type for the ExecutionHelperWithRawSecrets type +type MockExecutionHelperWithRawSecrets struct { + mock.Mock +} + +type MockExecutionHelperWithRawSecrets_Expecter struct { + mock *mock.Mock +} + +func (_m *MockExecutionHelperWithRawSecrets) EXPECT() *MockExecutionHelperWithRawSecrets_Expecter { + return &MockExecutionHelperWithRawSecrets_Expecter{mock: &_m.Mock} +} + +// CallCapability provides a mock function with given fields: ctx, request +func (_m *MockExecutionHelperWithRawSecrets) CallCapability(ctx context.Context, request *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for CallCapability") + } + + var r0 *sdk.CapabilityResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.CapabilityRequest) *sdk.CapabilityResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sdk.CapabilityResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.CapabilityRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelperWithRawSecrets_CallCapability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CallCapability' +type MockExecutionHelperWithRawSecrets_CallCapability_Call struct { + *mock.Call +} + +// CallCapability is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.CapabilityRequest +func (_e *MockExecutionHelperWithRawSecrets_Expecter) CallCapability(ctx interface{}, request interface{}) *MockExecutionHelperWithRawSecrets_CallCapability_Call { + return &MockExecutionHelperWithRawSecrets_CallCapability_Call{Call: _e.mock.On("CallCapability", ctx, request)} +} + +func (_c *MockExecutionHelperWithRawSecrets_CallCapability_Call) Run(run func(ctx context.Context, request *sdk.CapabilityRequest)) *MockExecutionHelperWithRawSecrets_CallCapability_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.CapabilityRequest)) + }) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_CallCapability_Call) Return(_a0 *sdk.CapabilityResponse, _a1 error) *MockExecutionHelperWithRawSecrets_CallCapability_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_CallCapability_Call) RunAndReturn(run func(context.Context, *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error)) *MockExecutionHelperWithRawSecrets_CallCapability_Call { + _c.Call.Return(run) + return _c +} + +// EmitUserLog provides a mock function with given fields: log +func (_m *MockExecutionHelperWithRawSecrets) EmitUserLog(log string) error { + ret := _m.Called(log) + + if len(ret) == 0 { + panic("no return value specified for EmitUserLog") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(log) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockExecutionHelperWithRawSecrets_EmitUserLog_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmitUserLog' +type MockExecutionHelperWithRawSecrets_EmitUserLog_Call struct { + *mock.Call +} + +// EmitUserLog is a helper method to define mock.On call +// - log string +func (_e *MockExecutionHelperWithRawSecrets_Expecter) EmitUserLog(log interface{}) *MockExecutionHelperWithRawSecrets_EmitUserLog_Call { + return &MockExecutionHelperWithRawSecrets_EmitUserLog_Call{Call: _e.mock.On("EmitUserLog", log)} +} + +func (_c *MockExecutionHelperWithRawSecrets_EmitUserLog_Call) Run(run func(log string)) *MockExecutionHelperWithRawSecrets_EmitUserLog_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_EmitUserLog_Call) Return(_a0 error) *MockExecutionHelperWithRawSecrets_EmitUserLog_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_EmitUserLog_Call) RunAndReturn(run func(string) error) *MockExecutionHelperWithRawSecrets_EmitUserLog_Call { + _c.Call.Return(run) + return _c +} + +// EmitUserMetric provides a mock function with given fields: ctx, metric +func (_m *MockExecutionHelperWithRawSecrets) EmitUserMetric(ctx context.Context, metric *v2.WorkflowUserMetric) error { + ret := _m.Called(ctx, metric) + + if len(ret) == 0 { + panic("no return value specified for EmitUserMetric") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.WorkflowUserMetric) error); ok { + r0 = rf(ctx, metric) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockExecutionHelperWithRawSecrets_EmitUserMetric_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmitUserMetric' +type MockExecutionHelperWithRawSecrets_EmitUserMetric_Call struct { + *mock.Call +} + +// EmitUserMetric is a helper method to define mock.On call +// - ctx context.Context +// - metric *v2.WorkflowUserMetric +func (_e *MockExecutionHelperWithRawSecrets_Expecter) EmitUserMetric(ctx interface{}, metric interface{}) *MockExecutionHelperWithRawSecrets_EmitUserMetric_Call { + return &MockExecutionHelperWithRawSecrets_EmitUserMetric_Call{Call: _e.mock.On("EmitUserMetric", ctx, metric)} +} + +func (_c *MockExecutionHelperWithRawSecrets_EmitUserMetric_Call) Run(run func(ctx context.Context, metric *v2.WorkflowUserMetric)) *MockExecutionHelperWithRawSecrets_EmitUserMetric_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*v2.WorkflowUserMetric)) + }) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_EmitUserMetric_Call) Return(_a0 error) *MockExecutionHelperWithRawSecrets_EmitUserMetric_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_EmitUserMetric_Call) RunAndReturn(run func(context.Context, *v2.WorkflowUserMetric) error) *MockExecutionHelperWithRawSecrets_EmitUserMetric_Call { + _c.Call.Return(run) + return _c +} + +// GetDONTime provides a mock function with no fields +func (_m *MockExecutionHelperWithRawSecrets) GetDONTime() (time.Time, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetDONTime") + } + + var r0 time.Time + var r1 error + if rf, ok := ret.Get(0).(func() (time.Time, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelperWithRawSecrets_GetDONTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDONTime' +type MockExecutionHelperWithRawSecrets_GetDONTime_Call struct { + *mock.Call +} + +// GetDONTime is a helper method to define mock.On call +func (_e *MockExecutionHelperWithRawSecrets_Expecter) GetDONTime() *MockExecutionHelperWithRawSecrets_GetDONTime_Call { + return &MockExecutionHelperWithRawSecrets_GetDONTime_Call{Call: _e.mock.On("GetDONTime")} +} + +func (_c *MockExecutionHelperWithRawSecrets_GetDONTime_Call) Run(run func()) *MockExecutionHelperWithRawSecrets_GetDONTime_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetDONTime_Call) Return(_a0 time.Time, _a1 error) *MockExecutionHelperWithRawSecrets_GetDONTime_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetDONTime_Call) RunAndReturn(run func() (time.Time, error)) *MockExecutionHelperWithRawSecrets_GetDONTime_Call { + _c.Call.Return(run) + return _c +} + +// GetNodeTime provides a mock function with no fields +func (_m *MockExecutionHelperWithRawSecrets) GetNodeTime() time.Time { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetNodeTime") + } + + var r0 time.Time + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + return r0 +} + +// MockExecutionHelperWithRawSecrets_GetNodeTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeTime' +type MockExecutionHelperWithRawSecrets_GetNodeTime_Call struct { + *mock.Call +} + +// GetNodeTime is a helper method to define mock.On call +func (_e *MockExecutionHelperWithRawSecrets_Expecter) GetNodeTime() *MockExecutionHelperWithRawSecrets_GetNodeTime_Call { + return &MockExecutionHelperWithRawSecrets_GetNodeTime_Call{Call: _e.mock.On("GetNodeTime")} +} + +func (_c *MockExecutionHelperWithRawSecrets_GetNodeTime_Call) Run(run func()) *MockExecutionHelperWithRawSecrets_GetNodeTime_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetNodeTime_Call) Return(_a0 time.Time) *MockExecutionHelperWithRawSecrets_GetNodeTime_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetNodeTime_Call) RunAndReturn(run func() time.Time) *MockExecutionHelperWithRawSecrets_GetNodeTime_Call { + _c.Call.Return(run) + return _c +} + +// GetOwner provides a mock function with no fields +func (_m *MockExecutionHelperWithRawSecrets) GetOwner() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetOwner") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockExecutionHelperWithRawSecrets_GetOwner_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOwner' +type MockExecutionHelperWithRawSecrets_GetOwner_Call struct { + *mock.Call +} + +// GetOwner is a helper method to define mock.On call +func (_e *MockExecutionHelperWithRawSecrets_Expecter) GetOwner() *MockExecutionHelperWithRawSecrets_GetOwner_Call { + return &MockExecutionHelperWithRawSecrets_GetOwner_Call{Call: _e.mock.On("GetOwner")} +} + +func (_c *MockExecutionHelperWithRawSecrets_GetOwner_Call) Run(run func()) *MockExecutionHelperWithRawSecrets_GetOwner_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetOwner_Call) Return(_a0 string) *MockExecutionHelperWithRawSecrets_GetOwner_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetOwner_Call) RunAndReturn(run func() string) *MockExecutionHelperWithRawSecrets_GetOwner_Call { + _c.Call.Return(run) + return _c +} + +// GetRawSecrets provides a mock function with given fields: ctx, request, fetcher +func (_m *MockExecutionHelperWithRawSecrets) GetRawSecrets(ctx context.Context, request *sdk.GetSecretsRequest, fetcher host.EncryptionKeyFetcher) ([]*vault.SecretResponse, error) { + ret := _m.Called(ctx, request, fetcher) + + if len(ret) == 0 { + panic("no return value specified for GetRawSecrets") + } + + var r0 []*vault.SecretResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.GetSecretsRequest, host.EncryptionKeyFetcher) ([]*vault.SecretResponse, error)); ok { + return rf(ctx, request, fetcher) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.GetSecretsRequest, host.EncryptionKeyFetcher) []*vault.SecretResponse); ok { + r0 = rf(ctx, request, fetcher) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*vault.SecretResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.GetSecretsRequest, host.EncryptionKeyFetcher) error); ok { + r1 = rf(ctx, request, fetcher) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelperWithRawSecrets_GetRawSecrets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRawSecrets' +type MockExecutionHelperWithRawSecrets_GetRawSecrets_Call struct { + *mock.Call +} + +// GetRawSecrets is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.GetSecretsRequest +// - fetcher host.EncryptionKeyFetcher +func (_e *MockExecutionHelperWithRawSecrets_Expecter) GetRawSecrets(ctx interface{}, request interface{}, fetcher interface{}) *MockExecutionHelperWithRawSecrets_GetRawSecrets_Call { + return &MockExecutionHelperWithRawSecrets_GetRawSecrets_Call{Call: _e.mock.On("GetRawSecrets", ctx, request, fetcher)} +} + +func (_c *MockExecutionHelperWithRawSecrets_GetRawSecrets_Call) Run(run func(ctx context.Context, request *sdk.GetSecretsRequest, fetcher host.EncryptionKeyFetcher)) *MockExecutionHelperWithRawSecrets_GetRawSecrets_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.GetSecretsRequest), args[2].(host.EncryptionKeyFetcher)) + }) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetRawSecrets_Call) Return(_a0 []*vault.SecretResponse, _a1 error) *MockExecutionHelperWithRawSecrets_GetRawSecrets_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetRawSecrets_Call) RunAndReturn(run func(context.Context, *sdk.GetSecretsRequest, host.EncryptionKeyFetcher) ([]*vault.SecretResponse, error)) *MockExecutionHelperWithRawSecrets_GetRawSecrets_Call { + _c.Call.Return(run) + return _c +} + +// GetSecrets provides a mock function with given fields: ctx, request +func (_m *MockExecutionHelperWithRawSecrets) GetSecrets(ctx context.Context, request *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for GetSecrets") + } + + var r0 []*sdk.SecretResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.GetSecretsRequest) []*sdk.SecretResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*sdk.SecretResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.GetSecretsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelperWithRawSecrets_GetSecrets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSecrets' +type MockExecutionHelperWithRawSecrets_GetSecrets_Call struct { + *mock.Call +} + +// GetSecrets is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.GetSecretsRequest +func (_e *MockExecutionHelperWithRawSecrets_Expecter) GetSecrets(ctx interface{}, request interface{}) *MockExecutionHelperWithRawSecrets_GetSecrets_Call { + return &MockExecutionHelperWithRawSecrets_GetSecrets_Call{Call: _e.mock.On("GetSecrets", ctx, request)} +} + +func (_c *MockExecutionHelperWithRawSecrets_GetSecrets_Call) Run(run func(ctx context.Context, request *sdk.GetSecretsRequest)) *MockExecutionHelperWithRawSecrets_GetSecrets_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.GetSecretsRequest)) + }) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetSecrets_Call) Return(_a0 []*sdk.SecretResponse, _a1 error) *MockExecutionHelperWithRawSecrets_GetSecrets_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetSecrets_Call) RunAndReturn(run func(context.Context, *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error)) *MockExecutionHelperWithRawSecrets_GetSecrets_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkflowExecutionID provides a mock function with no fields +func (_m *MockExecutionHelperWithRawSecrets) GetWorkflowExecutionID() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetWorkflowExecutionID") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkflowExecutionID' +type MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call struct { + *mock.Call +} + +// GetWorkflowExecutionID is a helper method to define mock.On call +func (_e *MockExecutionHelperWithRawSecrets_Expecter) GetWorkflowExecutionID() *MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call { + return &MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call{Call: _e.mock.On("GetWorkflowExecutionID")} +} + +func (_c *MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call) Run(run func()) *MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call) Return(_a0 string) *MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call) RunAndReturn(run func() string) *MockExecutionHelperWithRawSecrets_GetWorkflowExecutionID_Call { + _c.Call.Return(run) + return _c +} + +// NewMockExecutionHelperWithRawSecrets creates a new instance of MockExecutionHelperWithRawSecrets. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockExecutionHelperWithRawSecrets(t interface { + mock.TestingT + Cleanup(func()) +}) *MockExecutionHelperWithRawSecrets { + mock := &MockExecutionHelperWithRawSecrets{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/workflows/host/module.go b/pkg/workflows/host/module.go index f4debbb922..e16cb1597a 100644 --- a/pkg/workflows/host/module.go +++ b/pkg/workflows/host/module.go @@ -6,6 +6,7 @@ import ( "context" "time" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/vault" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" ) @@ -46,3 +47,19 @@ type ExecutionHelper interface { EmitUserMetric(ctx context.Context, metric *wfpb.WorkflowUserMetric) error } + +type ExecutionHelperWithRawSecrets interface { + ExecutionHelper + GetRawSecrets(ctx context.Context, request *sdkpb.GetSecretsRequest, fetcher EncryptionKeyFetcher) ([]*vault.SecretResponse, error) + GetOwner() string +} + +// RestrictionAwareModule allows the module to know of the user-enforced restrictions. +// Enforcement by this module is NOT to be trusted by the host, +// however a violation is considered an indicator of a serious issue, such as compromise. +type RestrictionAwareModule interface { + Module + + // SetRestrictions must respect the restrictions for the execution until it completes + SetRestrictions(executionId string, restrictions *sdkpb.Restrictions) +} diff --git a/pkg/workflows/host/requirement_selecting_module.go b/pkg/workflows/host/requirement_selecting_module.go index 30cd5ba6a5..3ce23f7e49 100644 --- a/pkg/workflows/host/requirement_selecting_module.go +++ b/pkg/workflows/host/requirement_selecting_module.go @@ -61,6 +61,7 @@ func NewRequirementSelectingModule(main ModuleAndHandler, additional []ModuleAnd type triggerInfo struct { moduleIdx int + preHook bool requirements *sdk.Requirements } @@ -102,7 +103,7 @@ func (r *requirementSelectingModule) subscribe(ctx context.Context, request *sdk for j, m := range r.modules { if CheckRequirements(ctx, m.RequirementsHandler, sub.Requirements) { m.ensureStarted() - r.cache.Store(uint64(i), triggerInfo{moduleIdx: j, requirements: sub.Requirements}) + r.cache.Store(uint64(i), triggerInfo{moduleIdx: j, requirements: sub.Requirements, preHook: sub.PreHook}) matched = true break } @@ -119,13 +120,35 @@ func (r *requirementSelectingModule) trigger(ctx context.Context, request *sdk.E trigger := request.GetTrigger() if val, cached := r.cache.Load(trigger.Id); cached { info := val.(triggerInfo) + m := r.modules[info.moduleIdx] + if info.preHook { + prehook := &sdk.ExecuteRequest{Request: &sdk.ExecuteRequest_PreHook{PreHook: trigger}} + preHookResult, err := r.modules[0].Execute(ctx, prehook, handler) + if err != nil { + return nil, fmt.Errorf("pre-hook execution failed: %w", err) + } + + switch preHookResult.Result.(type) { + case *sdk.ExecutionResult_Error: + return preHookResult, nil + } + + restrictions := preHookResult.GetRestrictions() + + handler = NewRestrictedExecutionHelper(handler, restrictions) + if rem, ok := m.Module.(RestrictionAwareModule); ok { + rem.SetRestrictions(handler.GetWorkflowExecutionID(), restrictions) + } + } + if rem, ok := m.Module.(RequirementEnforcingModule); ok && info.requirements != nil { rem.SetRequirements(handler.GetWorkflowExecutionID(), info.requirements) } return m.Execute(ctx, request, handler) } + return nil, errors.New("cannot trigger before gathering subscriptions") } diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go index f248330e8c..f91f539046 100644 --- a/pkg/workflows/host/requirement_selecting_module_test.go +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -1,33 +1,32 @@ -package host_test +package host import ( "context" "errors" "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" - "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" ) type stubModule struct { - startFn func() - closeFn func() - legacyFn func() bool - executeFn func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) + executeFn func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) + startCount atomic.Int32 + closeCount atomic.Int32 + legacy bool } -func (s *stubModule) Start() { s.startFn() } -func (s *stubModule) Close() { s.closeFn() } -func (s *stubModule) IsLegacyDAG() bool { return s.legacyFn() } -func (s *stubModule) Execute(ctx context.Context, req *sdk.ExecuteRequest, h host.ExecutionHelper) (*sdk.ExecutionResult, error) { +func (s *stubModule) Start() { s.startCount.Add(1) } +func (s *stubModule) Close() { s.closeCount.Add(1) } +func (s *stubModule) IsLegacyDAG() bool { return s.legacy } +func (s *stubModule) Execute(ctx context.Context, req *sdk.ExecuteRequest, h ExecutionHelper) (*sdk.ExecutionResult, error) { return s.executeFn(ctx, req, h) } @@ -40,8 +39,14 @@ func (s *requirementEnforcingStub) SetRequirements(executionID string, requireme s.setRequirementsFn(executionID, requirements) } -func noop() {} -func noopClose() {} +type restrictionAwareStub struct { + *stubModule + setRestrictionsFn func(string, *sdk.Restrictions) +} + +func (s *restrictionAwareStub) SetRestrictions(executionID string, restrictions *sdk.Restrictions) { + s.setRestrictionsFn(executionID, restrictions) +} func triggerRequest(id uint64) *sdk.ExecuteRequest { return &sdk.ExecuteRequest{ @@ -73,63 +78,60 @@ func subWithReqs(reqs *sdk.Requirements) *sdk.TriggerSubscription { func TestRequirementSelectingModule_Start(t *testing.T) { t.Run("starts only main module", func(t *testing.T) { - var mainStarted, additionalStarted bool - main := host.ModuleAndHandler{Module: &stubModule{startFn: func() { mainStarted = true }}} - add := host.ModuleAndHandler{Module: &stubModule{startFn: func() { additionalStarted = true }}} + main := &stubModule{} + unused := &stubModule{} - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule( + ModuleAndHandler{Module: main}, + []ModuleAndHandler{{Module: unused}}, + ) m.Start() - assert.True(t, mainStarted) - assert.False(t, additionalStarted) + assert.Equal(t, int32(1), main.startCount.Load()) + assert.Equal(t, int32(0), unused.startCount.Load()) }) } func TestRequirementSelectingModule_Close(t *testing.T) { t.Run("closes main and no additional when none started", func(t *testing.T) { - var mainClosed, addClosed bool - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, closeFn: func() { mainClosed = true }, - }} - add := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, closeFn: func() { addClosed = true }, - }} + main := &stubModule{} + unused := &stubModule{} - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule( + ModuleAndHandler{Module: main}, + []ModuleAndHandler{{Module: unused}}, + ) m.Start() m.Close() - assert.True(t, mainClosed) - assert.False(t, addClosed) + assert.Equal(t, int32(1), main.closeCount.Load()) + assert.Equal(t, int32(0), unused.closeCount.Load()) }) t.Run("closes main and all started additional modules", func(t *testing.T) { teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} - var mainClosed, add0Closed, add1Closed bool - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - closeFn: func() { mainClosed = true }, - executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := &stubModule{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { return subscribeResult(subWithReqs(teeReqs)), nil }, - }} - add0 := host.ModuleAndHandler{ - Module: &stubModule{ - startFn: noop, - closeFn: func() { add0Closed = true }, - }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - add1 := host.ModuleAndHandler{ - Module: &stubModule{ - startFn: noop, - closeFn: func() { add1Closed = true }, + requirementsSatisfier := &stubModule{} + nonMatcher := &stubModule{} + + m := NewRequirementSelectingModule( + ModuleAndHandler{Module: main}, + []ModuleAndHandler{ + { + Module: requirementsSatisfier, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + }, + { + Module: nonMatcher, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, - } - - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add0, add1}) + ) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -137,29 +139,28 @@ func TestRequirementSelectingModule_Close(t *testing.T) { m.Close() - assert.True(t, mainClosed, "main should be closed") - assert.True(t, add0Closed, "started additional should be closed") - assert.False(t, add1Closed, "never-started additional should not be closed") + assert.Equal(t, int32(1), main.closeCount.Load(), "main should be closed") + assert.Equal(t, int32(1), requirementsSatisfier.closeCount.Load(), "started additional should be closed") + assert.Equal(t, int32(0), nonMatcher.closeCount.Load(), "never-started additional should not be closed") }) } func TestRequirementSelectingModule_IsLegacyDAG(t *testing.T) { - main := host.ModuleAndHandler{Module: &stubModule{legacyFn: func() bool { return true }}} - m := host.NewRequirementSelectingModule(main, nil) + main := &stubModule{legacy: true} + m := NewRequirementSelectingModule(ModuleAndHandler{Module: main}, nil) assert.True(t, m.IsLegacyDAG()) } func TestRequirementSelectingModule_Execute(t *testing.T) { t.Run("trigger with no cached entry errors", func(t *testing.T) { - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { assert.Fail(t, "main should not be called for trigger when no subscriptions") return nil, errors.New("unexpected callback") }, }} - m := host.NewRequirementSelectingModule(main, nil) + m := NewRequirementSelectingModule(main, nil) m.Start() _, err := m.Execute(t.Context(), triggerRequest(1), nil) @@ -167,24 +168,22 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { }) t.Run("main error on subscribe propagates", func(t *testing.T) { - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { return nil, assert.AnError }, }} - add := host.ModuleAndHandler{ + add := ModuleAndHandler{ Module: &stubModule{ - startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { t.Fatal("additional module should not be called") return nil, nil }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -195,24 +194,21 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} want := &sdk.ExecutionResult{} - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { return subscribeResult(subWithReqs(teeReqs)), nil }, }} - add := host.ModuleAndHandler{ + add := ModuleAndHandler{ Module: &stubModule{ - startFn: noop, - closeFn: noopClose, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { return want, nil }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -226,18 +222,17 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { t.Run("subscribe with unmatched requirements returns error", func(t *testing.T) { teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { return subscribeResult(subWithReqs(teeReqs)), nil }, }} - add := host.ModuleAndHandler{ - Module: &stubModule{startFn: noop}, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + add := ModuleAndHandler{ + Module: &stubModule{}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -249,28 +244,25 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} want := &sdk.ExecutionResult{} - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { return subscribeResult(subWithReqs(teeReqs)), nil }, }} - add0 := host.ModuleAndHandler{ - Module: &stubModule{startFn: noop}, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + add0 := ModuleAndHandler{ + Module: &stubModule{}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, } - add1 := host.ModuleAndHandler{ + add1 := ModuleAndHandler{ Module: &stubModule{ - startFn: noop, - closeFn: noopClose, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { return want, nil }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add0, add1}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add0, add1}) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -283,47 +275,42 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { t.Run("additional module started lazily during subscribe", func(t *testing.T) { teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} - var addStartCount int32 - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { return subscribeResult(subWithReqs(teeReqs)), nil }, }} - add := host.ModuleAndHandler{ - Module: &stubModule{ - startFn: func() { atomic.AddInt32(&addStartCount, 1) }, - closeFn: noopClose, - }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + requirementsSatisfier := &stubModule{} + add := ModuleAndHandler{ + Module: requirementsSatisfier, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() - assert.Equal(t, int32(0), atomic.LoadInt32(&addStartCount)) + assert.Equal(t, int32(0), requirementsSatisfier.startCount.Load()) _, err := m.Execute(t.Context(), subscribeRequest(), nil) require.NoError(t, err) - assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) + assert.Equal(t, int32(1), requirementsSatisfier.startCount.Load()) // Second subscribe does not start additional again (sync.Once). _, err = m.Execute(t.Context(), subscribeRequest(), nil) require.NoError(t, err) - assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) + assert.Equal(t, int32(1), requirementsSatisfier.startCount.Load()) }) t.Run("subscribe with no requirements returns main result", func(t *testing.T) { want := subscribeResult() - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { return want, nil }, }} - m := host.NewRequirementSelectingModule(main, nil) + m := NewRequirementSelectingModule(main, nil) m.Start() got, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -336,10 +323,9 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { want := &sdk.ExecutionResult{} var mainTriggerCalls int32 - main := host.ModuleAndHandler{ + main := ModuleAndHandler{ Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { if req.GetTrigger() != nil { atomic.AddInt32(&mainTriggerCalls, 1) return want, nil @@ -347,20 +333,19 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { return subscribeResult(subWithReqs(teeReqs)), nil }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - add := host.ModuleAndHandler{ + add := ModuleAndHandler{ Module: &stubModule{ - startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { t.Fatal("additional module should not be called when main satisfies requirements") return nil, nil }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -377,14 +362,13 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { want := &sdk.ExecutionResult{} executionID := "wf-exec-1" - main := host.ModuleAndHandler{ + main := ModuleAndHandler{ Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { return subscribeResult(subWithReqs(teeReqs)), nil }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, } var calls []string @@ -392,9 +376,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { var gotExecutionID string enforcingAdd := &requirementEnforcingStub{ stubModule: &stubModule{ - startFn: noop, - closeFn: noopClose, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { calls = append(calls, "execute") return want, nil }, @@ -405,16 +387,15 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { gotReqs = requirements }, } - add := host.ModuleAndHandler{ + add := ModuleAndHandler{ Module: enforcingAdd, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() - helper := &mocks.MockExecutionHelper{} - helper.On("GetWorkflowExecutionID").Return(executionID).Once() + helper := &stubExecutionHelper{executionID: executionID} _, err := m.Execute(t.Context(), subscribeRequest(), nil) require.NoError(t, err) @@ -425,7 +406,6 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { assert.Equal(t, []string{"set", "execute"}, calls) assert.Equal(t, executionID, gotExecutionID) assert.Same(t, teeReqs, gotReqs) - helper.AssertExpectations(t) }) } @@ -434,27 +414,24 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} var mainTriggerCalls int32 - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { if req.GetTrigger() != nil { atomic.AddInt32(&mainTriggerCalls, 1) } return subscribeResult(subWithReqs(teeReqs)), nil }, }} - add := host.ModuleAndHandler{ + add := ModuleAndHandler{ Module: &stubModule{ - startFn: noop, - closeFn: noopClose, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { return &sdk.ExecutionResult{}, nil }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -473,9 +450,8 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} var mainTriggerCalls int32 - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { if req.GetTrigger() != nil { atomic.AddInt32(&mainTriggerCalls, 1) return &sdk.ExecutionResult{}, nil @@ -484,18 +460,16 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { return subscribeResult(subWithReqs(teeReqs), subWithReqs(nil)), nil }, }} - add := host.ModuleAndHandler{ + add := ModuleAndHandler{ Module: &stubModule{ - startFn: noop, - closeFn: noopClose, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { return &sdk.ExecutionResult{}, nil }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -517,9 +491,8 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { var mainTriggerCalls int32 wantAdditional := &sdk.ExecutionResult{} - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { if req.GetTrigger() != nil { atomic.AddInt32(&mainTriggerCalls, 1) return &sdk.ExecutionResult{}, nil @@ -527,17 +500,16 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { return subscribeResult(subWithReqs(teeReqs), subWithReqs(nil)), nil }, }} - add := host.ModuleAndHandler{ + add := ModuleAndHandler{ Module: &stubModule{ - startFn: noop, closeFn: noopClose, - executeFn: func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error) { + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { return wantAdditional, nil }, }, - RequirementsHandler: host.RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } - m := host.NewRequirementSelectingModule(main, []host.ModuleAndHandler{add}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -558,14 +530,13 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { t.Run("no additional modules when subscribe has requirements returns error", func(t *testing.T) { teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} - main := host.ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ host.ExecutionHelper) (*sdk.ExecutionResult, error) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { return subscribeResult(subWithReqs(teeReqs)), nil }, }} - m := host.NewRequirementSelectingModule(main, nil) + m := NewRequirementSelectingModule(main, nil) m.Start() _, err := m.Execute(t.Context(), subscribeRequest(), nil) @@ -573,3 +544,317 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { assert.Contains(t, err.Error(), "cannot find a runner") }) } + +func subWithReqsAndPreHook(reqs *sdk.Requirements) *sdk.TriggerSubscription { + return &sdk.TriggerSubscription{Requirements: reqs, PreHook: true} +} + +func restrictionsResult(r *sdk.Restrictions) *sdk.ExecutionResult { + return &sdk.ExecutionResult{ + Result: &sdk.ExecutionResult_Restrictions{Restrictions: r}, + } +} + +func TestRequirementSelectingModule_PreHook(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + + t.Run("pre-hook runs in main, trigger runs in additional with restricted helper", func(t *testing.T) { + restrictions := &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{ + MaxTotalCalls: 1, + Type: sdk.CapabilityRestrictionType_CAPABILITY_RESTRICTION_TYPE_OPEN, + }, + } + + var helperSeenByAdditional ExecutionHelper + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if _, ok := req.Request.(*sdk.ExecuteRequest_PreHook); ok { + return restrictionsResult(restrictions), nil + } + return subscribeResult(subWithReqsAndPreHook(teeReqs)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, h ExecutionHelper) (*sdk.ExecutionResult, error) { + helperSeenByAdditional = h + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + _, err = m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + + _, isRestricted := helperSeenByAdditional.(*executionRestrictions) + assert.True(t, isRestricted, "additional module should receive a restricted helper") + }) + + t.Run("pre-hook error result is returned directly without running the trigger", func(t *testing.T) { + errResult := &sdk.ExecutionResult{ + Result: &sdk.ExecutionResult_Error{Error: "denied by pre-hook"}, + } + + var helperSeenByAdditional ExecutionHelper + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if _, ok := req.Request.(*sdk.ExecuteRequest_PreHook); ok { + return errResult, nil + } + return subscribeResult(subWithReqsAndPreHook(teeReqs)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, h ExecutionHelper) (*sdk.ExecutionResult, error) { + helperSeenByAdditional = h + t.Fatal("additional module should not be called when pre-hook returns an error result") + return nil, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), &stubExecutionHelper{}) + require.NoError(t, err) + assert.Same(t, errResult, got, "pre-hook error result should be returned unchanged") + assert.Nil(t, helperSeenByAdditional, "additional module must not be invoked") + }) + + t.Run("pre-hook error propagates", func(t *testing.T) { + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if _, ok := req.Request.(*sdk.ExecuteRequest_PreHook); ok { + return nil, assert.AnError + } + return subscribeResult(subWithReqsAndPreHook(teeReqs)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + t.Fatal("additional module should not be called when pre-hook fails") + return nil, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + _, err = m.Execute(t.Context(), triggerRequest(0), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "pre-hook execution failed") + }) + + t.Run("pre-hook on main-routed trigger applies restrictions to main", func(t *testing.T) { + restrictions := &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{MaxTotalCalls: 0}, + } + var helperSeenByMain ExecutionHelper + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, h ExecutionHelper) (*sdk.ExecutionResult, error) { + if _, ok := req.Request.(*sdk.ExecuteRequest_PreHook); ok { + return restrictionsResult(restrictions), nil + } + if req.GetTrigger() != nil { + helperSeenByMain = h + return &sdk.ExecutionResult{}, nil + } + // Subscribe: no requirements, PreHook=true + return subscribeResult(&sdk.TriggerSubscription{PreHook: true}), nil + }, + }} + + m := NewRequirementSelectingModule(main, nil) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + _, err = m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + + _, isRestricted := helperSeenByMain.(*executionRestrictions) + assert.True(t, isRestricted, "main should receive a restricted helper when pre-hook is set") + }) + + t.Run("no pre-hook passes original helper to additional", func(t *testing.T) { + var helperSeenByAdditional ExecutionHelper + inner := &stubExecutionHelper{} + + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + t.Fatal("main should not be called for trigger when cached in additional") + } + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, h ExecutionHelper) (*sdk.ExecutionResult, error) { + helperSeenByAdditional = h + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + _, err = m.Execute(t.Context(), triggerRequest(0), inner) + require.NoError(t, err) + + assert.Same(t, inner, helperSeenByAdditional, "without pre-hook, original helper should be passed unchanged") + }) + + t.Run("pre-hook restrictions are forwarded to RestrictionAwareModule", func(t *testing.T) { + restrictions := &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{MaxTotalCalls: 3}, + } + executionID := "wf-exec-restricted" + + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if _, ok := req.Request.(*sdk.ExecuteRequest_PreHook); ok { + return restrictionsResult(restrictions), nil + } + return subscribeResult(subWithReqsAndPreHook(teeReqs)), nil + }, + }} + + var calls []string + var gotExecutionID string + var gotRestrictions *sdk.Restrictions + awareAdd := &restrictionAwareStub{ + stubModule: &stubModule{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + calls = append(calls, "execute") + return &sdk.ExecutionResult{}, nil + }, + }, + setRestrictionsFn: func(id string, r *sdk.Restrictions) { + calls = append(calls, "setRestrictions") + gotExecutionID = id + gotRestrictions = r + }, + } + add := ModuleAndHandler{ + Module: awareAdd, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + helper := &stubExecutionHelper{executionID: executionID} + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + _, err = m.Execute(t.Context(), triggerRequest(0), helper) + require.NoError(t, err) + + assert.Equal(t, []string{"setRestrictions", "execute"}, calls) + assert.Equal(t, executionID, gotExecutionID) + assert.Same(t, restrictions, gotRestrictions) + }) + + t.Run("pre-hook restrictions are forwarded to module implementing both Restriction- and Requirement-aware interfaces", func(t *testing.T) { + restrictions := &sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{MaxTotalCalls: 5}, + } + executionID := "wf-exec-both" + + main := ModuleAndHandler{Module: &stubModule{ + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if _, ok := req.Request.(*sdk.ExecuteRequest_PreHook); ok { + return restrictionsResult(restrictions), nil + } + return subscribeResult(subWithReqsAndPreHook(teeReqs)), nil + }, + }} + + var calls []string + bothAware := &requirementAndRestrictionAwareStub{ + restrictionAwareStub: &restrictionAwareStub{ + stubModule: &stubModule{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + calls = append(calls, "execute") + return &sdk.ExecutionResult{}, nil + }, + }, + setRestrictionsFn: func(string, *sdk.Restrictions) { calls = append(calls, "setRestrictions") }, + }, + setRequirementsFn: func(string, *sdk.Requirements) { calls = append(calls, "setRequirements") }, + } + add := ModuleAndHandler{ + Module: bothAware, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + helper := &stubExecutionHelper{executionID: executionID} + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + _, err = m.Execute(t.Context(), triggerRequest(0), helper) + require.NoError(t, err) + + // Restrictions must be set before requirements, both before execute. + assert.Equal(t, []string{"setRestrictions", "setRequirements", "execute"}, calls) + }) +} + +// requirementAndRestrictionAwareStub implements both RestrictionAwareModule and RequirementEnforcingModule. +type requirementAndRestrictionAwareStub struct { + *restrictionAwareStub + setRequirementsFn func(string, *sdk.Requirements) +} + +func (s *requirementAndRestrictionAwareStub) SetRequirements(executionID string, requirements *sdk.Requirements) { + s.setRequirementsFn(executionID, requirements) +} + +// stubExecutionHelper is a minimal ExecutionHelper implementation for testing. +type stubExecutionHelper struct{ executionID string } + +func (s *stubExecutionHelper) CallCapability(context.Context, *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error) { + return nil, nil +} +func (s *stubExecutionHelper) GetSecrets(context.Context, *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error) { + return nil, nil +} +func (s *stubExecutionHelper) GetWorkflowExecutionID() string { return s.executionID } +func (s *stubExecutionHelper) GetNodeTime() time.Time { return time.Time{} } +func (s *stubExecutionHelper) GetDONTime() (time.Time, error) { return time.Time{}, nil } +func (s *stubExecutionHelper) EmitUserLog(string) error { return nil } +func (s *stubExecutionHelper) EmitUserMetric(context.Context, *wfpb.WorkflowUserMetric) error { + return nil +} diff --git a/pkg/workflows/host/tee_provider_test.go b/pkg/workflows/host/tee_provider_test.go index 569bf138c9..5c5e5a333a 100644 --- a/pkg/workflows/host/tee_provider_test.go +++ b/pkg/workflows/host/tee_provider_test.go @@ -137,7 +137,12 @@ func TestNewTeeProvider(t *testing.T) { assert.False(t, provides(tee)) }) - t.Run("returns false when tee item is nil", func(t *testing.T) { + t.Run("returns true when tee is nil", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + assert.True(t, provides(nil)) + }) + + t.Run("returns false when tee.Item item is nil", func(t *testing.T) { provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) tee := &sdkpb.Tee{} assert.False(t, provides(tee)) diff --git a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go index e18a5fcd59..c11cd41d2a 100644 --- a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go +++ b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go @@ -61,6 +61,13 @@ func SendSubscription(subscriptions *sdk.TriggerSubscriptionRequest) { os.Exit(0) } +func SendRestrictions(restrictions *sdk.Restrictions) { + execResult := &sdk.ExecutionResult{Result: &sdk.ExecutionResult_Restrictions{Restrictions: restrictions}} + bytes := Must(proto.Marshal(execResult)) + sendResponse(BufferToPointerLen(bytes)) + os.Exit(0) +} + func Now() time.Time { var buf [8]byte // host writes UnixNano as little-endian uint64 rc := now(unsafe.Pointer(&buf[0])) diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 0ddb6fc8fc..5279f92a6e 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -41,6 +41,7 @@ const v2ImportPrefix = "version_v2" var ( defaultTickInterval = 100 * time.Millisecond defaultTimeout = 10 * time.Minute + defaultPreeHookTimeout = 10 * time.Second defaultMinMemoryMBs = uint64(128) DefaultInitialFuel = uint64(100_000_000) defaultMaxFetchRequests = 5 @@ -65,6 +66,7 @@ type DeterminismConfig struct { type ModuleConfig struct { TickInterval time.Duration Timeout *time.Duration + PrehookTimeout *time.Duration MaxMemoryMBs uint64 MinMemoryMBs uint64 MemoryLimiter limits.BoundLimiter[config.Size] // supersedes Max/MinMemoryMBs if set @@ -199,6 +201,10 @@ func NewModule(ctx context.Context, modCfg *ModuleConfig, binary []byte, opts .. modCfg.Timeout = &defaultTimeout } + if modCfg.PrehookTimeout == nil { + modCfg.PrehookTimeout = &defaultPreeHookTimeout + } + if modCfg.MinMemoryMBs == 0 { modCfg.MinMemoryMBs = defaultMinMemoryMBs } @@ -575,7 +581,12 @@ func (m *module) Execute(ctx context.Context, req *sdkpb.ExecuteRequest, executo r.MaxResponseSize = maxSize } - return runWasm(ctx, m, req, setMaxResponseSize, linkNoDAG, executor) + timeout := *m.cfg.Timeout + switch req.Request.(type) { + case *sdkpb.ExecuteRequest_PreHook: + timeout = *m.cfg.PrehookTimeout + } + return runWasm(ctx, m, req, setMaxResponseSize, linkNoDAG, executor, timeout) } // Run is deprecated, use execute instead @@ -601,7 +612,7 @@ func (m *module) Run(ctx context.Context, request *wasmdagpb.Request) (*wasmdagp } } - return runWasm(ctx, m, request, setMaxResponseSize, linkLegacyDAG, nil) + return runWasm(ctx, m, request, setMaxResponseSize, linkLegacyDAG, nil, *m.cfg.Timeout) } func runWasm[I, O proto.Message]( @@ -610,17 +621,19 @@ func runWasm[I, O proto.Message]( request I, setMaxResponseSize func(i I, maxSize uint64), linkWasm linkFn[O], - helper ExecutionHelper) (O, error) { + helper ExecutionHelper, + maxTimeout time.Duration) (O, error) { var o O // No reason to run the WASM longer if the outer ctx will cancel. ctxDeadline, hasDeadline := ctx.Deadline() var ctxWithTimeout context.Context var cancel func() - if hasDeadline && ctxDeadline.Before(time.Now().Add(*m.cfg.Timeout)) { + + if hasDeadline && ctxDeadline.Before(time.Now().Add(maxTimeout)) { ctxWithTimeout, cancel = context.WithCancel(ctx) } else { - ctxWithTimeout, cancel = context.WithTimeout(ctx, *m.cfg.Timeout) + ctxWithTimeout, cancel = context.WithTimeout(ctx, maxTimeout) } defer cancel() @@ -669,7 +682,7 @@ func runWasm[I, O proto.Message]( 1, // memories ) - deadline := *m.cfg.Timeout / m.cfg.TickInterval + deadline := maxTimeout / m.cfg.TickInterval store.SetEpochDeadline(uint64(deadline)) h := fnv.New64a() @@ -731,7 +744,7 @@ func runWasm[I, O proto.Message]( // Note - there is no other reliable signal on the error that can be used to infer it is due to epoch deadline // being reached, so if an error is returned after the deadline it is assumed it is due to that and return // context.DeadlineExceeded. - if err != nil && ((executionDuration >= *m.cfg.Timeout-m.cfg.TickInterval) || ctx.Err() != nil) { // As start could be called just before epoch update 1 tick interval is deducted to account for this + if err != nil && ((executionDuration >= maxTimeout-m.cfg.TickInterval) || ctx.Err() != nil) { // As start could be called just before epoch update 1 tick interval is deducted to account for this m.cfg.Logger.Errorw("start function returned error after deadline reached, returning deadline exceeded error", "errFromStartFunction", err) return o, context.DeadlineExceeded } diff --git a/pkg/workflows/wasm/host/standard_test.go b/pkg/workflows/wasm/host/standard_test.go index cdc230fe3c..5690231e85 100644 --- a/pkg/workflows/wasm/host/standard_test.go +++ b/pkg/workflows/wasm/host/standard_test.go @@ -25,6 +25,7 @@ import ( "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" @@ -620,6 +621,33 @@ func TestStandardTeeRuntime(t *testing.T) { assertProto(t, expected, actual.GetTriggerSubscriptions()) } +func TestStandardRestrictions(t *testing.T) { + t.Parallel() + mockExecutionHelper := mocks.NewMockExecutionHelper(t) + mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") + // Some languages call time during initiation of the executable before the main is called. + // This would be in unknown mode, which would call Node mode by default. + mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { + return time.Now() + }).Maybe() + + // subscribe so pre-hooks are known. + // subscriptions are always done before the first trigger + subscribe := &sdk.ExecuteRequest{Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}} + underlying := makeOptionalTestModuleWithConfig(t, nil) + m := host.NewRequirementSelectingModule(host.ModuleAndHandler{Module: underlying}, []host.ModuleAndHandler{}) + _, err := m.Execute(t.Context(), subscribe, mockExecutionHelper) + require.NoError(t, err) + + response := runWithBasicTriggerWithModule(t, mockExecutionHelper, m) + switch r := response.Result.(type) { + case *sdk.ExecutionResult_Error: + assert.Contains(t, r.Error, "call denied by user pre-hook restrictions: basic-test-action@1.0.0 PerformAction") + default: + assert.Fail(t, "Expected an error result due to restricted capability call, got %T", response.Result) + } +} + func triggerExecuteRequest(t *testing.T, id uint64, trigger proto.Message) *sdk.ExecuteRequest { wrappedTrigger, err := anypb.New(trigger) require.NoError(t, err) @@ -633,9 +661,12 @@ func triggerExecuteRequest(t *testing.T, id uint64, trigger proto.Message) *sdk. } func runWithBasicTrigger(t *testing.T, executor ExecutionHelper) *sdk.ExecutionResult { + return runWithBasicTriggerWithModule(t, executor, makeTestModule(t)) +} + +func runWithBasicTriggerWithModule(t *testing.T, executor ExecutionHelper, m ModuleV2) *sdk.ExecutionResult { trigger := &basictrigger.Outputs{CoolOutput: anyTestTriggerValue} executeRequest := triggerExecuteRequest(t, 0, trigger) - m := makeTestModule(t) response, err := m.Execute(t.Context(), executeRequest, executor) require.NoError(t, err) return response diff --git a/pkg/workflows/wasm/host/standard_tests/restrictions/main_wasip1.go b/pkg/workflows/wasm/host/standard_tests/restrictions/main_wasip1.go new file mode 100644 index 0000000000..a8bdcdd2ad --- /dev/null +++ b/pkg/workflows/wasm/host/standard_tests/restrictions/main_wasip1.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + + "google.golang.org/protobuf/types/known/anypb" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basicaction" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func main() { + request := rawsdk.GetRequest() + switch request.Request.(type) { + case *sdk.ExecuteRequest_Trigger: + input := &basicaction.Inputs{InputThing: true} + err := rawsdk.DoRequestErr("basic-test-action@1.0.0", "PerformAction", sdk.Mode_MODE_DON, input) + if err != nil { + rawsdk.SendError(err) + } + rawsdk.SendResponse("should have errored out...") + case *sdk.ExecuteRequest_PreHook: + rawsdk.SendRestrictions(&sdk.Restrictions{ + Capabilities: &sdk.CapabilityRestrictions{MaxTotalCalls: 0}, + }) + case *sdk.ExecuteRequest_Subscribe: + rawsdk.SendSubscription(&sdk.TriggerSubscriptionRequest{ + Subscriptions: []*sdk.TriggerSubscription{ + { + Id: "basic-test-trigger@1.0.0", + Payload: rawsdk.Must(anypb.New(&basictrigger.Config{ + Name: "first-trigger", + Number: 100, + })), + Method: "Trigger", + PreHook: true, + }, + }, + }) + default: + rawsdk.SendError(fmt.Errorf("unexpected request type: %T", request.Request)) + } +}