diff --git a/.changeset/confidential-relay-wiring.md b/.changeset/confidential-relay-wiring.md new file mode 100644 index 00000000000..b97e0305da3 --- /dev/null +++ b/.changeset/confidential-relay-wiring.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +Wire confidential relay handler as a CRE subservice for enclave secrets and capability dispatch #added #wip diff --git a/core/capabilities/confidentialrelay/handler.go b/core/capabilities/confidentialrelay/handler.go new file mode 100644 index 00000000000..7b5766491f9 --- /dev/null +++ b/core/capabilities/confidentialrelay/handler.go @@ -0,0 +1,466 @@ +package confidentialrelay + +import ( + "context" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/smartcontractkit/chainlink-common/pkg/beholder" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + confidentialrelaytypes "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/confidentialrelay" + vault "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/vault" + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + "github.com/smartcontractkit/chainlink-protos/cre/go/values" + valuespb "github.com/smartcontractkit/chainlink-protos/cre/go/values/pb" + + "github.com/smartcontractkit/chainlink-common/pkg/teeattestation" + "github.com/smartcontractkit/chainlink-common/pkg/teeattestation/nitro" +) + +var _ core.GatewayConnectorHandler = (*Handler)(nil) + +const HandlerName = "EnclaveRelayHandler" + +type handlerMetrics struct { + requestInternalError metric.Int64Counter + requestSuccess metric.Int64Counter +} + +func newMetrics() (*handlerMetrics, error) { + requestInternalError, err := beholder.GetMeter().Int64Counter("enclave_relay_request_internal_error") + if err != nil { + return nil, fmt.Errorf("failed to register internal error counter: %w", err) + } + requestSuccess, err := beholder.GetMeter().Int64Counter("enclave_relay_request_success") + if err != nil { + return nil, fmt.Errorf("failed to register success counter: %w", err) + } + return &handlerMetrics{ + requestInternalError: requestInternalError, + requestSuccess: requestSuccess, + }, nil +} + +type gatewayConnector interface { + SendToGateway(ctx context.Context, gatewayID string, resp *jsonrpc.Response[json.RawMessage]) error + AddHandler(ctx context.Context, methods []string, handler core.GatewayConnectorHandler) error + RemoveHandler(ctx context.Context, methods []string) error +} + +// attestationValidatorFunc validates a Nitro attestation document. +type attestationValidatorFunc func(attestation []byte, expectedUserData []byte, trustedMeasurements []byte, caRootsPEM string) error + +// Handler processes enclave relay requests from the gateway. +// It validates Nitro attestations and proxies requests to VaultDON or capability DONs. +type Handler struct { + services.Service + eng *services.Engine + + capRegistry core.CapabilitiesRegistry + gatewayConnector gatewayConnector + trustedPCRs []byte + lggr logger.Logger + metrics *handlerMetrics + + // validateAttestation validates Nitro attestation documents. + // Defaults to the real nitro validator; overridden in tests. + validateAttestation attestationValidatorFunc + + // caRootsPEM is an optional PEM-encoded CA root certificate for attestation + // validation. Empty string uses DefaultCARoots (production default). + caRootsPEM string +} + +func NewHandler(capRegistry core.CapabilitiesRegistry, conn gatewayConnector, trustedPCRs []byte, lggr logger.Logger, caRootsPEM ...string) (*Handler, error) { + m, err := newMetrics() + if err != nil { + return nil, fmt.Errorf("failed to create metrics: %w", err) + } + + var roots string + if len(caRootsPEM) > 0 { + roots = caRootsPEM[0] + } + h := &Handler{ + capRegistry: capRegistry, + gatewayConnector: conn, + trustedPCRs: trustedPCRs, + lggr: logger.Named(lggr, HandlerName), + metrics: m, + validateAttestation: nitro.ValidateAttestation, + caRootsPEM: roots, + } + h.Service, h.eng = services.Config{ + Name: HandlerName, + Start: h.start, + Close: h.close, + }.NewServiceEngine(lggr) + return h, nil +} + +func (h *Handler) start(ctx context.Context) error { + if err := h.gatewayConnector.AddHandler(ctx, h.Methods(), h); err != nil { + return fmt.Errorf("failed to add enclave relay handler to connector: %w", err) + } + return nil +} + +func (h *Handler) close() error { + if err := h.gatewayConnector.RemoveHandler(context.Background(), h.Methods()); err != nil { + return fmt.Errorf("failed to remove enclave relay handler from connector: %w", err) + } + return nil +} + +func (h *Handler) ID(_ context.Context) (string, error) { + return HandlerName, nil +} + +func (h *Handler) Methods() []string { + return []string{confidentialrelaytypes.MethodSecretsGet, confidentialrelaytypes.MethodCapabilityExec} +} + +func (h *Handler) HandleGatewayMessage(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage]) error { + h.lggr.Debugw("received message from gateway", "gatewayID", gatewayID, "requestID", req.ID) + + var response *jsonrpc.Response[json.RawMessage] + switch req.Method { + case confidentialrelaytypes.MethodSecretsGet: + response = h.handleSecretsGet(ctx, gatewayID, req) + case confidentialrelaytypes.MethodCapabilityExec: + response = h.handleCapabilityExecute(ctx, gatewayID, req) + default: + response = h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrMethodNotFound, errors.New("unsupported method: "+req.Method)) + } + + if err := h.gatewayConnector.SendToGateway(ctx, gatewayID, response); err != nil { + h.lggr.Errorw("failed to send message to gateway", "gatewayID", gatewayID, "err", err) + return err + } + + h.lggr.Infow("sent message to gateway", "gatewayID", gatewayID, "requestID", req.ID) + h.metrics.requestSuccess.Add(ctx, 1, metric.WithAttributes( + attribute.String("gateway_id", gatewayID), + )) + return nil +} + +func (h *Handler) handleSecretsGet(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage]) *jsonrpc.Response[json.RawMessage] { + var params confidentialrelaytypes.SecretsRequestParams + if err := json.Unmarshal(*req.Params, ¶ms); err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInvalidParams, err) + } + + att := params.Attestation + params.Attestation = "" + if err := h.verifyAttestationHash(att, params, confidentialrelaytypes.DomainSecretsGet); err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, err) + } + + vaultCap, err := h.capRegistry.GetExecutable(ctx, vault.CapabilityID) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, fmt.Errorf("failed to get vault capability: %w", err)) + } + + donID, err := h.resolveDONID(ctx, vaultCap) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, err) + } + + capConfig, err := h.capRegistry.ConfigForCapability(ctx, vault.CapabilityID, donID) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, fmt.Errorf("failed to get vault config: %w", err)) + } + var cfg vaultCapConfig + if err := capConfig.DefaultConfig.UnwrapTo(&cfg); err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, fmt.Errorf("failed to unwrap vault config: %w", err)) + } + + vaultReq := &vault.GetSecretsRequest{ + Requests: make([]*vault.SecretRequest, 0, len(params.Secrets)), + } + for _, s := range params.Secrets { + vaultReq.Requests = append(vaultReq.Requests, &vault.SecretRequest{ + Id: &vault.SecretIdentifier{ + Key: s.Key, + Namespace: s.Namespace, + }, + EncryptionKeys: []string{params.EnclavePublicKey}, + }) + } + + anypbReq, err := anypb.New(vaultReq) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, fmt.Errorf("failed to wrap vault request: %w", err)) + } + + capResp, err := vaultCap.Execute(ctx, capabilities.CapabilityRequest{ + Payload: anypbReq, + Method: vault.MethodGetSecrets, + CapabilityId: vault.CapabilityID, + Metadata: capabilities.RequestMetadata{ + WorkflowID: params.WorkflowID, + }, + }) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, fmt.Errorf("vault execute failed: %w", err)) + } + + vaultResp := &vault.GetSecretsResponse{} + if err := capResp.Payload.UnmarshalTo(vaultResp); err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, fmt.Errorf("failed to unmarshal vault response: %w", err)) + } + + result, err := translateVaultResponse(vaultResp, params.EnclavePublicKey, cfg.VaultPublicKey, cfg.Threshold) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, err) + } + + return h.jsonResponse(req, result) +} + +// resolveDONID determines the DON ID for a capability. +func (h *Handler) resolveDONID(ctx context.Context, cap capabilities.ExecutableCapability) (uint32, error) { + info, err := cap.Info(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get capability info: %w", err) + } + if info.IsLocal { + localNode, err := h.capRegistry.LocalNode(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get local node: %w", err) + } + return localNode.WorkflowDON.ID, nil + } + if info.DON == nil { + return 0, errors.New("capability is not associated with any DON") + } + return info.DON.ID, nil +} + +type vaultCapConfig struct { + VaultPublicKey string + Threshold int +} + +// translateVaultResponse converts a vault GetSecretsResponse to the enclave relay protocol format. +// Encoding conversion: hex (vault) -> base64 (enclave relay). +func translateVaultResponse(vaultResp *vault.GetSecretsResponse, enclaveKey, masterPK string, threshold int) (*confidentialrelaytypes.SecretsResponseResult, error) { + result := &confidentialrelaytypes.SecretsResponseResult{ + MasterPublicKey: "0x" + masterPK, + Threshold: threshold, + } + + for _, sr := range vaultResp.Responses { + if sr.GetError() != "" { + return nil, fmt.Errorf("vault error for secret %s/%s: %s", sr.Id.GetNamespace(), sr.Id.GetKey(), sr.GetError()) + } + + data := sr.GetData() + if data == nil { + return nil, fmt.Errorf("vault returned no data for secret %s/%s", sr.Id.GetNamespace(), sr.Id.GetKey()) + } + + encryptedBytes, err := hex.DecodeString(data.EncryptedValue) + if err != nil { + return nil, fmt.Errorf("failed to decode encrypted value for %s: %w", sr.Id.GetKey(), err) + } + + var shares []string + for _, es := range data.EncryptedDecryptionKeyShares { + if es.EncryptionKey == enclaveKey { + for _, share := range es.Shares { + shareBytes, err := hex.DecodeString(share) + if err != nil { + return nil, fmt.Errorf("failed to decode share: %w", err) + } + shares = append(shares, base64.StdEncoding.EncodeToString(shareBytes)) + } + break + } + } + if len(shares) == 0 { + return nil, fmt.Errorf("no shares found for enclave key in secret %s/%s", sr.Id.GetNamespace(), sr.Id.GetKey()) + } + + result.Secrets = append(result.Secrets, confidentialrelaytypes.SecretEntry{ + ID: confidentialrelaytypes.SecretIdentifier{ + Key: sr.Id.GetKey(), + Namespace: sr.Id.GetNamespace(), + }, + Ciphertext: base64.StdEncoding.EncodeToString(encryptedBytes), + EncryptedShares: shares, + }) + } + + return result, nil +} + +func (h *Handler) handleCapabilityExecute(ctx context.Context, gatewayID string, req *jsonrpc.Request[json.RawMessage]) *jsonrpc.Response[json.RawMessage] { + var params confidentialrelaytypes.CapabilityRequestParams + if err := json.Unmarshal(*req.Params, ¶ms); err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInvalidParams, err) + } + + att := params.Attestation + params.Attestation = "" + if err := h.verifyAttestationHash(att, params, confidentialrelaytypes.DomainCapabilityExec); err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, err) + } + + cap, err := h.capRegistry.GetExecutable(ctx, params.CapabilityID) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, fmt.Errorf("capability not found: %w", err)) + } + + payloadBytes, err := base64.StdEncoding.DecodeString(params.Payload) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInvalidParams, fmt.Errorf("failed to decode payload: %w", err)) + } + + var sdkReq sdkpb.CapabilityRequest + if err := proto.Unmarshal(payloadBytes, &sdkReq); err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInvalidParams, fmt.Errorf("failed to unmarshal capability request: %w", err)) + } + + capReq := capabilities.CapabilityRequest{ + Payload: sdkReq.Payload, + Method: sdkReq.Method, + CapabilityId: params.CapabilityID, + Metadata: capabilities.RequestMetadata{ + WorkflowID: params.WorkflowID, + }, + } + + // Backward compatibility: extract values.Map from Payload into Inputs + // for old-style capabilities that only look at Inputs. + if sdkReq.Payload != nil { + var valPB valuespb.Value + if sdkReq.Payload.UnmarshalTo(&valPB) == nil { + if v, vErr := values.FromProto(&valPB); vErr == nil { + if m, ok := v.(*values.Map); ok { + capReq.Inputs = m + } + } + } + } + + capResp, execErr := cap.Execute(ctx, capReq) + + var result confidentialrelaytypes.CapabilityResponseResult + if execErr != nil { + result.Error = execErr.Error() + } else { + sdkResp, err := toSDKCapabilityResponse(capResp) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, fmt.Errorf("converting capability response: %w", err)) + } + respBytes, err := proto.Marshal(sdkResp) + if err != nil { + return h.errorResponse(ctx, gatewayID, req, jsonrpc.ErrInternal, fmt.Errorf("marshalling capability response: %w", err)) + } + result.Payload = base64.StdEncoding.EncodeToString(respBytes) + } + + return h.jsonResponse(req, result) +} + +func (h *Handler) verifyAttestationHash(attestationB64 string, cleanParams any, domainTag string) error { + if attestationB64 == "" { + return errors.New("missing attestation") + } + + paramsJSON, err := json.Marshal(cleanParams) + if err != nil { + return fmt.Errorf("failed to marshal params for attestation: %w", err) + } + + hash := teeattestation.DomainHash(domainTag, paramsJSON) + + attestationBytes, err := base64.StdEncoding.DecodeString(attestationB64) + if err != nil { + return fmt.Errorf("failed to decode attestation: %w", err) + } + + return h.validateAttestation(attestationBytes, hash, h.trustedPCRs, h.caRootsPEM) +} + +func toSDKCapabilityResponse(capResp capabilities.CapabilityResponse) (*sdkpb.CapabilityResponse, error) { + if capResp.Payload != nil { + return &sdkpb.CapabilityResponse{ + Response: &sdkpb.CapabilityResponse_Payload{Payload: capResp.Payload}, + }, nil + } + + if capResp.Value != nil { + valProto := values.Proto(capResp.Value) + wrapped, err := anypb.New(valProto) + if err != nil { + return nil, fmt.Errorf("wrapping value map in Any: %w", err) + } + return &sdkpb.CapabilityResponse{ + Response: &sdkpb.CapabilityResponse_Payload{Payload: wrapped}, + }, nil + } + + return &sdkpb.CapabilityResponse{}, nil +} + +func (h *Handler) jsonResponse(req *jsonrpc.Request[json.RawMessage], result any) *jsonrpc.Response[json.RawMessage] { + resultBytes, err := json.Marshal(result) + if err != nil { + h.lggr.Errorw("failed to marshal response", "err", err) + return &jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: req.ID, + Method: req.Method, + Error: &jsonrpc.WireError{ + Code: jsonrpc.ErrInternal, + Message: err.Error(), + }, + } + } + resultJSON := json.RawMessage(resultBytes) + return &jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: req.ID, + Method: req.Method, + Result: &resultJSON, + } +} + +func (h *Handler) errorResponse( + ctx context.Context, + gatewayID string, + req *jsonrpc.Request[json.RawMessage], + errorCode int64, + err error, +) *jsonrpc.Response[json.RawMessage] { + h.lggr.Errorw("request error", "errorCode", errorCode, "err", err) + h.metrics.requestInternalError.Add(ctx, 1, metric.WithAttributes( + attribute.String("gateway_id", gatewayID), + attribute.Int64("error_code", errorCode), + )) + + return &jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: req.ID, + Method: req.Method, + Error: &jsonrpc.WireError{ + Code: errorCode, + Message: err.Error(), + }, + } +} diff --git a/core/capabilities/confidentialrelay/handler_test.go b/core/capabilities/confidentialrelay/handler_test.go new file mode 100644 index 00000000000..27f158d166a --- /dev/null +++ b/core/capabilities/confidentialrelay/handler_test.go @@ -0,0 +1,342 @@ +package confidentialrelay + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + confidentialrelaytypes "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/confidentialrelay" + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + "github.com/smartcontractkit/chainlink-protos/cre/go/values" + valuespb "github.com/smartcontractkit/chainlink-protos/cre/go/values/pb" +) + +func makeCapabilityPayload(t *testing.T, inputs map[string]any) string { + t.Helper() + wrapped, err := values.Wrap(inputs) + require.NoError(t, err) + payload, err := anypb.New(values.Proto(wrapped)) + require.NoError(t, err) + sdkReq := &sdkpb.CapabilityRequest{ + Id: "my-cap@1.0.0", + Payload: payload, + Method: "Execute", + } + b, err := proto.Marshal(sdkReq) + require.NoError(t, err) + return base64.StdEncoding.EncodeToString(b) +} + +const testAttestationB64 = "ZHVtbXktYXR0ZXN0YXRpb24=" // base64("dummy-attestation") + +func noopValidator(_ []byte, _, _ []byte, _ string) error { return nil } + +type mockGatewayConnector struct { + lastResp *jsonrpc.Response[json.RawMessage] + addedMethods []string + removed bool +} + +func (m *mockGatewayConnector) SendToGateway(_ context.Context, _ string, resp *jsonrpc.Response[json.RawMessage]) error { + m.lastResp = resp + return nil +} +func (m *mockGatewayConnector) AddHandler(_ context.Context, methods []string, _ core.GatewayConnectorHandler) error { + m.addedMethods = methods + return nil +} +func (m *mockGatewayConnector) RemoveHandler(_ context.Context, _ []string) error { + m.removed = true + return nil +} + +type mockExecutable struct { + infoResult capabilities.CapabilityInfo + infoErr error + execResult capabilities.CapabilityResponse + execErr error + lastRequest *capabilities.CapabilityRequest +} + +func (m *mockExecutable) Info(_ context.Context) (capabilities.CapabilityInfo, error) { + return m.infoResult, m.infoErr +} +func (m *mockExecutable) Execute(_ context.Context, req capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { + m.lastRequest = &req + return m.execResult, m.execErr +} +func (m *mockExecutable) RegisterToWorkflow(_ context.Context, _ capabilities.RegisterToWorkflowRequest) error { + return nil +} +func (m *mockExecutable) UnregisterFromWorkflow(_ context.Context, _ capabilities.UnregisterFromWorkflowRequest) error { + return nil +} + +type mockCapRegistry struct { + core.UnimplementedCapabilitiesRegistry + executables map[string]*mockExecutable + configs map[string]capabilities.CapabilityConfiguration + localNode capabilities.Node +} + +func (m *mockCapRegistry) GetExecutable(_ context.Context, id string) (capabilities.ExecutableCapability, error) { + if exec, ok := m.executables[id]; ok { + return exec, nil + } + return nil, fmt.Errorf("capability not found: %s", id) +} +func (m *mockCapRegistry) ConfigForCapability(_ context.Context, capID string, _ uint32) (capabilities.CapabilityConfiguration, error) { + if cfg, ok := m.configs[capID]; ok { + return cfg, nil + } + return capabilities.CapabilityConfiguration{}, fmt.Errorf("config not found: %s", capID) +} +func (m *mockCapRegistry) LocalNode(_ context.Context) (capabilities.Node, error) { + return m.localNode, nil +} + +func newTestHandler(t *testing.T, registry core.CapabilitiesRegistry, gwConn gatewayConnector) *Handler { + t.Helper() + lggr, err := logger.New() + require.NoError(t, err) + h, err := NewHandler(registry, gwConn, []byte(`{}`), lggr) + require.NoError(t, err) + h.validateAttestation = noopValidator + return h +} + +func makeRequest(t *testing.T, method string, params any) *jsonrpc.Request[json.RawMessage] { + t.Helper() + b, err := json.Marshal(params) + require.NoError(t, err) + raw := json.RawMessage(b) + return &jsonrpc.Request[json.RawMessage]{ + Method: method, + ID: "req-1", + Params: &raw, + } +} + +func TestHandler_HandleGatewayMessage(t *testing.T) { + tests := []struct { + name string + registry func(t *testing.T) *mockCapRegistry + req func(t *testing.T) *jsonrpc.Request[json.RawMessage] + checkResp func(t *testing.T, resp *jsonrpc.Response[json.RawMessage]) + checkExecutable func(t *testing.T, reg *mockCapRegistry) + }{ + { + name: "capability execute success", + registry: func(_ *testing.T) *mockCapRegistry { + return &mockCapRegistry{ + executables: map[string]*mockExecutable{ + "my-cap@1.0.0": { + execResult: capabilities.CapabilityResponse{ + Payload: &anypb.Any{Value: []byte("result-proto-bytes")}, + }, + }, + }, + } + }, + req: func(t *testing.T) *jsonrpc.Request[json.RawMessage] { + return makeRequest(t, confidentialrelaytypes.MethodCapabilityExec, confidentialrelaytypes.CapabilityRequestParams{ + WorkflowID: "wf-1", + CapabilityID: "my-cap@1.0.0", + Payload: makeCapabilityPayload(t, map[string]any{"key": "val"}), + Attestation: testAttestationB64, + }) + }, + checkResp: func(t *testing.T, resp *jsonrpc.Response[json.RawMessage]) { + require.Nil(t, resp.Error) + var result confidentialrelaytypes.CapabilityResponseResult + require.NoError(t, json.Unmarshal(*resp.Result, &result)) + decoded, err := base64.StdEncoding.DecodeString(result.Payload) + require.NoError(t, err) + var capResp sdkpb.CapabilityResponse + require.NoError(t, proto.Unmarshal(decoded, &capResp)) + require.NotNil(t, capResp.GetPayload()) + assert.Equal(t, "result-proto-bytes", string(capResp.GetPayload().GetValue())) + assert.Empty(t, result.Error) + }, + }, + { + name: "capability execute sets Inputs from Payload for backward compat", + registry: func(_ *testing.T) *mockCapRegistry { + return &mockCapRegistry{ + executables: map[string]*mockExecutable{ + "my-cap@1.0.0": { + execResult: capabilities.CapabilityResponse{}, + }, + }, + } + }, + req: func(t *testing.T) *jsonrpc.Request[json.RawMessage] { + return makeRequest(t, confidentialrelaytypes.MethodCapabilityExec, confidentialrelaytypes.CapabilityRequestParams{ + WorkflowID: "wf-1", + CapabilityID: "my-cap@1.0.0", + Payload: makeCapabilityPayload(t, map[string]any{"echo": "hello"}), + Attestation: testAttestationB64, + }) + }, + checkResp: func(t *testing.T, resp *jsonrpc.Response[json.RawMessage]) { + require.Nil(t, resp.Error) + }, + checkExecutable: func(t *testing.T, reg *mockCapRegistry) { + exec := reg.executables["my-cap@1.0.0"] + require.NotNil(t, exec.lastRequest, "Execute should have been called") + require.NotNil(t, exec.lastRequest.Payload) + var valPB valuespb.Value + require.NoError(t, exec.lastRequest.Payload.UnmarshalTo(&valPB)) + require.NotNil(t, exec.lastRequest.Inputs) + unwrapped, err := exec.lastRequest.Inputs.Unwrap() + require.NoError(t, err) + m, ok := unwrapped.(map[string]any) + require.True(t, ok) + assert.Equal(t, "hello", m["echo"]) + }, + }, + { + name: "capability execute attestation failure", + registry: func(_ *testing.T) *mockCapRegistry { + return &mockCapRegistry{} + }, + req: func(t *testing.T) *jsonrpc.Request[json.RawMessage] { + return makeRequest(t, confidentialrelaytypes.MethodCapabilityExec, confidentialrelaytypes.CapabilityRequestParams{ + WorkflowID: "wf-1", + CapabilityID: "my-cap@1.0.0", + Payload: base64.StdEncoding.EncodeToString([]byte("payload")), + }) + }, + checkResp: func(t *testing.T, resp *jsonrpc.Response[json.RawMessage]) { + require.NotNil(t, resp.Error) + assert.Equal(t, jsonrpc.ErrInternal, resp.Error.Code) + }, + }, + { + name: "capability execute not found", + registry: func(_ *testing.T) *mockCapRegistry { + return &mockCapRegistry{executables: map[string]*mockExecutable{}} + }, + req: func(t *testing.T) *jsonrpc.Request[json.RawMessage] { + return makeRequest(t, confidentialrelaytypes.MethodCapabilityExec, confidentialrelaytypes.CapabilityRequestParams{ + WorkflowID: "wf-1", + CapabilityID: "missing-cap@1.0.0", + Payload: base64.StdEncoding.EncodeToString([]byte("payload")), + Attestation: testAttestationB64, + }) + }, + checkResp: func(t *testing.T, resp *jsonrpc.Response[json.RawMessage]) { + require.NotNil(t, resp.Error) + assert.Equal(t, jsonrpc.ErrInternal, resp.Error.Code) + assert.Contains(t, resp.Error.Message, "capability not found") + }, + }, + { + name: "capability execute error returned in result", + registry: func(_ *testing.T) *mockCapRegistry { + return &mockCapRegistry{ + executables: map[string]*mockExecutable{ + "fail-cap@1.0.0": {execErr: errors.New("execution failed")}, + }, + } + }, + req: func(t *testing.T) *jsonrpc.Request[json.RawMessage] { + sdkReq := &sdkpb.CapabilityRequest{Id: "fail-cap@1.0.0", Method: "Execute"} + b, err := proto.Marshal(sdkReq) + require.NoError(t, err) + return makeRequest(t, confidentialrelaytypes.MethodCapabilityExec, confidentialrelaytypes.CapabilityRequestParams{ + WorkflowID: "wf-1", + CapabilityID: "fail-cap@1.0.0", + Payload: base64.StdEncoding.EncodeToString(b), + Attestation: testAttestationB64, + }) + }, + checkResp: func(t *testing.T, resp *jsonrpc.Response[json.RawMessage]) { + require.Nil(t, resp.Error) + var result confidentialrelaytypes.CapabilityResponseResult + require.NoError(t, json.Unmarshal(*resp.Result, &result)) + assert.Equal(t, "execution failed", result.Error) + assert.Empty(t, result.Payload) + }, + }, + { + name: "unsupported method", + registry: func(_ *testing.T) *mockCapRegistry { + return &mockCapRegistry{} + }, + req: func(t *testing.T) *jsonrpc.Request[json.RawMessage] { + return makeRequest(t, "unknown.method", nil) + }, + checkResp: func(t *testing.T, resp *jsonrpc.Response[json.RawMessage]) { + require.NotNil(t, resp.Error) + assert.Equal(t, jsonrpc.ErrMethodNotFound, resp.Error.Code) + }, + }, + { + name: "invalid params JSON", + registry: func(_ *testing.T) *mockCapRegistry { + return &mockCapRegistry{} + }, + req: func(_ *testing.T) *jsonrpc.Request[json.RawMessage] { + raw := json.RawMessage([]byte(`{invalid json`)) + return &jsonrpc.Request[json.RawMessage]{ + Method: confidentialrelaytypes.MethodCapabilityExec, + ID: "req-1", + Params: &raw, + } + }, + checkResp: func(t *testing.T, resp *jsonrpc.Response[json.RawMessage]) { + require.NotNil(t, resp.Error) + assert.Equal(t, jsonrpc.ErrInvalidParams, resp.Error.Code) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gwConn := &mockGatewayConnector{} + reg := tt.registry(t) + h := newTestHandler(t, reg, gwConn) + err := h.HandleGatewayMessage(t.Context(), "gw-1", tt.req(t)) + assert.NoError(t, err) + require.NotNil(t, gwConn.lastResp) + tt.checkResp(t, gwConn.lastResp) + if tt.checkExecutable != nil { + tt.checkExecutable(t, reg) + } + }) + } +} + +func TestHandler_Lifecycle(t *testing.T) { + gwConn := &mockGatewayConnector{} + h := newTestHandler(t, &mockCapRegistry{}, gwConn) + + t.Run("start registers handler", func(t *testing.T) { + require.NoError(t, h.Start(t.Context())) + assert.Equal(t, h.Methods(), gwConn.addedMethods) + }) + + t.Run("close removes handler", func(t *testing.T) { + require.NoError(t, h.Close()) + assert.True(t, gwConn.removed) + }) + + t.Run("ID returns handler name", func(t *testing.T) { + id, err := h.ID(t.Context()) + require.NoError(t, err) + assert.Equal(t, HandlerName, id) + }) +} diff --git a/core/capabilities/confidentialrelay/service.go b/core/capabilities/confidentialrelay/service.go new file mode 100644 index 00000000000..b9ae214848f --- /dev/null +++ b/core/capabilities/confidentialrelay/service.go @@ -0,0 +1,70 @@ +package confidentialrelay + +import ( + "context" + "errors" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" + + gatewayconnector "github.com/smartcontractkit/chainlink/v2/core/capabilities/gateway_connector" +) + +// Service is a thin lifecycle wrapper around the confidential relay handler. +// The relay handler needs the gateway connector, which isn't available until +// the ServiceWrapper starts. This wrapper defers handler creation to Start(). +type Service struct { + services.Service + eng *services.Engine + + wrapper *gatewayconnector.ServiceWrapper + capRegistry core.CapabilitiesRegistry + trustedPCRs []byte + caRootsPEM string + lggr logger.Logger + + handler *Handler +} + +func NewService( + wrapper *gatewayconnector.ServiceWrapper, + capRegistry core.CapabilitiesRegistry, + trustedPCRs []byte, + caRootsPEM string, + lggr logger.Logger, +) *Service { + s := &Service{ + wrapper: wrapper, + capRegistry: capRegistry, + trustedPCRs: trustedPCRs, + caRootsPEM: caRootsPEM, + lggr: lggr, + } + s.Service, s.eng = services.Config{ + Name: "ConfidentialRelayService", + Start: s.start, + Close: s.close, + }.NewServiceEngine(lggr) + return s +} + +func (s *Service) start(ctx context.Context) error { + conn := s.wrapper.GetGatewayConnector() + if conn == nil { + return errors.New("gateway connector not available") + } + h, err := NewHandler(s.capRegistry, conn, s.trustedPCRs, s.lggr, s.caRootsPEM) + if err != nil { + return err + } + s.handler = h + return h.Start(ctx) +} + +func (s *Service) close() error { + if s.handler != nil { + return s.handler.Close() + } + return nil +} diff --git a/core/config/cre_config.go b/core/config/cre_config.go index 57bc4ad31a9..78a600d1490 100644 --- a/core/config/cre_config.go +++ b/core/config/cre_config.go @@ -13,6 +13,7 @@ type CRE interface { // When enabled, additional OTel tracing and logging is performed. DebugMode() bool LocalSecrets() map[string]string + ConfidentialRelay() CREConfidentialRelay } // WorkflowFetcher defines configuration for fetching workflow files @@ -21,6 +22,13 @@ type WorkflowFetcher interface { URL() string } +// CREConfidentialRelay defines configuration for the confidential relay handler. +type CREConfidentialRelay interface { + Enabled() bool + TrustedPCRs() string + CARootsPEM() string +} + // CRELinking defines configuration for connecting to the CRE linking service type CRELinking interface { URL() string diff --git a/core/config/toml/types.go b/core/config/toml/types.go index 74984ec5db9..a0c5b904d81 100644 --- a/core/config/toml/types.go +++ b/core/config/toml/types.go @@ -1895,7 +1895,8 @@ type CreConfig struct { // Requires [Tracing].Enabled = true for traces to be exported (trace export is gated by // Tracing.Enabled in initGlobals; Telemetry.Enabled is optional—traces work with or without it). // WARNING: This is not suitable for production use due to performance overhead. - DebugMode *bool `toml:",omitempty"` + DebugMode *bool `toml:",omitempty"` + ConfidentialRelay *ConfidentialRelayConfig `toml:",omitempty"` } // WorkflowFetcherConfig holds the configuration for fetching workflow files @@ -1903,6 +1904,15 @@ type WorkflowFetcherConfig struct { URL *string `toml:",omitempty"` } +// ConfidentialRelayConfig holds the configuration for the confidential relay handler. +// When Enabled is true, the node participates in the confidential relay DON, +// validating enclave attestations and proxying capability requests. +type ConfidentialRelayConfig struct { + Enabled *bool `toml:",omitempty"` + TrustedPCRs *string `toml:",omitempty"` + CARootsPEM *string `toml:",omitempty"` +} + // LinkingConfig holds the configuration for connecting to the CRE linking service type LinkingConfig struct { URL *string `toml:",omitempty"` @@ -1956,6 +1966,21 @@ func (c *CreConfig) setFrom(f *CreConfig) { if f.DebugMode != nil { c.DebugMode = f.DebugMode } + + if f.ConfidentialRelay != nil { + if c.ConfidentialRelay == nil { + c.ConfidentialRelay = &ConfidentialRelayConfig{} + } + if v := f.ConfidentialRelay.Enabled; v != nil { + c.ConfidentialRelay.Enabled = v + } + if v := f.ConfidentialRelay.TrustedPCRs; v != nil { + c.ConfidentialRelay.TrustedPCRs = v + } + if v := f.ConfidentialRelay.CARootsPEM; v != nil { + c.ConfidentialRelay.CARootsPEM = v + } + } } func (w *WorkflowFetcherConfig) ValidateConfig() error { diff --git a/core/services/chainlink/config_cre.go b/core/services/chainlink/config_cre.go index 6714d6a63d5..e144d49dfd3 100644 --- a/core/services/chainlink/config_cre.go +++ b/core/services/chainlink/config_cre.go @@ -105,6 +105,35 @@ func (c *creConfig) Linking() config.CRELinking { return &linkingConfig{url: url, tlsEnabled: tlsEnabled} } +type confidentialRelayConfig struct { + enabled bool + trustedPCRs string + caRootsPEM string +} + +func (cr *confidentialRelayConfig) Enabled() bool { return cr.enabled } +func (cr *confidentialRelayConfig) TrustedPCRs() string { return cr.trustedPCRs } +func (cr *confidentialRelayConfig) CARootsPEM() string { return cr.caRootsPEM } + +func (c *creConfig) ConfidentialRelay() config.CREConfidentialRelay { + if c.c.ConfidentialRelay == nil { + return &confidentialRelayConfig{} + } + enabled := false + if c.c.ConfidentialRelay.Enabled != nil { + enabled = *c.c.ConfidentialRelay.Enabled + } + trustedPCRs := "" + if c.c.ConfidentialRelay.TrustedPCRs != nil { + trustedPCRs = *c.c.ConfidentialRelay.TrustedPCRs + } + caRootsPEM := "" + if c.c.ConfidentialRelay.CARootsPEM != nil { + caRootsPEM = *c.c.ConfidentialRelay.CARootsPEM + } + return &confidentialRelayConfig{enabled: enabled, trustedPCRs: trustedPCRs, caRootsPEM: caRootsPEM} +} + func (c *creConfig) LocalSecrets() map[string]string { return c.s.LocalSecrets } diff --git a/core/services/cre/cre.go b/core/services/cre/cre.go index baf87584cc8..de260a53d64 100644 --- a/core/services/cre/cre.go +++ b/core/services/cre/cre.go @@ -32,6 +32,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/capabilities" "github.com/smartcontractkit/chainlink/v2/core/capabilities/compute" + "github.com/smartcontractkit/chainlink/v2/core/capabilities/confidentialrelay" gatewayconnector "github.com/smartcontractkit/chainlink/v2/core/capabilities/gateway_connector" "github.com/smartcontractkit/chainlink/v2/core/capabilities/localcapmgr" "github.com/smartcontractkit/chainlink/v2/core/capabilities/remote" @@ -169,6 +170,17 @@ func (s *Services) newSubservices( } s.GatewayConnectorWrapper = gatewayConnectorWrapper srvs = append(srvs, gatewayConnectorWrapper) + + if relayConfig := cfg.CRE().ConfidentialRelay(); relayConfig.Enabled() { + relayService := confidentialrelay.NewService( + gatewayConnectorWrapper, + opts.CapabilitiesRegistry, + []byte(relayConfig.TrustedPCRs()), + relayConfig.CARootsPEM(), + lggr, + ) + srvs = append(srvs, relayService) + } } if cfg.CRE().Linking().URL() != "" { diff --git a/core/services/gateway/handler_factory.go b/core/services/gateway/handler_factory.go index 76172b3dc9b..055021cb28a 100644 --- a/core/services/gateway/handler_factory.go +++ b/core/services/gateway/handler_factory.go @@ -17,6 +17,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/confidentialrelay" v2 "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/capabilities/v2" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions" "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/vault" @@ -29,7 +30,8 @@ const ( DummyHandlerType HandlerType = "dummy" WebAPICapabilitiesType HandlerType = "web-api-capabilities" // Handler for v0.1 HTTP capabilities for DAG workflows HTTPCapabilityType HandlerType = "http-capabilities" // Handler for v1.0 HTTP capabilities for NoDAG workflows - VaultHandlerType HandlerType = "vault" + VaultHandlerType HandlerType = "vault" + ConfidentialRelayHandlerType HandlerType = "confidential-compute-relay" ) type handlerFactory struct { @@ -87,6 +89,8 @@ func (hf *handlerFactory) NewHandler( case VaultHandlerType: requestAuthorizer := vaultcap.NewRequestAuthorizer(hf.lggr, hf.workflowRegistrySyncer) return vault.NewHandler(handlerConfig, donConfig, don, hf.capabilitiesRegistry, requestAuthorizer, hf.lggr, clockwork.NewRealClock(), hf.lf) + case ConfidentialRelayHandlerType: + return confidentialrelay.NewHandler(handlerConfig, donConfig, don, hf.lggr, clockwork.NewRealClock()) default: return nil, fmt.Errorf("unsupported handler type %s", handlerType) } diff --git a/core/services/gateway/handlers/confidentialrelay/aggregator.go b/core/services/gateway/handlers/confidentialrelay/aggregator.go new file mode 100644 index 00000000000..40af2848405 --- /dev/null +++ b/core/services/gateway/handlers/confidentialrelay/aggregator.go @@ -0,0 +1,50 @@ +package confidentialrelay + +import ( + "encoding/json" + "errors" + "strconv" + + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +var ( + errInsufficientResponsesForQuorum = errors.New("insufficient valid responses to reach quorum") + errQuorumUnobtainable = errors.New("quorum unobtainable") +) + +type aggregator struct{} + +func (a *aggregator) Aggregate(resps map[string]jsonrpc.Response[json.RawMessage], donF int, donMembersCount int, l logger.Logger) (*jsonrpc.Response[json.RawMessage], error) { + requiredQuorum := 2*donF + 1 + + if len(resps) < requiredQuorum { + return nil, errInsufficientResponsesForQuorum + } + + shaToCount := map[string]int{} + maxShaToCount := 0 + for _, r := range resps { + sha, err := r.Digest() + if err != nil { + l.Errorw("failed to compute digest of response during quorum validation, skipping...", "error", err) + continue + } + shaToCount[sha]++ + if shaToCount[sha] > maxShaToCount { + maxShaToCount = shaToCount[sha] + } + if shaToCount[sha] >= requiredQuorum { + return &r, nil + } + } + + remainingResponses := donMembersCount - len(resps) + if maxShaToCount+remainingResponses < requiredQuorum { + l.Warnw("quorum unattainable for request", "requiredQuorum", requiredQuorum, "remainingResponses", remainingResponses, "maxShaToCount", maxShaToCount) + return nil, errors.New(errQuorumUnobtainable.Error() + ". RequiredQuorum=" + strconv.Itoa(requiredQuorum) + ". maxShaToCount=" + strconv.Itoa(maxShaToCount) + " remainingResponses=" + strconv.Itoa(remainingResponses)) + } + + return nil, errInsufficientResponsesForQuorum +} diff --git a/core/services/gateway/handlers/confidentialrelay/handler.go b/core/services/gateway/handlers/confidentialrelay/handler.go new file mode 100644 index 00000000000..ad85c503d14 --- /dev/null +++ b/core/services/gateway/handlers/confidentialrelay/handler.go @@ -0,0 +1,432 @@ +package confidentialrelay + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "maps" + "strconv" + "sync" + "time" + + "github.com/jonboulle/clockwork" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/smartcontractkit/chainlink-common/pkg/beholder" + relaytypes "github.com/smartcontractkit/chainlink-common/pkg/capabilities/actions/confidentialrelay" + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/ratelimit" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" + gwhandlers "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers" +) + +const ( + defaultCleanUpPeriod = 5 * time.Second + + // Re-exported from chainlink-common for local use and test convenience. + MethodSecretsGet = relaytypes.MethodSecretsGet + MethodCapabilityExec = relaytypes.MethodCapabilityExec +) + +var _ gwhandlers.Handler = (*handler)(nil) + +type metrics struct { + requestInternalError metric.Int64Counter + requestUserError metric.Int64Counter + requestSuccess metric.Int64Counter +} + +func newMetrics() (*metrics, error) { + requestInternalError, err := beholder.GetMeter().Int64Counter("confidential_relay_gateway_request_internal_error") + if err != nil { + return nil, fmt.Errorf("failed to register internal error counter: %w", err) + } + + requestUserError, err := beholder.GetMeter().Int64Counter("confidential_relay_gateway_request_user_error") + if err != nil { + return nil, fmt.Errorf("failed to register user error counter: %w", err) + } + + requestSuccess, err := beholder.GetMeter().Int64Counter("confidential_relay_gateway_request_success") + if err != nil { + return nil, fmt.Errorf("failed to register success counter: %w", err) + } + + return &metrics{ + requestInternalError: requestInternalError, + requestUserError: requestUserError, + requestSuccess: requestSuccess, + }, nil +} + +type activeRequest struct { + req jsonrpc.Request[json.RawMessage] + responses map[string]*jsonrpc.Response[json.RawMessage] + mu sync.Mutex + + createdAt time.Time + gwhandlers.Callback +} + +func (ar *activeRequest) addResponseForNode(nodeAddr string, resp *jsonrpc.Response[json.RawMessage]) bool { + ar.mu.Lock() + defer ar.mu.Unlock() + _, exists := ar.responses[nodeAddr] + if exists { + return false + } + + ar.responses[nodeAddr] = resp + return true +} + +func (ar *activeRequest) copiedResponses() map[string]jsonrpc.Response[json.RawMessage] { + ar.mu.Lock() + defer ar.mu.Unlock() + copied := make(map[string]jsonrpc.Response[json.RawMessage], len(ar.responses)) + for k, response := range ar.responses { + var copiedResponse jsonrpc.Response[json.RawMessage] + if response != nil { + copiedResponse = *response + if response.Result != nil { + copiedResult := *response.Result + copiedResponse.Result = &copiedResult + } + if response.Error != nil { + copiedError := *response.Error + copiedResponse.Error = &copiedError + } + } + copied[k] = copiedResponse + } + return copied +} + +type relayAggregator interface { + Aggregate(resps map[string]jsonrpc.Response[json.RawMessage], donF int, donMembersCount int, l logger.Logger) (*jsonrpc.Response[json.RawMessage], error) +} + +type Config struct { + NodeRateLimiter ratelimit.RateLimiterConfig `json:"nodeRateLimiter"` + RequestTimeoutSec int `json:"requestTimeoutSec"` +} + +type handler struct { + services.StateMachine + donConfig *config.DONConfig + don gwhandlers.DON + codec api.JsonRPCCodec + lggr logger.Logger + mu sync.RWMutex + stopCh services.StopChan + + nodeRateLimiter *ratelimit.RateLimiter + requestTimeout time.Duration + + activeRequests map[string]*activeRequest + metrics *metrics + + aggregator relayAggregator + + clock clockwork.Clock +} + +func (h *handler) HealthReport() map[string]error { + return map[string]error{h.Name(): h.Healthy()} +} + +func (h *handler) Name() string { + return h.lggr.Name() +} + +func NewHandler(methodConfig json.RawMessage, donConfig *config.DONConfig, don gwhandlers.DON, lggr logger.Logger, clock clockwork.Clock) (*handler, error) { + var cfg Config + if err := json.Unmarshal(methodConfig, &cfg); err != nil { + return nil, fmt.Errorf("failed to unmarshal method config: %w", err) + } + + if cfg.RequestTimeoutSec == 0 { + cfg.RequestTimeoutSec = 30 + } + + nodeRateLimiter, err := ratelimit.NewRateLimiter(cfg.NodeRateLimiter) + if err != nil { + return nil, fmt.Errorf("failed to create node rate limiter: %w", err) + } + + metrics, err := newMetrics() + if err != nil { + return nil, fmt.Errorf("failed to create metrics: %w", err) + } + + return &handler{ + donConfig: donConfig, + don: don, + lggr: logger.Named(lggr, "ConfidentialRelayHandler:"+donConfig.DonId), + requestTimeout: time.Duration(cfg.RequestTimeoutSec) * time.Second, + nodeRateLimiter: nodeRateLimiter, + activeRequests: make(map[string]*activeRequest), + mu: sync.RWMutex{}, + stopCh: make(services.StopChan), + metrics: metrics, + aggregator: &aggregator{}, + clock: clock, + }, nil +} + +func (h *handler) Start(_ context.Context) error { + return h.StartOnce("ConfidentialRelayHandler", func() error { + h.lggr.Info("starting confidential relay handler") + go func() { + ctx, cancel := h.stopCh.NewCtx() + defer cancel() + ticker := h.clock.NewTicker(defaultCleanUpPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.Chan(): + h.removeExpiredRequests(ctx) + case <-h.stopCh: + return + } + } + }() + return nil + }) +} + +func (h *handler) Close() error { + return h.StopOnce("ConfidentialRelayHandler", func() error { + h.lggr.Info("closing confidential relay handler") + close(h.stopCh) + return nil + }) +} + +func (h *handler) removeExpiredRequests(ctx context.Context) { + h.mu.RLock() + var expiredRequests []*activeRequest + now := h.clock.Now() + for _, userRequest := range h.activeRequests { + if now.Sub(userRequest.createdAt) > h.requestTimeout { + expiredRequests = append(expiredRequests, userRequest) + } + } + h.mu.RUnlock() + + for _, er := range expiredRequests { + var nodeResponses string + for nodeKey, nodeResponse := range er.responses { + nodeResponses += fmt.Sprintf("%s ---::: %v ", nodeKey, nodeResponse) + } + err := h.sendResponse(ctx, er, h.errorResponse(er.req, api.RequestTimeoutError, errors.New("request expired without getting quorum of responses from nodes. Available responses: "+nodeResponses), []byte(nodeResponses))) + if err != nil { + h.lggr.Errorw("error sending response to user", "requestID", er.req.ID, "error", err) + } + } +} + +func (h *handler) Methods() []string { + return []string{MethodSecretsGet, MethodCapabilityExec} +} + +func (h *handler) HandleLegacyUserMessage(_ context.Context, _ *api.Message, _ gwhandlers.Callback) error { + return errors.New("confidential relay handler does not support legacy messages") +} + +func (h *handler) HandleJSONRPCUserMessage(ctx context.Context, req jsonrpc.Request[json.RawMessage], callback gwhandlers.Callback) error { + if req.ID == "" { + return errors.New("request ID cannot be empty") + } + if len(req.ID) > 200 { + return errors.New("request ID is too long: " + strconv.Itoa(len(req.ID)) + ". max is 200 characters") + } + + l := logger.With(h.lggr, "method", req.Method, "requestID", req.ID) + l.Debugw("handling confidential relay request") + + ar, err := h.newActiveRequest(req, callback) + if err != nil { + return err + } + + return h.fanOutToNodes(ctx, l, ar) +} + +func (h *handler) newActiveRequest(req jsonrpc.Request[json.RawMessage], callback gwhandlers.Callback) (*activeRequest, error) { + h.mu.Lock() + defer h.mu.Unlock() + if h.activeRequests[req.ID] != nil { + h.lggr.Errorw("request id already exists", "requestID", req.ID) + return nil, errors.New("request ID already exists: " + req.ID) + } + ar := &activeRequest{ + Callback: callback, + req: req, + createdAt: h.clock.Now(), + responses: map[string]*jsonrpc.Response[json.RawMessage]{}, + } + h.activeRequests[req.ID] = ar + return ar, nil +} + +func (h *handler) getActiveRequest(requestID string) *activeRequest { + h.mu.RLock() + defer h.mu.RUnlock() + return h.activeRequests[requestID] +} + +func (h *handler) HandleNodeMessage(ctx context.Context, resp *jsonrpc.Response[json.RawMessage], nodeAddr string) error { + l := logger.With(h.lggr, "method", resp.Method, "requestID", resp.ID, "nodeAddr", nodeAddr) + l.Debugw("handling node response") + + if !h.nodeRateLimiter.Allow(nodeAddr) { + l.Debugw("node is rate limited", "nodeAddr", nodeAddr) + return nil + } + + ar := h.getActiveRequest(resp.ID) + if ar == nil { + l.Debugw("no pending request found for ID") + return nil + } + + ok := ar.addResponseForNode(nodeAddr, resp) + if !ok { + l.Errorw("duplicate response from node, ignoring", "nodeAddr", nodeAddr) + return nil + } + + copiedResponses := ar.copiedResponses() + aggregatedResp, err := h.aggregator.Aggregate(copiedResponses, h.donConfig.F, len(h.donConfig.Members), l) + switch { + case errors.Is(err, errInsufficientResponsesForQuorum): + l.Debugw("aggregating responses, waiting for other nodes...", "error", err) + return nil + case err != nil: + l.Error("quorum unobtainable, returning response to user...", "error", err, "responses", maps.Values(ar.responses)) + return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.FatalError, err, nil)) + } + + return h.sendSuccessResponse(ctx, l, ar, aggregatedResp) +} + +func (h *handler) fanOutToNodes(ctx context.Context, l logger.Logger, ar *activeRequest) error { + var nodeErrors []error + for _, node := range h.donConfig.Members { + err := h.don.SendToNode(ctx, node.Address, &ar.req) + if err != nil { + nodeErrors = append(nodeErrors, err) + l.Errorw("error sending request to node", "node", node.Address, "error", err) + } + } + + if len(nodeErrors) == len(h.donConfig.Members) && len(nodeErrors) > 0 { + return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.FatalError, errors.New("failed to forward user request to nodes"), nil)) + } + + l.Debugw("successfully forwarded request to relay nodes") + return nil +} + +func (h *handler) sendSuccessResponse(ctx context.Context, l logger.Logger, ar *activeRequest, resp *jsonrpc.Response[json.RawMessage]) error { + rawResponse, err := jsonrpc.EncodeResponse(resp) + if err != nil { + l.Errorw("failed to encode response", "error", err) + return h.sendResponse(ctx, ar, h.errorResponse(ar.req, api.NodeReponseEncodingError, fmt.Errorf("failed to marshal response: %w", err), nil)) + } + + var errorCode api.ErrorCode + if resp.Error != nil { + errorCode = api.FromJSONRPCErrorCode(resp.Error.Code) + } else { + errorCode = api.NoError + } + + l.Debugw("issued user callback", "errorCode", errorCode) + successResp := gwhandlers.UserCallbackPayload{ + RawResponse: rawResponse, + ErrorCode: errorCode, + } + return h.sendResponse(ctx, ar, successResp) +} + +func (h *handler) errorResponse( + req jsonrpc.Request[json.RawMessage], + errorCode api.ErrorCode, + err error, + data []byte, +) gwhandlers.UserCallbackPayload { + switch errorCode { + case api.FatalError: + case api.NodeReponseEncodingError: + h.lggr.Errorw(err.Error(), "requestID", req.ID) + err = errors.New(errorCode.String()) + case api.InvalidParamsError: + h.lggr.Errorw("invalid params", "requestID", req.ID, "params", string(*req.Params)) + err = errors.New("invalid params error: " + err.Error()) + case api.UnsupportedMethodError: + h.lggr.Errorw("unsupported method", "requestID", req.ID, "method", req.Method, "error", err.Error()) + err = errors.New("unsupported method(" + req.Method + "): " + err.Error()) + case api.UserMessageParseError: + h.lggr.Errorw("user message parse error", "requestID", req.ID, "error", err.Error()) + err = errors.New("user message parse error: " + err.Error()) + case api.NoError: + case api.UnsupportedDONIdError: + case api.HandlerError: + case api.RequestTimeoutError: + case api.StaleNodeResponseError: + } + + return gwhandlers.UserCallbackPayload{ + RawResponse: h.codec.EncodeNewErrorResponse( + req.ID, + api.ToJSONRPCErrorCode(errorCode), + err.Error(), + data, + ), + ErrorCode: errorCode, + } +} + +func (h *handler) sendResponse(ctx context.Context, userRequest *activeRequest, resp gwhandlers.UserCallbackPayload) error { + switch resp.ErrorCode { + case api.StaleNodeResponseError: + case api.FatalError: + case api.NodeReponseEncodingError: + case api.RequestTimeoutError: + case api.HandlerError: + h.metrics.requestInternalError.Add(ctx, 1, metric.WithAttributes( + attribute.String("don_id", h.donConfig.DonId), + attribute.String("error", resp.ErrorCode.String()), + )) + case api.InvalidParamsError: + case api.UnsupportedMethodError: + case api.UserMessageParseError: + case api.UnsupportedDONIdError: + h.metrics.requestUserError.Add(ctx, 1, metric.WithAttributes( + attribute.String("don_id", h.donConfig.DonId), + )) + case api.NoError: + h.metrics.requestSuccess.Add(ctx, 1, metric.WithAttributes( + attribute.String("don_id", h.donConfig.DonId), + )) + } + + err := userRequest.SendResponse(resp) + if err != nil { + h.lggr.Errorw("error sending response to user", "requestID", userRequest.req.ID, "error", err) + return err + } + + h.mu.Lock() + defer h.mu.Unlock() + delete(h.activeRequests, userRequest.req.ID) + h.lggr.Debugw("response sent to user", "requestID", userRequest.req.ID, "errorCode", resp.ErrorCode) + return nil +} diff --git a/core/services/gateway/handlers/confidentialrelay/handler_test.go b/core/services/gateway/handlers/confidentialrelay/handler_test.go new file mode 100644 index 00000000000..1434b23f1f9 --- /dev/null +++ b/core/services/gateway/handlers/confidentialrelay/handler_test.go @@ -0,0 +1,533 @@ +package confidentialrelay + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + jsonrpc "github.com/smartcontractkit/chainlink-common/pkg/jsonrpc2" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/ratelimit" + + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/api" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/config" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/common" + "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/mocks" +) + +var nodeOne = config.NodeConfig{ + Name: "node1", + Address: "0x1234", +} + +func setupHandler(t *testing.T, numNodes int) (*handler, *common.Callback, *mocks.DON, *clockwork.FakeClock) { + t.Helper() + lggr := logger.Test(t) + don := mocks.NewDON(t) + + members := make([]config.NodeConfig, numNodes) + for i := range numNodes { + members[i] = config.NodeConfig{ + Name: fmt.Sprintf("node%d", i), + Address: fmt.Sprintf("0x%04d", i), + } + } + + donConfig := &config.DONConfig{ + DonId: "test_relay_don", + F: 1, + Members: members, + } + handlerConfig := Config{ + RequestTimeoutSec: 30, + NodeRateLimiter: ratelimit.RateLimiterConfig{ + GlobalRPS: 100, + GlobalBurst: 100, + PerSenderRPS: 10, + PerSenderBurst: 10, + }, + } + methodConfig, err := json.Marshal(handlerConfig) + require.NoError(t, err) + + clock := clockwork.NewFakeClock() + h, err := NewHandler(methodConfig, donConfig, don, lggr, clock) + require.NoError(t, err) + h.aggregator = &mockAggregator{} + cb := common.NewCallback() + return h, cb, don, clock +} + +type mockAggregator struct { + err error +} + +func (m *mockAggregator) Aggregate(_ map[string]jsonrpc.Response[json.RawMessage], _ int, _ int, _ logger.Logger) (*jsonrpc.Response[json.RawMessage], error) { + return nil, m.err +} + +type respondingMockAggregator struct{} + +func (m *respondingMockAggregator) Aggregate(resps map[string]jsonrpc.Response[json.RawMessage], _ int, _ int, _ logger.Logger) (*jsonrpc.Response[json.RawMessage], error) { + if len(resps) == 0 { + return nil, errInsufficientResponsesForQuorum + } + // Return the first response we find. + for _, r := range resps { + return &r, nil + } + return nil, errInsufficientResponsesForQuorum +} + +func TestConfidentialRelayHandler_Methods(t *testing.T) { + h, _, _, _ := setupHandler(t, 4) + methods := h.Methods() + assert.Equal(t, []string{MethodSecretsGet, MethodCapabilityExec}, methods) +} + +func TestConfidentialRelayHandler_HandleLegacyUserMessage(t *testing.T) { + h, cb, _, _ := setupHandler(t, 4) + err := h.HandleLegacyUserMessage(t.Context(), nil, cb) + require.ErrorContains(t, err, "confidential relay handler does not support legacy messages") +} + +func TestConfidentialRelayHandler_RequestIDTooLong(t *testing.T) { + h, cb, _, _ := setupHandler(t, 4) + + longID := strings.Repeat("x", 201) + req := jsonrpc.Request[json.RawMessage]{ + ID: longID, + Method: MethodCapabilityExec, + } + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + expected := fmt.Sprintf("request ID is too long: %d. max is 200 characters", len(longID)) + require.EqualError(t, err, expected) +} + +func TestConfidentialRelayHandler_EmptyRequestID(t *testing.T) { + h, cb, _, _ := setupHandler(t, 4) + + req := jsonrpc.Request[json.RawMessage]{ + ID: "", + Method: MethodCapabilityExec, + } + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.EqualError(t, err, "request ID cannot be empty") +} + +func TestConfidentialRelayHandler_FanOutAndQuorumSuccess(t *testing.T) { + h, cb, don, _ := setupHandler(t, 4) + h.aggregator = &respondingMockAggregator{} + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + params := json.RawMessage(`{"workflow_id":"wf1","secrets":[{"key":"k","namespace":"ns"}],"enclave_public_key":"pk"}`) + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-1", + Method: MethodCapabilityExec, + Params: ¶ms, + } + + resultData := json.RawMessage(`{"secrets":[],"master_public_key":"mpk","threshold":1}`) + response := jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: "req-1", + Method: MethodCapabilityExec, + Result: &resultData, + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + resp, err := cb.Wait(t.Context()) + assert.NoError(t, err) + assert.Equal(t, api.NoError, resp.ErrorCode) + var jsonResp jsonrpc.Response[json.RawMessage] + err = json.Unmarshal(resp.RawResponse, &jsonResp) + assert.NoError(t, err) + assert.Equal(t, "req-1", jsonResp.ID) + }() + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.NoError(t, err) + + err = h.HandleNodeMessage(t.Context(), &response, "0x0000") + require.NoError(t, err) + wg.Wait() +} + +func TestConfidentialRelayHandler_QuorumWithRealAggregator(t *testing.T) { + h, cb, don, _ := setupHandler(t, 4) + // Use the real aggregator; DON F=1 so quorum = 2*1+1 = 3 + h.aggregator = &aggregator{} + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + params := json.RawMessage(`{"workflow_id":"wf1"}`) + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-quorum", + Method: MethodCapabilityExec, + Params: ¶ms, + } + + resultData := json.RawMessage(`{"payload":"result"}`) + makeResp := func() *jsonrpc.Response[json.RawMessage] { + rd := make(json.RawMessage, len(resultData)) + copy(rd, resultData) + return &jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: "req-quorum", + Method: MethodCapabilityExec, + Result: &rd, + } + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + resp, err := cb.Wait(t.Context()) + assert.NoError(t, err) + assert.Equal(t, api.NoError, resp.ErrorCode) + }() + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.NoError(t, err) + + // Send 3 matching responses (2F+1 = 3) + for i := range 3 { + err = h.HandleNodeMessage(t.Context(), makeResp(), fmt.Sprintf("0x%04d", i)) + require.NoError(t, err) + } + wg.Wait() +} + +func TestConfidentialRelayHandler_QuorumWithDivergentResponses(t *testing.T) { + h, cb, don, _ := setupHandler(t, 4) + h.aggregator = &aggregator{} + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + params := json.RawMessage(`{"workflow_id":"wf1"}`) + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-diverge", + Method: MethodCapabilityExec, + Params: ¶ms, + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + resp, err := cb.Wait(t.Context()) + assert.NoError(t, err) + assert.Equal(t, api.NoError, resp.ErrorCode) + }() + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.NoError(t, err) + + // One divergent response + divergentResult := json.RawMessage(`{"secrets":[],"master_public_key":"DIFFERENT","threshold":1}`) + divergentResp := &jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: "req-diverge", + Method: MethodCapabilityExec, + Result: &divergentResult, + } + err = h.HandleNodeMessage(t.Context(), divergentResp, "0x0000") + require.NoError(t, err) + + // Three matching responses (quorum = 3) + matchingResult := json.RawMessage(`{"secrets":[],"master_public_key":"mpk","threshold":1}`) + for i := 1; i <= 3; i++ { + rd := make(json.RawMessage, len(matchingResult)) + copy(rd, matchingResult) + resp := &jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: "req-diverge", + Method: MethodCapabilityExec, + Result: &rd, + } + err = h.HandleNodeMessage(t.Context(), resp, fmt.Sprintf("0x%04d", i)) + require.NoError(t, err) + } + wg.Wait() +} + +func TestConfidentialRelayHandler_QuorumUnobtainable(t *testing.T) { + h, cb, don, _ := setupHandler(t, 4) + h.aggregator = &mockAggregator{err: errQuorumUnobtainable} + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + params := json.RawMessage(`{"workflow_id":"wf1"}`) + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-unobtainable", + Method: MethodCapabilityExec, + Params: ¶ms, + } + + response := jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: "req-unobtainable", + Method: MethodCapabilityExec, + Error: &jsonrpc.WireError{ + Code: -32603, + Message: errQuorumUnobtainable.Error(), + }, + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + resp, err := cb.Wait(t.Context()) + assert.NoError(t, err) + var jsonResp jsonrpc.Response[json.RawMessage] + err = json.Unmarshal(resp.RawResponse, &jsonResp) + assert.NoError(t, err) + assert.Equal(t, "req-unobtainable", jsonResp.ID) + assert.NotNil(t, jsonResp.Error) + assert.Contains(t, jsonResp.Error.Message, "quorum unobtainable") + }() + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.NoError(t, err) + + err = h.HandleNodeMessage(t.Context(), &response, "0x0000") + require.NoError(t, err) + wg.Wait() +} + +func TestConfidentialRelayHandler_RequestTimeout(t *testing.T) { + h, cb, don, clock := setupHandler(t, 4) + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) + // Use the real aggregator so responses are not immediately satisfied + h.aggregator = &aggregator{} + + params := json.RawMessage(`{"workflow_id":"wf1"}`) + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-timeout", + Method: MethodCapabilityExec, + Params: ¶ms, + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + resp, err := cb.Wait(t.Context()) + assert.NoError(t, err) + assert.Equal(t, api.RequestTimeoutError, resp.ErrorCode) + }() + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.NoError(t, err) + + // Advance clock past the request timeout and trigger cleanup + clock.Advance(31 * time.Second) + h.removeExpiredRequests(t.Context()) + wg.Wait() +} + +func TestConfidentialRelayHandler_DuplicateRequestID(t *testing.T) { + h, cb, don, _ := setupHandler(t, 4) + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + params := json.RawMessage(`{"workflow_id":"wf1"}`) + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-dup", + Method: MethodCapabilityExec, + Params: ¶ms, + } + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.NoError(t, err) + + cb2 := common.NewCallback() + err = h.HandleJSONRPCUserMessage(t.Context(), req, cb2) + require.ErrorContains(t, err, "request ID already exists") +} + +func TestConfidentialRelayHandler_RateLimitedNode(t *testing.T) { + handlerConfig := Config{ + RequestTimeoutSec: 30, + NodeRateLimiter: ratelimit.RateLimiterConfig{ + GlobalRPS: 100, + GlobalBurst: 100, + PerSenderRPS: 0.001, // Effectively zero + PerSenderBurst: 1, + }, + } + methodConfig, err := json.Marshal(handlerConfig) + require.NoError(t, err) + + lggr := logger.Test(t) + don := mocks.NewDON(t) + donConfig := &config.DONConfig{ + DonId: "test_relay_don", + F: 1, + Members: []config.NodeConfig{nodeOne}, + } + clock := clockwork.NewFakeClock() + h, err := NewHandler(methodConfig, donConfig, don, lggr, clock) + require.NoError(t, err) + h.aggregator = &respondingMockAggregator{} + + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + cb := common.NewCallback() + params := json.RawMessage(`{"workflow_id":"wf1"}`) + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-ratelimit", + Method: MethodCapabilityExec, + Params: ¶ms, + } + + err = h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.NoError(t, err) + + resultData := json.RawMessage(`{"secrets":[]}`) + response := jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: "req-ratelimit", + Method: MethodCapabilityExec, + Result: &resultData, + } + + // First response from node uses the burst allowance + err = h.HandleNodeMessage(t.Context(), &response, nodeOne.Address) + require.NoError(t, err) + + // Verify callback was called + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer cancel() + resp, err := cb.Wait(ctx) + require.NoError(t, err) + assert.Equal(t, api.NoError, resp.ErrorCode) + + // Start a new request + cb2 := common.NewCallback() + req2 := jsonrpc.Request[json.RawMessage]{ + ID: "req-ratelimit-2", + Method: MethodCapabilityExec, + Params: ¶ms, + } + err = h.HandleJSONRPCUserMessage(t.Context(), req2, cb2) + require.NoError(t, err) + + response2 := jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: "req-ratelimit-2", + Method: MethodCapabilityExec, + Result: &resultData, + } + + // Second response should be rate limited (silently dropped) + err = h.HandleNodeMessage(t.Context(), &response2, nodeOne.Address) + require.NoError(t, err) + + // Callback should NOT be called - verify with timeout + ctx2, cancel2 := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel2() + _, err = cb2.Wait(ctx2) + require.Error(t, err) // Should timeout +} + +func TestConfidentialRelayHandler_LateNodeResponse(t *testing.T) { + h, cb, _, _ := setupHandler(t, 4) + + resultData := json.RawMessage(`{"secrets":[]}`) + staleResponse := jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: "nonexistent-request", + Method: MethodCapabilityExec, + Result: &resultData, + } + + // This should not error, just silently ignore + err := h.HandleNodeMessage(t.Context(), &staleResponse, "0x0000") + require.NoError(t, err) + + // Verify callback was not triggered + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Millisecond) + defer cancel() + _, err = cb.Wait(ctx) + require.Error(t, err) +} + +func TestConfidentialRelayHandler_AllNodesFanOutFail(t *testing.T) { + h, cb, don, _ := setupHandler(t, 4) + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("connection refused")) + + params := json.RawMessage(`{"workflow_id":"wf1"}`) + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-allfail", + Method: MethodCapabilityExec, + Params: ¶ms, + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + resp, err := cb.Wait(t.Context()) + assert.NoError(t, err) + assert.Equal(t, api.FatalError, resp.ErrorCode) + var jsonResp jsonrpc.Response[json.RawMessage] + err = json.Unmarshal(resp.RawResponse, &jsonResp) + assert.NoError(t, err) + assert.Contains(t, jsonResp.Error.Message, "failed to forward user request to nodes") + }() + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.NoError(t, err) + wg.Wait() +} + +func TestConfidentialRelayHandler_CapabilityExecMethod(t *testing.T) { + h, cb, don, _ := setupHandler(t, 4) + h.aggregator = &respondingMockAggregator{} + don.On("SendToNode", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + params := json.RawMessage(`{"workflow_id":"wf1","capability_id":"cap1","payload":"data"}`) + req := jsonrpc.Request[json.RawMessage]{ + ID: "req-cap", + Method: MethodCapabilityExec, + Params: ¶ms, + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + resp, err := cb.Wait(t.Context()) + assert.NoError(t, err) + assert.Equal(t, api.NoError, resp.ErrorCode) + }() + + err := h.HandleJSONRPCUserMessage(t.Context(), req, cb) + require.NoError(t, err) + + resultData := json.RawMessage(`{"payload":"result"}`) + response := jsonrpc.Response[json.RawMessage]{ + Version: jsonrpc.JsonRpcVersion, + ID: "req-cap", + Method: MethodCapabilityExec, + Result: &resultData, + } + err = h.HandleNodeMessage(t.Context(), &response, "0x0000") + require.NoError(t, err) + wg.Wait() + don.AssertCalled(t, "SendToNode", mock.Anything, mock.Anything, mock.Anything) +} diff --git a/core/services/standardcapabilities/conversions/conversions.go b/core/services/standardcapabilities/conversions/conversions.go index 3f2130a7b19..1335d38a0d6 100644 --- a/core/services/standardcapabilities/conversions/conversions.go +++ b/core/services/standardcapabilities/conversions/conversions.go @@ -33,6 +33,8 @@ func GetCapabilityIDFromCommand(command string, config string) string { return "http-trigger@1.0.0-alpha" case "http_action": return "http-actions@1.0.0-alpha" // plural "actions" + case "mock": + return "mock@1.0.0" default: return "" } @@ -52,6 +54,8 @@ func GetCommandFromCapabilityID(capabilityID string) string { return "http_trigger" case strings.HasPrefix(capabilityID, "http-actions"): return "http_action" + case strings.HasPrefix(capabilityID, "mock"): + return "mock" default: return "" } diff --git a/deployment/cre/jobs/pkg/gateway_job.go b/deployment/cre/jobs/pkg/gateway_job.go index 2315f13afa7..df52381ac6a 100644 --- a/deployment/cre/jobs/pkg/gateway_job.go +++ b/deployment/cre/jobs/pkg/gateway_job.go @@ -14,9 +14,11 @@ const ( GatewayHandlerTypeWebAPICapabilities = "web-api-capabilities" GatewayHandlerTypeHTTPCapabilities = "http-capabilities" GatewayHandlerTypeVault = "vault" + GatewayHandlerTypeConfidentialRelay = "confidential-compute-relay" - ServiceNameWorkflows = "workflows" - ServiceNameVault = "vault" + ServiceNameWorkflows = "workflows" + ServiceNameVault = "vault" + ServiceNameConfidential = "confidential" minimumRequestTimeoutSec = 5 ) @@ -28,6 +30,8 @@ func HandlerServiceName(handlerType string) string { return ServiceNameVault case GatewayHandlerTypeHTTPCapabilities, GatewayHandlerTypeWebAPICapabilities: return ServiceNameWorkflows + case GatewayHandlerTypeConfidentialRelay: + return ServiceNameConfidential default: return handlerType } @@ -226,6 +230,8 @@ func (g GatewayJob) buildLegacyDons() ([]legacyDON, error) { hs = append(hs, newDefaultVaultHandler(g.RequestTimeoutSec)) case GatewayHandlerTypeHTTPCapabilities: hs = append(hs, newDefaultHTTPCapabilitiesHandler()) + case GatewayHandlerTypeConfidentialRelay: + hs = append(hs, newDefaultConfidentialRelayHandler()) default: return nil, errors.New("unknown handler type: " + ht) } @@ -266,6 +272,8 @@ func (g GatewayJob) buildServicesAndShardedDONs() ([]shardedDON, []service, erro handlers = append(handlers, newDefaultVaultHandler(g.RequestTimeoutSec)) case GatewayHandlerTypeHTTPCapabilities: handlers = append(handlers, newDefaultHTTPCapabilitiesHandler()) + case GatewayHandlerTypeConfidentialRelay: + handlers = append(handlers, newDefaultConfidentialRelayHandler()) default: return nil, nil, errors.New("unknown handler type: " + ht) } @@ -437,3 +445,22 @@ func newDefaultHTTPCapabilitiesHandler() handler { }, } } + +type confidentialRelayHandlerConfig struct { + NodeRateLimiter nodeRateLimiterConfig `toml:"NodeRateLimiter"` +} + +func newDefaultConfidentialRelayHandler() handler { + return handler{ + Name: GatewayHandlerTypeConfidentialRelay, + ServiceName: "confidential", + Config: confidentialRelayHandlerConfig{ + NodeRateLimiter: nodeRateLimiterConfig{ + GlobalBurst: 10, + GlobalRPS: 50, + PerSenderBurst: 10, + PerSenderRPS: 10, + }, + }, + } +} diff --git a/go.mod b/go.mod index 6263e463898..21bdc2c2fe7 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( github.com/esote/minmaxheap v1.0.0 github.com/ethereum/go-ethereum v1.17.1 github.com/fatih/color v1.18.0 - github.com/fxamacker/cbor/v2 v2.7.0 + github.com/fxamacker/cbor/v2 v2.9.0 github.com/gagliardetto/binary v0.8.0 github.com/gagliardetto/solana-go v1.13.0 github.com/getsentry/sentry-go v0.27.0 @@ -88,6 +88,7 @@ require ( github.com/smartcontractkit/chainlink-common v0.10.1-0.20260309085605-12d6180b51ff github.com/smartcontractkit/chainlink-common/keystore v1.0.2 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 + github.com/smartcontractkit/chainlink-common/pkg/teeattestation v0.0.0-20260316172927-2c727f64397c github.com/smartcontractkit/chainlink-data-streams v0.1.12-0.20260227110503-42b236799872 github.com/smartcontractkit/chainlink-evm v0.3.4-0.20260309171438-f10976da0b9b github.com/smartcontractkit/chainlink-evm/contracts/cre/gobindings v0.0.0-20260107191744-4b93f62cffe3 @@ -290,6 +291,7 @@ require ( github.com/hashicorp/yamux v0.1.2 // indirect github.com/hasura/go-graphql-client v0.15.1 // indirect github.com/hdevalence/ed25519consensus v0.2.0 // indirect + github.com/hf/nitrite v0.0.0-20241225144000-c2d5d3c4f303 // indirect github.com/holiman/billy v0.0.0-20250707135307-f2f9b9aae7db // indirect github.com/holiman/bloomfilter/v2 v2.0.3 // indirect github.com/huin/goupnp v1.3.0 // indirect diff --git a/go.sum b/go.sum index 122a54e0a02..7c19b721d1b 100644 --- a/go.sum +++ b/go.sum @@ -403,8 +403,9 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4 github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= -github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.2.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= +github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/gabriel-vasile/mimetype v1.4.10 h1:zyueNbySn/z8mJZHLt6IPw0KoZsiQNszIpU+bX4+ZK0= github.com/gabriel-vasile/mimetype v1.4.10/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gagliardetto/anchor-go v1.0.0 h1:YNt9I/9NOrNzz5uuzfzByAcbp39Ft07w63iPqC/wi34= @@ -708,6 +709,8 @@ github.com/hasura/go-graphql-client v0.15.1 h1:mCb5I+8Bk3FU3GKWvf/zDXkTh7FbGlqJm github.com/hasura/go-graphql-client v0.15.1/go.mod h1:jfSZtBER3or+88Q9vFhWHiFMPppfYILRyl+0zsgPIIw= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= github.com/hdevalence/ed25519consensus v0.2.0/go.mod h1:w3BHWjwJbFU29IRHL1Iqkw3sus+7FctEyM4RqDxYNzo= +github.com/hf/nitrite v0.0.0-20241225144000-c2d5d3c4f303 h1:XBSq4rXFUgD8ic6Mr7dBwJN/47yg87XpZQhiknfr4Cg= +github.com/hf/nitrite v0.0.0-20241225144000-c2d5d3c4f303/go.mod h1:ycRhVmo6wegyEl6WN+zXOHUTJvB0J2tiuH88q/McTK8= github.com/holiman/billy v0.0.0-20250707135307-f2f9b9aae7db h1:IZUYC/xb3giYwBLMnr8d0TGTzPKFGNTCGgGLoyeX330= github.com/holiman/billy v0.0.0-20250707135307-f2f9b9aae7db/go.mod h1:xTEYN9KCHxuYHs+NmrmzFcnvHMzLLNiGFafCb1n3Mfg= github.com/holiman/bloomfilter/v2 v2.0.3 h1:73e0e/V0tCydx14a0SCYS/EWCxgwLZ18CZcZKVu0fao= @@ -1191,6 +1194,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-common/pkg/monitoring v0.0.0-20251215152504-b1e41f508340 h1:PsjEI+5jZIz9AS4eOsLS5VpSWJINf38clXV3wryPyMk= github.com/smartcontractkit/chainlink-common/pkg/monitoring v0.0.0-20251215152504-b1e41f508340/go.mod h1:P/0OSXUlFaxxD4B/P6HWbxYtIRmmWGDJAvanq19879c= +github.com/smartcontractkit/chainlink-common/pkg/teeattestation v0.0.0-20260316172927-2c727f64397c h1:0Ciqup1r9884ZttfN4uS7NNuqqJodkl3WijIaQIF++E= +github.com/smartcontractkit/chainlink-common/pkg/teeattestation v0.0.0-20260316172927-2c727f64397c/go.mod h1:+X7Cb8ysHfNnPd74htGIInzZMHfUrpdDrjlr6+VW0gU= github.com/smartcontractkit/chainlink-data-streams v0.1.12-0.20260227110503-42b236799872 h1:/nhvP6cBqGLrf4JwA/1FHLxnJjFhKRP6xtXAPcpE8g0= github.com/smartcontractkit/chainlink-data-streams v0.1.12-0.20260227110503-42b236799872/go.mod h1:5jROIH/4cgHBQn875A+E2DCqvkBtrkHs+ciedcGTjNI= github.com/smartcontractkit/chainlink-evm v0.3.4-0.20260309171438-f10976da0b9b h1:afL24qwToT98lOIMO+e5nu7ndSWVPkPkvjL9/jwi66c= diff --git a/plugins/plugins.private.yaml b/plugins/plugins.private.yaml index 319564597ee..76fe21ca95a 100644 --- a/plugins/plugins.private.yaml +++ b/plugins/plugins.private.yaml @@ -52,5 +52,9 @@ plugins: installPath: "." confidential-http: - moduleURI: "github.com/smartcontractkit/confidential-compute/enclave/apps/confidential-http/capability" - gitRef: "ed10df3862dc8c70d85ee46f123138a87e7c7ed4" + gitRef: "1efd81acd9949fc79d14a46b1b55faa74fcb436e" installPath: "./cmd/confidential-http" + confidential-workflows: + - moduleURI: "github.com/smartcontractkit/confidential-compute/enclave/apps/confidential-workflows/capability" + gitRef: "1efd81acd9949fc79d14a46b1b55faa74fcb436e" + installPath: "./cmd/confidential-workflows" diff --git a/system-tests/lib/cre/features/confidential_relay/confidential_relay.go b/system-tests/lib/cre/features/confidential_relay/confidential_relay.go new file mode 100644 index 00000000000..aabb45bf24d --- /dev/null +++ b/system-tests/lib/cre/features/confidential_relay/confidential_relay.go @@ -0,0 +1,95 @@ +package confidentialrelay + +import ( + "context" + + tomlser "github.com/pelletier/go-toml/v2" + "github.com/pkg/errors" + "github.com/rs/zerolog" + + chainselectors "github.com/smartcontractkit/chain-selectors" + + corechainlink "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" + coretoml "github.com/smartcontractkit/chainlink/v2/core/config/toml" + "github.com/smartcontractkit/chainlink/deployment/cre/jobs/pkg" + "github.com/smartcontractkit/chainlink/system-tests/lib/cre" +) + +const flag = cre.ConfidentialRelayCapability + +type ConfidentialRelay struct{} + +func (o *ConfidentialRelay) Flag() cre.CapabilityFlag { + return flag +} + +func (o *ConfidentialRelay) PreEnvStartup( + ctx context.Context, + testLogger zerolog.Logger, + don *cre.DonMetadata, + topology *cre.Topology, + creEnv *cre.Environment, +) (*cre.PreEnvStartupOutput, error) { + registryChainID, chErr := chainselectors.ChainIdFromSelector(creEnv.RegistryChainSelector) + if chErr != nil { + return nil, errors.Wrapf(chErr, "failed to get chain ID from selector %d", creEnv.RegistryChainSelector) + } + + hErr := topology.AddGatewayHandlers(*don, []string{pkg.GatewayHandlerTypeConfidentialRelay}) + if hErr != nil { + return nil, errors.Wrapf(hErr, "failed to add gateway handlers to gateway config for don %s", don.Name) + } + + cErr := don.ConfigureForGatewayAccess(registryChainID, *topology.GatewayConnectors) + if cErr != nil { + return nil, errors.Wrapf(cErr, "failed to add gateway connectors to node's TOML config for don %s", don.Name) + } + + // Set TOML config to activate the confidential relay handler on DON nodes. + capConfig, ok := don.CapabilityConfigs[flag] + if ok && capConfig.Values != nil { + ns := don.MustNodeSet() + for i := range ns.NodeSpecs { + currentConfig := ns.NodeSpecs[i].Node.TestConfigOverrides + var typedConfig corechainlink.Config + if currentConfig != "" { + if err := tomlser.Unmarshal([]byte(currentConfig), &typedConfig); err != nil { + return nil, errors.Wrapf(err, "failed to unmarshal node TOML config for node %d", i) + } + } + + enabled := true + relayConf := &coretoml.ConfidentialRelayConfig{Enabled: &enabled} + if v, exists := capConfig.Values["trustedPCRs"]; exists { + s := v.(string) + relayConf.TrustedPCRs = &s + } + if v, exists := capConfig.Values["caRootsPEM"]; exists { + s := v.(string) + relayConf.CARootsPEM = &s + } + typedConfig.CRE.ConfidentialRelay = relayConf + + out, err := tomlser.Marshal(typedConfig) + if err != nil { + return nil, errors.Wrapf(err, "failed to marshal node TOML config for node %d", i) + } + ns.NodeSpecs[i].Node.TestConfigOverrides = string(out) + } + } + + // No on-chain capability registration needed. The relay handler is a CRE subservice, + // not a registered capability. The mock capability that runs on the relay DON is + // registered separately via the mock flag. + return &cre.PreEnvStartupOutput{}, nil +} + +func (o *ConfidentialRelay) PostEnvStartup( + ctx context.Context, + testLogger zerolog.Logger, + don *cre.Don, + dons *cre.Dons, + creEnv *cre.Environment, +) error { + return nil +} diff --git a/system-tests/lib/cre/types.go b/system-tests/lib/cre/types.go index ee593e156e6..94392d1db51 100644 --- a/system-tests/lib/cre/types.go +++ b/system-tests/lib/cre/types.go @@ -54,22 +54,23 @@ const ( // Capabilities const ( - ConsensusCapability CapabilityFlag = "ocr3" - DONTimeCapability CapabilityFlag = "don-time" - ConsensusCapabilityV2 CapabilityFlag = "consensus" // v2 - CronCapability CapabilityFlag = "cron" - EVMCapability CapabilityFlag = "evm" - CustomComputeCapability CapabilityFlag = "custom-compute" - WriteEVMCapability CapabilityFlag = "write-evm" - ReadContractCapability CapabilityFlag = "read-contract" - LogEventTriggerCapability CapabilityFlag = "log-event-trigger" - WebAPITargetCapability CapabilityFlag = "web-api-target" - WebAPITriggerCapability CapabilityFlag = "web-api-trigger" - MockCapability CapabilityFlag = "mock" - VaultCapability CapabilityFlag = "vault" - HTTPTriggerCapability CapabilityFlag = "http-trigger" - HTTPActionCapability CapabilityFlag = "http-action" - SolanaCapability CapabilityFlag = "solana" + ConsensusCapability CapabilityFlag = "ocr3" + DONTimeCapability CapabilityFlag = "don-time" + ConsensusCapabilityV2 CapabilityFlag = "consensus" // v2 + CronCapability CapabilityFlag = "cron" + EVMCapability CapabilityFlag = "evm" + CustomComputeCapability CapabilityFlag = "custom-compute" + WriteEVMCapability CapabilityFlag = "write-evm" + ReadContractCapability CapabilityFlag = "read-contract" + LogEventTriggerCapability CapabilityFlag = "log-event-trigger" + WebAPITargetCapability CapabilityFlag = "web-api-target" + WebAPITriggerCapability CapabilityFlag = "web-api-trigger" + MockCapability CapabilityFlag = "mock" + VaultCapability CapabilityFlag = "vault" + HTTPTriggerCapability CapabilityFlag = "http-trigger" + HTTPActionCapability CapabilityFlag = "http-action" + SolanaCapability CapabilityFlag = "solana" + ConfidentialRelayCapability CapabilityFlag = "confidential-relay" // Add more capabilities as needed )