From 0661abd1526a8f3f228e5deb56fb33b039bafad5 Mon Sep 17 00:00:00 2001 From: cawthorne Date: Fri, 5 Dec 2025 15:37:23 +0000 Subject: [PATCH 01/42] Enable LLO CRE Capability --- go.mod | 2 + go.sum | 3 + .../v2/triggers/streams/generate.go | 3 + .../streams/server/trigger_server_gen.go | 147 ++++++ .../v2/triggers/streams/trigger.pb.go | 469 ++++++++++++++++++ 5 files changed, 624 insertions(+) create mode 100644 pkg/capabilities/v2/triggers/streams/generate.go create mode 100644 pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go create mode 100644 pkg/capabilities/v2/triggers/streams/trigger.pb.go diff --git a/go.mod b/go.mod index 1de7be0a22..c90e4dfaa9 100644 --- a/go.mod +++ b/go.mod @@ -155,3 +155,5 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) + +replace github.com/smartcontractkit/chainlink-protos/cre/go => ../chainlink-protos/cre/go diff --git a/go.sum b/go.sum index 5013e47ee3..4ac9b674b6 100644 --- a/go.sum +++ b/go.sum @@ -330,8 +330,11 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= +<<<<<<< Updated upstream github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20251124151448-0448aefdaab9 h1:QRWXJusIj/IRY5Pl3JclNvDre0cZPd/5NbILwc4RV2M= github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20251124151448-0448aefdaab9/go.mod h1:jUC52kZzEnWF9tddHh85zolKybmLpbQ1oNA4FjOHt1Q= +======= +>>>>>>> Stashed changes github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 h1:B7itmjy+CMJ26elVw/cAJqqhBQ3Xa/mBYWK0/rQ5MuI= diff --git a/pkg/capabilities/v2/triggers/streams/generate.go b/pkg/capabilities/v2/triggers/streams/generate.go new file mode 100644 index 0000000000..037d5a347a --- /dev/null +++ b/pkg/capabilities/v2/triggers/streams/generate.go @@ -0,0 +1,3 @@ +//go:generate go run ../../gen --pkg=github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams --file=capabilities/streams/v1/trigger.proto +package streams + diff --git a/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go b/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go new file mode 100644 index 0000000000..a6d7490790 --- /dev/null +++ b/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go @@ -0,0 +1,147 @@ +// Code generated by github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc, DO NOT EDIT. + +package server + +import ( + "context" + "fmt" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" +) + +// Avoid unused imports if there is configuration type +var _ = emptypb.Empty{} + +type StreamsCapability interface { + RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], error) + UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) error + + Start(ctx context.Context) error + Close() error + HealthReport() map[string]error + Name() string + Description() string + Ready() error + Initialise(ctx context.Context, dependencies core.StandardCapabilitiesDependencies) error +} + +func NewStreamsServer(capability StreamsCapability) *StreamsServer { + stopCh := make(chan struct{}) + return &StreamsServer{ + streamsCapability: streamsCapability{StreamsCapability: capability, stopCh: stopCh}, + stopCh: stopCh, + } +} + +type StreamsServer struct { + streamsCapability + capabilityRegistry core.CapabilitiesRegistry + stopCh chan struct{} +} + +func (c *StreamsServer) Initialise(ctx context.Context, dependencies core.StandardCapabilitiesDependencies) error { + if err := c.StreamsCapability.Initialise(ctx, dependencies); err != nil { + return fmt.Errorf("error when initializing capability: %w", err) + } + + c.capabilityRegistry = dependencies.CapabilityRegistry + + if err := dependencies.CapabilityRegistry.Add(ctx, &streamsCapability{ + StreamsCapability: c.StreamsCapability, + }); err != nil { + return fmt.Errorf("error when adding %s to the registry: %w", "streams-trigger@1.0.0", err) + } + + return nil +} + +func (c *StreamsServer) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if c.capabilityRegistry != nil { + if err := c.capabilityRegistry.Remove(ctx, "streams-trigger@1.0.0"); err != nil { + return err + } + } + + if c.stopCh != nil { + close(c.stopCh) + } + + return c.streamsCapability.Close() +} + +func (c *StreamsServer) Infos(ctx context.Context) ([]capabilities.CapabilityInfo, error) { + info, err := c.streamsCapability.Info(ctx) + if err != nil { + return nil, err + } + return []capabilities.CapabilityInfo{info}, nil +} + +type streamsCapability struct { + StreamsCapability + stopCh chan struct{} +} + +func (c *streamsCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) { + // Maybe we do need to split it out, even if the user doesn't see it + return capabilities.NewCapabilityInfo("streams-trigger@1.0.0", capabilities.CapabilityTypeCombined, c.StreamsCapability.Description()) +} + +var _ capabilities.ExecutableAndTriggerCapability = (*streamsCapability)(nil) + +const StreamsID = "streams-trigger@1.0.0" + +func (c *streamsCapability) RegisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) (<-chan capabilities.TriggerResponse, error) { + switch request.Method { + case "Trigger": + input := &streams.Config{} + return capabilities.RegisterTrigger(ctx, c.stopCh, "streams-trigger@1.0.0", request, input, c.StreamsCapability.RegisterTrigger) + case "": + input := &streams.Config{} + return capabilities.RegisterTrigger(ctx, c.stopCh, "streams-trigger@1.0.0", request, input, c.StreamsCapability.RegisterTrigger) + default: + return nil, fmt.Errorf("trigger %s not found", request.Method) + } +} + +func (c *streamsCapability) UnregisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) error { + switch request.Method { + case "Trigger": + input := &streams.Config{} + _, err := capabilities.FromValueOrAny(request.Config, request.Payload, input) + if err != nil { + return err + } + return c.StreamsCapability.UnregisterTrigger(ctx, request.TriggerID, request.Metadata, input) + case "": + input := &streams.Config{} + _, err := capabilities.FromValueOrAny(request.Config, request.Payload, input) + if err != nil { + return err + } + return c.StreamsCapability.UnregisterTrigger(ctx, request.TriggerID, request.Metadata, input) + default: + return fmt.Errorf("method %s not found", request.Method) + } +} + +func (c *streamsCapability) RegisterToWorkflow(ctx context.Context, request capabilities.RegisterToWorkflowRequest) error { + return nil +} + +func (c *streamsCapability) UnregisterFromWorkflow(ctx context.Context, request capabilities.UnregisterFromWorkflowRequest) error { + return nil +} + +func (c *streamsCapability) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { + return capabilities.CapabilityResponse{}, fmt.Errorf("method %s not found", request.Method) +} + diff --git a/pkg/capabilities/v2/triggers/streams/trigger.pb.go b/pkg/capabilities/v2/triggers/streams/trigger.pb.go new file mode 100644 index 0000000000..7ae1104d07 --- /dev/null +++ b/pkg/capabilities/v2/triggers/streams/trigger.pb.go @@ -0,0 +1,469 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.34.2 +// protoc v5.27.3 +// source: cre/capabilities/streams/v1/trigger.proto + +package streams + +import ( + _ "github.com/smartcontractkit/chainlink-protos/cre/go/tools/generator" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Configuration for the Streams Trigger +type Config struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // The IDs of the data feeds that will have their reports included in the trigger event. + // Feed IDs are hex-encoded strings (e.g., "0x000..."). + FeedIds []string `protobuf:"bytes,1,rep,name=feed_ids,json=feedIds,proto3" json:"feed_ids,omitempty"` + // The interval in milliseconds after which a new trigger event is generated. + MaxFrequencyMs uint64 `protobuf:"varint,2,opt,name=max_frequency_ms,json=maxFrequencyMs,proto3" json:"max_frequency_ms,omitempty"` +} + +func (x *Config) Reset() { + *x = Config{} + if protoimpl.UnsafeEnabled { + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetFeedIds() []string { + if x != nil { + return x.FeedIds + } + return nil +} + +func (x *Config) GetMaxFrequencyMs() uint64 { + if x != nil { + return x.MaxFrequencyMs + } + return 0 +} + +// Metadata about the signers that produced the reports +type SignersMetadata struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // The IDs of the signers + Signers []string `protobuf:"bytes,1,rep,name=signers,proto3" json:"signers,omitempty"` + // The minimum number of signatures required to validate a report + MinRequiredSignatures int64 `protobuf:"varint,2,opt,name=min_required_signatures,json=minRequiredSignatures,proto3" json:"min_required_signatures,omitempty"` +} + +func (x *SignersMetadata) Reset() { + *x = SignersMetadata{} + if protoimpl.UnsafeEnabled { + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SignersMetadata) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SignersMetadata) ProtoMessage() {} + +func (x *SignersMetadata) ProtoReflect() protoreflect.Message { + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SignersMetadata.ProtoReflect.Descriptor instead. +func (*SignersMetadata) Descriptor() ([]byte, []int) { + return file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP(), []int{1} +} + +func (x *SignersMetadata) GetSigners() []string { + if x != nil { + return x.Signers + } + return nil +} + +func (x *SignersMetadata) GetMinRequiredSignatures() int64 { + if x != nil { + return x.MinRequiredSignatures + } + return 0 +} + +// A single feed report containing data and signatures +type FeedReport struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // The ID of the data feed (hex-encoded) + FeedId string `protobuf:"bytes,1,opt,name=feed_id,json=feedId,proto3" json:"feed_id,omitempty"` + // The full report as raw bytes + FullReport []byte `protobuf:"bytes,2,opt,name=full_report,json=fullReport,proto3" json:"full_report,omitempty"` + // Report context required to validate signatures + ReportContext []byte `protobuf:"bytes,3,opt,name=report_context,json=reportContext,proto3" json:"report_context,omitempty"` + // Signatures over the full report and report context + Signatures [][]byte `protobuf:"bytes,4,rep,name=signatures,proto3" json:"signatures,omitempty"` + // The benchmark price extracted from the full report + BenchmarkPrice []byte `protobuf:"bytes,5,opt,name=benchmark_price,json=benchmarkPrice,proto3" json:"benchmark_price,omitempty"` + // Timestamp when the observation was made + ObservationTimestamp int64 `protobuf:"varint,6,opt,name=observation_timestamp,json=observationTimestamp,proto3" json:"observation_timestamp,omitempty"` +} + +func (x *FeedReport) Reset() { + *x = FeedReport{} + if protoimpl.UnsafeEnabled { + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *FeedReport) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FeedReport) ProtoMessage() {} + +func (x *FeedReport) ProtoReflect() protoreflect.Message { + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FeedReport.ProtoReflect.Descriptor instead. +func (*FeedReport) Descriptor() ([]byte, []int) { + return file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP(), []int{2} +} + +func (x *FeedReport) GetFeedId() string { + if x != nil { + return x.FeedId + } + return "" +} + +func (x *FeedReport) GetFullReport() []byte { + if x != nil { + return x.FullReport + } + return nil +} + +func (x *FeedReport) GetReportContext() []byte { + if x != nil { + return x.ReportContext + } + return nil +} + +func (x *FeedReport) GetSignatures() [][]byte { + if x != nil { + return x.Signatures + } + return nil +} + +func (x *FeedReport) GetBenchmarkPrice() []byte { + if x != nil { + return x.BenchmarkPrice + } + return nil +} + +func (x *FeedReport) GetObservationTimestamp() int64 { + if x != nil { + return x.ObservationTimestamp + } + return 0 +} + +// The payload emitted by the Streams Trigger containing feed data +type Feed struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Timestamp when the trigger event was generated + Timestamp int64 `protobuf:"varint,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + // Metadata about the signers + Metadata *SignersMetadata `protobuf:"bytes,2,opt,name=metadata,proto3" json:"metadata,omitempty"` + // Array of feed reports + Payload []*FeedReport `protobuf:"bytes,3,rep,name=payload,proto3" json:"payload,omitempty"` +} + +func (x *Feed) Reset() { + *x = Feed{} + if protoimpl.UnsafeEnabled { + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Feed) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Feed) ProtoMessage() {} + +func (x *Feed) ProtoReflect() protoreflect.Message { + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Feed.ProtoReflect.Descriptor instead. +func (*Feed) Descriptor() ([]byte, []int) { + return file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP(), []int{3} +} + +func (x *Feed) GetTimestamp() int64 { + if x != nil { + return x.Timestamp + } + return 0 +} + +func (x *Feed) GetMetadata() *SignersMetadata { + if x != nil { + return x.Metadata + } + return nil +} + +func (x *Feed) GetPayload() []*FeedReport { + if x != nil { + return x.Payload + } + return nil +} + +var File_cre_capabilities_streams_v1_trigger_proto protoreflect.FileDescriptor + +var file_cre_capabilities_streams_v1_trigger_proto_rawDesc = []byte{ + 0x0a, 0x29, 0x63, 0x72, 0x65, 0x2f, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, + 0x65, 0x73, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2f, 0x76, 0x31, 0x2f, 0x74, 0x72, + 0x69, 0x67, 0x67, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x63, 0x61, 0x70, + 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, + 0x73, 0x2e, 0x76, 0x31, 0x1a, 0x2a, 0x74, 0x6f, 0x6f, 0x6c, 0x73, 0x2f, 0x67, 0x65, 0x6e, 0x65, + 0x72, 0x61, 0x74, 0x6f, 0x72, 0x2f, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x2f, 0x63, 0x72, + 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x22, 0x4d, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x19, 0x0a, 0x08, 0x66, 0x65, + 0x65, 0x64, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x66, 0x65, + 0x65, 0x64, 0x49, 0x64, 0x73, 0x12, 0x28, 0x0a, 0x10, 0x6d, 0x61, 0x78, 0x5f, 0x66, 0x72, 0x65, + 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x5f, 0x6d, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, + 0x0e, 0x6d, 0x61, 0x78, 0x46, 0x72, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x4d, 0x73, 0x22, + 0x63, 0x0a, 0x0f, 0x53, 0x69, 0x67, 0x6e, 0x65, 0x72, 0x73, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, + 0x74, 0x61, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x72, 0x73, 0x12, 0x36, 0x0a, 0x17, + 0x6d, 0x69, 0x6e, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x5f, 0x73, 0x69, 0x67, + 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x15, 0x6d, + 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x73, 0x22, 0xeb, 0x01, 0x0a, 0x0a, 0x46, 0x65, 0x65, 0x64, 0x52, 0x65, 0x70, + 0x6f, 0x72, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x65, 0x65, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x66, 0x65, 0x65, 0x64, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x0b, + 0x66, 0x75, 0x6c, 0x6c, 0x5f, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x25, 0x0a, + 0x0e, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0d, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x6f, 0x6e, + 0x74, 0x65, 0x78, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, + 0x65, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0a, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x0f, 0x62, 0x65, 0x6e, 0x63, 0x68, 0x6d, 0x61, 0x72, + 0x6b, 0x5f, 0x70, 0x72, 0x69, 0x63, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0e, 0x62, + 0x65, 0x6e, 0x63, 0x68, 0x6d, 0x61, 0x72, 0x6b, 0x50, 0x72, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, + 0x15, 0x6f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x69, 0x6d, + 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x6f, 0x62, + 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x22, 0xa9, 0x01, 0x0a, 0x04, 0x46, 0x65, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x74, + 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, + 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x44, 0x0a, 0x08, 0x6d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x63, 0x61, + 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x65, 0x72, 0x73, 0x4d, 0x65, 0x74, + 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, + 0x3d, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x23, 0x2e, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x2e, + 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x65, 0x65, 0x64, 0x52, + 0x65, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x32, 0x75, + 0x0a, 0x07, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x12, 0x4b, 0x0a, 0x07, 0x54, 0x72, 0x69, + 0x67, 0x67, 0x65, 0x72, 0x12, 0x1f, 0x2e, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, + 0x69, 0x65, 0x73, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x1a, 0x1d, 0x2e, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, + 0x74, 0x69, 0x65, 0x73, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x76, 0x31, 0x2e, + 0x46, 0x65, 0x65, 0x64, 0x30, 0x01, 0x1a, 0x1d, 0x82, 0xb5, 0x18, 0x19, 0x08, 0x01, 0x12, 0x15, + 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2d, 0x74, 0x72, 0x69, 0x67, 0x67, 0x65, 0x72, 0x40, + 0x31, 0x2e, 0x30, 0x2e, 0x30, 0x42, 0x53, 0x5a, 0x51, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, + 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6d, 0x61, 0x72, 0x74, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x61, 0x63, + 0x74, 0x6b, 0x69, 0x74, 0x2f, 0x63, 0x68, 0x61, 0x69, 0x6e, 0x6c, 0x69, 0x6e, 0x6b, 0x2d, 0x63, + 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, + 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x2f, 0x76, 0x32, 0x2f, 0x74, 0x72, 0x69, 0x67, 0x67, 0x65, + 0x72, 0x73, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, +} + +var ( + file_cre_capabilities_streams_v1_trigger_proto_rawDescOnce sync.Once + file_cre_capabilities_streams_v1_trigger_proto_rawDescData = file_cre_capabilities_streams_v1_trigger_proto_rawDesc +) + +func file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP() []byte { + file_cre_capabilities_streams_v1_trigger_proto_rawDescOnce.Do(func() { + file_cre_capabilities_streams_v1_trigger_proto_rawDescData = protoimpl.X.CompressGZIP(file_cre_capabilities_streams_v1_trigger_proto_rawDescData) + }) + return file_cre_capabilities_streams_v1_trigger_proto_rawDescData +} + +var file_cre_capabilities_streams_v1_trigger_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_cre_capabilities_streams_v1_trigger_proto_goTypes = []any{ + (*Config)(nil), // 0: capabilities.streams.v1.Config + (*SignersMetadata)(nil), // 1: capabilities.streams.v1.SignersMetadata + (*FeedReport)(nil), // 2: capabilities.streams.v1.FeedReport + (*Feed)(nil), // 3: capabilities.streams.v1.Feed +} +var file_cre_capabilities_streams_v1_trigger_proto_depIdxs = []int32{ + 1, // 0: capabilities.streams.v1.Feed.metadata:type_name -> capabilities.streams.v1.SignersMetadata + 2, // 1: capabilities.streams.v1.Feed.payload:type_name -> capabilities.streams.v1.FeedReport + 0, // 2: capabilities.streams.v1.Streams.Trigger:input_type -> capabilities.streams.v1.Config + 3, // 3: capabilities.streams.v1.Streams.Trigger:output_type -> capabilities.streams.v1.Feed + 3, // [3:4] is the sub-list for method output_type + 2, // [2:3] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_cre_capabilities_streams_v1_trigger_proto_init() } +func file_cre_capabilities_streams_v1_trigger_proto_init() { + if File_cre_capabilities_streams_v1_trigger_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_cre_capabilities_streams_v1_trigger_proto_msgTypes[0].Exporter = func(v any, i int) any { + switch v := v.(*Config); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cre_capabilities_streams_v1_trigger_proto_msgTypes[1].Exporter = func(v any, i int) any { + switch v := v.(*SignersMetadata); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cre_capabilities_streams_v1_trigger_proto_msgTypes[2].Exporter = func(v any, i int) any { + switch v := v.(*FeedReport); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_cre_capabilities_streams_v1_trigger_proto_msgTypes[3].Exporter = func(v any, i int) any { + switch v := v.(*Feed); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_cre_capabilities_streams_v1_trigger_proto_rawDesc, + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_cre_capabilities_streams_v1_trigger_proto_goTypes, + DependencyIndexes: file_cre_capabilities_streams_v1_trigger_proto_depIdxs, + MessageInfos: file_cre_capabilities_streams_v1_trigger_proto_msgTypes, + }.Build() + File_cre_capabilities_streams_v1_trigger_proto = out.File + file_cre_capabilities_streams_v1_trigger_proto_rawDesc = nil + file_cre_capabilities_streams_v1_trigger_proto_goTypes = nil + file_cre_capabilities_streams_v1_trigger_proto_depIdxs = nil +} From b276386fef073d3c76ab0733a3d55f7369846efe Mon Sep 17 00:00:00 2001 From: cawthorne Date: Fri, 5 Dec 2025 15:40:19 +0000 Subject: [PATCH 02/42] Resolve stash conflict --- go.sum | 5 ----- 1 file changed, 5 deletions(-) diff --git a/go.sum b/go.sum index 4ac9b674b6..d59af8648d 100644 --- a/go.sum +++ b/go.sum @@ -330,11 +330,6 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -<<<<<<< Updated upstream -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20251124151448-0448aefdaab9 h1:QRWXJusIj/IRY5Pl3JclNvDre0cZPd/5NbILwc4RV2M= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20251124151448-0448aefdaab9/go.mod h1:jUC52kZzEnWF9tddHh85zolKybmLpbQ1oNA4FjOHt1Q= -======= ->>>>>>> Stashed changes github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 h1:B7itmjy+CMJ26elVw/cAJqqhBQ3Xa/mBYWK0/rQ5MuI= From 3306fd5a5d4de58ef7d63e1017b0cf7425979e3d Mon Sep 17 00:00:00 2001 From: cawthorne Date: Fri, 5 Dec 2025 18:08:39 +0000 Subject: [PATCH 03/42] Add new Streams Capability --- .../v2/triggers/streams/authorization.go | 181 ++++++++ .../v2/triggers/streams/authorization_test.go | 405 ++++++++++++++++++ .../server/authorized_capability_test.go | 263 ++++++++++++ .../streams/server/authorized_server.go | 104 +++++ .../v2/triggers/streams/streams_test.go | 386 +++++++++++++++++ 5 files changed, 1339 insertions(+) create mode 100644 pkg/capabilities/v2/triggers/streams/authorization.go create mode 100644 pkg/capabilities/v2/triggers/streams/authorization_test.go create mode 100644 pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go create mode 100644 pkg/capabilities/v2/triggers/streams/server/authorized_server.go create mode 100644 pkg/capabilities/v2/triggers/streams/streams_test.go diff --git a/pkg/capabilities/v2/triggers/streams/authorization.go b/pkg/capabilities/v2/triggers/streams/authorization.go new file mode 100644 index 0000000000..2d992c6203 --- /dev/null +++ b/pkg/capabilities/v2/triggers/streams/authorization.go @@ -0,0 +1,181 @@ +package streams + +import ( + "fmt" + "regexp" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" +) + +// Authorizer handles workflow authorization for streams trigger +// Ensures only authorized workflows (e.g., Data Feeds) can use the trigger +type Authorizer struct { + allowedWorkflowIDs map[string]bool + allowedWorkflowPattern *regexp.Regexp + allowedWorkflowOwners map[string]bool + allowedWorkflowNamePattern *regexp.Regexp + enabled bool +} + +// AuthConfig configures authorization rules for the streams trigger +type AuthConfig struct { + // Enable authorization checks (set to false to disable authorization) + Enabled bool + + // AllowedWorkflowIDs is an explicit allowlist of workflow IDs + AllowedWorkflowIDs []string + + // AllowedWorkflowPattern is a regex pattern for allowed workflow IDs + // Example: "^df-.*" allows all workflows starting with "df-" + AllowedWorkflowPattern string + + // AllowedWorkflowOwners is an explicit allowlist of workflow owner addresses + // Example: ["0xDFOwner1", "0xDFOwner2"] + AllowedWorkflowOwners []string + + // AllowedWorkflowNamePattern is a regex pattern for allowed workflow names + // Example: "^data-feed-.*" for workflow names starting with "data-feed-" + AllowedWorkflowNamePattern string +} + +// NewAuthorizer creates a new authorizer with the given configuration +func NewAuthorizer(config AuthConfig) (*Authorizer, error) { + auth := &Authorizer{ + enabled: config.Enabled, + allowedWorkflowIDs: make(map[string]bool), + allowedWorkflowOwners: make(map[string]bool), + } + + // If authorization is disabled, return early + if !config.Enabled { + return auth, nil + } + + // Build workflow ID allowlist map for O(1) lookups + for _, id := range config.AllowedWorkflowIDs { + auth.allowedWorkflowIDs[id] = true + } + + // Build workflow owner allowlist map for O(1) lookups + for _, owner := range config.AllowedWorkflowOwners { + auth.allowedWorkflowOwners[owner] = true + } + + // Compile workflow ID pattern if provided + if config.AllowedWorkflowPattern != "" { + pattern, err := regexp.Compile(config.AllowedWorkflowPattern) + if err != nil { + return nil, fmt.Errorf("invalid workflow ID pattern '%s': %w", config.AllowedWorkflowPattern, err) + } + auth.allowedWorkflowPattern = pattern + } + + // Compile workflow name pattern if provided + if config.AllowedWorkflowNamePattern != "" { + pattern, err := regexp.Compile(config.AllowedWorkflowNamePattern) + if err != nil { + return nil, fmt.Errorf("invalid workflow name pattern '%s': %w", config.AllowedWorkflowNamePattern, err) + } + auth.allowedWorkflowNamePattern = pattern + } + + return auth, nil +} + +// NewDefaultDataFeedsAuthorizer creates an authorizer for Data Feeds workflows +// This is a convenience function for the common case +// Allows workflows with IDs starting with "df-" or names containing "data-feed" +func NewDefaultDataFeedsAuthorizer() (*Authorizer, error) { + return NewAuthorizer(AuthConfig{ + Enabled: true, + AllowedWorkflowPattern: "^df-.*", // Allow workflow IDs starting with "df-" + AllowedWorkflowNamePattern: "data-feed", // Allow workflow names containing "data-feed" + }) +} + +// IsAuthorized checks if a workflow is authorized to use the streams trigger +// Returns nil if authorized, error otherwise +// Authorization checks (in order): +// 1. Explicit workflow ID allowlist +// 2. Workflow ID pattern matching +// 3. Workflow owner address allowlist +// 4. Workflow name pattern matching +// If ANY check passes, the workflow is authorized +func (a *Authorizer) IsAuthorized(metadata capabilities.RequestMetadata) error { + // If authorization is disabled, allow all + if !a.enabled { + return nil + } + + workflowID := metadata.WorkflowID + workflowName := metadata.WorkflowName + if workflowName == "" { + workflowName = metadata.DecodedWorkflowName + } + workflowOwner := metadata.WorkflowOwner + + // If no checks configured, deny by default + if len(a.allowedWorkflowIDs) == 0 && a.allowedWorkflowPattern == nil && + len(a.allowedWorkflowOwners) == 0 && a.allowedWorkflowNamePattern == nil { + return fmt.Errorf("workflow %s: no authorization checks configured, denying by default", workflowID) + } + + // Check 1: Explicit workflow ID allowlist + if len(a.allowedWorkflowIDs) > 0 { + if a.allowedWorkflowIDs[workflowID] { + return nil // Authorized + } + } + + // Check 2: Workflow ID pattern matching + if a.allowedWorkflowPattern != nil { + if a.allowedWorkflowPattern.MatchString(workflowID) { + return nil // Authorized + } + } + + // Check 3: Workflow owner allowlist + if len(a.allowedWorkflowOwners) > 0 && workflowOwner != "" { + if a.allowedWorkflowOwners[workflowOwner] { + return nil // Authorized + } + } + + // Check 4: Workflow name pattern matching + if a.allowedWorkflowNamePattern != nil && workflowName != "" { + if a.allowedWorkflowNamePattern.MatchString(workflowName) { + return nil // Authorized + } + } + + // None of the checks passed + return fmt.Errorf("workflow %s (name: %s, owner: %s) not authorized", workflowID, workflowName, workflowOwner) +} + +// String returns a human-readable description of the authorization rules +func (a *Authorizer) String() string { + if !a.enabled { + return "Authorization: Disabled (all workflows allowed)" + } + + desc := "Authorization: Enabled\n" + + if len(a.allowedWorkflowIDs) > 0 { + desc += fmt.Sprintf(" - Workflow ID allowlist: %d entries\n", len(a.allowedWorkflowIDs)) + } + + if a.allowedWorkflowPattern != nil { + desc += fmt.Sprintf(" - Workflow ID pattern: %s\n", a.allowedWorkflowPattern.String()) + } + + if len(a.allowedWorkflowOwners) > 0 { + desc += fmt.Sprintf(" - Workflow owner allowlist: %d entries\n", len(a.allowedWorkflowOwners)) + } + + if a.allowedWorkflowNamePattern != nil { + desc += fmt.Sprintf(" - Workflow name pattern: %s\n", a.allowedWorkflowNamePattern.String()) + } + + return desc +} + diff --git a/pkg/capabilities/v2/triggers/streams/authorization_test.go b/pkg/capabilities/v2/triggers/streams/authorization_test.go new file mode 100644 index 0000000000..9c5592efcb --- /dev/null +++ b/pkg/capabilities/v2/triggers/streams/authorization_test.go @@ -0,0 +1,405 @@ +package streams_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams" +) + +func TestAuthorizerDisabled(t *testing.T) { + config := streams.AuthConfig{ + Enabled: false, + } + + auth, err := streams.NewAuthorizer(config) + require.NoError(t, err) + + // Any workflow should be allowed when disabled + metadata := capabilities.RequestMetadata{ + WorkflowID: "any-workflow", + WorkflowName: "anything", + WorkflowOwner: "0xAnyOwner", + } + + err = auth.IsAuthorized(metadata) + assert.NoError(t, err, "Should allow all workflows when authorization is disabled") +} + +func TestAuthorizerWorkflowIDAllowlist(t *testing.T) { + config := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowIDs: []string{ + "df-prod-1", + "df-prod-2", + "df-staging-1", + }, + } + + auth, err := streams.NewAuthorizer(config) + require.NoError(t, err) + + tests := []struct { + name string + workflowID string + expectError bool + }{ + {"workflow in allowlist 1", "df-prod-1", false}, + {"workflow in allowlist 2", "df-prod-2", false}, + {"workflow in allowlist 3", "df-staging-1", false}, + {"workflow not in allowlist", "other-workflow", true}, + {"workflow similar but not exact", "df-prod-10", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + metadata := capabilities.RequestMetadata{ + WorkflowID: tt.workflowID, + } + err := auth.IsAuthorized(metadata) + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), "not authorized") + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestAuthorizerWorkflowIDPattern(t *testing.T) { + config := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowPattern: "^df-.*-mainnet$", + } + + auth, err := streams.NewAuthorizer(config) + require.NoError(t, err) + + tests := []struct { + name string + workflowID string + expectError bool + }{ + {"matches pattern 1", "df-btc-mainnet", false}, + {"matches pattern 2", "df-eth-mainnet", false}, + {"matches pattern 3", "df-link-usd-mainnet", false}, + {"doesn't match - no prefix", "other-mainnet", true}, + {"doesn't match - no suffix", "df-btc", true}, + {"doesn't match - wrong suffix", "df-btc-testnet", true}, + {"doesn't match - completely different", "malicious-workflow", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + metadata := capabilities.RequestMetadata{ + WorkflowID: tt.workflowID, + } + err := auth.IsAuthorized(metadata) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestAuthorizerWorkflowOwner(t *testing.T) { + config := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowOwners: []string{ + "0xDFOwner1", + "0xDFOwner2", + }, + } + + auth, err := streams.NewAuthorizer(config) + require.NoError(t, err) + + tests := []struct { + name string + workflowOwner string + expectError bool + }{ + {"owner in allowlist 1", "0xDFOwner1", false}, + {"owner in allowlist 2", "0xDFOwner2", false}, + {"owner not in allowlist", "0xOtherOwner", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + metadata := capabilities.RequestMetadata{ + WorkflowID: "some-workflow", + WorkflowOwner: tt.workflowOwner, + } + err := auth.IsAuthorized(metadata) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestAuthorizerWorkflowNamePattern(t *testing.T) { + config := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowNamePattern: "data-feed", + } + + auth, err := streams.NewAuthorizer(config) + require.NoError(t, err) + + tests := []struct { + name string + workflowName string + expectError bool + }{ + {"matches pattern 1", "data-feed-btc-usd", false}, + {"matches pattern 2", "mainnet-data-feed", false}, + {"matches pattern 3", "data-feed", false}, + {"doesn't match", "other-workflow", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + metadata := capabilities.RequestMetadata{ + WorkflowID: "some-id", + WorkflowName: tt.workflowName, + } + err := auth.IsAuthorized(metadata) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestAuthorizerInvalidPattern(t *testing.T) { + config := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowPattern: "[invalid(regex", + } + + _, err := streams.NewAuthorizer(config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid workflow ID pattern") +} + +func TestAuthorizerInvalidNamePattern(t *testing.T) { + config := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowNamePattern: "[invalid(regex", + } + + _, err := streams.NewAuthorizer(config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid workflow name pattern") +} + +func TestAuthorizerCombinedChecksAnyMatch(t *testing.T) { + // If ANY check passes, workflow is authorized + config := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowPattern: "^df-.*", + AllowedWorkflowOwners: []string{"0xDFOwner"}, + } + + auth, err := streams.NewAuthorizer(config) + require.NoError(t, err) + + tests := []struct { + name string + workflowID string + workflowOwner string + expectError bool + }{ + { + name: "matches ID pattern", + workflowID: "df-prod-1", + workflowOwner: "0xOther", + expectError: false, // Passes ID pattern check + }, + { + name: "matches owner", + workflowID: "other-workflow", + workflowOwner: "0xDFOwner", + expectError: false, // Passes owner check + }, + { + name: "matches both", + workflowID: "df-prod-1", + workflowOwner: "0xDFOwner", + expectError: false, // Passes both checks + }, + { + name: "matches neither", + workflowID: "other-workflow", + workflowOwner: "0xOther", + expectError: true, // Fails both checks + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + metadata := capabilities.RequestMetadata{ + WorkflowID: tt.workflowID, + WorkflowOwner: tt.workflowOwner, + } + err := auth.IsAuthorized(metadata) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestAuthorizerNoChecksConfigured(t *testing.T) { + // If no checks configured, deny by default + config := streams.AuthConfig{ + Enabled: true, + // No checks configured + } + + auth, err := streams.NewAuthorizer(config) + require.NoError(t, err) + + metadata := capabilities.RequestMetadata{ + WorkflowID: "any-workflow", + } + + err = auth.IsAuthorized(metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no authorization checks configured") +} + +func TestDefaultDataFeedsAuthorizer(t *testing.T) { + auth, err := streams.NewDefaultDataFeedsAuthorizer() + require.NoError(t, err) + + tests := []struct { + name string + workflowID string + workflowName string + expectError bool + }{ + { + name: "DF workflow ID", + workflowID: "df-btc-usd", + workflowName: "", + expectError: false, + }, + { + name: "DF workflow name", + workflowID: "other-id", + workflowName: "data-feed-eth-usd", + expectError: false, + }, + { + name: "Neither matches", + workflowID: "other-id", + workflowName: "other-name", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + metadata := capabilities.RequestMetadata{ + WorkflowID: tt.workflowID, + WorkflowName: tt.workflowName, + } + err := auth.IsAuthorized(metadata) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestAuthorizerString(t *testing.T) { + // Test disabled + auth1, _ := streams.NewAuthorizer(streams.AuthConfig{Enabled: false}) + str := auth1.String() + assert.Contains(t, str, "Disabled") + + // Test with ID allowlist + auth2, _ := streams.NewAuthorizer(streams.AuthConfig{ + Enabled: true, + AllowedWorkflowIDs: []string{"id1", "id2"}, + }) + str = auth2.String() + assert.Contains(t, str, "Enabled") + assert.Contains(t, str, "Workflow ID allowlist") + + // Test with pattern + auth3, _ := streams.NewAuthorizer(streams.AuthConfig{ + Enabled: true, + AllowedWorkflowPattern: "^df-.*", + }) + str = auth3.String() + assert.Contains(t, str, "pattern") + + // Test with owner allowlist + auth4, _ := streams.NewAuthorizer(streams.AuthConfig{ + Enabled: true, + AllowedWorkflowOwners: []string{"0xOwner1"}, + }) + str = auth4.String() + assert.Contains(t, str, "owner") +} + +// BenchmarkAuthorizerCheck benchmarks the authorization check +func BenchmarkAuthorizerCheck(b *testing.B) { + config := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowPattern: "^df-.*", + } + + auth, _ := streams.NewAuthorizer(config) + + metadata := capabilities.RequestMetadata{ + WorkflowID: "df-prod-1", + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = auth.IsAuthorized(metadata) + } +} + +func BenchmarkAuthorizerCheckAllowlist(b *testing.B) { + // Create large allowlist + allowlist := make([]string, 1000) + for i := 0; i < 1000; i++ { + allowlist[i] = fmt.Sprintf("df-workflow-%d", i) + } + + config := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowIDs: allowlist, + } + + auth, _ := streams.NewAuthorizer(config) + + metadata := capabilities.RequestMetadata{ + WorkflowID: "df-workflow-500", // Middle of allowlist + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = auth.IsAuthorized(metadata) + } +} diff --git a/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go b/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go new file mode 100644 index 0000000000..437566b6f0 --- /dev/null +++ b/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go @@ -0,0 +1,263 @@ +package server_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams/server" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" +) + +func TestAuthorizedCapabilityBlocksUnauthorizedWorkflows(t *testing.T) { + lggr, _ := logger.New() + mockCap := &mockStreamsCapability{} + + // Create authorized capability with DF authorization + authCap, err := server.NewDefaultDataFeedsCapability(mockCap, lggr) + require.NoError(t, err) + + // Test 1: Authorized DF workflow (by ID pattern) - should succeed + ch, err := authCap.RegisterTrigger( + context.Background(), + "trigger-1", + capabilities.RequestMetadata{ + WorkflowID: "df-btc-usd", + WorkflowOwner: "0xDF001", + WorkflowName: "Bitcoin Data Feed", + }, + &streams.Config{FeedIds: []string{"0x001"}}, + ) + assert.NoError(t, err) + assert.NotNil(t, ch) + + // Test 2: Authorized DF workflow (by name pattern) - should succeed + ch, err = authCap.RegisterTrigger( + context.Background(), + "trigger-2", + capabilities.RequestMetadata{ + WorkflowID: "workflow-123", + WorkflowOwner: "0xDF002", + WorkflowName: "mainnet-data-feed-eth", + }, + &streams.Config{FeedIds: []string{"0x002"}}, + ) + assert.NoError(t, err) + assert.NotNil(t, ch) + + // Test 3: Unauthorized workflow (doesn't match ID or name) - should fail with auth error + ch, err = authCap.RegisterTrigger( + context.Background(), + "trigger-3", + capabilities.RequestMetadata{ + WorkflowID: "other-workflow", + WorkflowOwner: "0xOTHER", + WorkflowName: "Other Workflow", + }, + &streams.Config{FeedIds: []string{"0x003"}}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "authorization failed") + assert.Nil(t, ch) +} + +func TestAuthorizedCapabilityWithCustomConfig(t *testing.T) { + lggr, _ := logger.New() + mockCap := &mockStreamsCapability{} + + // Custom authorization: specific allowlist + authConfig := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowIDs: []string{ + "df-prod-btc-usd", + "df-prod-eth-usd", + }, + } + + authCap, err := server.NewAuthorizedStreamsCapability(mockCap, authConfig, lggr) + require.NoError(t, err) + + // Test allowed workflow + ch, err := authCap.RegisterTrigger( + context.Background(), + "trigger-1", + capabilities.RequestMetadata{ + WorkflowID: "df-prod-btc-usd", + }, + &streams.Config{}, + ) + assert.NoError(t, err) + assert.NotNil(t, ch) + + // Test non-allowed workflow + ch, err = authCap.RegisterTrigger( + context.Background(), + "trigger-2", + capabilities.RequestMetadata{ + WorkflowID: "df-prod-link-usd", // Not in allowlist + }, + &streams.Config{}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "authorization failed") + assert.Nil(t, ch) +} + +func TestAuthorizedCapabilityWorkflowOwnerAllowlist(t *testing.T) { + lggr, _ := logger.New() + mockCap := &mockStreamsCapability{} + + // Authorization by owner address + authConfig := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowOwners: []string{ + "0xDFOwner1", + "0xDFOwner2", + }, + } + + authCap, err := server.NewAuthorizedStreamsCapability(mockCap, authConfig, lggr) + require.NoError(t, err) + + // Test allowed owner + ch, err := authCap.RegisterTrigger( + context.Background(), + "trigger-1", + capabilities.RequestMetadata{ + WorkflowID: "any-workflow-id", + WorkflowOwner: "0xDFOwner1", + }, + &streams.Config{}, + ) + assert.NoError(t, err) + assert.NotNil(t, ch) + + // Test non-allowed owner + ch, err = authCap.RegisterTrigger( + context.Background(), + "trigger-2", + capabilities.RequestMetadata{ + WorkflowID: "any-workflow-id", + WorkflowOwner: "0xOtherOwner", + }, + &streams.Config{}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "authorization failed") + assert.Nil(t, ch) +} + +func TestAuthorizedCapabilityDisabled(t *testing.T) { + lggr, _ := logger.New() + mockCap := &mockStreamsCapability{} + + // Disable authorization + authConfig := streams.AuthConfig{ + Enabled: false, + } + + authCap, err := server.NewAuthorizedStreamsCapability(mockCap, authConfig, lggr) + require.NoError(t, err) + + // Any workflow should be allowed + ch, err := authCap.RegisterTrigger( + context.Background(), + "trigger-1", + capabilities.RequestMetadata{ + WorkflowID: "any-workflow", + WorkflowOwner: "0xAnyone", + WorkflowName: "Anything", + }, + &streams.Config{}, + ) + assert.NoError(t, err) + assert.NotNil(t, ch) +} + +func TestAuthorizedCapabilityUnregisterAlsoChecksAuth(t *testing.T) { + lggr, _ := logger.New() + mockCap := &mockStreamsCapability{} + + // Authorization enabled + authConfig := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowPattern: "^df-.*", + } + + authCap, err := server.NewAuthorizedStreamsCapability(mockCap, authConfig, lggr) + require.NoError(t, err) + + // Test authorized unregister + err = authCap.UnregisterTrigger( + context.Background(), + "trigger-1", + capabilities.RequestMetadata{ + WorkflowID: "df-prod-btc", + }, + &streams.Config{}, + ) + assert.NoError(t, err) + + // Test unauthorized unregister + err = authCap.UnregisterTrigger( + context.Background(), + "trigger-2", + capabilities.RequestMetadata{ + WorkflowID: "other-workflow", + }, + &streams.Config{}, + ) + require.Error(t, err) + assert.Contains(t, err.Error(), "authorization failed") +} + +// Mock implementations for testing + +type mockStreamsCapability struct { + registerCalled bool +} + +func (m *mockStreamsCapability) RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], error) { + m.registerCalled = true + ch := make(chan capabilities.TriggerAndId[*streams.Feed]) + close(ch) + return ch, nil +} + +func (m *mockStreamsCapability) UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) error { + return nil +} + +func (m *mockStreamsCapability) Start(ctx context.Context) error { + return nil +} + +func (m *mockStreamsCapability) Close() error { + return nil +} + +func (m *mockStreamsCapability) HealthReport() map[string]error { + return map[string]error{} +} + +func (m *mockStreamsCapability) Name() string { + return "MockStreams" +} + +func (m *mockStreamsCapability) Description() string { + return "Mock" +} + +func (m *mockStreamsCapability) Ready() error { + return nil +} + +func (m *mockStreamsCapability) Initialise(ctx context.Context, deps core.StandardCapabilitiesDependencies) error { + return nil +} + diff --git a/pkg/capabilities/v2/triggers/streams/server/authorized_server.go b/pkg/capabilities/v2/triggers/streams/server/authorized_server.go new file mode 100644 index 0000000000..4d5ba3c4cc --- /dev/null +++ b/pkg/capabilities/v2/triggers/streams/server/authorized_server.go @@ -0,0 +1,104 @@ +package server + +import ( + "context" + "fmt" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// AuthorizedStreamsCapability wraps StreamsCapability with authorization checks +type AuthorizedStreamsCapability struct { + StreamsCapability + authorizer *streams.Authorizer + lggr logger.Logger +} + +// NewAuthorizedStreamsCapability creates a new capability with authorization enabled +func NewAuthorizedStreamsCapability(capability StreamsCapability, authConfig streams.AuthConfig, lggr logger.Logger) (*AuthorizedStreamsCapability, error) { + authorizer, err := streams.NewAuthorizer(authConfig) + if err != nil { + return nil, fmt.Errorf("failed to create authorizer: %w", err) + } + + return &AuthorizedStreamsCapability{ + StreamsCapability: capability, + authorizer: authorizer, + lggr: logger.Named(lggr, "AuthorizedStreamsCapability"), + }, nil +} + +// NewDefaultDataFeedsCapability creates a capability with default Data Feeds authorization +// Only workflows with IDs starting with "df-" or names containing "data-feed" will be allowed +func NewDefaultDataFeedsCapability(capability StreamsCapability, lggr logger.Logger) (*AuthorizedStreamsCapability, error) { + authConfig := streams.AuthConfig{ + Enabled: true, + AllowedWorkflowPattern: "^df-.*", + AllowedWorkflowNamePattern: "data-feed", + } + + return NewAuthorizedStreamsCapability(capability, authConfig, lggr) +} + +// RegisterTrigger wraps the base RegisterTrigger with authorization check +func (a *AuthorizedStreamsCapability) RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], error) { + // Authorization check + if err := a.authorizer.IsAuthorized(metadata); err != nil { + a.lggr.Warnw("Unauthorized trigger registration attempt", + "workflowID", metadata.WorkflowID, + "workflowOwner", metadata.WorkflowOwner, + "error", err, + ) + return nil, fmt.Errorf("authorization failed: %w", err) + } + + a.lggr.Debugw("Authorized trigger registration", + "workflowID", metadata.WorkflowID, + "triggerID", triggerID, + ) + + // Call the underlying implementation + return a.StreamsCapability.RegisterTrigger(ctx, triggerID, metadata, input) +} + +// UnregisterTrigger wraps the base UnregisterTrigger with authorization check +func (a *AuthorizedStreamsCapability) UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) error { + // Authorization check + if err := a.authorizer.IsAuthorized(metadata); err != nil { + a.lggr.Warnw("Unauthorized trigger unregistration attempt", + "workflowID", metadata.WorkflowID, + "error", err, + ) + return fmt.Errorf("authorization failed: %w", err) + } + + a.lggr.Debugw("Authorized trigger unregistration", + "workflowID", metadata.WorkflowID, + "triggerID", triggerID, + ) + + // Call the underlying implementation + return a.StreamsCapability.UnregisterTrigger(ctx, triggerID, metadata, input) +} + +// NewAuthorizedStreamsServer creates a server wrapping an authorized capability +func NewAuthorizedStreamsServer(capability StreamsCapability, authConfig streams.AuthConfig, lggr logger.Logger) (*StreamsServer, error) { + authCap, err := NewAuthorizedStreamsCapability(capability, authConfig, lggr) + if err != nil { + return nil, err + } + + return NewStreamsServer(authCap), nil +} + +// NewDefaultDataFeedsServer creates a server with default Data Feeds authorization +func NewDefaultDataFeedsServer(capability StreamsCapability, lggr logger.Logger) (*StreamsServer, error) { + authCap, err := NewDefaultDataFeedsCapability(capability, lggr) + if err != nil { + return nil, err + } + + return NewStreamsServer(authCap), nil +} diff --git a/pkg/capabilities/v2/triggers/streams/streams_test.go b/pkg/capabilities/v2/triggers/streams/streams_test.go new file mode 100644 index 0000000000..718b4632d0 --- /dev/null +++ b/pkg/capabilities/v2/triggers/streams/streams_test.go @@ -0,0 +1,386 @@ +package streams_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/libocr/ragep2p/types" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams/server" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" +) + +// TestProtoTypesExist verifies that all protobuf types are properly generated +func TestProtoTypesExist(t *testing.T) { + // Config type + config := &streams.Config{ + FeedIds: []string{"0x0001", "0x0002"}, + MaxFrequencyMs: 5000, + } + assert.NotNil(t, config) + assert.Len(t, config.FeedIds, 2) + assert.Equal(t, uint64(5000), config.MaxFrequencyMs) + + // Feed type + feed := &streams.Feed{ + Timestamp: 1234567890, + Metadata: &streams.SignersMetadata{ + Signers: []string{"signer1", "signer2"}, + MinRequiredSignatures: 2, + }, + Payload: []*streams.FeedReport{ + { + FeedId: "0x0001", + FullReport: []byte("report-data"), + ReportContext: []byte("context"), + Signatures: [][]byte{[]byte("sig1")}, + BenchmarkPrice: []byte("price"), + ObservationTimestamp: 1234567890, + }, + }, + } + assert.NotNil(t, feed) + assert.Equal(t, int64(1234567890), feed.Timestamp) + assert.Len(t, feed.Payload, 1) +} + +// TestConfigGetters verifies getter methods work +func TestConfigGetters(t *testing.T) { + config := &streams.Config{ + FeedIds: []string{"0xfeed1", "0xfeed2", "0xfeed3"}, + MaxFrequencyMs: 10000, + } + + assert.Equal(t, []string{"0xfeed1", "0xfeed2", "0xfeed3"}, config.GetFeedIds()) + assert.Equal(t, uint64(10000), config.GetMaxFrequencyMs()) +} + +// TestFeedGetters verifies Feed getter methods +func TestFeedGetters(t *testing.T) { + metadata := &streams.SignersMetadata{ + Signers: []string{"signer1"}, + MinRequiredSignatures: 1, + } + + feed := &streams.Feed{ + Timestamp: 9999999999, + Metadata: metadata, + Payload: []*streams.FeedReport{}, + } + + assert.Equal(t, int64(9999999999), feed.GetTimestamp()) + assert.Equal(t, metadata, feed.GetMetadata()) + assert.NotNil(t, feed.GetPayload()) +} + +// TestStreamsCapabilityInterface verifies the server interface +func TestStreamsCapabilityInterface(t *testing.T) { + // Verify interface is defined correctly + var _ server.StreamsCapability = (*mockStreamsCapability)(nil) +} + +// mockStreamsCapability implements server.StreamsCapability for testing +type mockStreamsCapability struct { + registerCalled bool + unregisterCalled bool + startCalled bool + closeCalled bool +} + +func (m *mockStreamsCapability) RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], error) { + m.registerCalled = true + ch := make(chan capabilities.TriggerAndId[*streams.Feed], 1) + return ch, nil +} + +func (m *mockStreamsCapability) UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) error { + m.unregisterCalled = true + return nil +} + +func (m *mockStreamsCapability) Start(ctx context.Context) error { + m.startCalled = true + return nil +} + +func (m *mockStreamsCapability) Close() error { + m.closeCalled = true + return nil +} + +func (m *mockStreamsCapability) HealthReport() map[string]error { + return map[string]error{"mock": nil} +} + +func (m *mockStreamsCapability) Name() string { + return "MockStreamsCapability" +} + +func (m *mockStreamsCapability) Description() string { + return "Mock implementation for testing" +} + +func (m *mockStreamsCapability) Ready() error { + return nil +} + +func (m *mockStreamsCapability) Initialise(ctx context.Context, deps core.StandardCapabilitiesDependencies) error { + return nil +} + +// TestStreamsServerCreation tests creating a server wrapper +func TestStreamsServerCreation(t *testing.T) { + mock := &mockStreamsCapability{} + srv := server.NewStreamsServer(mock) + + require.NotNil(t, srv) + + // Test initialization + ctx := context.Background() + mockRegistry := &mockCapabilityRegistry{} + deps := core.StandardCapabilitiesDependencies{ + CapabilityRegistry: mockRegistry, + } + + err := srv.Initialise(ctx, deps) + assert.NoError(t, err) + + // Start should be called separately + err = mock.Start(ctx) + assert.NoError(t, err) + assert.True(t, mock.startCalled) + + // Test close + err = srv.Close() + assert.NoError(t, err) + // Note: Close on server doesn't automatically call Close on capability + err = mock.Close() + assert.NoError(t, err) + assert.True(t, mock.closeCalled) +} + +// TestTriggerRegistration tests the trigger registration flow +func TestTriggerRegistration(t *testing.T) { + mock := &mockStreamsCapability{} + + ctx := context.Background() + triggerID := "test-trigger-123" + metadata := capabilities.RequestMetadata{ + WorkflowID: "test-workflow", + } + + config := &streams.Config{ + FeedIds: []string{"0x0001"}, + MaxFrequencyMs: 1000, + } + + ch, err := mock.RegisterTrigger(ctx, triggerID, metadata, config) + require.NoError(t, err) + require.NotNil(t, ch) + assert.True(t, mock.registerCalled) + + // Test unregister + err = mock.UnregisterTrigger(ctx, triggerID, metadata, config) + assert.NoError(t, err) + assert.True(t, mock.unregisterCalled) +} + +// TestFeedReportStructure tests the FeedReport structure +func TestFeedReportStructure(t *testing.T) { + report := &streams.FeedReport{ + FeedId: "0xfeedid12345", + FullReport: []byte("full-report-bytes"), + ReportContext: []byte("report-context"), + Signatures: [][]byte{[]byte("sig1"), []byte("sig2")}, + BenchmarkPrice: []byte("benchmark-price-bytes"), + ObservationTimestamp: 1700000000, + } + + assert.Equal(t, "0xfeedid12345", report.GetFeedId()) + assert.Equal(t, []byte("full-report-bytes"), report.GetFullReport()) + assert.Equal(t, []byte("report-context"), report.GetReportContext()) + assert.Len(t, report.GetSignatures(), 2) + assert.Equal(t, []byte("benchmark-price-bytes"), report.GetBenchmarkPrice()) + assert.Equal(t, int64(1700000000), report.GetObservationTimestamp()) +} + +// TestSignersMetadata tests the SignersMetadata structure +func TestSignersMetadata(t *testing.T) { + metadata := &streams.SignersMetadata{ + Signers: []string{"0xsigner1", "0xsigner2", "0xsigner3"}, + MinRequiredSignatures: 2, + } + + assert.Len(t, metadata.GetSigners(), 3) + assert.Equal(t, int64(2), metadata.GetMinRequiredSignatures()) +} + +// TestConfigValidation tests configuration validation scenarios +func TestConfigValidation(t *testing.T) { + tests := []struct { + name string + config *streams.Config + expectValid bool + }{ + { + name: "valid config with single feed", + config: &streams.Config{ + FeedIds: []string{"0x0001"}, + MaxFrequencyMs: 1000, + }, + expectValid: true, + }, + { + name: "valid config with multiple feeds", + config: &streams.Config{ + FeedIds: []string{"0x0001", "0x0002", "0x0003"}, + MaxFrequencyMs: 5000, + }, + expectValid: true, + }, + { + name: "high frequency", + config: &streams.Config{ + FeedIds: []string{"0x0001"}, + MaxFrequencyMs: 100, + }, + expectValid: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Basic validation - config should be creatable + assert.NotNil(t, tt.config) + assert.NotEmpty(t, tt.config.FeedIds) + assert.Greater(t, tt.config.MaxFrequencyMs, uint64(0)) + }) + } +} + +// mockCapabilityRegistry for testing +type mockCapabilityRegistry struct { + added []capabilities.BaseCapability + removed []string +} + +func (m *mockCapabilityRegistry) Add(ctx context.Context, capability capabilities.BaseCapability) error { + m.added = append(m.added, capability) + return nil +} + +func (m *mockCapabilityRegistry) Remove(ctx context.Context, id string) error { + m.removed = append(m.removed, id) + return nil +} + +func (m *mockCapabilityRegistry) Get(ctx context.Context, id string) (capabilities.BaseCapability, error) { + return nil, nil +} + +func (m *mockCapabilityRegistry) GetTrigger(ctx context.Context, id string) (capabilities.TriggerCapability, error) { + return nil, nil +} + +func (m *mockCapabilityRegistry) GetAction(ctx context.Context, id string) (capabilities.ActionCapability, error) { + return nil, nil +} + +func (m *mockCapabilityRegistry) GetExecutable(ctx context.Context, id string) (capabilities.ExecutableCapability, error) { + return nil, nil +} + +func (m *mockCapabilityRegistry) GetConsensus(ctx context.Context, id string) (capabilities.ConsensusCapability, error) { + return nil, nil +} + +func (m *mockCapabilityRegistry) GetTarget(ctx context.Context, id string) (capabilities.TargetCapability, error) { + return nil, nil +} + +func (m *mockCapabilityRegistry) List(ctx context.Context) ([]capabilities.BaseCapability, error) { + return m.added, nil +} + +func (m *mockCapabilityRegistry) ConfigForCapability(ctx context.Context, capabilityID string, capabilityDonID uint32) (capabilities.CapabilityConfiguration, error) { + return capabilities.CapabilityConfiguration{}, nil +} + +func (m *mockCapabilityRegistry) DONsForCapability(ctx context.Context, id string) ([]capabilities.DONWithNodes, error) { + return nil, nil +} + +func (m *mockCapabilityRegistry) LocalNode(ctx context.Context) (capabilities.Node, error) { + return capabilities.Node{}, nil +} + +func (m *mockCapabilityRegistry) NodeByPeerID(ctx context.Context, peerID types.PeerID) (capabilities.Node, error) { + return capabilities.Node{}, nil +} + +// TestServerLifecycle tests the complete server lifecycle +func TestServerLifecycle(t *testing.T) { + mock := &mockStreamsCapability{} + srv := server.NewStreamsServer(mock) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + mockRegistry := &mockCapabilityRegistry{} + deps := core.StandardCapabilitiesDependencies{ + CapabilityRegistry: mockRegistry, + } + + // Initialize + err := srv.Initialise(ctx, deps) + require.NoError(t, err) + assert.Len(t, mockRegistry.added, 1, "Capability should be registered") + + // Start must be called separately + err = mock.Start(ctx) + require.NoError(t, err) + assert.True(t, mock.startCalled, "Start should be called") + + // Get infos + infos, err := srv.Infos(ctx) + require.NoError(t, err) + require.Len(t, infos, 1) + assert.Equal(t, "streams-trigger@1.0.0", infos[0].ID) + + // Close + err = srv.Close() + require.NoError(t, err) + assert.True(t, mock.closeCalled, "Close should be called") + assert.Len(t, mockRegistry.removed, 1, "Capability should be unregistered") + assert.Equal(t, "streams-trigger@1.0.0", mockRegistry.removed[0]) +} + +// BenchmarkFeedCreation benchmarks creating Feed objects +func BenchmarkFeedCreation(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = &streams.Feed{ + Timestamp: int64(i), + Metadata: &streams.SignersMetadata{ + Signers: []string{"signer1", "signer2"}, + MinRequiredSignatures: 2, + }, + Payload: []*streams.FeedReport{ + { + FeedId: "0x0001", + FullReport: []byte("report"), + ReportContext: []byte("context"), + Signatures: [][]byte{[]byte("sig")}, + BenchmarkPrice: []byte("price"), + ObservationTimestamp: int64(i), + }, + }, + } + } +} + From a835972b107902581e89beb5b0dbe7e6780af155 Mon Sep 17 00:00:00 2001 From: connorwstein Date: Fri, 5 Dec 2025 10:15:41 -0500 Subject: [PATCH 04/42] Add remaining functions to CLI --- keystore/cli/cli.go | 216 +++++++++++++++++++++++++++------------ keystore/cli/cli_test.go | 105 +++++++++++++++++-- 2 files changed, 247 insertions(+), 74 deletions(-) diff --git a/keystore/cli/cli.go b/keystore/cli/cli.go index 9946e13c50..a99b7f92f0 100644 --- a/keystore/cli/cli.go +++ b/keystore/cli/cli.go @@ -52,14 +52,14 @@ KEYSTORE_PASSWORD is the password used to encrypt the key material before storag cmd.PersistentFlags().String("keystore-db-url", "", "Overrides KEYSTORE_DB_URL environment variable") cmd.PersistentFlags().String("keystore-password", "", "Overrides KEYSTORE_PASSWORD environment variable. Not recommended as will leave shell traces.") - cmd.AddCommand(NewListCmd(), NewGetCmd(), NewCreateCmd(), NewDeleteCmd(), NewExportCmd(), NewImportCmd(), NewSetMetadataCmd()) + cmd.AddCommand(NewListCmd(), NewGetCmd(), NewCreateCmd(), NewDeleteCmd(), NewExportCmd(), NewImportCmd(), NewSetMetadataCmd(), NewSignCmd(), NewVerifyCmd(), NewEncryptCmd(), NewDecryptCmd(), NewDeriveSharedSecretCmd()) return cmd } func NewListCmd() *cobra.Command { cmd := cobra.Command{ Use: "list", Short: "List keys", - RunE: func(cmd *cobra.Command, _ []string) error { + RunE: func(cmd *cobra.Command, args []string) error { ctx, cancel := context.WithTimeout(cmd.Context(), KeystoreLoadTimeout) defer cancel() k, err := loadKeystore(ctx, cmd) @@ -84,31 +84,10 @@ func NewListCmd() *cobra.Command { func NewGetCmd() *cobra.Command { cmd := cobra.Command{ Use: "get", Short: "Get keys", - RunE: func(cmd *cobra.Command, _ []string) error { - jsonBytes, err := readJSONInput(cmd) - if err != nil { - return err - } - var req ks.GetKeysRequest - if err := json.Unmarshal(jsonBytes, &req); err != nil { - return fmt.Errorf("invalid JSON request: %w", err) - } - ctx, cancel := context.WithTimeout(cmd.Context(), KeystoreLoadTimeout) - defer cancel() - k, err := loadKeystore(ctx, cmd) - if err != nil { - return err - } - resp, err := k.GetKeys(ctx, req) - if err != nil { - return err - } - jsonBytes, err = json.Marshal(resp) - if err != nil { - return err - } - _, err = cmd.OutOrStdout().Write(jsonBytes) - return err + RunE: func(cmd *cobra.Command, args []string) error { + return runKeystoreCommand[ks.GetKeysRequest, ks.GetKeysResponse](cmd, args, func(ctx context.Context, k ks.Keystore, req ks.GetKeysRequest) (ks.GetKeysResponse, error) { + return k.GetKeys(ctx, req) + }) }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") @@ -119,32 +98,10 @@ func NewGetCmd() *cobra.Command { func NewCreateCmd() *cobra.Command { cmd := cobra.Command{ Use: "create", Short: "Create a key", - RunE: func(cmd *cobra.Command, _ []string) error { - jsonBytesIn, err := readJSONInput(cmd) - if err != nil { - return err - } - var req ks.CreateKeysRequest - err = json.Unmarshal(jsonBytesIn, &req) - if err != nil { - return err - } - ctx, cancel := context.WithTimeout(cmd.Context(), KeystoreLoadTimeout) - defer cancel() - k, err := loadKeystore(ctx, cmd) - if err != nil { - return err - } - resp, err := k.CreateKeys(ctx, req) - if err != nil { - return err - } - jsonBytes, err := json.Marshal(resp) - if err != nil { - return err - } - _, err = cmd.OutOrStdout().Write(jsonBytes) - return err + RunE: func(cmd *cobra.Command, args []string) error { + return runKeystoreCommand[ks.CreateKeysRequest, ks.CreateKeysResponse](cmd, args, func(ctx context.Context, k ks.Keystore, req ks.CreateKeysRequest) (ks.CreateKeysResponse, error) { + return k.CreateKeys(ctx, req) + }) }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") @@ -201,13 +158,120 @@ func NewDeleteCmd() *cobra.Command { func NewExportCmd() *cobra.Command { cmd := cobra.Command{ Use: "export", Short: "Export a key to an encrypted JSON file", + RunE: func(cmd *cobra.Command, args []string) error { + return runKeystoreCommand[ks.ExportKeysRequest, ks.ExportKeysResponse](cmd, args, func(ctx context.Context, k ks.Keystore, req ks.ExportKeysRequest) (ks.ExportKeysResponse, error) { + return k.ExportKeys(ctx, req) + }) + }, + } + cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") + cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"Keys\": [{\"KeyName\": \"key1\", \"Enc\": {\"Password\": \"pass\", \"ScryptParams\": {\"N\": 1024, \"P\": 1, \"R\": 8}}}]}'") + return &cmd +} + +func NewImportCmd() *cobra.Command { + cmd := cobra.Command{ + Use: "import", Short: "Import an encrypted key JSON file", + RunE: func(cmd *cobra.Command, args []string) error { + return runKeystoreCommand[ks.ImportKeysRequest, ks.ImportKeysResponse](cmd, args, func(ctx context.Context, k ks.Keystore, req ks.ImportKeysRequest) (ks.ImportKeysResponse, error) { + return k.ImportKeys(ctx, req) + }) + }, + } + cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") + cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"Keys\": [{\"KeyName\": \"key1\", \"Data\": \"encBytes\", \"Password\": \"pass\"}]}'") + return &cmd +} + +func NewSetMetadataCmd() *cobra.Command { + cmd := cobra.Command{ + Use: "set-metadata", Short: "Set metadata for keys", + RunE: func(cmd *cobra.Command, args []string) error { + return runKeystoreCommand[ks.SetMetadataRequest, ks.SetMetadataResponse](cmd, args, func(ctx context.Context, k ks.Keystore, req ks.SetMetadataRequest) (ks.SetMetadataResponse, error) { + return k.SetMetadata(ctx, req) + }) + }, + } + cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") + cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"Updates\": [{\"KeyName\": \"key1\", \"Metadata\": \"base64-encoded-metadata\"}]}'") + return &cmd +} + +func zeroValue[T any]() T { + var t T + return t +} + +func runKeystoreCommand[Req any, Resp any](cmd *cobra.Command, args []string, fn func(ctx context.Context, k ks.Keystore, + req Req) (Resp, error)) error { + jsonBytes, err := readJSONInput(cmd) + if err != nil { + return err + } + var req Req + err = json.Unmarshal(jsonBytes, &req) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(cmd.Context(), KeystoreLoadTimeout) + defer cancel() + k, err := loadKeystore(ctx, cmd) + if err != nil { + return err + } + resp, err := fn(ctx, k, req) + if err != nil { + return err + } + jsonBytesOut, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = cmd.OutOrStdout().Write(jsonBytesOut) + if err != nil { + return err + } + return nil +} + +func NewSignCmd() *cobra.Command { + cmd := cobra.Command{ + Use: "sign", Short: "Sign data with a key", + RunE: func(cmd *cobra.Command, args []string) error { + return runKeystoreCommand[ks.SignRequest, ks.SignResponse](cmd, args, func(ctx context.Context, k ks.Keystore, req ks.SignRequest) (ks.SignResponse, error) { + return k.Sign(ctx, req) + }) + }, + } + cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") + cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"KeyName\": \"key1\", \"Data\": \"base64-encoded-data\"}'") + return &cmd +} + +func NewVerifyCmd() *cobra.Command { + cmd := cobra.Command{ + Use: "verify", Short: "Verify a signature", + RunE: func(cmd *cobra.Command, args []string) error { + return runKeystoreCommand[ks.VerifyRequest, ks.VerifyResponse](cmd, args, func(ctx context.Context, k ks.Keystore, req ks.VerifyRequest) (ks.VerifyResponse, error) { + return k.Verify(ctx, req) + }) + }, + } + cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") + cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"KeyType\": \"Ed25519\", \"PublicKey\": \"base64-pubkey\", \"Data\": \"base64-data\", \"Signature\": \"base64-sig\"}'") + return &cmd +} + +func NewEncryptCmd() *cobra.Command { + cmd := cobra.Command{ + Use: "encrypt", Short: "Encrypt data to a remote public key", RunE: func(cmd *cobra.Command, _ []string) error { - jsonBytesIn, err := readJSONInput(cmd) + jsonBytes, err := readJSONInput(cmd) if err != nil { return err } - var req ks.ExportKeysRequest - err = json.Unmarshal(jsonBytesIn, &req) + var req ks.EncryptRequest + err = json.Unmarshal(jsonBytes, &req) if err != nil { return err } @@ -217,7 +281,7 @@ func NewExportCmd() *cobra.Command { if err != nil { return err } - resp, err := k.ExportKeys(ctx, req) + resp, err := k.Encrypt(ctx, req) if err != nil { return err } @@ -230,19 +294,19 @@ func NewExportCmd() *cobra.Command { }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") - cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"Keys\": [{\"KeyName\": \"key1\", \"Enc\": {\"Password\": \"pass\", \"ScryptParams\": {\"N\": 1024, \"P\": 1, \"R\": 8}}}]}'") + cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"RemoteKeyType\": \"X25519\", \"RemotePubKey\": \"base64-pubkey\", \"Data\": \"base64-data\"}'") return &cmd } -func NewImportCmd() *cobra.Command { +func NewDecryptCmd() *cobra.Command { cmd := cobra.Command{ - Use: "import", Short: "Import an encrypted key JSON file", + Use: "decrypt", Short: "Decrypt data with a key", RunE: func(cmd *cobra.Command, _ []string) error { jsonBytes, err := readJSONInput(cmd) if err != nil { return err } - var req ks.ImportKeysRequest + var req ks.DecryptRequest err = json.Unmarshal(jsonBytes, &req) if err != nil { return err @@ -253,24 +317,32 @@ func NewImportCmd() *cobra.Command { if err != nil { return err } - _, err = k.ImportKeys(ctx, req) + resp, err := k.Decrypt(ctx, req) + if err != nil { + return err + } + jsonBytesOut, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = cmd.OutOrStdout().Write(jsonBytesOut) return err }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") - cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"Keys\": [{\"KeyName\": \"key1\", \"Data\": \"encBytes\", \"Password\": \"pass\"}]}'") + cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"KeyName\": \"key1\", \"EncryptedData\": \"base64-encrypted-data\"}'") return &cmd } -func NewSetMetadataCmd() *cobra.Command { +func NewDeriveSharedSecretCmd() *cobra.Command { cmd := cobra.Command{ - Use: "set-metadata", Short: "Set metadata for keys", + Use: "derive-shared-secret", Short: "Derive a shared secret between a key and a remote public key", RunE: func(cmd *cobra.Command, _ []string) error { jsonBytes, err := readJSONInput(cmd) if err != nil { return err } - var req ks.SetMetadataRequest + var req ks.DeriveSharedSecretRequest err = json.Unmarshal(jsonBytes, &req) if err != nil { return err @@ -281,12 +353,20 @@ func NewSetMetadataCmd() *cobra.Command { if err != nil { return err } - _, err = k.SetMetadata(ctx, req) + resp, err := k.DeriveSharedSecret(ctx, req) + if err != nil { + return err + } + jsonBytesOut, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = cmd.OutOrStdout().Write(jsonBytesOut) return err }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") - cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"Updates\": [{\"KeyName\": \"key1\", \"Metadata\": \"base64-encoded-metadata\"}]}'") + cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"KeyName\": \"key1\", \"RemotePubKey\": \"base64-remote-pubkey\"}'") return &cmd } diff --git a/keystore/cli/cli_test.go b/keystore/cli/cli_test.go index ce869cd0bc..2ccfaa7f11 100644 --- a/keystore/cli/cli_test.go +++ b/keystore/cli/cli_test.go @@ -2,6 +2,7 @@ package cli_test import ( "bytes" + "crypto/sha256" "encoding/base64" "encoding/json" "os" @@ -14,18 +15,25 @@ import ( "github.com/smartcontractkit/chainlink-common/keystore/cli" ) -func TestCLI(t *testing.T) { +func setupKeystore(t *testing.T) func(t *testing.T) { tempDir := t.TempDir() - defer os.RemoveAll(tempDir) keystoreFile := filepath.Join(tempDir, "keystore.json") f, err := os.Create(keystoreFile) require.NoError(t, err) - defer f.Close() - os.Setenv("KEYSTORE_FILE_PATH", keystoreFile) - os.Setenv("KEYSTORE_PASSWORD", "testpassword") + t.Setenv("KEYSTORE_FILE_PATH", keystoreFile) + t.Setenv("KEYSTORE_PASSWORD", "testpassword") + return func(t *testing.T) { + f.Close() + os.RemoveAll(tempDir) + } +} + +func TestAdminCLI(t *testing.T) { + teardown := setupKeystore(t) + defer teardown(t) // No error just listing help. - _, err = runCommand(t, nil, "") + _, err := runCommand(t, nil, "") require.NoError(t, err) // Create a key. @@ -114,6 +122,91 @@ func TestCLI(t *testing.T) { require.Empty(t, resp.Keys) } +func TestSignerCLI(t *testing.T) { + teardown := setupKeystore(t) + defer teardown(t) + + // Create an ECDSA key for signing. + _, err := runCommand(t, nil, "create", "-d", `{"Keys": [{"KeyName": "ecdsakey", "KeyType": "ECDSA_S256"}]}`) + require.NoError(t, err) + + // Get the key to retrieve the public key. + out, err := runCommand(t, nil, "get", "-d", `{"KeyNames": ["ecdsakey"]}`) + require.NoError(t, err) + getResp := ks.GetKeysResponse{} + err = json.Unmarshal(out.Bytes(), &getResp) + require.NoError(t, err) + require.Len(t, getResp.Keys, 1) + publicKey := getResp.Keys[0].KeyInfo.PublicKey + + // ECDSA_S256 requires a 32-byte hash to sign. + dataToSign := sha256.Sum256([]byte("hello world")) + dataB64 := base64.StdEncoding.EncodeToString(dataToSign[:]) + out, err = runCommand(t, nil, "sign", "-d", `{"KeyName": "ecdsakey", "Data": "`+dataB64+`"}`) + require.NoError(t, err) + signResp := ks.SignResponse{} + err = json.Unmarshal(out.Bytes(), &signResp) + require.NoError(t, err) + require.NotEmpty(t, signResp.Signature) + + // Verify the signature. + sigB64 := base64.StdEncoding.EncodeToString(signResp.Signature) + pubKeyB64 := base64.StdEncoding.EncodeToString(publicKey) + out, err = runCommand(t, nil, "verify", "-d", `{"KeyType": "ECDSA_S256", "PublicKey": "`+pubKeyB64+`", "Data": "`+dataB64+`", "Signature": "`+sigB64+`"}`) + require.NoError(t, err) + verifyResp := ks.VerifyResponse{} + err = json.Unmarshal(out.Bytes(), &verifyResp) + require.NoError(t, err) + require.True(t, verifyResp.Valid) +} + +func TestEncryptDecryptCLI(t *testing.T) { + teardown := setupKeystore(t) + defer teardown(t) + + // Create an X25519 key for encryption. + _, err := runCommand(t, nil, "create", "-d", `{"Keys": [{"KeyName": "x25519key", "KeyType": "X25519"}]}`) + require.NoError(t, err) + + // Get the key to retrieve the public key. + out, err := runCommand(t, nil, "get", "-d", `{"KeyNames": ["x25519key"]}`) + require.NoError(t, err) + getResp := ks.GetKeysResponse{} + err = json.Unmarshal(out.Bytes(), &getResp) + require.NoError(t, err) + require.Len(t, getResp.Keys, 1) + publicKey := getResp.Keys[0].KeyInfo.PublicKey + + // Encrypt some data to the key's public key. + plaintext := []byte("secret message") + pubKeyB64 := base64.StdEncoding.EncodeToString(publicKey) + plaintextB64 := base64.StdEncoding.EncodeToString(plaintext) + out, err = runCommand(t, nil, "encrypt", "-d", `{"RemoteKeyType": "X25519", "RemotePubKey": "`+pubKeyB64+`", "Data": "`+plaintextB64+`"}`) + require.NoError(t, err) + encryptResp := ks.EncryptResponse{} + err = json.Unmarshal(out.Bytes(), &encryptResp) + require.NoError(t, err) + require.NotEmpty(t, encryptResp.EncryptedData) + + // Decrypt the data. + encryptedB64 := base64.StdEncoding.EncodeToString(encryptResp.EncryptedData) + out, err = runCommand(t, nil, "decrypt", "-d", `{"KeyName": "x25519key", "EncryptedData": "`+encryptedB64+`"}`) + require.NoError(t, err) + decryptResp := ks.DecryptResponse{} + err = json.Unmarshal(out.Bytes(), &decryptResp) + require.NoError(t, err) + require.Equal(t, plaintext, decryptResp.Data) + + // Derive shared secret from key's perspective (key private + key public). + pubKeyB64 = base64.StdEncoding.EncodeToString(publicKey) + out, err = runCommand(t, nil, "derive-shared-secret", "-d", `{"KeyName": "x25519key", "RemotePubKey": "`+pubKeyB64+`"}`) + require.NoError(t, err) + deriveResp := ks.DeriveSharedSecretResponse{} + err = json.Unmarshal(out.Bytes(), &deriveResp) + require.NoError(t, err) + require.NotEmpty(t, deriveResp.SharedSecret) +} + func runCommand(t *testing.T, in *bytes.Buffer, args ...string) (bytes.Buffer, error) { // Cobra commands are stateful which can cause subtle bugs if not reset. // For simplicity just create a fresh object. From 7fedbfbaa4461378ad715f76ee8ba18fbc4402aa Mon Sep 17 00:00:00 2001 From: connorwstein Date: Fri, 5 Dec 2025 10:23:20 -0500 Subject: [PATCH 05/42] Cleanup encrypt/decrypt --- keystore/cli/cli.go | 98 ++++------------------------------------ keystore/cli/cli_test.go | 9 ---- 2 files changed, 9 insertions(+), 98 deletions(-) diff --git a/keystore/cli/cli.go b/keystore/cli/cli.go index a99b7f92f0..1dea1d6d02 100644 --- a/keystore/cli/cli.go +++ b/keystore/cli/cli.go @@ -52,7 +52,7 @@ KEYSTORE_PASSWORD is the password used to encrypt the key material before storag cmd.PersistentFlags().String("keystore-db-url", "", "Overrides KEYSTORE_DB_URL environment variable") cmd.PersistentFlags().String("keystore-password", "", "Overrides KEYSTORE_PASSWORD environment variable. Not recommended as will leave shell traces.") - cmd.AddCommand(NewListCmd(), NewGetCmd(), NewCreateCmd(), NewDeleteCmd(), NewExportCmd(), NewImportCmd(), NewSetMetadataCmd(), NewSignCmd(), NewVerifyCmd(), NewEncryptCmd(), NewDecryptCmd(), NewDeriveSharedSecretCmd()) + cmd.AddCommand(NewListCmd(), NewGetCmd(), NewCreateCmd(), NewDeleteCmd(), NewExportCmd(), NewImportCmd(), NewSetMetadataCmd(), NewSignCmd(), NewVerifyCmd(), NewEncryptCmd(), NewDecryptCmd()) return cmd } @@ -265,32 +265,10 @@ func NewVerifyCmd() *cobra.Command { func NewEncryptCmd() *cobra.Command { cmd := cobra.Command{ Use: "encrypt", Short: "Encrypt data to a remote public key", - RunE: func(cmd *cobra.Command, _ []string) error { - jsonBytes, err := readJSONInput(cmd) - if err != nil { - return err - } - var req ks.EncryptRequest - err = json.Unmarshal(jsonBytes, &req) - if err != nil { - return err - } - ctx, cancel := context.WithTimeout(cmd.Context(), KeystoreLoadTimeout) - defer cancel() - k, err := loadKeystore(ctx, cmd) - if err != nil { - return err - } - resp, err := k.Encrypt(ctx, req) - if err != nil { - return err - } - jsonBytesOut, err := json.Marshal(resp) - if err != nil { - return err - } - _, err = cmd.OutOrStdout().Write(jsonBytesOut) - return err + RunE: func(cmd *cobra.Command, args []string) error { + return runKeystoreCommand[ks.EncryptRequest, ks.EncryptResponse](cmd, args, func(ctx context.Context, k ks.Keystore, req ks.EncryptRequest) (ks.EncryptResponse, error) { + return k.Encrypt(ctx, req) + }) }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") @@ -301,32 +279,10 @@ func NewEncryptCmd() *cobra.Command { func NewDecryptCmd() *cobra.Command { cmd := cobra.Command{ Use: "decrypt", Short: "Decrypt data with a key", - RunE: func(cmd *cobra.Command, _ []string) error { - jsonBytes, err := readJSONInput(cmd) - if err != nil { - return err - } - var req ks.DecryptRequest - err = json.Unmarshal(jsonBytes, &req) - if err != nil { - return err - } - ctx, cancel := context.WithTimeout(cmd.Context(), KeystoreLoadTimeout) - defer cancel() - k, err := loadKeystore(ctx, cmd) - if err != nil { - return err - } - resp, err := k.Decrypt(ctx, req) - if err != nil { - return err - } - jsonBytesOut, err := json.Marshal(resp) - if err != nil { - return err - } - _, err = cmd.OutOrStdout().Write(jsonBytesOut) - return err + RunE: func(cmd *cobra.Command, args []string) error { + return runKeystoreCommand[ks.DecryptRequest, ks.DecryptResponse](cmd, args, func(ctx context.Context, k ks.Keystore, req ks.DecryptRequest) (ks.DecryptResponse, error) { + return k.Decrypt(ctx, req) + }) }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") @@ -334,42 +290,6 @@ func NewDecryptCmd() *cobra.Command { return &cmd } -func NewDeriveSharedSecretCmd() *cobra.Command { - cmd := cobra.Command{ - Use: "derive-shared-secret", Short: "Derive a shared secret between a key and a remote public key", - RunE: func(cmd *cobra.Command, _ []string) error { - jsonBytes, err := readJSONInput(cmd) - if err != nil { - return err - } - var req ks.DeriveSharedSecretRequest - err = json.Unmarshal(jsonBytes, &req) - if err != nil { - return err - } - ctx, cancel := context.WithTimeout(cmd.Context(), KeystoreLoadTimeout) - defer cancel() - k, err := loadKeystore(ctx, cmd) - if err != nil { - return err - } - resp, err := k.DeriveSharedSecret(ctx, req) - if err != nil { - return err - } - jsonBytesOut, err := json.Marshal(resp) - if err != nil { - return err - } - _, err = cmd.OutOrStdout().Write(jsonBytesOut) - return err - }, - } - cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") - cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"KeyName\": \"key1\", \"RemotePubKey\": \"base64-remote-pubkey\"}'") - return &cmd -} - func loadKeystore(ctx context.Context, cmd *cobra.Command) (ks.Keystore, error) { root := cmd.Root() filePath, err := root.Flags().GetString("keystore-file-path") diff --git a/keystore/cli/cli_test.go b/keystore/cli/cli_test.go index 2ccfaa7f11..c7adc1b111 100644 --- a/keystore/cli/cli_test.go +++ b/keystore/cli/cli_test.go @@ -196,15 +196,6 @@ func TestEncryptDecryptCLI(t *testing.T) { err = json.Unmarshal(out.Bytes(), &decryptResp) require.NoError(t, err) require.Equal(t, plaintext, decryptResp.Data) - - // Derive shared secret from key's perspective (key private + key public). - pubKeyB64 = base64.StdEncoding.EncodeToString(publicKey) - out, err = runCommand(t, nil, "derive-shared-secret", "-d", `{"KeyName": "x25519key", "RemotePubKey": "`+pubKeyB64+`"}`) - require.NoError(t, err) - deriveResp := ks.DeriveSharedSecretResponse{} - err = json.Unmarshal(out.Bytes(), &deriveResp) - require.NoError(t, err) - require.NotEmpty(t, deriveResp.SharedSecret) } func runCommand(t *testing.T, in *bytes.Buffer, args ...string) (bytes.Buffer, error) { From 6ede41258696601efe9d2a153d235245b4ab70e0 Mon Sep 17 00:00:00 2001 From: connorwstein Date: Fri, 5 Dec 2025 13:07:16 -0500 Subject: [PATCH 06/42] Cleanup, UX improvements --- keystore/admin.go | 2 +- keystore/cli/cli.go | 62 +++++++++++++++++++++++++++++++++++++++---- keystore/encryptor.go | 11 +++++--- keystore/keystore.go | 16 ++++++++--- keystore/signer.go | 4 +-- 5 files changed, 81 insertions(+), 14 deletions(-) diff --git a/keystore/admin.go b/keystore/admin.go index 0aa4b69d94..f6754ba053 100644 --- a/keystore/admin.go +++ b/keystore/admin.go @@ -232,7 +232,7 @@ func (ks *keystore) CreateKeys(ctx context.Context, req CreateKeysRequest) (Crea } ksCopy[keyReq.KeyName] = newKey(keyReq.KeyType, internal.NewRaw(privateKey.Bytes()), publicKey, time.Now(), []byte{}) default: - return CreateKeysResponse{}, fmt.Errorf("%w: %s", ErrUnsupportedKeyType, keyReq.KeyType) + return CreateKeysResponse{}, fmt.Errorf("%w: %s, available key types: %s", ErrUnsupportedKeyType, keyReq.KeyType, AllKeyTypes.String()) } created := ksCopy[keyReq.KeyName].createdAt diff --git a/keystore/cli/cli.go b/keystore/cli/cli.go index 1dea1d6d02..1bf4d71aa0 100644 --- a/keystore/cli/cli.go +++ b/keystore/cli/cli.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log/slog" "os" "strings" "time" @@ -244,7 +245,32 @@ func NewSignCmd() *cobra.Command { }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") - cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"KeyName\": \"key1\", \"Data\": \"base64-encoded-data\"}'") + cmd.Flags().StringP("data", "d", "", ` + Inline JSON request. Data is base64-encoded. + Example: + echo -n 'hello' | base64 + aGVsbG8= + + ./keystore list | jq + { + "Keys": [ + { + "KeyName": "mykey", + "KeyType": "Ed25519", + "CreatedAt": "2025-01-01T00:00:00Z", + "PublicKey": "GJnS+erQbyuEm1byCjXy+6JqyX5hrGLE8oUuHSb9DFc=" + } + ] + } + ./keystore sign -d '{"KeyName": "mykey", "Data": "aGVsbG8="}' | jq + { + "Signature": "OVPaQIwQAZycQtiGjhwxZ3KmAdXOHczwi3LpwQTCbtMHfy5mmrp0KusICSO0lzCMeQvxJcd5y6f3siQsohQeCg==" + } + ./keystore verify -d '{"KeyType": "Ed25519", "PublicKey": "GJnS+erQbyuEm1byCjXy+6JqyX5hrGLE8oUuHSb9DFc=", "Data": "aGVsbG8=", "Signature": "OVPaQIwQAZycQtiGjhwxZ3KmAdXOHczwi3LpwQTCbtMHfy5mmrp0KusICSO0lzCMeQvxJcd5y6f3siQsohQeCg=="}' | jq + { + "Valid": true + } + `) return &cmd } @@ -258,7 +284,7 @@ func NewVerifyCmd() *cobra.Command { }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") - cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"KeyType\": \"Ed25519\", \"PublicKey\": \"base64-pubkey\", \"Data\": \"base64-data\", \"Signature\": \"base64-sig\"}'") + cmd.Flags().StringP("data", "d", "", `inline JSON request. All byte fields are base64-encoded. Example: '{"KeyType": "Ed25519", "PublicKey": "", "Data": "aGVsbG8=", "Signature": ""}'`) return &cmd } @@ -272,7 +298,32 @@ func NewEncryptCmd() *cobra.Command { }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") - cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"RemoteKeyType\": \"X25519\", \"RemotePubKey\": \"base64-pubkey\", \"Data\": \"base64-data\"}'") + cmd.Flags().StringP("data", "d", "", ` + Inline JSON request. Data/RemotePubKey are base64-encoded. + Example: + echo -n 'hello' | base64 + aGVsbG8= + + ./keystore list | jq + { + "Keys": [ + { + "KeyName": "x25519key", + "KeyType": "X25519", + "CreatedAt": "2025-01-01T00:00:00Z", + "PublicKey": "GJnS+erQbyuEm1byCjXy+6JqyX5hrGLE8oUuHSb9DFc=" + } + ] + } + ./keystore encrypt -d '{"RemoteKeyType": "X25519", "RemotePubKey": "GJnS+erQbyuEm1byCjXy+6JqyX5hrGLE8oUuHSb9DFc=", "Data": "aGVsbG8="}' | jq + { + "EncryptedData": "ZGVjb3JhdGVkRGF0YQ==" + } + ./keystore decrypt -d '{"KeyName": "x25519key", "EncryptedData": "ZGVjb3JhdGVkRGF0YQ=="}' | jq + { + "Data": "aGVsbG8=" + } + `) return &cmd } @@ -286,7 +337,7 @@ func NewDecryptCmd() *cobra.Command { }, } cmd.Flags().StringP("file", "f", "", "input file path (use \"-\" for stdin)") - cmd.Flags().StringP("data", "d", "", "inline JSON request, e.g. '{\"KeyName\": \"key1\", \"EncryptedData\": \"base64-encrypted-data\"}'") + cmd.Flags().StringP("data", "d", "", `inline JSON request. EncryptedData is base64-encoded. Example: '{"KeyName": "mykey", "EncryptedData": ""}'`) return &cmd } @@ -339,7 +390,8 @@ func loadKeystore(ctx context.Context, cmd *cobra.Command) (ks.Keystore, error) } // Can revisit whether custom scrypt params are actually needed in a CLI context // (I doubt it, so simpler to leave out). - return ks.LoadKeystore(ctx, storage, password) + lggr := slog.New(slog.NewTextHandler(os.Stdout, nil)) + return ks.LoadKeystore(ctx, storage, password, ks.WithLogger(lggr)) } // readJSONInput reads JSON from either -f/--file (file path, "-" for stdin) or -d/--data (inline JSON). diff --git a/keystore/encryptor.go b/keystore/encryptor.go index 61723f7aa9..ff650aec3f 100644 --- a/keystore/encryptor.go +++ b/keystore/encryptor.go @@ -102,6 +102,11 @@ func (k *keystore) Encrypt(ctx context.Context, req EncryptRequest) (EncryptResp return EncryptResponse{}, ErrEncryptionFailed } + if len(req.RemotePubKey) == 0 { + k.lggr.Error("encrypt failed: remote public key is empty") + return EncryptResponse{}, ErrEncryptionFailed + } + switch req.RemoteKeyType { case X25519: encrypted, err := k.encryptX25519Anonymous(req.Data, req.RemotePubKey) @@ -120,7 +125,7 @@ func (k *keystore) Encrypt(ctx context.Context, req EncryptRequest) (EncryptResp } return EncryptResponse{EncryptedData: encrypted}, nil default: - k.lggr.Error("encrypt failed: unsupported remote key type", "remoteKeyType", req.RemoteKeyType.String()) + k.lggr.Error("encrypt failed: unsupported remote key type, available key types: %s", "remoteKeyType", req.RemoteKeyType.String(), AllEncryptionKeyTypes.String()) return EncryptResponse{}, ErrEncryptionFailed } } @@ -156,7 +161,7 @@ func (k *keystore) Decrypt(ctx context.Context, req DecryptRequest) (DecryptResp } return DecryptResponse{Data: decrypted}, nil default: - k.lggr.Error("decrypt failed: unsupported key type", "keyType", key.keyType.String()) + k.lggr.Error("decrypt failed: unsupported key type, available key types: %s", "keyType", key.keyType.String(), AllEncryptionKeyTypes.String()) return DecryptResponse{}, ErrDecryptionFailed } } @@ -209,7 +214,7 @@ func (k *keystore) DeriveSharedSecret(ctx context.Context, req DeriveSharedSecre } return DeriveSharedSecretResponse{SharedSecret: shared}, nil default: - k.lggr.Error("derive shared secret failed: unsupported key type", "keyType", key.keyType.String()) + k.lggr.Error("derive shared secret failed: unsupported key type, available key types: %s", "keyType", key.keyType.String(), AllEncryptionKeyTypes.String()) return DeriveSharedSecretResponse{}, ErrSharedSecretFailed } } diff --git a/keystore/keystore.go b/keystore/keystore.go index ccc107424b..206ee4685a 100644 --- a/keystore/keystore.go +++ b/keystore/keystore.go @@ -98,9 +98,19 @@ const ( ECDSA_S256 KeyType = "ECDSA_S256" ) -var AllKeyTypes = []KeyType{X25519, ECDH_P256, Ed25519, ECDSA_S256} -var AllEncryptionKeyTypes = []KeyType{X25519, ECDH_P256} -var AllDigitalSignatureKeyTypes = []KeyType{Ed25519, ECDSA_S256} +type KeyTypeList []KeyType + +func (k KeyTypeList) String() string { + types := make([]string, 0, len(k)) + for _, k := range k { + types = append(types, k.String()) + } + return strings.Join(types, ", ") +} + +var AllKeyTypes = KeyTypeList{X25519, ECDH_P256, Ed25519, ECDSA_S256} +var AllEncryptionKeyTypes = KeyTypeList{X25519, ECDH_P256} +var AllDigitalSignatureKeyTypes = KeyTypeList{Ed25519, ECDSA_S256} type ScryptParams struct { N int diff --git a/keystore/signer.go b/keystore/signer.go index 948e9f8777..c909b73388 100644 --- a/keystore/signer.go +++ b/keystore/signer.go @@ -82,7 +82,7 @@ func (k *keystore) Sign(ctx context.Context, req SignRequest) (SignResponse, err Signature: signature, }, nil default: - return SignResponse{}, fmt.Errorf("unsupported key type: %s", key.keyType) + return SignResponse{}, fmt.Errorf("unsupported key type: %s, available key types: %s", key.keyType, AllDigitalSignatureKeyTypes.String()) } } @@ -119,6 +119,6 @@ func (k *keystore) Verify(ctx context.Context, req VerifyRequest) (VerifyRespons Valid: valid, }, nil default: - return VerifyResponse{}, fmt.Errorf("unsupported key type: %s", req.KeyType) + return VerifyResponse{}, fmt.Errorf("unsupported key type: %s, available key types: %s", req.KeyType, AllDigitalSignatureKeyTypes.String()) } } From a68137696ae905ebee89e289ba4e2e692a9dbdc2 Mon Sep 17 00:00:00 2001 From: connorwstein Date: Fri, 5 Dec 2025 13:09:42 -0500 Subject: [PATCH 07/42] Remove log --- keystore/cli/cli.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keystore/cli/cli.go b/keystore/cli/cli.go index 1bf4d71aa0..eff6b8b8f4 100644 --- a/keystore/cli/cli.go +++ b/keystore/cli/cli.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "io" - "log/slog" "os" "strings" "time" @@ -390,8 +389,7 @@ func loadKeystore(ctx context.Context, cmd *cobra.Command) (ks.Keystore, error) } // Can revisit whether custom scrypt params are actually needed in a CLI context // (I doubt it, so simpler to leave out). - lggr := slog.New(slog.NewTextHandler(os.Stdout, nil)) - return ks.LoadKeystore(ctx, storage, password, ks.WithLogger(lggr)) + return ks.LoadKeystore(ctx, storage, password) } // readJSONInput reads JSON from either -f/--file (file path, "-" for stdin) or -d/--data (inline JSON). From 4b04fc9acc2459db43dc69f7a7d4eef3cff4ad76 Mon Sep 17 00:00:00 2001 From: connorwstein Date: Fri, 5 Dec 2025 13:12:48 -0500 Subject: [PATCH 08/42] Fix log statements --- keystore/encryptor.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keystore/encryptor.go b/keystore/encryptor.go index ff650aec3f..a641685adf 100644 --- a/keystore/encryptor.go +++ b/keystore/encryptor.go @@ -125,7 +125,7 @@ func (k *keystore) Encrypt(ctx context.Context, req EncryptRequest) (EncryptResp } return EncryptResponse{EncryptedData: encrypted}, nil default: - k.lggr.Error("encrypt failed: unsupported remote key type, available key types: %s", "remoteKeyType", req.RemoteKeyType.String(), AllEncryptionKeyTypes.String()) + k.lggr.Error("encrypt failed: unsupported remote key type", "remoteKeyType", req.RemoteKeyType.String(), "availableKeyTypes", AllEncryptionKeyTypes.String()) return EncryptResponse{}, ErrEncryptionFailed } } @@ -161,7 +161,7 @@ func (k *keystore) Decrypt(ctx context.Context, req DecryptRequest) (DecryptResp } return DecryptResponse{Data: decrypted}, nil default: - k.lggr.Error("decrypt failed: unsupported key type, available key types: %s", "keyType", key.keyType.String(), AllEncryptionKeyTypes.String()) + k.lggr.Error("decrypt failed: unsupported key type", "keyType", key.keyType.String(), "availableKeyTypes", AllEncryptionKeyTypes.String()) return DecryptResponse{}, ErrDecryptionFailed } } @@ -214,7 +214,7 @@ func (k *keystore) DeriveSharedSecret(ctx context.Context, req DeriveSharedSecre } return DeriveSharedSecretResponse{SharedSecret: shared}, nil default: - k.lggr.Error("derive shared secret failed: unsupported key type, available key types: %s", "keyType", key.keyType.String(), AllEncryptionKeyTypes.String()) + k.lggr.Error("derive shared secret failed: unsupported key type", "keyType", key.keyType.String(), "availableKeyTypes", AllEncryptionKeyTypes.String()) return DeriveSharedSecretResponse{}, ErrSharedSecretFailed } } From 96d550af538a7405be75dbe1d5adb171f1f177e9 Mon Sep 17 00:00:00 2001 From: connorwstein Date: Fri, 5 Dec 2025 13:23:34 -0500 Subject: [PATCH 09/42] Fix linter add doc link --- keystore/README.md | 2 ++ keystore/cli/cli.go | 5 ----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/keystore/README.md b/keystore/README.md index 14924d70da..f72cac1362 100644 --- a/keystore/README.md +++ b/keystore/README.md @@ -1,3 +1,5 @@ +[![Go Reference](https://pkg.go.dev/badge/github.com/smartcontractkit/chainlink-common/keystore.svg)](https://pkg.go.dev/github.com/smartcontractkit/chainlink-common/keystore) + WARNING: In development do not use in production. # Keystore diff --git a/keystore/cli/cli.go b/keystore/cli/cli.go index eff6b8b8f4..912301eaff 100644 --- a/keystore/cli/cli.go +++ b/keystore/cli/cli.go @@ -197,11 +197,6 @@ func NewSetMetadataCmd() *cobra.Command { return &cmd } -func zeroValue[T any]() T { - var t T - return t -} - func runKeystoreCommand[Req any, Resp any](cmd *cobra.Command, args []string, fn func(ctx context.Context, k ks.Keystore, req Req) (Resp, error)) error { jsonBytes, err := readJSONInput(cmd) From 64b637799aa526d84bb23be26e2db644f0a11651 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Thu, 4 Dec 2025 21:33:33 -0600 Subject: [PATCH 10/42] pkg/settings/cresettings: add PerWorkflow.ChainAllowed --- pkg/capabilities/capabilities.go | 33 +++- pkg/capabilities/capabilities_test.go | 29 +++ pkg/contexts/chains.go | 23 +++ pkg/settings/cresettings/defaults.json | 7 + pkg/settings/cresettings/defaults.toml | 7 + pkg/settings/cresettings/settings.go | 8 + pkg/settings/cresettings/settings_test.go | 95 +++++++++- pkg/settings/json.go | 28 ++- pkg/settings/keys.go | 3 +- pkg/settings/limits/bound.go | 6 +- pkg/settings/limits/errors.go | 21 +++ pkg/settings/limits/factory.go | 48 +++-- pkg/settings/limits/gate.go | 212 ++++++++++++++++++++++ pkg/settings/limits/gate_test.go | 127 +++++++++++++ pkg/settings/limits/limits.go | 2 +- pkg/settings/map.go | 85 +++++++++ pkg/settings/settings.go | 8 + 17 files changed, 702 insertions(+), 40 deletions(-) create mode 100644 pkg/contexts/chains.go create mode 100644 pkg/settings/limits/gate.go create mode 100644 pkg/settings/limits/gate_test.go create mode 100644 pkg/settings/map.go diff --git a/pkg/capabilities/capabilities.go b/pkg/capabilities/capabilities.go index a4ae4d5350..ce5074678d 100644 --- a/pkg/capabilities/capabilities.go +++ b/pkg/capabilities/capabilities.go @@ -5,6 +5,7 @@ import ( "fmt" "iter" "regexp" + "strconv" "strings" "time" @@ -12,8 +13,9 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" - "github.com/smartcontractkit/chainlink-common/pkg/contexts" "github.com/smartcontractkit/chainlink-protos/cre/go/values" + + "github.com/smartcontractkit/chainlink-common/pkg/contexts" ) // CapabilityType is an enum for the type of capability. @@ -180,6 +182,35 @@ func ParseID(id string) (name string, labels iter.Seq2[string, string], version return } +// ChainSelectorLabel returns a chain selector value from the labels if one is present. +// It supports both a normal key/value pair, and sequential keys for historical reasons. +func ChainSelectorLabel(labels iter.Seq2[string, string]) (*uint64, error) { + const key = "ChainSelector" + var next bool + for k, v := range labels { + if next { + cs, err := strconv.ParseUint(k, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid chain selector: %s", v) + } + return &cs, nil + } + if k == key { + if v != "" { + cs, err := strconv.ParseUint(v, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid chain selector: %s", v) + } + return &cs, nil + } else { + // empty value means it will be in the next key + next = true + } + } + } + return nil, nil +} + type RegisterToWorkflowRequest struct { Metadata RegistrationMetadata Config *values.Map diff --git a/pkg/capabilities/capabilities_test.go b/pkg/capabilities/capabilities_test.go index 6ddc59f1f0..5060b0f172 100644 --- a/pkg/capabilities/capabilities_test.go +++ b/pkg/capabilities/capabilities_test.go @@ -302,3 +302,32 @@ func TestParseID(t *testing.T) { }) } } + +func TestChainSelectorLabel(t *testing.T) { + for _, tc := range []struct { + id string + cs *uint64 + errMsg string + }{ + {"none@v1.0.0", nil, ""}, + {"kv:ChainSelector_1@v1.0.0", ptr[uint64](1), ""}, + {"kk:ChainSelector:1@v1.0.0", ptr[uint64](1), ""}, + {"kv-others:k_v:ChainSelector_1@v1.0.0", ptr[uint64](1), ""}, + {"kk-others:k_v:ChainSelector:1@v1.0.0", ptr[uint64](1), ""}, + + {"kv:ChainSelector_foo@v1.0.0", ptr[uint64](1), "invalid chain selector"}, + {"kk:ChainSelector:bar@v1.0.0", ptr[uint64](1), "invalid chain selector"}, + } { + t.Run(tc.id, func(t *testing.T) { + _, labels, _ := ParseID(tc.id) + cs, err := ChainSelectorLabel(labels) + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + } else { + require.Equal(t, tc.cs, cs) + } + }) + } +} + +func ptr[T any](v T) *T { return &v } diff --git a/pkg/contexts/chains.go b/pkg/contexts/chains.go new file mode 100644 index 0000000000..c41319dddb --- /dev/null +++ b/pkg/contexts/chains.go @@ -0,0 +1,23 @@ +package contexts + +import ( + "context" + "errors" +) + +const chainSelectorCtxKey key = "chainSelectorCtx" + +// WithChainSelector returns a new context that includes the chain selector. +// Use ChainSelectorValue to get the value. +func WithChainSelector(ctx context.Context, cs uint64) context.Context { + return context.WithValue(ctx, chainSelectorCtxKey, cs) +} + +// ChainSelectorValue returns the chain selector, if one was set via WithChainSelector. +func ChainSelectorValue(ctx context.Context) (uint64, error) { + val := Value[uint64](ctx, chainSelectorCtxKey) + if val == 0 { + return 0, errors.New("context missing chain selector") + } + return val, nil +} diff --git a/pkg/settings/cresettings/defaults.json b/pkg/settings/cresettings/defaults.json index 15e715595a..5dac0d1f70 100644 --- a/pkg/settings/cresettings/defaults.json +++ b/pkg/settings/cresettings/defaults.json @@ -41,6 +41,13 @@ "ConsensusCallsLimit": "2000", "LogLineLimit": "1kb", "LogEventLimit": "1000", + "ChainAllowed": { + "Default": "false", + "Values": { + "12922642891491394802": "true", + "3379446385462418246": "true" + } + }, "CRONTrigger": { "FastestScheduleInterval": "30s", "RateLimit": "every30s:1" diff --git a/pkg/settings/cresettings/defaults.toml b/pkg/settings/cresettings/defaults.toml index 210c421331..c19170047e 100644 --- a/pkg/settings/cresettings/defaults.toml +++ b/pkg/settings/cresettings/defaults.toml @@ -42,6 +42,13 @@ ConsensusCallsLimit = '2000' LogLineLimit = '1kb' LogEventLimit = '1000' +[PerWorkflow.ChainAllowed] +Default = 'false' + +[PerWorkflow.ChainAllowed.Values] +12922642891491394802 = 'true' +3379446385462418246 = 'true' + [PerWorkflow.CRONTrigger] FastestScheduleInterval = '30s' RateLimit = 'every30s:1' diff --git a/pkg/settings/cresettings/settings.go b/pkg/settings/cresettings/settings.go index 0f0d54d1c1..9efc145133 100644 --- a/pkg/settings/cresettings/settings.go +++ b/pkg/settings/cresettings/settings.go @@ -103,6 +103,12 @@ var Default = Schema{ ConsensusCallsLimit: Int(2000), LogLineLimit: Size(config.KByte), LogEventLimit: Int(1_000), + ChainAllowed: PerChainSelector(Bool(false), map[string]bool{ + // geth-testnet + "3379446385462418246": true, + // geth-devnet2 + "12922642891491394802": true, + }), CRONTrigger: cronTrigger{ FastestScheduleInterval: Duration(30 * time.Second), @@ -210,6 +216,8 @@ type Workflows struct { LogLineLimit Setting[config.Size] LogEventLimit Setting[int] `unit:"{log}"` + ChainAllowed SettingMap[bool] + CRONTrigger cronTrigger HTTPTrigger httpTrigger LogTrigger logTrigger diff --git a/pkg/settings/cresettings/settings_test.go b/pkg/settings/cresettings/settings_test.go index 96024ef72b..95fedb083b 100644 --- a/pkg/settings/cresettings/settings_test.go +++ b/pkg/settings/cresettings/settings_test.go @@ -16,6 +16,8 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/contexts" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" ) var update = flag.Bool("update", false, "update the golden files of this test") @@ -74,6 +76,12 @@ func TestSchema_Unmarshal(t *testing.T) { }, "PerWorkflow": { "WASMMemoryLimit": "250mb", + "ChainAllowed": { + "Default": "false", + "Values": { + "1": "true" + } + }, "CRONTrigger": { "RateLimit": "every10s:5" }, @@ -103,6 +111,10 @@ func TestSchema_Unmarshal(t *testing.T) { assert.Equal(t, 48*time.Hour, cfg.PerOrg.ZeroBalancePruningTimeout.DefaultValue) assert.Equal(t, 99, cfg.PerOwner.WorkflowExecutionConcurrencyLimit.DefaultValue) assert.Equal(t, 250*config.MByte, cfg.PerWorkflow.WASMMemoryLimit.DefaultValue) + assert.Equal(t, false, cfg.PerWorkflow.ChainAllowed.Default.DefaultValue) + assert.Equal(t, "true", cfg.PerWorkflow.ChainAllowed.Values["1"]) + assert.NotNil(t, cfg.PerWorkflow.ChainAllowed.Default.Parse) + assert.NotNil(t, cfg.PerWorkflow.ChainAllowed.KeyFromCtx) assert.Equal(t, config.Rate{Limit: rate.Every(10 * time.Second), Burst: 5}, cfg.PerWorkflow.CRONTrigger.RateLimit.DefaultValue) assert.Equal(t, config.Rate{Limit: rate.Every(30 * time.Second), Burst: 3}, cfg.PerWorkflow.HTTPTrigger.RateLimit.DefaultValue) assert.Equal(t, config.Rate{Limit: rate.Every(13 * time.Second), Burst: 6}, cfg.PerWorkflow.LogTrigger.EventRateLimit.DefaultValue) @@ -142,11 +154,6 @@ func TestDefaultGetter(t *testing.T) { }`) reinit() // set default vars - _ = ` -[workflow.test-wf-id] -PerWorkflow.HTTPAction.CallLimit = 20 -` - // Default unchanged got, err = limit.GetOrDefault(ctx, DefaultGetter) require.NoError(t, err) @@ -158,3 +165,81 @@ PerWorkflow.HTTPAction.CallLimit = 20 require.Equal(t, 20, got) } + +func TestDefaultGetter_SettingMap(t *testing.T) { + limit := Default.PerWorkflow.ChainAllowed + + ctx := contexts.WithCRE(t.Context(), contexts.CRE{Owner: "owner-id", Workflow: "foo"}) + ctx = contexts.WithChainSelector(ctx, 1234) + overrideCtx := contexts.WithCRE(t.Context(), contexts.CRE{Owner: "owner-id", Workflow: "test-wf-id"}) + overrideCtx = contexts.WithChainSelector(overrideCtx, 1234) + + // None allowed by default + got, err := limit.GetOrDefault(ctx, DefaultGetter) + require.NoError(t, err) + require.False(t, got) + got, err = limit.GetOrDefault(overrideCtx, DefaultGetter) + require.NoError(t, err) + require.False(t, got) + + t.Cleanup(reinit) // restore default vars + + // Org override to allow + t.Setenv(envNameSettings, `{ + "workflow": { + "test-wf-id": { + "PerWorkflow": { + "ChainAllowed": { + "Values": { + "1234": "true" + } + } + } + } + } +}`) + reinit() // set default vars + got, err = limit.GetOrDefault(ctx, DefaultGetter) + require.NoError(t, err) + require.False(t, got) + got, err = limit.GetOrDefault(overrideCtx, DefaultGetter) + require.NoError(t, err) + require.True(t, got) + + // Org override to allow by default, but disallow some + t.Setenv(envNameSettings, `{ + "workflow": { + "test-wf-id": { + "PerWorkflow": { + "ChainAllowed": { + "Default": true, + "Values": { + "1234": "false" + } + } + } + } + } +}`) + reinit() // set default vars + got, err = limit.GetOrDefault(ctx, DefaultGetter) + require.NoError(t, err) + require.False(t, got) + got, err = limit.GetOrDefault(overrideCtx, DefaultGetter) + require.NoError(t, err) + require.False(t, got) + got, err = limit.GetOrDefault(contexts.WithChainSelector(overrideCtx, 42), DefaultGetter) + require.NoError(t, err) + require.True(t, got) +} + +func TestChainAllows(t *testing.T) { + gl, err := limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t)}, Default.PerWorkflow.ChainAllowed) + require.NoError(t, err) + + ctx := contexts.WithCRE(t.Context(), contexts.CRE{Owner: "owner-id", Workflow: "foo"}) + + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 3379446385462418246))) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 12922642891491394802))) + assert.ErrorIs(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234)), limits.ErrorNotAllowed{}) +} diff --git a/pkg/settings/json.go b/pkg/settings/json.go index a498ce70eb..c7c29461ff 100644 --- a/pkg/settings/json.go +++ b/pkg/settings/json.go @@ -8,14 +8,16 @@ import ( "io/fs" "maps" "slices" + "strconv" "strings" ) // CombineJSONFiles reads a set of JSON config files and combines them in to one file. The expected inputs are: -// - global.json -// - org/*.json -// - owner/*.json -// - workflow/*.json +// - global.json +// - org/*.json +// - owner/*.json +// - workflow/*.json +// // The directory and file names translate to keys in the JSON structure, while the file extensions are discarded. // For example: owner/0x1234.json:Foo.Bar becomes owner.0x1234.Foo.Bar func CombineJSONFiles(files fs.FS) ([]byte, error) { @@ -152,11 +154,17 @@ func (s *jsonSettings) get(key string) (string, error) { } field := parts[len(parts)-1] - switch t := m[field].(type) { - case string: - return t, nil - case json.Number: - return t.String(), nil + if val, ok := m[field]; ok { + switch t := val.(type) { + case string: + return t, nil + case json.Number: + return t.String(), nil + case bool: + return strconv.FormatBool(t), nil + default: + return "", fmt.Errorf("non-string value: %s: %t(%v)", key, val, val) + } } return "", nil // no value } @@ -166,7 +174,7 @@ type jsonGetter struct { } // NewJSONGetter returns a static Getter backed by the given JSON. -//TODO https://smartcontract-it.atlassian.net/browse/CAPPL-775 +// TODO https://smartcontract-it.atlassian.net/browse/CAPPL-775 // NewJSONRegistry with polling & subscriptions func NewJSONGetter(b []byte) (Getter, error) { s, err := newJSONSettings(b) diff --git a/pkg/settings/keys.go b/pkg/settings/keys.go index 8d89f855c1..d96ce73314 100644 --- a/pkg/settings/keys.go +++ b/pkg/settings/keys.go @@ -17,8 +17,9 @@ func (s Scope) rawKeys(ctx context.Context, key string) (keys []string, err erro if i.IsTenantRequired() { err = errors.Join(err, fmt.Errorf("empty %s key", i)) } + } else { + keys = append(keys, i.String()+"."+tenant+"."+key) } - keys = append(keys, i.String()+"."+tenant+"."+key) } keys = append(keys, ScopeGlobal.String()+"."+key) // ScopeGlobal return diff --git a/pkg/settings/limits/bound.go b/pkg/settings/limits/bound.go index e2cee9cc69..815bd722c3 100644 --- a/pkg/settings/limits/bound.go +++ b/pkg/settings/limits/bound.go @@ -54,9 +54,8 @@ func newBoundLimiter[N Number](f Factory, bound settings.Setting[N]) (BoundLimit updater: newUpdater[N](nil, func(ctx context.Context) (N, error) { return bound.GetOrDefault(ctx, f.Settings) }, nil), - defaultBound: bound.DefaultValue, - key: bound.Key, - scope: bound.Scope, + key: bound.Key, + scope: bound.Scope, } b.updater.recordLimit = func(ctx context.Context, n N) { b.recordBound(ctx, n) } @@ -115,7 +114,6 @@ func newBoundLimiter[N Number](f Factory, bound settings.Setting[N]) (BoundLimit type boundLimiter[N Number] struct { *updater[N] - defaultBound N key string // optional scope settings.Scope diff --git a/pkg/settings/limits/errors.go b/pkg/settings/limits/errors.go index 4d1a570aa1..0aede8454e 100644 --- a/pkg/settings/limits/errors.go +++ b/pkg/settings/limits/errors.go @@ -144,3 +144,24 @@ func (e ErrorQueueFull) Error() string { } var ErrQueueEmpty = fmt.Errorf("queue is empty") + +type ErrorNotAllowed struct { + Key string + + Scope settings.Scope + Tenant string +} + +func (e ErrorNotAllowed) GRPCStatus() *status.Status { + return status.New(codes.PermissionDenied, e.Error()) +} + +func (e ErrorNotAllowed) Is(target error) bool { + _, ok := target.(ErrorNotAllowed) //nolint:errcheck // implementing errors.Is + return ok +} + +func (e ErrorNotAllowed) Error() string { + which, who := errArgs(e.Key, e.Scope, e.Tenant) + return fmt.Sprintf("%slimited%s: not allowed", which, who) +} diff --git a/pkg/settings/limits/factory.go b/pkg/settings/limits/factory.go index ca6b87e541..97c876d03b 100644 --- a/pkg/settings/limits/factory.go +++ b/pkg/settings/limits/factory.go @@ -31,10 +31,10 @@ func (f Factory) NewRateLimiter(rate settings.Setting[config.Rate]) (RateLimiter // MakeRateLimiter creates a RateLimiter for the given rate and configured by the Factory. // If Meter is set, the following metrics will be emitted -// - rate.*.limit - float gauge -// - rate.*.burst - int gauge -// - rate.*.usage - int counter -// - rate.*.denied - int histogram +// - rate.*.limit - float gauge +// - rate.*.burst - int gauge +// - rate.*.usage - int counter +// - rate.*.denied - int histogram func (f Factory) MakeRateLimiter(rate settings.Setting[config.Rate]) (RateLimiter, error) { if rate.Scope == settings.ScopeGlobal { return f.globalRateLimiter(rate) @@ -49,10 +49,11 @@ func (f Factory) NewTimeLimiter(timeout settings.Setting[time.Duration]) (TimeLi // MakeTimeLimiter returns a TimeLimiter for given timeout, and configured by the Factory. // If Meter is set, the following metrics will be emitted -// - time.*.limit - float gauge -// - time.*.runtime - float gauge -// - time.*.success - int counter -// - time.*.timeout - int counter +// - time.*.limit - float gauge +// - time.*.runtime - float gauge +// - time.*.success - int counter +// - time.*.timeout - int counter +// // Note: Unit will be ignored. All TimeLimiters emit seconds as "s". func (f Factory) MakeTimeLimiter(timeout settings.Setting[time.Duration]) (TimeLimiter, error) { return f.newTimeLimiter(timeout) @@ -65,10 +66,10 @@ func NewResourcePoolLimiter[N Number](f Factory, limit settings.Setting[N]) (Res // MakeResourcePoolLimiter returns a ResourcePoolLimiter for the given limit, and configured by the Factory. // If Meter is set, the following metrics will be emitted -// - resource.*.limit - gauge -// - resource.*.usage - gauge -// - resource.*.amount - histogram -// - resource.*.denied - histogram +// - resource.*.limit - gauge +// - resource.*.usage - gauge +// - resource.*.amount - histogram +// - resource.*.denied - histogram func MakeResourcePoolLimiter[N Number](f Factory, limit settings.Setting[N]) (ResourcePoolLimiter[N], error) { if limit.Scope == settings.ScopeGlobal { return newGlobalResourcePoolLimiter(f, limit) @@ -78,21 +79,32 @@ func MakeResourcePoolLimiter[N Number](f Factory, limit settings.Setting[N]) (Re // MakeBoundLimiter returns a BoundLimiter for the given bound and configured by the Factory. // If Meter is set, the following metrics will be emitted -// - bound.*.limit - gauge -// - bound.*.usage - histogram -// - bound.*.denied - histogram +// - bound.*.limit - gauge +// - bound.*.usage - histogram +// - bound.*.denied - histogram func MakeBoundLimiter[N Number](f Factory, bound settings.Setting[N]) (BoundLimiter[N], error) { return newBoundLimiter(f, bound) } // MakeQueueLimiter returns a QueueLimiter for the given limit and configured by the Factory. // If Meter is set, the following metrics will be emitted -// - queue.*.limit - int gauge -// - queue.*.usage - int gauge -// - queue.*.denied - int histogram +// - queue.*.limit - int gauge +// - queue.*.usage - int gauge +// - queue.*.denied - int histogram func MakeQueueLimiter[T any](f Factory, limit settings.Setting[int]) (QueueLimiter[T], error) { if limit.Scope == settings.ScopeGlobal { return newUnscopedQueue[T](f, limit) } return newScopedQueue[T](f, limit) } + +// MakeGateLimiter returns a GateLimiter for the given limit and configured by the factory. +// If Meter is set, the following metrics will be emitted +// - gate.*.limit - int gauge +// - gate.*.usage - int counter +// - gate.*.denied - int counter +// +// OPT: accept an interface for limit +func MakeGateLimiter(f Factory, limit settings.SettingMap[bool]) (GateLimiter, error) { + return newGateLimiter(f, limit) +} diff --git a/pkg/settings/limits/gate.go b/pkg/settings/limits/gate.go new file mode 100644 index 0000000000..20f9d0ceaf --- /dev/null +++ b/pkg/settings/limits/gate.go @@ -0,0 +1,212 @@ +package limits + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + + "go.opentelemetry.io/otel/metric" + + "github.com/smartcontractkit/chainlink-common/pkg/contexts" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/settings" +) + +type GateLimiter interface { + Limiter[bool] + AllowErr(context.Context) error +} + +func NewGateLimiter(open bool) GateLimiter { + return &simpleGateLimiter{open: open} +} + +type simpleGateLimiter struct { + open bool + closed atomic.Bool +} + +func (s *simpleGateLimiter) Close() error { s.closed.Store(true); return nil } + +func (s *simpleGateLimiter) Limit(ctx context.Context) (bool, error) { + return s.open, nil +} + +func (s *simpleGateLimiter) AllowErr(ctx context.Context) error { + if ok, err := s.Limit(ctx); err != nil { + return err + } else if !ok { + return ErrorNotAllowed{} + } + return nil +} + +// OPT: interface satisfied by Setting[bool] & SettingMap[bool] +func newGateLimiter(f Factory, limit settings.SettingMap[bool]) (GateLimiter, error) { + g := &gateLimiter{ + updater: newUpdater[bool](nil, func(ctx context.Context) (bool, error) { + return limit.GetOrDefault(ctx, f.Settings) + }, nil), + key: limit.Default.Key, + scope: limit.Default.Scope, + } + g.updater.recordLimit = func(ctx context.Context, b bool) { g.recordStatus(ctx, b) } + + if f.Meter != nil { + if g.key == "" { + return nil, errors.New("metrics require Key to be set") + } + limitGauge, err := f.Meter.Int64Gauge("gate."+g.key+".limit", metric.WithUnit(limit.Default.Unit)) + if err != nil { + return nil, err + } + g.recordStatus = func(ctx context.Context, b bool, options ...metric.RecordOption) { + var val int64 + if b { + val = 1 + } + limitGauge.Record(ctx, val, options...) + } + usageCounter, err := f.Meter.Int64Counter("gate."+g.key+".usage", metric.WithUnit(limit.Default.Unit)) + if err != nil { + return nil, err + } + g.recordUsage = func(ctx context.Context, options ...metric.AddOption) { + usageCounter.Add(ctx, 1, options...) + } + deniedCounter, err := f.Meter.Int64Counter("gate."+g.key+".denied", metric.WithUnit(limit.Default.Unit)) + if err != nil { + return nil, err + } + g.recordDenied = func(ctx context.Context, options ...metric.AddOption) { + deniedCounter.Add(ctx, 1, options...) + } + } else { + g.recordStatus = func(ctx context.Context, value bool, options ...metric.RecordOption) {} + g.recordUsage = func(ctx context.Context, options ...metric.AddOption) {} + g.recordDenied = func(ctx context.Context, options ...metric.AddOption) {} + } + + if f.Logger != nil { + g.lggr = logger.Sugared(f.Logger).Named("GateLimiter").With("key", limit.Default.Key) + } + + // OPT: support settings.Registry subscriptions + //if f.Settings != nil { + // if r, ok := f.Settings.(settings.Registry); ok { + // g.subFn = func(ctx context.Context) (<-chan settings.Update[bool], func()) { + // return limit.Subscribe(ctx, r) + // } + // } + //} + + // OPT: restore with support for SettingMap + //if limit.Default.Scope == settings.ScopeGlobal { + // g.updateCRE(contexts.CRE{}) + // go g.updateLoop(contexts.CRE{}) + //} + close(g.done) + + return g, nil +} + +type gateLimiter struct { + *updater[bool] + + key string // optional + scope settings.Scope + + recordStatus func(ctx context.Context, value bool, options ...metric.RecordOption) + recordUsage func(ctx context.Context, options ...metric.AddOption) + recordDenied func(ctx context.Context, options ...metric.AddOption) + + // opt: reap after period of non-use + updaters sync.Map // map[string]*updater[N] + wg services.WaitGroup // tracks and blocks updaters background routines +} + +func (g *gateLimiter) Close() (err error) { + g.wg.Wait() + + // cleanup + if g.scope == settings.ScopeGlobal { + return g.updater.Close() + } else { + g.updaters.Range(func(key, value any) bool { + // opt: parallelize + err = errors.Join(err, value.(*updater[bool]).Close()) + return true + }) + } + return +} + +func (g *gateLimiter) Limit(ctx context.Context) (bool, error) { + if err := g.wg.TryAdd(1); err != nil { + return false, err + } + defer g.wg.Done() + + _, limit, err := g.get(ctx) + if err != nil { + return false, err + } + + return limit, nil +} + +func (g *gateLimiter) AllowErr(ctx context.Context) error { + if err := g.wg.TryAdd(1); err != nil { + return err + } + defer g.wg.Done() + + tenant, open, err := g.get(ctx) + if err != nil { + return err + } else if !open { + g.recordDenied(ctx, withScope(ctx, g.scope)) + return ErrorNotAllowed{Key: g.key, Scope: g.scope, Tenant: tenant} + } + g.recordUsage(ctx, withScope(ctx, g.scope)) + return nil +} + +func (g *gateLimiter) get(ctx context.Context) (tenant string, open bool, err error) { + if g.scope != settings.ScopeGlobal { + tenant = g.scope.Value(ctx) + if tenant == "" { + if !g.scope.IsTenantRequired() { + kvs := contexts.CREValue(ctx).LoggerKVs() + g.lggr.Warnw("Unable to get scoped gate status due to missing tenant: failing open", append([]any{"scope", g.scope}, kvs...)...) + return + } + err = fmt.Errorf("unable to get scoped gate status due to missing tenant for scope: %s", g.scope) + return + } + + u := newUpdater(g.lggr, g.getLimitFn, g.subFn) + actual, loaded := g.updaters.LoadOrStore(tenant, u) + cre := g.scope.RoundCRE(contexts.CREValue(ctx)) + if !loaded { + // OPT: restore with support for SettingMap + //u.cre.Store(cre) + //go u.updateLoop(cre) + close(u.done) + } else { + u = actual.(*updater[bool]) + u.updateCRE(cre) + } + } + + open, err = g.getLimitFn(ctx) + if err != nil { + g.lggr.Errorw("Failed to get status. Using default value", "default", open, "err", err) + } + // TODO: include map key in attributes + g.recordStatus(ctx, open, withScope(ctx, g.scope)) + return +} diff --git a/pkg/settings/limits/gate_test.go b/pkg/settings/limits/gate_test.go new file mode 100644 index 0000000000..6e98e8e8bd --- /dev/null +++ b/pkg/settings/limits/gate_test.go @@ -0,0 +1,127 @@ +package limits + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + + "github.com/smartcontractkit/chainlink-common/pkg/contexts" + "github.com/smartcontractkit/chainlink-common/pkg/settings" +) + +func ExampleGateLimiter_AllowErr() { + ctx := context.Background() + gl := NewGateLimiter(true) + + open, err := gl.Limit(ctx) + if err != nil { + fmt.Println(err) + return + } + fmt.Println("open:", open) + + err = gl.AllowErr(ctx) + fmt.Println("allow:", err) + + gl = NewGateLimiter(false) + + open, err = gl.Limit(ctx) + if err != nil { + fmt.Println(err) + return + } + fmt.Println("open:", open) + + err = gl.AllowErr(ctx) + fmt.Println("allow:", err) + + // Output: + // open: true + // allow: + // open: false + // allow: limited: not allowed +} + +func TestMakeGateLimiter(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + scope settings.Scope + cre contexts.CRE + }{ + {settings.ScopeGlobal, contexts.CRE{}}, + {settings.ScopeOwner, contexts.CRE{Owner: "ow-id"}}, + } { + t.Run(tt.scope.String(), func(t *testing.T) { + t.Parallel() + mc := newMetricsChecker(t) + f := Factory{Meter: mc.Meter(t.Name())} + limit := settings.PerChainSelector(settings.Bool(false), + map[string]bool{ + "42": true, + }) + + limit.Default.Key = "foo.bar" + limit.Default.Scope = tt.scope + gl, err := MakeGateLimiter(f, limit) + require.NoError(t, err) + t.Cleanup(func() { assert.NoError(t, gl.Close()) }) + + ctx := t.Context() + ctx = contexts.WithCRE(ctx, tt.cre) + + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 42))) + var errGate ErrorNotAllowed + if assert.ErrorAs(t, gl.AllowErr(contexts.WithChainSelector(ctx, 100)), &errGate) { + assert.Equal(t, "foo.bar", errGate.Key) + assert.Equal(t, tt.scope, errGate.Scope) + } + + ms := mc.lastResourceFirstScopeMetric(t) + + attrs := attribute.NewSet(kvsFromScope(ctx, tt.scope)...) + + require.Equal(t, metrics{ + { + Name: "gate.foo.bar.limit", + Data: metricdata.Gauge[int64]{ + DataPoints: []metricdata.DataPoint[int64]{ + {Attributes: attrs, Value: int64(0)}, + }, + }, + }, + { + Name: "gate.foo.bar.usage", + Data: metricdata.Sum[int64]{ + DataPoints: []metricdata.DataPoint[int64]{ + { + Attributes: attrs, + Value: int64(1), + }, + }, + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + }, + }, + { + Name: "gate.foo.bar.denied", + Data: metricdata.Sum[int64]{ + DataPoints: []metricdata.DataPoint[int64]{ + { + Attributes: attrs, + Value: int64(1), + }, + }, + Temporality: metricdata.CumulativeTemporality, + IsMonotonic: true, + }, + }, + }, ms) + }) + } +} diff --git a/pkg/settings/limits/limits.go b/pkg/settings/limits/limits.go index 48ec1ac900..bd944417c5 100644 --- a/pkg/settings/limits/limits.go +++ b/pkg/settings/limits/limits.go @@ -8,7 +8,7 @@ // Every limit requires a default value. Additional features like Otel metrics and dynamic updates are available by // using the [settings.Setting] variants. // -// Limiter errors are GRPC [codes.ResourceExhausted] and [codes.DeadlineExceeded]. +// Limiter errors are GRPC [codes.ResourceExhausted], [codes.DeadlineExceeded], and [code.PermissionDenied]. package limits import ( diff --git a/pkg/settings/map.go b/pkg/settings/map.go new file mode 100644 index 0000000000..786cadfe85 --- /dev/null +++ b/pkg/settings/map.go @@ -0,0 +1,85 @@ +package settings + +import ( + "context" + "errors" + "fmt" + "strconv" + + "github.com/smartcontractkit/chainlink-common/pkg/contexts" +) + +// PerChainSelector returns a new SettingMap for the given values, which is keyed on +// chain selector from the context.Context. +func PerChainSelector[T any](defaultValue Setting[T], vals map[string]T) SettingMap[T] { + svals := make(map[string]string, len(vals)) + for k, v := range vals { + svals[k] = fmt.Sprint(v) + } + return SettingMap[T]{ + Default: defaultValue, + Values: svals, + KeyFromCtx: contexts.ChainSelectorValue, + } +} + +type SettingMap[T any] struct { + Default Setting[T] + Values map[string]string // unparsed + KeyFromCtx func(context.Context) (uint64, error) `json:"-" toml:"-"` +} + +func (s *SettingMap[T]) initSetting(key string, scope Scope, unit *string) error { + if s.KeyFromCtx == nil { + return errors.New("missing KeyFromCtx func") + } + return s.Default.initSetting(key, scope, unit) +} + +// GetOrDefault gets the setting from the Getter for the given Scope, or returns the default value with an error. +func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, err error) { + if s.KeyFromCtx == nil { + return s.Default.DefaultValue, errors.New("missing KeyFromCtx func") + } + k, err := s.KeyFromCtx(ctx) + if err != nil { + return s.Default.DefaultValue, fmt.Errorf("failed to get value from context: %w", err) + } + if g == nil { + if str, ok := s.Values[strconv.FormatUint(k, 10)]; ok { + value, err = s.Default.Parse(str) + if err != nil { + return s.Default.DefaultValue, err + } + return + } + return s.Default.DefaultValue, nil + } + + valueKey := s.Default.Key + ".Values." + strconv.FormatUint(k, 10) + defaultKey := s.Default.Key + ".Default" + + // Values override + str, err := g.GetScoped(ctx, s.Default.Scope, valueKey) + if err != nil { + return s.Default.DefaultValue, err + } else if str != "" { + value, err = s.Default.Parse(str) + if err != nil { + return s.Default.DefaultValue, err + } + return + } + + // Default override + str, err = g.GetScoped(ctx, s.Default.Scope, defaultKey) + if err != nil || str == "" { + return s.Default.DefaultValue, err + } + + value, err = s.Default.Parse(str) + if err != nil { + return s.Default.DefaultValue, err + } + return +} diff --git a/pkg/settings/settings.go b/pkg/settings/settings.go index 29e78d0508..e8b955bbbb 100644 --- a/pkg/settings/settings.go +++ b/pkg/settings/settings.go @@ -3,6 +3,7 @@ package settings import ( "context" "encoding" + "errors" "fmt" "net/url" "reflect" @@ -48,6 +49,9 @@ func (s *Setting[T]) UnmarshalText(b []byte) (err error) { if len(b) >= 2 && b[0] == '"' && b[len(b)-1] == '"' { b = b[1 : len(b)-1] // unquote string } + if s.Parse == nil { + return errors.New("missing Parse func") + } s.DefaultValue, err = s.Parse(string(b)) if err != nil { err = fmt.Errorf("%s: failed to parse %s: %w", s.Key, string(b), err) @@ -78,6 +82,10 @@ func MarshaledText[T encoding.TextUnmarshaler](defaultValue T) Setting[T] { }) } +func Bool(defaultValue bool) Setting[bool] { + return NewSetting(defaultValue, strconv.ParseBool) +} + func Duration(defaultValue time.Duration) Setting[time.Duration] { s := NewSetting(defaultValue, time.ParseDuration) s.Unit = "s" From 5bdfc30734b1960a537be2faf94b63127c26b536 Mon Sep 17 00:00:00 2001 From: connorwstein Date: Wed, 10 Dec 2025 11:08:44 -0500 Subject: [PATCH 11/42] Code owners --- .github/CODEOWNERS | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a579f15220..743b5e8a01 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -4,8 +4,10 @@ # Please define less specific codeowner paths before more specific codeowner paths in order for the more specific rule to have priority + * @smartcontractkit/foundations +/keystore @connorwstein @pavel-rakov @smartcontractkit/prodsec-public /pkg/beholder/ @smartcontractkit/data-tooling /pkg/capabilities/v2/chain-capabilities @smartcontractkit/keystone @smartcontractkit/capabilities-team @smartcontractkit/bix-framework /pkg/chains/evm @smartcontractkit/bix-framework From ae9dd8c45c85403d958a7753c5f809f1a9b0c935 Mon Sep 17 00:00:00 2001 From: connorwstein Date: Wed, 10 Dec 2025 11:13:35 -0500 Subject: [PATCH 12/42] Prod sec co-owns keystore --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 743b5e8a01..52e0187a1e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -7,7 +7,7 @@ * @smartcontractkit/foundations -/keystore @connorwstein @pavel-rakov @smartcontractkit/prodsec-public +/keystore @smartcontractkit/prodsec-public @smartcontractkit/foundations /pkg/beholder/ @smartcontractkit/data-tooling /pkg/capabilities/v2/chain-capabilities @smartcontractkit/keystone @smartcontractkit/capabilities-team @smartcontractkit/bix-framework /pkg/chains/evm @smartcontractkit/bix-framework From f62663f147244f44cec9bd2e5c3a28554fe6ff3b Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Wed, 10 Dec 2025 14:50:46 -0500 Subject: [PATCH 13/42] pkg/settings: fix SettingMap.GetOrDefault --- pkg/settings/cresettings/settings.go | 2 + pkg/settings/cresettings/settings_test.go | 73 +++++++++++++++++++++-- pkg/settings/map.go | 13 ++-- 3 files changed, 79 insertions(+), 9 deletions(-) diff --git a/pkg/settings/cresettings/settings.go b/pkg/settings/cresettings/settings.go index 9efc145133..bab448cb05 100644 --- a/pkg/settings/cresettings/settings.go +++ b/pkg/settings/cresettings/settings.go @@ -40,6 +40,8 @@ func reinit() { if err != nil { log.Fatalf("failed to initialize settings: %v", err) } + } else { + DefaultGetter = nil } } diff --git a/pkg/settings/cresettings/settings_test.go b/pkg/settings/cresettings/settings_test.go index 95fedb083b..74fa55a89d 100644 --- a/pkg/settings/cresettings/settings_test.go +++ b/pkg/settings/cresettings/settings_test.go @@ -17,6 +17,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/contexts" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/settings" "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" ) @@ -199,6 +200,10 @@ func TestDefaultGetter_SettingMap(t *testing.T) { } }`) reinit() // set default vars + + // ensure merged values; defaults must remain + require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["3379446385462418246"]) + // confirm got, err = limit.GetOrDefault(ctx, DefaultGetter) require.NoError(t, err) require.False(t, got) @@ -233,13 +238,73 @@ func TestDefaultGetter_SettingMap(t *testing.T) { require.True(t, got) } -func TestChainAllows(t *testing.T) { - gl, err := limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t)}, Default.PerWorkflow.ChainAllowed) +func TestDefaultEnvVars(t *testing.T) { + // confirm defaults + require.Equal(t, "", Default.PerWorkflow.ChainAllowed.Values["1234"]) + require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["3379446385462418246"]) + + t.Cleanup(reinit) // restore after + + // update defaults + t.Setenv(envNameSettingsDefault, `{ + "PerWorkflow": { + "ChainAllowed": { + "Values": { + "1234": "true" + } + } + } +}`) + reinit() // set default vars + + // confirm through Default + require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["1234"]) + // without affecting others (they must merge) + require.Equal(t, "true", Default.PerWorkflow.ChainAllowed.Values["3379446385462418246"]) + + // confirm through DefaultGetter + gl, err := limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t), Settings: DefaultGetter}, Default.PerWorkflow.ChainAllowed) require.NoError(t, err) - ctx := contexts.WithCRE(t.Context(), contexts.CRE{Owner: "owner-id", Workflow: "foo"}) + ctx := contexts.WithCRE(t.Context(), contexts.CRE{Org: "foo", Owner: "owner-id", Workflow: "foo"}) + // defaults and global override allowed + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 3379446385462418246))) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 12922642891491394802))) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234))) + + // update overrides + t.Setenv(envNameSettingsDefault, "{}") + t.Setenv(envNameSettings, `{ + "global": { + "PerWorkflow": { + "ChainAllowed": { + "Values": { + "1234": "true" + } + } + } + } +}`) + + reinit() // set default vars + + // confirm through DefaultGetter + gl, err = limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t), Settings: DefaultGetter}, Default.PerWorkflow.ChainAllowed) + require.NoError(t, err) + + // defaults and global override allowed + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 3379446385462418246))) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 12922642891491394802))) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234))) + + // confirm through an empty, but non-nil getter + getter, err := settings.NewJSONGetter([]byte(`{}`)) + require.NoError(t, err) + gl, err = limits.MakeGateLimiter(limits.Factory{Logger: logger.Test(t), Settings: getter}, Default.PerWorkflow.ChainAllowed) + require.NoError(t, err) + // defaults and global override allowed assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 3379446385462418246))) assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 12922642891491394802))) - assert.ErrorIs(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234)), limits.ErrorNotAllowed{}) + assert.NoError(t, gl.AllowErr(contexts.WithChainSelector(ctx, 1234))) } diff --git a/pkg/settings/map.go b/pkg/settings/map.go index 786cadfe85..c758d9c169 100644 --- a/pkg/settings/map.go +++ b/pkg/settings/map.go @@ -45,16 +45,19 @@ func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, er if err != nil { return s.Default.DefaultValue, fmt.Errorf("failed to get value from context: %w", err) } - if g == nil { + valueOrDefault := func() (T, error) { if str, ok := s.Values[strconv.FormatUint(k, 10)]; ok { value, err = s.Default.Parse(str) if err != nil { return s.Default.DefaultValue, err } - return + return value, nil } return s.Default.DefaultValue, nil } + if g == nil { + return valueOrDefault() + } valueKey := s.Default.Key + ".Values." + strconv.FormatUint(k, 10) defaultKey := s.Default.Key + ".Default" @@ -66,7 +69,7 @@ func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, er } else if str != "" { value, err = s.Default.Parse(str) if err != nil { - return s.Default.DefaultValue, err + return valueOrDefault() } return } @@ -74,12 +77,12 @@ func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, er // Default override str, err = g.GetScoped(ctx, s.Default.Scope, defaultKey) if err != nil || str == "" { - return s.Default.DefaultValue, err + return valueOrDefault() } value, err = s.Default.Parse(str) if err != nil { - return s.Default.DefaultValue, err + return valueOrDefault() } return } From 3dd3707c0a6f412fd8851c5738c0f76940e31bd7 Mon Sep 17 00:00:00 2001 From: Bruno Moura Date: Fri, 21 Nov 2025 15:02:17 +0000 Subject: [PATCH 14/42] pkg/types/llo: ChannelDefinitionCache accepts the previous outcome definitions --- pkg/types/llo/types.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/types/llo/types.go b/pkg/types/llo/types.go index cef021cc37..61fbb803d4 100644 --- a/pkg/types/llo/types.go +++ b/pkg/types/llo/types.go @@ -359,7 +359,7 @@ func (c ChannelDefinitions) Value() (driver.Value, error) { type ChannelID = uint32 type ChannelDefinitionCache interface { - Definitions() ChannelDefinitions + Definitions(previous ChannelDefinitions) ChannelDefinitions services.Service } From d56805f6d61d16b74983110c956e1e0f708cb100 Mon Sep 17 00:00:00 2001 From: Bruno Moura Date: Thu, 4 Dec 2025 16:50:07 +0000 Subject: [PATCH 15/42] pkg/types/llo: ChannelDefiniton.Equals considers tombstones and source for equality check --- pkg/types/llo/types.go | 8 ++++++++ pkg/types/llo/types_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/pkg/types/llo/types.go b/pkg/types/llo/types.go index 61fbb803d4..a95dc0ee22 100644 --- a/pkg/types/llo/types.go +++ b/pkg/types/llo/types.go @@ -279,6 +279,14 @@ type ChannelDefinition struct { } func (a ChannelDefinition) Equals(b ChannelDefinition) bool { + if a.Tombstone != b.Tombstone { + return false + } + + if a.Source != b.Source { + return false + } + if a.ReportFormat != b.ReportFormat { return false } diff --git a/pkg/types/llo/types_test.go b/pkg/types/llo/types_test.go index 8d4cd18742..f7b51e53ed 100644 --- a/pkg/types/llo/types_test.go +++ b/pkg/types/llo/types_test.go @@ -101,6 +101,37 @@ func Test_ChannelDefinition_Equals(t *testing.T) { } assert.False(t, a.Equals(b)) }) + t.Run("different Tombstone", func(t *testing.T) { + a := ChannelDefinition{ + ReportFormat: ReportFormatJSON, + Streams: []Stream{{0, AggregatorMedian}, {1, AggregatorMode}}, + Opts: nil, + } + b := ChannelDefinition{ + ReportFormat: ReportFormatJSON, + Streams: []Stream{{0, AggregatorMedian}, {1, AggregatorMode}}, + Opts: nil, + Tombstone: true, + } + assert.False(t, a.Equals(b)) + }) + + t.Run("different Source", func(t *testing.T) { + a := ChannelDefinition{ + ReportFormat: ReportFormatJSON, + Streams: []Stream{{0, AggregatorMedian}, {1, AggregatorMode}}, + Opts: nil, + Source: 1, + } + b := ChannelDefinition{ + ReportFormat: ReportFormatJSON, + Streams: []Stream{{0, AggregatorMedian}, {1, AggregatorMode}}, + Opts: nil, + Source: 2, + } + assert.False(t, a.Equals(b)) + }) + t.Run("equal", func(t *testing.T) { a := ChannelDefinition{ ReportFormat: ReportFormatJSON, From 65fd87d06ab4f2ec1056af5c0e211b04138418c1 Mon Sep 17 00:00:00 2001 From: Erik Burton Date: Fri, 14 Nov 2025 14:34:31 -0800 Subject: [PATCH 16/42] fix: refactor for readability --- go.mod | 1 + go.sum | 2 + pkg/loop/cmd/loopinstall/install.go | 242 +++++++++++++------------ pkg/loop/cmd/loopinstall/validation.go | 18 +- 4 files changed, 148 insertions(+), 115 deletions(-) diff --git a/go.mod b/go.mod index c90e4dfaa9..2945de30d6 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/kylelemons/godebug v1.1.0 github.com/lib/pq v1.10.9 github.com/marcboeker/go-duckdb v1.8.5 + github.com/mattn/go-shellwords v1.0.12 github.com/mr-tron/base58 v1.2.0 github.com/pelletier/go-toml v1.9.5 github.com/pelletier/go-toml/v2 v2.2.4 diff --git a/go.sum b/go.sum index d59af8648d..182fa64d83 100644 --- a/go.sum +++ b/go.sum @@ -256,6 +256,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk= +github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= diff --git a/pkg/loop/cmd/loopinstall/install.go b/pkg/loop/cmd/loopinstall/install.go index fcd258f699..3d8b18f8c3 100644 --- a/pkg/loop/cmd/loopinstall/install.go +++ b/pkg/loop/cmd/loopinstall/install.go @@ -12,6 +12,8 @@ import ( "strings" "sync" "time" + + shellwords "github.com/mattn/go-shellwords" ) // execCommand is a function variable that can be replaced in tests @@ -51,108 +53,149 @@ func mergeOrReplaceEnvVars(existing []string, newVars []string) []string { return result } -// downloadAndInstallPlugin downloads and installs a single plugin -func downloadAndInstallPlugin(pluginType string, pluginIdx int, plugin PluginDef, defaults DefaultsConfig) error { - if !isPluginEnabled(plugin) { - log.Printf("Skipping disabled plugin %s[%d]", pluginType, pluginIdx) - return nil +func determineModuleDirectory(goPrivate, fullModulePath string) (string, error) { + cmd := exec.Command("go", "mod", "download", "-json", fullModulePath) + var out bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = os.Stderr + + if goPrivate != "" { + // Inherit the current environment and override GOPRIVATE. + // Note: Not really sure why this is needed - tried to simplify existing logic + cmd.Env = append(os.Environ(), "GOPRIVATE="+goPrivate) } - moduleURI := plugin.ModuleURI - gitRef := plugin.GitRef - installPath := plugin.InstallPath + if err := execCommand(cmd); err != nil { + return "", fmt.Errorf("failed to download module %s: %w", fullModulePath, err) + } - // Validate inputs - if err := validateModuleURI(moduleURI); err != nil { - return fmt.Errorf("validation failed: %w", err) + var result ModDownloadResult + if err := json.Unmarshal(out.Bytes(), &result); err != nil { + return "", fmt.Errorf("failed to parse go mod download output: %w", err) } - if gitRef != "" { - if err := validateGitRef(gitRef); err != nil { - return fmt.Errorf("validation failed: %w", err) + if result.Dir == "" { + return "", fmt.Errorf("empty module directory returned for %s", fullModulePath) + } + + return result.Dir, nil +} + +func determineGoFlags(defaultGoFlags, pluginGoFlags string) ([]string, error) { + var flags []string + parser := shellwords.NewParser() + + // Determine base flags + if envGoFlags := os.Getenv("CL_PLUGIN_GOFLAGS"); envGoFlags != "" { + log.Printf("Overriding config's default goflags with CL_PLUGIN_GOFLAGS env var: %s", envGoFlags) + f, err := parser.Parse(envGoFlags) + if err != nil { + return nil, err } + flags = f + } else if defaultGoFlags != "" { + f, err := parser.Parse(defaultGoFlags) + if err != nil { + return nil, err + } + flags = f } - if err := validateInstallPath(installPath); err != nil { - return fmt.Errorf("validation failed: %w", err) + // Append plugin-specific flags + if pluginGoFlags != "" { + f, err := parser.Parse(pluginGoFlags) + if err != nil { + return nil, err + } + flags = append(flags, f...) } - // Full module path with git reference - fullModulePath := moduleURI - if gitRef != "" { - fullModulePath = fmt.Sprintf("%s@%s", moduleURI, gitRef) + // Validate + if len(flags) > 0 { + if err := validateGoFlags(strings.Join(flags, " ")); err != nil { + return nil, err + } } - log.Printf("Installing plugin %s[%d] from %s", pluginType, pluginIdx, fullModulePath) + return flags, nil +} - // Get GOPRIVATE environment variable - goPrivate := os.Getenv("GOPRIVATE") +func determineInstallArg(installPath, moduleURI string) string { + // Determine the actual argument for 'go install' based on installPath and moduleURI. + // - installPath is the user-provided path from YAML (no environment variable expansion). + // - moduleURI is the URI of the module being downloaded and installed (no environment variable expansion). + // The 'go install' command will be run with cmd.Dir set to the root of the downloaded moduleURI. + // Therefore, installArg must be "." or a path starting with "./" relative to the module root. - // Download the module and get its directory - var moduleDir string - { - cmd := exec.Command("go", "mod", "download", "-json", fullModulePath) - var out bytes.Buffer - cmd.Stdout = &out - cmd.Stderr = os.Stderr + // Case 1: installPath is the moduleURI itself. Install the module root. + if installPath == moduleURI { + return "." + } - // Set GOPRIVATE environment variable while preserving other environment variables - if goPrivate != "" { - // Start with all current environment variables - env := os.Environ() - - // Find and replace GOPRIVATE if it exists, or add it if it doesn't - goprivateFound := false - for i, e := range env { - if strings.HasPrefix(e, "GOPRIVATE=") { - env[i] = "GOPRIVATE=" + goPrivate - goprivateFound = true - break - } - } + // Case 2: installPath is a sub-package of moduleURI (e.g., "moduleURI/cmd/plugin"). + if after, ok := strings.CutPrefix(installPath, moduleURI+"/"); ok { + // Extract the relative path and prefix with "./". + relativePath := after + cleanedRelativePath := strings.TrimLeft(relativePath, "/") // Handles "moduleURI///subpath" + if cleanedRelativePath == "" || cleanedRelativePath == "." { // Handles "moduleURI/" or "moduleURI/." + return "." + } - // Add GOPRIVATE if it wasn't already in the environment - if !goprivateFound { - env = append(env, "GOPRIVATE="+goPrivate) - } + // cleanedRelativePath is like "cmd/plugin" or "sub/../pkg". Prepend "./". + return "./" + cleanedRelativePath + } - cmd.Env = env - } + // Case 3: installPath is not moduleURI and not a sub-package of moduleURI. + // Assumed to be: + // a) A path already relative to the module root (e.g., "cmd/plugin", "./cmd/plugin", "."). + // b) A full path to a different module (e.g., "github.com/other/mod"). + // For (b), prefixing with "./" when cmd.Dir is set is problematic but replicates prior behavior if any. - if err := execCommand(cmd); err != nil { - return fmt.Errorf("failed to download module %s: %w", fullModulePath, err) - } + // Simple case + if installPath == "." { + return "." + } - var result ModDownloadResult - if err := json.Unmarshal(out.Bytes(), &result); err != nil { - return fmt.Errorf("failed to parse go mod download output: %w", err) - } + // Already correctly formatted (e.g., "./cmd/plugin", "./sub/../pkg") + if strings.HasPrefix(installPath, "./") { + return installPath + } - moduleDir = result.Dir - if moduleDir == "" { - return fmt.Errorf("empty module directory returned for %s", fullModulePath) - } + // Needs "./" prefix. Handles "cmd/plugin", "/cmd/plugin", "github.com/other/mod". + return "./" + strings.TrimLeft(installPath, "/") +} + +// downloadAndInstallPlugin downloads and installs a single plugin +func downloadAndInstallPlugin(pluginType string, pluginIdx int, plugin PluginDef, defaults DefaultsConfig) error { + if !isPluginEnabled(plugin) { + log.Printf("Skipping disabled plugin %s[%d]", pluginType, pluginIdx) + return nil } - // Build goflags - goflags := defaults.GoFlags - if envGoFlags := os.Getenv("CL_PLUGIN_GOFLAGS"); envGoFlags != "" { - goflags = envGoFlags + // Validate inputs + if err := plugin.Validate(); err != nil { + return fmt.Errorf("plugin input validation failed: %w", err) } - // If goflags from plugindef is set, append it - if plugin.Flags != "" { - if goflags != "" { - goflags += " " - } - goflags += plugin.Flags + moduleURI := plugin.ModuleURI + gitRef := plugin.GitRef + installPath := plugin.InstallPath + + // Full module path with git reference + fullModulePath := moduleURI + if gitRef != "" { + fullModulePath = fmt.Sprintf("%s@%s", moduleURI, gitRef) } - // Validate goflags - if goflags != "" { - if err := validateGoFlags(goflags); err != nil { - return fmt.Errorf("validation failed: %w", err) - } + log.Printf("Installing plugin %s[%d] from %s", pluginType, pluginIdx, fullModulePath) + + // Get GOPRIVATE environment variable + goPrivate := os.Getenv("GOPRIVATE") + + // Download the module and get its directory + moduleDir, err := determineModuleDirectory(goPrivate, fullModulePath) + if err != nil { + return fmt.Errorf("failed to determine module directory: %w", err) } // Build env vars from defaults, environment variable, and plugin-specific settings @@ -168,42 +211,7 @@ func downloadAndInstallPlugin(pluginType string, pluginIdx int, plugin PluginDef // Install the plugin { - // Determine the actual argument for 'go install' based on installPath and moduleURI. - // installPath is the user-provided path from YAML (no environment variable expansion). - // moduleURI is the URI of the module being downloaded and installed (no environment variable expansion). - // The 'go install' command will be run with cmd.Dir set to the root of the downloaded moduleURI. - // Therefore, installArg must be "." or a path starting with "./" relative to the module root. - var installArg string - if installPath == moduleURI { - // Case 1: installPath is the moduleURI itself. Install the module root. - installArg = "." - } else if after, ok := strings.CutPrefix(installPath, moduleURI+"/"); ok { - // Case 2: installPath is a sub-package of moduleURI (e.g., "moduleURI/cmd/plugin"). - // Extract the relative path and prefix with "./". - relativePath := after - cleanedRelativePath := strings.TrimLeft(relativePath, "/") // Handles "moduleURI///subpath" - if cleanedRelativePath == "" || cleanedRelativePath == "." { // Handles "moduleURI/" or "moduleURI/." - installArg = "." - } else { - // cleanedRelativePath is like "cmd/plugin" or "sub/../pkg". Prepend "./". - installArg = "./" + cleanedRelativePath - } - } else { - // Case 3: installPath is not moduleURI and not a sub-package of moduleURI. - // Assumed to be: - // a) A path already relative to the module root (e.g., "cmd/plugin", "./cmd/plugin", "."). - // b) A full path to a different module (e.g., "github.com/other/mod"). - // For (b), prefixing with "./" when cmd.Dir is set is problematic but replicates prior behavior if any. - if installPath == "." { - installArg = "." - } else if strings.HasPrefix(installPath, "./") { - // Already correctly formatted (e.g., "./cmd/plugin", "./sub/../pkg") - installArg = installPath - } else { - // Needs "./" prefix. Handles "cmd/plugin", "/cmd/plugin", "github.com/other/mod". - installArg = "./" + strings.TrimLeft(installPath, "/") - } - } + installArg := determineInstallArg(installPath, moduleURI) binaryName := filepath.Base(installArg) if binaryName == "." { @@ -222,9 +230,15 @@ func downloadAndInstallPlugin(pluginType string, pluginIdx int, plugin PluginDef outputPath := filepath.Join(outputDir, binaryName) + // Build goflags + goflags, err := determineGoFlags(defaults.GoFlags, plugin.Flags) + if err != nil { + return fmt.Errorf("validation failed: %w", err) + } + args := []string{"build", "-o", outputPath} - if goflags != "" { - args = append(args, strings.Fields(goflags)...) + if len(goflags) != 0 { + args = append(args, goflags...) } args = append(args, installArg) diff --git a/pkg/loop/cmd/loopinstall/validation.go b/pkg/loop/cmd/loopinstall/validation.go index 3287df7033..453019017f 100644 --- a/pkg/loop/cmd/loopinstall/validation.go +++ b/pkg/loop/cmd/loopinstall/validation.go @@ -6,6 +6,22 @@ import ( "strings" ) +func (plugin PluginDef) Validate() error { + if err := validateModuleURI(plugin.ModuleURI); err != nil { + return err + } + if plugin.GitRef != "" { + if err := validateGitRef(plugin.GitRef); err != nil { + return err + } + } + if err := validateInstallPath(plugin.InstallPath); err != nil { + return err + } + + return nil +} + // validateModuleURI ensures the module URI follows Go module conventions func validateModuleURI(uri string) error { // Check for valid Go module path format @@ -35,7 +51,7 @@ func validateInstallPath(path string) error { // validateGoFlags ensures flags are safe and prevents command injection func validateGoFlags(flags string) error { - // Check for potentially dangerous characters that could enable command injection + // Check for potentially dangerous characters or substrings that could enable command injection dangerousPatterns := []string{ ";", "&&", "||", "`", "$", "|", "<", ">", "#", "//", "shutdown", "reboot", "rm -", "format", "mkfs", "dd", From 072bce3ac0b211b55fd9dc55bde09ee255a1f292 Mon Sep 17 00:00:00 2001 From: Erik Burton Date: Mon, 17 Nov 2025 10:20:18 -0800 Subject: [PATCH 17/42] feat: support local installs --- pkg/loop/cmd/loopinstall/install.go | 242 ++++++++++++++--------- pkg/loop/cmd/loopinstall/install_test.go | 237 ++++++++++++++++++++++ 2 files changed, 381 insertions(+), 98 deletions(-) create mode 100644 pkg/loop/cmd/loopinstall/install_test.go diff --git a/pkg/loop/cmd/loopinstall/install.go b/pkg/loop/cmd/loopinstall/install.go index 3d8b18f8c3..90999c23ed 100644 --- a/pkg/loop/cmd/loopinstall/install.go +++ b/pkg/loop/cmd/loopinstall/install.go @@ -22,7 +22,7 @@ var execCommand = func(cmd *exec.Cmd) error { } // mergeOrReplaceEnvVars merges new environment variables into an existing slice, -// replacing any existing variables with the same key +// replacing any existing variables with the same key. func mergeOrReplaceEnvVars(existing []string, newVars []string) []string { result := make([]string, len(existing)) copy(result, existing) @@ -53,7 +53,33 @@ func mergeOrReplaceEnvVars(existing []string, newVars []string) []string { return result } -func determineModuleDirectory(goPrivate, fullModulePath string) (string, error) { +// determineModuleDirectory locates the directory to build from. +// - Local path (absolute or "./relative"): resolve and return the directory (no download). +func determineModuleDirectoryLocal(pluginKey, moduleURI string) (string, error) { + log.Printf("%s - resolving local module path %q", pluginKey, moduleURI) + abs, err := filepath.Abs(moduleURI) + if err != nil { + return "", fmt.Errorf("failed to resolve local module path %q: %w", moduleURI, err) + } + info, err := os.Stat(abs) + if err != nil { + return "", fmt.Errorf("local module path %q not accessible: %w", abs, err) + } + if !info.IsDir() { + return "", fmt.Errorf("local module path %q is not a directory", abs) + } + return abs, nil +} + +// determineModuleDirectory locates the directory to build from. +// - Remote module path (e.g., "github.com/org/repo@ref"): use `go mod download -json` to get a module cache dir. +func determineModuleDirectoryRemote(pluginKey, moduleURI, gitRef, goPrivate string) (string, error) { + fullModulePath := moduleURI + if gitRef != "" { + fullModulePath = fmt.Sprintf("%s@%s", moduleURI, gitRef) + } + log.Printf("%s - downloading remote module %s", pluginKey, fullModulePath) + cmd := exec.Command("go", "mod", "download", "-json", fullModulePath) var out bytes.Buffer cmd.Stdout = &out @@ -61,7 +87,6 @@ func determineModuleDirectory(goPrivate, fullModulePath string) (string, error) if goPrivate != "" { // Inherit the current environment and override GOPRIVATE. - // Note: Not really sure why this is needed - tried to simplify existing logic cmd.Env = append(os.Environ(), "GOPRIVATE="+goPrivate) } @@ -81,13 +106,18 @@ func determineModuleDirectory(goPrivate, fullModulePath string) (string, error) return result.Dir, nil } -func determineGoFlags(defaultGoFlags, pluginGoFlags string) ([]string, error) { +// determineGoFlags resolves go build flags in priority order: +// 1) CL_PLUGIN_GOFLAGS env var (overrides config) +// 2) defaults.GoFlags +// 3) plugin-specific goflags appended +// It validates flags if any are present. +func determineGoFlags(pluginKey, defaultGoFlags, pluginGoFlags string) ([]string, error) { var flags []string parser := shellwords.NewParser() // Determine base flags if envGoFlags := os.Getenv("CL_PLUGIN_GOFLAGS"); envGoFlags != "" { - log.Printf("Overriding config's default goflags with CL_PLUGIN_GOFLAGS env var: %s", envGoFlags) + log.Printf("%s - overriding config's default goflags with CL_PLUGIN_GOFLAGS env var: %s", pluginKey, envGoFlags) f, err := parser.Parse(envGoFlags) if err != nil { return nil, err @@ -120,155 +150,170 @@ func determineGoFlags(defaultGoFlags, pluginGoFlags string) ([]string, error) { return flags, nil } -func determineInstallArg(installPath, moduleURI string) string { - // Determine the actual argument for 'go install' based on installPath and moduleURI. - // - installPath is the user-provided path from YAML (no environment variable expansion). - // - moduleURI is the URI of the module being downloaded and installed (no environment variable expansion). - // The 'go install' command will be run with cmd.Dir set to the root of the downloaded moduleURI. - // Therefore, installArg must be "." or a path starting with "./" relative to the module root. +// determineInstallArg computes the argument passed to `go build` given we're changing cmd.Dir. +// For remote modules, we keep the legacy behavior. +// For local moduleURIs, we compute a relative path from the module root to the installPath +// so the resulting arg is "." or "./sub/package". +func determineInstallArg(installPath, moduleURI string, isLocal bool) string { + cleanInstallPath := filepath.Clean(installPath) + cleanModuleURI := filepath.Clean(moduleURI) + + // Local modules + if isLocal { + // 1 - If building the module root + if cleanInstallPath == cleanModuleURI || cleanInstallPath == "." { + return "." + } + // 2 - If installPath is inside the module root, return "./" + if rel, err := filepath.Rel(cleanModuleURI, cleanInstallPath); err == nil && rel != "" && !strings.HasPrefix(rel, "..") { + rel = filepath.ToSlash(rel) + if rel == "." { + return "." + } + return "./" + rel + } + + // 3 - If installPath is already relative to the module root, normalize "./" prefix + if !filepath.IsAbs(cleanInstallPath) { + cleanInstallPath = filepath.ToSlash(cleanInstallPath) + if cleanInstallPath == "." || strings.HasPrefix(cleanInstallPath, "./") { + return cleanInstallPath + } + return "./" + strings.TrimLeft(cleanInstallPath, "/") + } + + // Absolute path outside module root: still give a relative-looking arg; + // cmd.Dir will be set to module root so Go expects package paths like "./x/y". + return "./" + filepath.ToSlash(strings.TrimLeft(cleanInstallPath, string(filepath.Separator))) + } - // Case 1: installPath is the moduleURI itself. Install the module root. + // Remote modules + // 1 - installPath is the module root itself. if installPath == moduleURI { return "." } - // Case 2: installPath is a sub-package of moduleURI (e.g., "moduleURI/cmd/plugin"). + // 2 - installPath is a sub-package of moduleURI. if after, ok := strings.CutPrefix(installPath, moduleURI+"/"); ok { - // Extract the relative path and prefix with "./". - relativePath := after - cleanedRelativePath := strings.TrimLeft(relativePath, "/") // Handles "moduleURI///subpath" - if cleanedRelativePath == "" || cleanedRelativePath == "." { // Handles "moduleURI/" or "moduleURI/." + cleanedRelativePath := strings.TrimLeft(after, "/") + if cleanedRelativePath == "" || cleanedRelativePath == "." { return "." } - - // cleanedRelativePath is like "cmd/plugin" or "sub/../pkg". Prepend "./". return "./" + cleanedRelativePath } - // Case 3: installPath is not moduleURI and not a sub-package of moduleURI. - // Assumed to be: - // a) A path already relative to the module root (e.g., "cmd/plugin", "./cmd/plugin", "."). - // b) A full path to a different module (e.g., "github.com/other/mod"). - // For (b), prefixing with "./" when cmd.Dir is set is problematic but replicates prior behavior if any. - - // Simple case + // 3 - other inputs; normalize to a "./" path. if installPath == "." { return "." } - - // Already correctly formatted (e.g., "./cmd/plugin", "./sub/../pkg") if strings.HasPrefix(installPath, "./") { return installPath } - - // Needs "./" prefix. Handles "cmd/plugin", "/cmd/plugin", "github.com/other/mod". return "./" + strings.TrimLeft(installPath, "/") } -// downloadAndInstallPlugin downloads and installs a single plugin +// downloadAndInstallPlugin downloads (if remote) and builds the plugin. +// For local moduleURIs (absolute or "./relative"), we skip network download, +// ignore gitRef (with a log message), and build directly from the local dir. func downloadAndInstallPlugin(pluginType string, pluginIdx int, plugin PluginDef, defaults DefaultsConfig) error { + pluginKey := fmt.Sprintf("%s[%d]", pluginType, pluginIdx) if !isPluginEnabled(plugin) { - log.Printf("Skipping disabled plugin %s[%d]", pluginType, pluginIdx) + log.Printf("%s - skipping disabled plugin", pluginKey) return nil } // Validate inputs if err := plugin.Validate(); err != nil { - return fmt.Errorf("plugin input validation failed: %w", err) + return fmt.Errorf("%s - plugin input validation failed: %w", pluginKey, err) } moduleURI := plugin.ModuleURI gitRef := plugin.GitRef installPath := plugin.InstallPath - // Full module path with git reference - fullModulePath := moduleURI - if gitRef != "" { - fullModulePath = fmt.Sprintf("%s@%s", moduleURI, gitRef) - } - - log.Printf("Installing plugin %s[%d] from %s", pluginType, pluginIdx, fullModulePath) - - // Get GOPRIVATE environment variable goPrivate := os.Getenv("GOPRIVATE") - // Download the module and get its directory - moduleDir, err := determineModuleDirectory(goPrivate, fullModulePath) + // Determine the directory to run `go build` in. + isLocal := filepath.IsAbs(moduleURI) || strings.HasPrefix(moduleURI, "."+string(filepath.Separator)) + moduleDir, err := func() (string, error) { + if isLocal { + return determineModuleDirectoryLocal(pluginKey, moduleURI) + } + return determineModuleDirectoryRemote(pluginKey, moduleURI, gitRef, goPrivate) + }() if err != nil { - return fmt.Errorf("failed to determine module directory: %w", err) + return fmt.Errorf("%s - failed to determine module directory: %w", pluginKey, err) + } + if moduleDir == "" { + return fmt.Errorf("%s - empty module directory resolved", pluginKey) } - // Build env vars from defaults, environment variable, and plugin-specific settings + log.Printf("%s - installing plugin from %s", pluginKey, moduleDir) + + // Build env vars from defaults, environment variable, and plugin-specific settings. envVars := defaults.EnvVars if envEnvVars := os.Getenv("CL_PLUGIN_ENVVARS"); envEnvVars != "" { envVars = mergeOrReplaceEnvVars(envVars, strings.Fields(envEnvVars)) } - - // Merge plugin-specific env vars if len(plugin.EnvVars) != 0 { envVars = mergeOrReplaceEnvVars(envVars, plugin.EnvVars) } - // Install the plugin - { - installArg := determineInstallArg(installPath, moduleURI) - - binaryName := filepath.Base(installArg) - if binaryName == "." { - binaryName = filepath.Base(moduleURI) - } - - // Determine output directory - outputDir := os.Getenv("GOBIN") - if outputDir == "" { - gopath := os.Getenv("GOPATH") - if gopath == "" { - gopath = filepath.Join(os.Getenv("HOME"), "go") - } - outputDir = filepath.Join(gopath, "bin") - } - - outputPath := filepath.Join(outputDir, binaryName) + // Compute build target relative to module root ('.' or './subpkg'). + installArg := determineInstallArg(installPath, moduleURI, isLocal) - // Build goflags - goflags, err := determineGoFlags(defaults.GoFlags, plugin.Flags) - if err != nil { - return fmt.Errorf("validation failed: %w", err) - } + // Derive output binary name. When arg is ".", use the module/repo (or local dir) name. + binaryName := filepath.Base(installArg) + if binaryName == "." { + binaryName = filepath.Base(filepath.Clean(moduleURI)) + } - args := []string{"build", "-o", outputPath} - if len(goflags) != 0 { - args = append(args, goflags...) + // Determine output directory (GOBIN, or GOPATH/bin, or $HOME/go/bin). + outputDir := os.Getenv("GOBIN") + if outputDir == "" { + gopath := os.Getenv("GOPATH") + if gopath == "" { + gopath = filepath.Join(os.Getenv("HOME"), "go") } - args = append(args, installArg) + outputDir = filepath.Join(gopath, "bin") + } + outputPath := filepath.Join(outputDir, binaryName) - cmd := exec.Command("go", args...) - cmd.Dir = moduleDir - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr + // Build goflags + goflags, err := determineGoFlags(pluginKey, defaults.GoFlags, plugin.Flags) + if err != nil { + return fmt.Errorf("%s - goflags validation failed: %w", pluginKey, err) + } - // Start with all current environment variables - cmd.Env = os.Environ() + // Assemble `go build` command. + args := []string{"build", "-o", outputPath} + if len(goflags) != 0 { + args = append(args, goflags...) + } + args = append(args, installArg) - // Set GOPRIVATE environment variable if provided - if goPrivate != "" { - cmd.Env = mergeOrReplaceEnvVars(cmd.Env, []string{"GOPRIVATE=" + goPrivate}) - } + cmd := exec.Command("go", args...) + cmd.Dir = moduleDir + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr - // Add/replace custom environment variables (e.g., GOOS, GOARCH, CGO_ENABLED) - cmd.Env = mergeOrReplaceEnvVars(cmd.Env, envVars) + // Start with all current environment variables. + cmd.Env = os.Environ() + if goPrivate != "" { + cmd.Env = mergeOrReplaceEnvVars(cmd.Env, []string{"GOPRIVATE=" + goPrivate}) + } + cmd.Env = mergeOrReplaceEnvVars(cmd.Env, envVars) - log.Printf("Running install command: go %s (in directory: %s)", strings.Join(args, " "), moduleDir) + log.Printf("%s - running install command: go %s (in directory: %s)", pluginKey, strings.Join(args, " "), moduleDir) - if err := execCommand(cmd); err != nil { - return fmt.Errorf("failed to install plugin %s[%d]: %w", pluginType, pluginIdx, err) - } + if err := execCommand(cmd); err != nil { + return fmt.Errorf("%s - failed to install plugin: %w", pluginKey, err) } return nil } -// writeBuildManifest writes installation artifacts to the specified file +// writeBuildManifest writes installation artifacts to the specified file. func writeBuildManifest(tasks []PluginInstallTask, outputFile string) error { manifest := BuildManifest{ BuildTime: time.Now().UTC().Format(time.RFC3339), @@ -279,8 +324,7 @@ func writeBuildManifest(tasks []PluginInstallTask, outputFile string) error { for _, task := range tasks { configPath := task.ConfigFile if !filepath.IsAbs(configPath) { - absPath, err := filepath.Abs(configPath) - if err == nil { + if absPath, err := filepath.Abs(configPath); err == nil { configPath = absPath } } @@ -320,7 +364,7 @@ func writeBuildManifest(tasks []PluginInstallTask, outputFile string) error { return nil } -// installPlugins installs plugins concurrently using worker pool pattern +// installPlugins installs plugins concurrently using a worker pool. func installPlugins(tasks []PluginInstallTask, concurrency int, verbose bool, outputFile string) error { if len(tasks) == 0 { log.Println("No enabled plugins found to install") @@ -329,6 +373,7 @@ func installPlugins(tasks []PluginInstallTask, concurrency int, verbose bool, ou log.Printf("Installing %d plugins with concurrency %d", len(tasks), concurrency) + // Optionally write the manifest first (so artifacts exist even if a build fails). if outputFile != "" { if err := writeBuildManifest(tasks, outputFile); err != nil { return fmt.Errorf("failed to write installation artifacts: %w", err) @@ -349,6 +394,7 @@ func installPlugins(tasks []PluginInstallTask, concurrency int, verbose bool, ou } start := time.Now() + err := downloadAndInstallPlugin(task.PluginType, 0, task.Plugin, task.Defaults) duration := time.Since(start) @@ -407,7 +453,7 @@ func installPlugins(tasks []PluginInstallTask, concurrency int, verbose bool, ou return nil } -// setupOutputFile ensures the output directory exists +// setupOutputFile ensures the output path is absolute (and its directory exists is handled elsewhere). func setupOutputFile(outputFile string) (string, error) { if !filepath.IsAbs(outputFile) { wd, err := os.Getwd() diff --git a/pkg/loop/cmd/loopinstall/install_test.go b/pkg/loop/cmd/loopinstall/install_test.go new file mode 100644 index 0000000000..4d501d74b3 --- /dev/null +++ b/pkg/loop/cmd/loopinstall/install_test.go @@ -0,0 +1,237 @@ +package main + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +// --- helpers --- + +// withMockExec temporarily swaps execCommand and restores it after the test. +func withMockExec(t *testing.T, f func(cmd *exec.Cmd) error, body func()) { + t.Helper() + orig := execCommand + execCommand = f + defer func() { execCommand = orig }() + body() +} + +// normalize slashes for stable asserts across platforms. +func toSlash(p string) string { return filepath.ToSlash(p) } + +// --- determineModuleDirectoryLocal --- + +func TestDetermineModuleDirectoryLocal(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) string // returns the path to test + wantErrSub string // substring expected in error, or empty for success + }{ + { + name: "success_directory", + setup: func(t *testing.T) string { + return t.TempDir() + }, + wantErrSub: "", + }, + { + name: "not_a_directory", + setup: func(t *testing.T) string { + td := t.TempDir() + fp := filepath.Join(td, "file.txt") + if err := os.WriteFile(fp, []byte("hi"), 0o600); err != nil { + t.Fatal(err) + } + return fp + }, + wantErrSub: "is not a directory", + }, + { + name: "not_accessible", + setup: func(t *testing.T) string { + return filepath.Join(t.TempDir(), "does-not-exist-foo-bar-baz") + }, + wantErrSub: "not accessible", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + path := tc.setup(t) + got, err := determineModuleDirectoryLocal("plugin[0]", path) + + // error cases + if err != nil { + if tc.wantErrSub == "" { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(err.Error(), tc.wantErrSub) { + t.Fatalf("expected error containing %q, got %v", tc.wantErrSub, err) + } + return + } + + // success case + want, _ := filepath.Abs(path) + if got != want { + t.Fatalf("got %q, want %q", got, want) + } + }) + } +} + +// --- determineModuleDirectoryRemote --- + +func TestDetermineModuleDirectoryRemote_Success(t *testing.T) { + wantDir := filepath.Join(t.TempDir(), "gomodcache", "module") + mod := "github.com/acme/thing" + ref := "v1.2.3" + + withMockExec(t, func(cmd *exec.Cmd) error { + // Basic command shape + if len(cmd.Args) < 5 || + cmd.Args[0] != "go" || + cmd.Args[1] != "mod" || + cmd.Args[2] != "download" || + cmd.Args[3] != "-json" || + cmd.Args[4] != fmt.Sprintf("%s@%s", mod, ref) { + t.Fatalf("unexpected args: %v", cmd.Args) + } + // GOPRIVATE propagates only if provided + found := false + for _, e := range cmd.Env { + if strings.HasPrefix(e, "GOPRIVATE=") { + found = true + break + } + } + if !found { + t.Fatalf("expected GOPRIVATE in env") + } + + // Simulate go's JSON + type dl struct{ Dir string } + enc := json.NewEncoder(cmd.Stdout) + _ = enc.Encode(dl{Dir: wantDir}) + return nil + }, func() { + got, err := determineModuleDirectoryRemote("plugin[0]", mod, ref, "github.com/private/*") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != wantDir { + t.Fatalf("got %q, want %q", got, wantDir) + } + }) +} + +func TestDetermineModuleDirectoryRemote_EmptyDir(t *testing.T) { + withMockExec(t, func(cmd *exec.Cmd) error { + // write empty object (no Dir) + _, _ = fmt.Fprint(cmd.Stdout, `{}`) + return nil + }, func() { + _, err := determineModuleDirectoryRemote("plugin[0]", "github.com/acme/thing", "main", "") + if err == nil || !strings.Contains(err.Error(), "empty module directory") { + t.Fatalf("expected empty module directory error, got %v", err) + } + }) +} + +func TestDetermineModuleDirectoryRemote_CommandError(t *testing.T) { + withMockExec(t, func(cmd *exec.Cmd) error { + return errors.New("boom") + }, func() { + _, err := determineModuleDirectoryRemote("plugin[0]", "github.com/acme/thing", "main", "") + if err == nil || !strings.Contains(err.Error(), "failed to download module") { + t.Fatalf("expected download failure, got %v", err) + } + }) +} + +func TestDetermineModuleDirectoryRemote_InvalidJSON(t *testing.T) { + withMockExec(t, func(cmd *exec.Cmd) error { + _, _ = fmt.Fprint(cmd.Stdout, `not-json`) + return nil + }, func() { + _, err := determineModuleDirectoryRemote("plugin[0]", "github.com/acme/thing", "main", "") + if err == nil || !strings.Contains(err.Error(), "failed to parse") { + t.Fatalf("expected parse error, got %v", err) + } + }) +} + +// --- determineInstallArg --- + +func TestDetermineInstallArg_LocalModule(t *testing.T) { + // Use an absolute module root to exercise local path handling + modRoot := filepath.Clean(t.TempDir()) + + tests := []struct { + name string + installPath string + want string + }{ + {"root equals module", modRoot, "."}, + {"dot", ".", "."}, + {"subdir in module", filepath.Join(modRoot, "cmd", "tool"), "./cmd/tool"}, + {"relative subdir", "cmd/tool", "./cmd/tool"}, + { + "absolute outside module becomes relative-ish", + func() string { + // a path that is surely outside modRoot + base := string(os.PathSeparator) + "opt" + string(os.PathSeparator) + "other" + return filepath.Clean(base) + }(), + func() string { + abs := func() string { + base := string(os.PathSeparator) + "opt" + string(os.PathSeparator) + "other" + return filepath.Clean(base) + }() + return "./" + toSlash(strings.TrimLeft(abs, string(os.PathSeparator))) + }(), + }, + {"already ./ prefixed", "./cmd/tool", "./cmd/tool"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := determineInstallArg(tc.installPath, modRoot, true) + if toSlash(got) != toSlash(tc.want) { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } +} + +func TestDetermineInstallArg_RemoteModule(t *testing.T) { + module := "github.com/acme/repo" + + tests := []struct { + name string + installPath string + want string + }{ + {"root equals module", module, "."}, + {"subpackage absolute import", module + "/sub/pkg", "./sub/pkg"}, + {"dot", ".", "."}, + {"plain relative", "sub/pkg", "./sub/pkg"}, + {"already ./ prefixed", "./sub/pkg", "./sub/pkg"}, + {"normalize leading slash", "/sub/pkg", "./sub/pkg"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := determineInstallArg(tc.installPath, module, false) + if toSlash(got) != toSlash(tc.want) { + t.Fatalf("got %q, want %q", got, tc.want) + } + }) + } +} From c22b1ded4e0a1c9c0269e701aee0d4effd33a164 Mon Sep 17 00:00:00 2001 From: Oliver Townsend <133903322+ogtownsend@users.noreply.github.com> Date: Mon, 15 Dec 2025 07:22:34 -0800 Subject: [PATCH 18/42] Port generic chain-agnostic balance monitor to cl-common (#1728) * Move generic chain-agnostic balance monitor to cl-common * gomods tidy * lint * lint * fix typo --- pkg/monitoring/balance/generic_balance.go | 197 ++++++++++++++++++++++ pkg/monitoring/balance/metadata.go | 100 +++++++++++ pkg/monitoring/balance/metrics.go | 50 ++++++ pkg/monitoring/go.mod | 12 +- pkg/monitoring/go.sum | 17 ++ 5 files changed, 374 insertions(+), 2 deletions(-) create mode 100644 pkg/monitoring/balance/generic_balance.go create mode 100644 pkg/monitoring/balance/metadata.go create mode 100644 pkg/monitoring/balance/metrics.go diff --git a/pkg/monitoring/balance/generic_balance.go b/pkg/monitoring/balance/generic_balance.go new file mode 100644 index 0000000000..e6a5151080 --- /dev/null +++ b/pkg/monitoring/balance/generic_balance.go @@ -0,0 +1,197 @@ +// Package balance provides a generic chain-agnostic balance monitoring service +// that tracks account balances across different blockchain networks. +package balance + +import ( + "context" + "fmt" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" + "github.com/smartcontractkit/chainlink-common/pkg/utils" +) + +// Config defines the balance monitor configuration. +type GenericBalanceConfig struct { + BalancePollPeriod config.Duration +} + +// GenericBalanceClient defines the interface for getting account balances. +type GenericBalanceClient interface { + GetAccountBalance(addr string) (float64, error) +} + +// GenericBalanceMonitorOpts contains the options for creating a new balance monitor. +type GenericBalanceMonitorOpts struct { + ChainInfo ChainInfo + ChainNativeCurrency string + + Config GenericBalanceConfig + Logger logger.Logger + Keystore core.Keystore + NewGenericBalanceClient func() (GenericBalanceClient, error) + + // Maps a public key to an account address (optional, can return key as is) + KeyToAccountMapper func(context.Context, string) (string, error) +} + +// ChainInfo contains information about the blockchain network. +type ChainInfo struct { + ChainFamilyName string + ChainID string + NetworkName string + NetworkNameFull string +} + +// NewGenericBalanceMonitor returns a balance monitoring services.Service which reports the balance of all Keystore accounts. +func NewGenericBalanceMonitor(opts GenericBalanceMonitorOpts) (services.Service, error) { + // Try to create a new gauge for account balance + gauge, err := NewGaugeAccBalance(opts.ChainNativeCurrency) + if err != nil { + return nil, fmt.Errorf("failed to create gauge: %w", err) + } + + lggr := logger.Named(opts.Logger, "BalanceMonitor") + return &genericBalanceMonitor{ + cfg: opts.Config, + lggr: lggr, + ks: opts.Keystore, + + newReader: opts.NewGenericBalanceClient, + keyToAccountMapper: opts.KeyToAccountMapper, + updateFn: func(ctx context.Context, acc string, balance float64) { + lggr.Infow("Account balance updated", "unit", opts.ChainNativeCurrency, "account", acc, "balance", balance) + gauge.Record(ctx, balance, acc, opts.ChainInfo) + }, + + stop: make(chan struct{}), + done: make(chan struct{}), + }, nil +} + +type genericBalanceMonitor struct { + services.StateMachine + cfg GenericBalanceConfig + lggr logger.Logger + ks core.Keystore + + // Returns a new GenericBalanceClient + newReader func() (GenericBalanceClient, error) + // Maps a public key to an account address (optional, can return key as is) + keyToAccountMapper func(context.Context, string) (string, error) + // Updates the balance metric + updateFn func(ctx context.Context, acc string, balance float64) // overridable for testing + + // Cached instance, intermittently reset to nil. + reader GenericBalanceClient + + stop services.StopChan + done chan struct{} +} + +func (m *genericBalanceMonitor) Name() string { + return m.lggr.Name() +} + +func (m *genericBalanceMonitor) Start(context.Context) error { + return m.StartOnce(m.Name(), func() error { + go m.start() + return nil + }) +} + +func (m *genericBalanceMonitor) Close() error { + return m.StopOnce(m.Name(), func() error { + close(m.stop) + <-m.done + return nil + }) +} + +func (m *genericBalanceMonitor) HealthReport() map[string]error { + return map[string]error{m.Name(): m.Healthy()} +} + +// monitor fn continuously updates balances, until stop signal is received. +func (m *genericBalanceMonitor) start() { + defer close(m.done) + ctx, cancel := m.stop.NewCtx() + defer cancel() + + period := m.cfg.BalancePollPeriod.Duration() + tick := time.After(utils.WithJitter(period)) + for { + select { + case <-m.stop: + return + case <-tick: + m.updateBalances(ctx) + tick = time.After(utils.WithJitter(period)) + } + } +} + +// getReader returns the stored GenericBalanceClient, creating a new one if necessary. +func (m *genericBalanceMonitor) getReader() (GenericBalanceClient, error) { + if m.reader == nil { + var err error + m.reader, err = m.newReader() + if err != nil { + return nil, err + } + } + return m.reader, nil +} + +// updateBalances updates the balances of all accounts in the keystore, using the provided GenericBalanceClient and the updateFn. +func (m *genericBalanceMonitor) updateBalances(ctx context.Context) { + m.lggr.Debug("Updating account balances") + keys, err := m.ks.Accounts(ctx) + if err != nil { + m.lggr.Errorw("Failed to get keys", "err", err) + return + } + if len(keys) == 0 { + return + } + reader, err := m.getReader() + if err != nil { + m.lggr.Errorw("Failed to get client", "err", err) + return + } + + var gotSomeBals bool + for _, pk := range keys { + // Check for shutdown signal, since Balance blocks and may be slow. + select { + case <-m.stop: + return + default: + } + + // Account address can always be derived from the public key currently + // TODO: if we need to support key rotation, the keystore should store the address explicitly + // Notice: this is chain-specific key to account mapping injected (e.g., relevant for Aptos key management) + accAddr, err := m.keyToAccountMapper(ctx, pk) + if err != nil { + m.lggr.Errorw("Failed to convert public key to account address", "err", err) + continue + } + + balance, err := reader.GetAccountBalance(accAddr) + if err != nil { + m.lggr.Errorw("Failed to get balance", "account", accAddr, "err", err) + continue + } + gotSomeBals = true + m.updateFn(ctx, accAddr, balance) + } + + // Try a new client next time. // TODO: This is for multinode + if !gotSomeBals { + m.reader = nil + } +} diff --git a/pkg/monitoring/balance/metadata.go b/pkg/monitoring/balance/metadata.go new file mode 100644 index 0000000000..b86541e87a --- /dev/null +++ b/pkg/monitoring/balance/metadata.go @@ -0,0 +1,100 @@ +// Package balance provides a generic chain-agnostic balance monitoring service +// that tracks account balances across different blockchain networks. +package balance + +import ( + "encoding/hex" + + "go.opentelemetry.io/otel/attribute" +) + +const ( + // WorkflowExecutionIDShortLen is the length of the short version of the WorkflowExecutionId (label) + WorkflowExecutionIDShortLen = 3 // first 3 characters, 16^3 = 4.096 possibilities (mid-high cardinality) +) + +// TODO: Refactor as a proto referenced from the other proto files (telemetry messages) +type ExecutionMetadata struct { + // Execution Context - Source + SourceID string + // Execution Context - Chain + ChainFamilyName string + ChainID string + NetworkName string + NetworkNameFull string + // Execution Context - Workflow (capabilities.RequestMetadata) + WorkflowID string + WorkflowOwner string + WorkflowExecutionID string + WorkflowName string + WorkflowDonID uint32 + WorkflowDonConfigVersion uint32 + ReferenceID string + // Execution Context - Capability + CapabilityType string + CapabilityID string + CapabilityTimestampStart uint32 + CapabilityTimestampEmit uint32 +} + +// Attributes returns common attributes used for metrics +func (m ExecutionMetadata) Attributes() []attribute.KeyValue { + // Decode workflow name attribute for output + workflowName := m.decodeWorkflowName() + + return []attribute.KeyValue{ + // Execution Context - Source + attribute.String("source_id", ValOrUnknown(m.SourceID)), + // Execution Context - Chain + attribute.String("chain_family_name", ValOrUnknown(m.ChainFamilyName)), + attribute.String("chain_id", ValOrUnknown(m.ChainID)), + attribute.String("network_name", ValOrUnknown(m.NetworkName)), + attribute.String("network_name_full", ValOrUnknown(m.NetworkNameFull)), + // Execution Context - Workflow (capabilities.RequestMetadata) + attribute.String("workflow_id", ValOrUnknown(m.WorkflowID)), + attribute.String("workflow_owner", ValOrUnknown(m.WorkflowOwner)), + // Notice: We lower the cardinality on the WorkflowExecutionID so it can be used by metrics + // This label has good chances to be unique per workflow, in a reasonable bounded time window + // TODO: enable this when sufficiently tested (PromQL queries like alerts might need to change if this is used) + // attribute.String("workflow_execution_id_short", ValShortOrUnknown(m.WorkflowExecutionID, WorkflowExecutionIDShortLen)), + attribute.String("workflow_name", ValOrUnknown(workflowName)), + attribute.Int64("workflow_don_id", int64(m.WorkflowDonID)), + attribute.Int64("workflow_don_config_version", int64(m.WorkflowDonConfigVersion)), + attribute.String("reference_id", ValOrUnknown(m.ReferenceID)), + // Execution Context - Capability + attribute.String("capability_type", ValOrUnknown(m.CapabilityType)), + attribute.String("capability_id", ValOrUnknown(m.CapabilityID)), + // Notice: we don't include the timestamps here (high cardinality) + } +} + +// decodeWorkflowName decodes the workflow name from hex string to raw string (underlying, output) +func (m ExecutionMetadata) decodeWorkflowName() string { + bytes, err := hex.DecodeString(m.WorkflowName) + if err != nil { + // This should never happen + bytes = []byte("unknown-decode-error") + } + return string(bytes) +} + +// ValOrUnknown returns the value if it is not empty, otherwise it returns "unknown" +// This is needed to avoid issues during exporting OTel metrics to Prometheus +// For more details see https://smartcontract-it.atlassian.net/browse/INFOPLAT-1349 +func ValOrUnknown(val string) string { + if val == "" { + return "unknown" + } + return val +} + +// ValShortOrUnknown returns the short len value if not empty or available, otherwise it returns "unknown" +func ValShortOrUnknown(val string, maxLen int) string { + if val == "" || maxLen <= 0 { + return "unknown" + } + if maxLen > len(val) { + return val + } + return val[:maxLen] +} diff --git a/pkg/monitoring/balance/metrics.go b/pkg/monitoring/balance/metrics.go new file mode 100644 index 0000000000..25b8cf137e --- /dev/null +++ b/pkg/monitoring/balance/metrics.go @@ -0,0 +1,50 @@ +// Package balance provides a generic chain-agnostic balance monitoring service +// that tracks account balances across different blockchain networks. +package balance + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/smartcontractkit/chainlink-common/pkg/beholder" +) + +// GaugeAccBalance defines a new gauge metric for account balance +type GaugeAccBalance struct { + // account_balance + gauge metric.Float64Gauge +} + +func NewGaugeAccBalance(unitStr string) (*GaugeAccBalance, error) { + name := "account_balance" + description := "Balance for configured WT account" + gauge, err := beholder.GetMeter().Float64Gauge(name, metric.WithUnit(unitStr), metric.WithDescription(description)) + if err != nil { + return nil, fmt.Errorf("failed to create new gauge %s: %+w", name, err) + } + return &GaugeAccBalance{gauge}, nil +} + +func (g *GaugeAccBalance) Record(ctx context.Context, balance float64, account string, chainInfo ChainInfo) { + oAttrs := metric.WithAttributeSet(g.GetAttributes(account, chainInfo)) + g.gauge.Record(ctx, balance, oAttrs) + + // TODO: consider also recording record in Prom for availability to NOPs +} + +func (g *GaugeAccBalance) GetAttributes(account string, chainInfo ChainInfo) attribute.Set { + return attribute.NewSet( + attribute.String("account", account), + + // Execution Context - Source + attribute.String("source_id", ValOrUnknown(account)), // reusing account as source_id + // Execution Context - Chain + attribute.String("chain_family_name", ValOrUnknown(chainInfo.ChainFamilyName)), + attribute.String("chain_id", ValOrUnknown(chainInfo.ChainID)), + attribute.String("network_name", ValOrUnknown(chainInfo.NetworkName)), + attribute.String("network_name_full", ValOrUnknown(chainInfo.NetworkNameFull)), + ) +} diff --git a/pkg/monitoring/go.mod b/pkg/monitoring/go.mod index d7c27adfb4..532de60dff 100644 --- a/pkg/monitoring/go.mod +++ b/pkg/monitoring/go.mod @@ -12,12 +12,16 @@ require ( github.com/smartcontractkit/chainlink-common v0.7.1-0.20250627153434-ed6ed7b7fcd7 github.com/smartcontractkit/libocr v0.0.0-20250220133800-f3b940c4f298 github.com/stretchr/testify v1.10.0 + go.opentelemetry.io/otel v1.35.0 + go.opentelemetry.io/otel/metric v1.35.0 go.uber.org/goleak v1.3.0 google.golang.org/protobuf v1.36.7 ) require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudevents/sdk-go/binding/format/protobuf/v2 v2.15.2 // indirect @@ -29,14 +33,17 @@ require ( github.com/go-playground/locales v0.13.0 // indirect github.com/go-playground/universal-translator v0.17.0 // indirect github.com/go-playground/validator/v10 v10.4.1 // indirect + github.com/go-viper/mapstructure/v2 v2.3.0 // indirect github.com/golang/snappy v0.0.5-0.20220116011046-fa5810519dcb // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 // indirect + github.com/invopop/jsonschema v0.12.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/leodido/go-urn v1.2.0 // indirect github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/moby/sys/sequential v0.6.0 // indirect github.com/moby/sys/user v0.3.0 // indirect github.com/moby/term v0.5.2 // indirect @@ -53,13 +60,15 @@ require ( github.com/prometheus/procfs v0.16.0 // indirect github.com/santhosh-tekuri/jsonschema/v5 v5.2.0 // indirect github.com/shirou/gopsutil/v4 v4.25.2 // indirect + github.com/shopspring/decimal v1.4.0 // indirect + github.com/smartcontractkit/chainlink-common/pkg/values v0.0.0-20250626141212-e50b2e7ffe2d // indirect github.com/smartcontractkit/freeport v0.1.1 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/testcontainers/testcontainers-go v0.36.0 // indirect github.com/tklauser/go-sysconf v0.3.15 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.60.0 // indirect - go.opentelemetry.io/otel v1.35.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.0.0-20240823153156-2a54df7bffb9 // indirect go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.6.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.35.0 // indirect @@ -71,7 +80,6 @@ require ( go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.28.0 // indirect go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.28.0 // indirect go.opentelemetry.io/otel/log v0.6.0 // indirect - go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/sdk/log v0.6.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.35.0 // indirect diff --git a/pkg/monitoring/go.sum b/pkg/monitoring/go.sum index c08f580af6..cbaec80e12 100644 --- a/pkg/monitoring/go.sum +++ b/pkg/monitoring/go.sum @@ -4,8 +4,12 @@ github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEK github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= @@ -53,6 +57,8 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-viper/mapstructure/v2 v2.3.0 h1:27XbWsHIqhbdR5TIC911OfYvgSaW93HM+dX7970Q7jk= +github.com/go-viper/mapstructure/v2 v2.3.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= @@ -67,6 +73,9 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1 h1:e9Rjr40Z98/clHv5Yg79Is0NtosR5LXRvdr7o/6NwbA= github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.1/go.mod h1:tIxuGz/9mpox++sgp9fJjHO0+q1X9/UOWd798aAm22M= +github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI= +github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= @@ -86,6 +95,8 @@ github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 h1:PpXWgLPs+Fqr32 github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/magiconair/properties v1.8.9 h1:nWcCbLq1N2v/cpNsy5WvQ37Fb+YElfq20WJ/a8RkpQM= github.com/magiconair/properties v1.8.9/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= @@ -139,10 +150,14 @@ github.com/santhosh-tekuri/jsonschema/v5 v5.2.0 h1:WCcC4vZDS1tYNxjWlwRJZQy28r8CM github.com/santhosh-tekuri/jsonschema/v5 v5.2.0/go.mod h1:FKdcjfQW6rpZSnxxUvEA5H/cDPdvJ/SZJQLWWXWGrZ0= github.com/shirou/gopsutil/v4 v4.25.2 h1:NMscG3l2CqtWFS86kj3vP7soOczqrQYIEhO/pMvvQkk= github.com/shirou/gopsutil/v4 v4.25.2/go.mod h1:34gBYJzyqCDT11b6bMHP0XCvWeU3J61XRT7a2EmCRTA= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartcontractkit/chainlink-common v0.7.1-0.20250627153434-ed6ed7b7fcd7 h1:Z6irOxlyglCP9qbJVndcoVxApJpUwFsQsJrhsfHYMGI= github.com/smartcontractkit/chainlink-common v0.7.1-0.20250627153434-ed6ed7b7fcd7/go.mod h1:mRKPMPyJhg1RBjxtRTL2gHvRhTcZ+nk2Upu/u97Y16M= +github.com/smartcontractkit/chainlink-common/pkg/values v0.0.0-20250626141212-e50b2e7ffe2d h1:86gp4tIXRb6ccSrjcm4gV8iA5wJN6er3rJY9f2UxRLU= +github.com/smartcontractkit/chainlink-common/pkg/values v0.0.0-20250626141212-e50b2e7ffe2d/go.mod h1:QUEPHdSkH19Or+E1iMGG+rDQ6jpCTIbm//9Osa6MXDE= github.com/smartcontractkit/freeport v0.1.1 h1:B5fhEtmgomdIhw03uPVbVTP6oPv27fBhZsoZZMSIS8I= github.com/smartcontractkit/freeport v0.1.1/go.mod h1:T4zH9R8R8lVWKfU7tUvYz2o2jMv1OpGCdpY2j2QZXzU= github.com/smartcontractkit/libocr v0.0.0-20250220133800-f3b940c4f298 h1:PKiqnVOTChlH4a4ljJKL3OKGRgYfIpJS4YD1daAIKks= @@ -165,6 +180,8 @@ github.com/tklauser/numcpus v0.10.0 h1:18njr6LDBk1zuna922MgdjQuJFjrdppsZG60sHGfj github.com/tklauser/numcpus v0.10.0/go.mod h1:BiTKazU708GQTYF4mB+cmlpT2Is1gLk7XVuEeem8LsQ= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= From 58ecd6ac20b813fdce9494971ae5e4cf00570780 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Mon, 15 Dec 2025 10:25:04 -0500 Subject: [PATCH 19/42] pkg/utils: fix SleeperTask race (#1737) --- pkg/utils/sleeper_task.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pkg/utils/sleeper_task.go b/pkg/utils/sleeper_task.go index 02dc970b35..2024881c4f 100644 --- a/pkg/utils/sleeper_task.go +++ b/pkg/utils/sleeper_task.go @@ -2,8 +2,6 @@ package utils import ( "context" - "fmt" - "time" "github.com/smartcontractkit/chainlink-common/pkg/services" ) @@ -70,11 +68,7 @@ func NewSleeperTaskCtx(w WorkerCtx) *SleeperTask { func (s *SleeperTask) Stop() error { return s.StopOnce("SleeperTask-"+s.worker.Name(), func() error { close(s.chStop) - select { - case <-s.chDone: - case <-time.After(15 * time.Second): - return fmt.Errorf("SleeperTask-%s took too long to stop", s.worker.Name()) - } + <-s.chDone return nil }) } From b1813502fe97b8008bce901db6f354ef0ca45897 Mon Sep 17 00:00:00 2001 From: Prashant Yadav <34992934+prashantkumar1982@users.noreply.github.com> Date: Thu, 18 Dec 2025 23:23:52 -0800 Subject: [PATCH 20/42] Better error handling for empty server responses (#1741) --- observability-lib/grafana/datasource.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/observability-lib/grafana/datasource.go b/observability-lib/grafana/datasource.go index 163ae2a800..b8756961aa 100644 --- a/observability-lib/grafana/datasource.go +++ b/observability-lib/grafana/datasource.go @@ -1,6 +1,10 @@ package grafana -import "github.com/smartcontractkit/chainlink-common/observability-lib/api" +import ( + "errors" + + "github.com/smartcontractkit/chainlink-common/observability-lib/api" +) type DataSource struct { ID uint @@ -26,6 +30,9 @@ func GetDataSourceFromGrafana(name string, grafanaURL string, grafanaToken strin if err != nil { return nil, err } + if datasource.Name == "" { + return nil, errors.New("unexpected empty response. please check connection or vpn settings") + } return &DataSource{ID: datasource.ID, Name: datasource.Name, UID: datasource.UID, Type: datasource.Type}, nil } From 746138b955f30ff281f534664016ec256c7c3624 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Fri, 19 Dec 2025 10:45:53 -0500 Subject: [PATCH 21/42] pkg/settings/cresettings: remove deprecated fields (#1738) --- pkg/settings/cresettings/defaults.json | 11 +------ pkg/settings/cresettings/defaults.toml | 9 ------ pkg/settings/cresettings/settings.go | 36 +++-------------------- pkg/settings/cresettings/settings_test.go | 1 - 4 files changed, 5 insertions(+), 52 deletions(-) diff --git a/pkg/settings/cresettings/defaults.json b/pkg/settings/cresettings/defaults.json index 5dac0d1f70..da605124b9 100644 --- a/pkg/settings/cresettings/defaults.json +++ b/pkg/settings/cresettings/defaults.json @@ -1,7 +1,6 @@ { "WorkflowLimit": "200", "WorkflowExecutionConcurrencyLimit": "200", - "WorkflowTriggerRateLimit": "200rps:200", "GatewayIncomingPayloadSizeLimit": "1mb", "VaultCiphertextSizeLimit": "2kb", "VaultIdentifierKeySizeLimit": "64b", @@ -10,16 +9,13 @@ "VaultPluginBatchSizeLimit": "20", "VaultRequestBatchSizeLimit": "10", "PerOrg": { - "WorkflowDeploymentRateLimit": "every1m0s:1", "ZeroBalancePruningTimeout": "24h0m0s" }, "PerOwner": { "WorkflowExecutionConcurrencyLimit": "5", - "WorkflowTriggerRateLimit": "5rps:5", "VaultSecretsLimit": "100" }, "PerWorkflow": { - "TriggerRateLimit": "every30s:3", "TriggerRegistrationsTimeout": "10s", "TriggerSubscriptionTimeout": "15s", "TriggerSubscriptionLimit": "10", @@ -36,9 +32,6 @@ "WASMCompressedBinarySizeLimit": "20mb", "WASMConfigSizeLimit": "1mb", "WASMSecretsSizeLimit": "1mb", - "WASMResponseSizeLimit": "100kb", - "ConsensusObservationSizeLimit": "100kb", - "ConsensusCallsLimit": "2000", "LogLineLimit": "1kb", "LogEventLimit": "1000", "ChainAllowed": { @@ -49,14 +42,12 @@ } }, "CRONTrigger": { - "FastestScheduleInterval": "30s", - "RateLimit": "every30s:1" + "FastestScheduleInterval": "30s" }, "HTTPTrigger": { "RateLimit": "every30s:3" }, "LogTrigger": { - "Limit": "5", "EventRateLimit": "every6s:10", "EventSizeLimit": "5kb", "FilterAddressLimit": "5", diff --git a/pkg/settings/cresettings/defaults.toml b/pkg/settings/cresettings/defaults.toml index c19170047e..195f3977d7 100644 --- a/pkg/settings/cresettings/defaults.toml +++ b/pkg/settings/cresettings/defaults.toml @@ -1,6 +1,5 @@ WorkflowLimit = '200' WorkflowExecutionConcurrencyLimit = '200' -WorkflowTriggerRateLimit = '200rps:200' GatewayIncomingPayloadSizeLimit = '1mb' VaultCiphertextSizeLimit = '2kb' VaultIdentifierKeySizeLimit = '64b' @@ -10,16 +9,13 @@ VaultPluginBatchSizeLimit = '20' VaultRequestBatchSizeLimit = '10' [PerOrg] -WorkflowDeploymentRateLimit = 'every1m0s:1' ZeroBalancePruningTimeout = '24h0m0s' [PerOwner] WorkflowExecutionConcurrencyLimit = '5' -WorkflowTriggerRateLimit = '5rps:5' VaultSecretsLimit = '100' [PerWorkflow] -TriggerRateLimit = 'every30s:3' TriggerRegistrationsTimeout = '10s' TriggerSubscriptionTimeout = '15s' TriggerSubscriptionLimit = '10' @@ -36,9 +32,6 @@ WASMBinarySizeLimit = '100mb' WASMCompressedBinarySizeLimit = '20mb' WASMConfigSizeLimit = '1mb' WASMSecretsSizeLimit = '1mb' -WASMResponseSizeLimit = '100kb' -ConsensusObservationSizeLimit = '100kb' -ConsensusCallsLimit = '2000' LogLineLimit = '1kb' LogEventLimit = '1000' @@ -51,13 +44,11 @@ Default = 'false' [PerWorkflow.CRONTrigger] FastestScheduleInterval = '30s' -RateLimit = 'every30s:1' [PerWorkflow.HTTPTrigger] RateLimit = 'every30s:3' [PerWorkflow.LogTrigger] -Limit = '5' EventRateLimit = 'every6s:10' EventSizeLimit = '5kb' FilterAddressLimit = '5' diff --git a/pkg/settings/cresettings/settings.go b/pkg/settings/cresettings/settings.go index bab448cb05..6a7c4978c9 100644 --- a/pkg/settings/cresettings/settings.go +++ b/pkg/settings/cresettings/settings.go @@ -54,7 +54,6 @@ var Config Schema var Default = Schema{ WorkflowLimit: Int(200), WorkflowExecutionConcurrencyLimit: Int(200), - WorkflowTriggerRateLimit: Rate(200, 200), GatewayIncomingPayloadSizeLimit: Size(1 * config.MByte), // DANGER(cedric): Be extremely careful changing these vault limits as they act as a default value @@ -69,12 +68,10 @@ var Default = Schema{ VaultRequestBatchSizeLimit: Int(10), PerOrg: Orgs{ - WorkflowDeploymentRateLimit: Rate(rate.Every(time.Minute), 1), - ZeroBalancePruningTimeout: Duration(24 * time.Hour), + ZeroBalancePruningTimeout: Duration(24 * time.Hour), }, PerOwner: Owners{ WorkflowExecutionConcurrencyLimit: Int(5), - WorkflowTriggerRateLimit: Rate(5, 5), // DANGER(cedric): Be extremely careful changing this vault limit as it acts as a default value // used by the Vault OCR plugin -- changing this value could cause issues with the plugin during an image @@ -83,7 +80,6 @@ var Default = Schema{ VaultSecretsLimit: Int(100), }, PerWorkflow: Workflows{ - TriggerRateLimit: Rate(rate.Every(30*time.Second), 3), TriggerRegistrationsTimeout: Duration(10 * time.Second), TriggerEventQueueLimit: Int(1_000), TriggerEventQueueTimeout: Duration(10 * time.Minute), @@ -100,9 +96,6 @@ var Default = Schema{ WASMCompressedBinarySizeLimit: Size(20 * config.MByte), WASMConfigSizeLimit: Size(config.MByte), WASMSecretsSizeLimit: Size(config.MByte), - WASMResponseSizeLimit: Size(100 * config.KByte), - ConsensusObservationSizeLimit: Size(100 * config.KByte), - ConsensusCallsLimit: Int(2000), LogLineLimit: Size(config.KByte), LogEventLimit: Int(1_000), ChainAllowed: PerChainSelector(Bool(false), map[string]bool{ @@ -114,13 +107,11 @@ var Default = Schema{ CRONTrigger: cronTrigger{ FastestScheduleInterval: Duration(30 * time.Second), - RateLimit: Rate(rate.Every(30*time.Second), 1), }, HTTPTrigger: httpTrigger{ RateLimit: Rate(rate.Every(30*time.Second), 3), }, LogTrigger: logTrigger{ - Limit: Int(5), EventRateLimit: Rate(rate.Every(time.Minute/10), 10), FilterAddressLimit: Int(5), FilterTopicsPerSlotLimit: Int(10), @@ -156,9 +147,7 @@ var Default = Schema{ type Schema struct { WorkflowLimit Setting[int] `unit:"{workflow}"` WorkflowExecutionConcurrencyLimit Setting[int] `unit:"{workflow}"` - // Deprecated - WorkflowTriggerRateLimit Setting[config.Rate] - GatewayIncomingPayloadSizeLimit Setting[config.Size] + GatewayIncomingPayloadSizeLimit Setting[config.Size] VaultCiphertextSizeLimit Setting[config.Size] VaultIdentifierKeySizeLimit Setting[config.Size] @@ -172,21 +161,15 @@ type Schema struct { PerWorkflow Workflows `scope:"workflow"` } type Orgs struct { - // Deprecated - WorkflowDeploymentRateLimit Setting[config.Rate] - ZeroBalancePruningTimeout Setting[time.Duration] + ZeroBalancePruningTimeout Setting[time.Duration] } type Owners struct { WorkflowExecutionConcurrencyLimit Setting[int] `unit:"{workflow}"` - // Deprecated - WorkflowTriggerRateLimit Setting[config.Rate] - VaultSecretsLimit Setting[int] `unit:"{secret}"` + VaultSecretsLimit Setting[int] `unit:"{secret}"` } type Workflows struct { - // Deprecated - TriggerRateLimit Setting[config.Rate] TriggerRegistrationsTimeout Setting[time.Duration] TriggerSubscriptionTimeout Setting[time.Duration] TriggerSubscriptionLimit Setting[int] `unit:"{subscription}"` @@ -207,13 +190,6 @@ type Workflows struct { WASMCompressedBinarySizeLimit Setting[config.Size] WASMConfigSizeLimit Setting[config.Size] WASMSecretsSizeLimit Setting[config.Size] - // Deprecated: use ExecutionResponseLimit - WASMResponseSizeLimit Setting[config.Size] - - // Deprecated: use Consensus.ObservationSizeLimit - ConsensusObservationSizeLimit Setting[config.Size] - // Deprecated: use Consensus.CallLimit - ConsensusCallsLimit Setting[int] `unit:"{call}"` LogLineLimit Setting[config.Size] LogEventLimit Setting[int] `unit:"{log}"` @@ -232,15 +208,11 @@ type Workflows struct { type cronTrigger struct { FastestScheduleInterval Setting[time.Duration] - // Deprecated: to be removed - RateLimit Setting[config.Rate] } type httpTrigger struct { RateLimit Setting[config.Rate] } type logTrigger struct { - // Deprecated - Limit Setting[int] `unit:"{trigger}"` EventRateLimit Setting[config.Rate] EventSizeLimit Setting[config.Size] FilterAddressLimit Setting[int] `unit:"{address}"` diff --git a/pkg/settings/cresettings/settings_test.go b/pkg/settings/cresettings/settings_test.go index 74fa55a89d..284a9cfb33 100644 --- a/pkg/settings/cresettings/settings_test.go +++ b/pkg/settings/cresettings/settings_test.go @@ -116,7 +116,6 @@ func TestSchema_Unmarshal(t *testing.T) { assert.Equal(t, "true", cfg.PerWorkflow.ChainAllowed.Values["1"]) assert.NotNil(t, cfg.PerWorkflow.ChainAllowed.Default.Parse) assert.NotNil(t, cfg.PerWorkflow.ChainAllowed.KeyFromCtx) - assert.Equal(t, config.Rate{Limit: rate.Every(10 * time.Second), Burst: 5}, cfg.PerWorkflow.CRONTrigger.RateLimit.DefaultValue) assert.Equal(t, config.Rate{Limit: rate.Every(30 * time.Second), Burst: 3}, cfg.PerWorkflow.HTTPTrigger.RateLimit.DefaultValue) assert.Equal(t, config.Rate{Limit: rate.Every(13 * time.Second), Burst: 6}, cfg.PerWorkflow.LogTrigger.EventRateLimit.DefaultValue) assert.Equal(t, 5, cfg.PerWorkflow.HTTPAction.CallLimit.DefaultValue) From 04f56c23058a794511af5c343af364b9d326ca22 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Mon, 29 Dec 2025 13:11:28 -0500 Subject: [PATCH 22/42] pkg/capabilities: fix Request/RegistrationMetadata.ContextWithCRE to preserve org id (#1744) --- pkg/capabilities/capabilities.go | 18 +++++++++-------- pkg/capabilities/capabilities_test.go | 29 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/pkg/capabilities/capabilities.go b/pkg/capabilities/capabilities.go index ce5074678d..7bc081d437 100644 --- a/pkg/capabilities/capabilities.go +++ b/pkg/capabilities/capabilities.go @@ -117,10 +117,11 @@ type RequestMetadata struct { } func (m *RequestMetadata) ContextWithCRE(ctx context.Context) context.Context { - return contexts.WithCRE(ctx, contexts.CRE{ - Owner: m.WorkflowOwner, - Workflow: m.WorkflowID, - }) + val := contexts.CREValue(ctx) + // preserve org, if set + val.Owner = m.WorkflowOwner + val.Workflow = m.WorkflowID + return contexts.WithCRE(ctx, val) } type RegistrationMetadata struct { @@ -131,10 +132,11 @@ type RegistrationMetadata struct { } func (m *RegistrationMetadata) ContextWithCRE(ctx context.Context) context.Context { - return contexts.WithCRE(ctx, contexts.CRE{ - Owner: m.WorkflowOwner, - Workflow: m.WorkflowID, - }) + val := contexts.CREValue(ctx) + // preserve org, if set + val.Owner = m.WorkflowOwner + val.Workflow = m.WorkflowID + return contexts.WithCRE(ctx, val) } // CapabilityRequest is a struct for the Execute request of a capability. diff --git a/pkg/capabilities/capabilities_test.go b/pkg/capabilities/capabilities_test.go index 5060b0f172..06e3bac43b 100644 --- a/pkg/capabilities/capabilities_test.go +++ b/pkg/capabilities/capabilities_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/contexts" "github.com/smartcontractkit/chainlink-protos/cre/go/values" ) @@ -331,3 +332,31 @@ func TestChainSelectorLabel(t *testing.T) { } func ptr[T any](v T) *T { return &v } + +func TestRequestMetadata_ContextWithCRE(t *testing.T) { + ctx := t.Context() + require.Equal(t, "", contexts.CREValue(ctx).Org) + + // set it + ctx = contexts.WithCRE(ctx, contexts.CRE{Org: "org-id"}) + require.Equal(t, "org-id", contexts.CREValue(ctx).Org) + + // preserve it + md := RequestMetadata{WorkflowOwner: "owner-id", WorkflowID: "workflow-id"} + ctx = md.ContextWithCRE(ctx) + require.Equal(t, "org-id", contexts.CREValue(ctx).Org) +} + +func TestRegistrationMetadata_ContextWithCRE(t *testing.T) { + ctx := t.Context() + require.Equal(t, "", contexts.CREValue(ctx).Org) + + // set it + ctx = contexts.WithCRE(ctx, contexts.CRE{Org: "org-id"}) + require.Equal(t, "org-id", contexts.CREValue(ctx).Org) + + // preserve it + md := RegistrationMetadata{WorkflowOwner: "owner-id", WorkflowID: "workflow-id"} + ctx = md.ContextWithCRE(ctx) + require.Equal(t, "org-id", contexts.CREValue(ctx).Org) +} From e46cb3b354791ebb302ed164b11cc3b1e1ce540f Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Tue, 30 Dec 2025 13:04:49 -0500 Subject: [PATCH 23/42] CRE-1613: Fix internal errors passed from WASM host to guest and add a standard test for it (#1745) * CRE-1613: Fix internal errors passed from WASM host to guest and add a standard test for it * Fix a typo --- .../host/internal/rawsdk/helpers_wasip1.go | 3 +- pkg/workflows/wasm/host/module.go | 4 ++- pkg/workflows/wasm/host/standard_test.go | 31 +++++++++++++++++++ .../main_wasip1.go | 15 +++++++++ 4 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 pkg/workflows/wasm/host/standard_tests/host_wasm_write_errors_are_respected/main_wasip1.go diff --git a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go index 3bbfa0a2ac..409acbc880 100644 --- a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go +++ b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go @@ -197,7 +197,8 @@ func await[I, O proto.Message](input I, output O, fn awaitFn) { bytes := fn(mptr, mlen, responsePtr, responseLen) if bytes < 0 { - SendError(errors.New("awaitCapabilities returned an error")) + response = response[:-bytes] + SendError(fmt.Errorf("awaitCapabilities returned an error %s", string(response))) } if proto.Unmarshal(response[:bytes], output) != nil { diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 73aad9f098..493d209031 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -1066,7 +1066,9 @@ func truncateWasmWrite(caller *wasmtime.Caller, src []byte, ptr int32, size int3 src = src[:size] } - return write(memory, src, ptr, size) + // truncateWasmWrite is only called for returning error strings + // Therefore, we need to return the negated bytes written to indicate the failure to the guest. + return -write(memory, src, ptr, size) } // write copies the given src byte slice into the memory at the given pointer and max size. diff --git a/pkg/workflows/wasm/host/standard_test.go b/pkg/workflows/wasm/host/standard_test.go index 4273c35647..eb17367d92 100644 --- a/pkg/workflows/wasm/host/standard_test.go +++ b/pkg/workflows/wasm/host/standard_test.go @@ -118,6 +118,37 @@ func TestStandardCapabilityCallsAreAsync(t *testing.T) { assert.Equal(t, "truefalse", result) } +func TestStandardHostWasmWriteErrorsAreRespected(t *testing.T) { + t.Parallel() + mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { + return time.Now() + }).Maybe() + mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") + mockExecutionHelper.EXPECT().CallCapability(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, request *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error) { + // In this test the response from the capability is successful, + // but the WASM didn't provide a large enough buffer to fit it + // 500 MB will suffice for the overflow on writes. + + tooLargeResponse := make([]byte, 500000000) + + // Since the bytes in the payload shouldn't be read, we don't need a valid proto + payload := &anypb.Any{ + TypeUrl: "fake", + Value: tooLargeResponse, + } + + return &sdk.CapabilityResponse{Response: &sdk.CapabilityResponse_Payload{Payload: payload}}, nil + }) + + m := makeTestModule(t) + request := triggerExecuteRequest(t, 0, &basictrigger.Outputs{CoolOutput: anyTestTriggerValue}) + errStr := executeWithError(t, m, request, mockExecutionHelper) + + // Use Contains instead of Equal for flexibility, as languages have different conventions for errors. + require.Contains(t, errStr, ResponseBufferTooSmall) +} + func TestStandardModeSwitch(t *testing.T) { t.Parallel() t.Run("successful mode switch", func(t *testing.T) { diff --git a/pkg/workflows/wasm/host/standard_tests/host_wasm_write_errors_are_respected/main_wasip1.go b/pkg/workflows/wasm/host/standard_tests/host_wasm_write_errors_are_respected/main_wasip1.go new file mode 100644 index 0000000000..3a7d3025e3 --- /dev/null +++ b/pkg/workflows/wasm/host/standard_tests/host_wasm_write_errors_are_respected/main_wasip1.go @@ -0,0 +1,15 @@ +package main + +import ( + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basicaction" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func main() { + input := &basicaction.Inputs{InputThing: true} + rId := rawsdk.DoRequestAsync("basic-test-action@1.0.0", "PerformAction", sdk.Mode_MODE_DON, input) + + rawsdk.Await(rId, &basicaction.Outputs{}) + rawsdk.SendResponse("should not get here as Await sends error on errors...") +} From 9d87379812c7a2c579206761aee9d2a8f7fbd73e Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Wed, 31 Dec 2025 08:56:14 -0500 Subject: [PATCH 24/42] pkg/contexts: expand CRE.Normalized (#1746) --- pkg/contexts/contexts.go | 8 ++++++++ pkg/settings/keys_test.go | 4 ++-- pkg/settings/limits/resource_test.go | 6 +++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pkg/contexts/contexts.go b/pkg/contexts/contexts.go index 1e6928f825..b414a92e9b 100644 --- a/pkg/contexts/contexts.go +++ b/pkg/contexts/contexts.go @@ -38,8 +38,16 @@ type CRE struct { // Normalized returns a possibly modified CRE with normalized values. func (c CRE) Normalized() CRE { + c.Org = strings.TrimPrefix(c.Org, "org_") + c.Org = strings.TrimPrefix(c.Org, "0x") + c.Org = strings.ToLower(c.Org) + + c.Owner = strings.TrimPrefix(c.Owner, "owner_") c.Owner = strings.TrimPrefix(c.Owner, "0x") c.Owner = strings.ToLower(c.Owner) + + c.Workflow = strings.TrimPrefix(c.Workflow, "0x") + c.Workflow = strings.ToLower(c.Workflow) return c } diff --git a/pkg/settings/keys_test.go b/pkg/settings/keys_test.go index d9a5775e97..430293689e 100644 --- a/pkg/settings/keys_test.go +++ b/pkg/settings/keys_test.go @@ -10,9 +10,9 @@ import ( func TestTenant_rawKeys(t *testing.T) { const ( - org = "AcmeCorporation" + org = "acmecorporation" owner = "1234abcd" - workflow = "ABCDEFGH" + workflow = "abcdefgh" key = "foo" ) for _, test := range []struct { diff --git a/pkg/settings/limits/resource_test.go b/pkg/settings/limits/resource_test.go index 4c943062f9..b587c6a984 100644 --- a/pkg/settings/limits/resource_test.go +++ b/pkg/settings/limits/resource_test.go @@ -72,7 +72,7 @@ func ExampleResourceLimiter_Use() { func ExampleMultiResourcePoolLimiter() { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - ctx = contexts.WithCRE(ctx, contexts.CRE{Org: "orgID", Owner: "owner-id", Workflow: "workflowID"}) + ctx = contexts.WithCRE(ctx, contexts.CRE{Org: "org-id", Owner: "owner-id", Workflow: "workflow-id"}) global := GlobalResourcePoolLimiter[int](100) freeGlobal, err := global.Wait(ctx, 95) if err != nil { @@ -118,9 +118,9 @@ func ExampleMultiResourcePoolLimiter() { free() // Output: // resource limited: cannot use 10, already using 95/100 - // resource limited for org[orgID]: cannot use 10, already using 45/50 + // resource limited for org[org-id]: cannot use 10, already using 45/50 // resource limited for owner[owner-id]: cannot use 10, already using 15/20 - // resource limited for workflow[workflowID]: cannot use 10, already using 5/10 + // resource limited for workflow[workflow-id]: cannot use 10, already using 5/10 // } From bf51f53e7c96ad80d644c2cf7854380919ae00d1 Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Fri, 2 Jan 2026 13:44:18 -0800 Subject: [PATCH 25/42] [CRE] Adjust default limits (#1748) 1. Bump capability concurrency up to 30. Relaying on concurrency limit can cause problems with OCR when different nodes throttle different requests. Let's rely on per-capability limits instead. 2. Reduce consensus call limit to something sane. 3. Reduce trigger event queus size. Usually queueing a lot of events makes them cross the expiration threshold anyway. --- pkg/settings/cresettings/defaults.json | 6 +++--- pkg/settings/cresettings/defaults.toml | 6 +++--- pkg/settings/cresettings/settings.go | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/settings/cresettings/defaults.json b/pkg/settings/cresettings/defaults.json index da605124b9..62bbfc34c8 100644 --- a/pkg/settings/cresettings/defaults.json +++ b/pkg/settings/cresettings/defaults.json @@ -19,9 +19,9 @@ "TriggerRegistrationsTimeout": "10s", "TriggerSubscriptionTimeout": "15s", "TriggerSubscriptionLimit": "10", - "TriggerEventQueueLimit": "1000", + "TriggerEventQueueLimit": "50", "TriggerEventQueueTimeout": "10m0s", - "CapabilityConcurrencyLimit": "3", + "CapabilityConcurrencyLimit": "30", "CapabilityCallTimeout": "3m0s", "SecretsConcurrencyLimit": "5", "ExecutionConcurrencyLimit": "5", @@ -67,7 +67,7 @@ }, "Consensus": { "ObservationSizeLimit": "100kb", - "CallLimit": "2000" + "CallLimit": "20" }, "HTTPAction": { "CallLimit": "5", diff --git a/pkg/settings/cresettings/defaults.toml b/pkg/settings/cresettings/defaults.toml index 195f3977d7..013475840d 100644 --- a/pkg/settings/cresettings/defaults.toml +++ b/pkg/settings/cresettings/defaults.toml @@ -19,9 +19,9 @@ VaultSecretsLimit = '100' TriggerRegistrationsTimeout = '10s' TriggerSubscriptionTimeout = '15s' TriggerSubscriptionLimit = '10' -TriggerEventQueueLimit = '1000' +TriggerEventQueueLimit = '50' TriggerEventQueueTimeout = '10m0s' -CapabilityConcurrencyLimit = '3' +CapabilityConcurrencyLimit = '30' CapabilityCallTimeout = '3m0s' SecretsConcurrencyLimit = '5' ExecutionConcurrencyLimit = '5' @@ -68,7 +68,7 @@ PayloadSizeLimit = '5kb' [PerWorkflow.Consensus] ObservationSizeLimit = '100kb' -CallLimit = '2000' +CallLimit = '20' [PerWorkflow.HTTPAction] CallLimit = '5' diff --git a/pkg/settings/cresettings/settings.go b/pkg/settings/cresettings/settings.go index 6a7c4978c9..2fd0964197 100644 --- a/pkg/settings/cresettings/settings.go +++ b/pkg/settings/cresettings/settings.go @@ -81,11 +81,11 @@ var Default = Schema{ }, PerWorkflow: Workflows{ TriggerRegistrationsTimeout: Duration(10 * time.Second), - TriggerEventQueueLimit: Int(1_000), + TriggerEventQueueLimit: Int(50), TriggerEventQueueTimeout: Duration(10 * time.Minute), TriggerSubscriptionTimeout: Duration(15 * time.Second), TriggerSubscriptionLimit: Int(10), - CapabilityConcurrencyLimit: Int(3), + CapabilityConcurrencyLimit: Int(30), // we should rely on per-capability execution limits instead of concurrency limit CapabilityCallTimeout: Duration(3 * time.Minute), SecretsConcurrencyLimit: Int(5), ExecutionConcurrencyLimit: Int(5), @@ -132,7 +132,7 @@ var Default = Schema{ }, Consensus: consensus{ ObservationSizeLimit: Size(100 * config.KByte), - CallLimit: Int(2000), + CallLimit: Int(20), }, HTTPAction: httpAction{ CallLimit: Int(5), From 00708ed44b2a01fa0e0334d5643406e4bf632ad9 Mon Sep 17 00:00:00 2001 From: mchain0 Date: Mon, 5 Jan 2026 09:45:00 +0100 Subject: [PATCH 26/42] CRE-1601: Ring OCR plugin for shard orchestration (#1742) * cre-1601: shard orchestrator plugin for delegate * cre-1601: consistent hashing and plugin test * cre-1601: tidy * cre-1601: pb generate * cre-1601: review improvement * cre-1601: review improvement * cre-1601: review improvement * cre-1601: remove previous outcome, use outctx.SeqNr instead * cre-1601: transition state machine * cre-1601: removed TransmissionScheduleOverride * cre-1601: comments * cre-1601: rename plugin to ring * cre-1601: renames and cleanup * cre-1601: delegate integrations adjustments * cre-1601: proto comments; import fix * cre-1601: snake_case consistent for proto fields * cre-1601: remove unused field * cre-1601: more proto comments * cre-1601: proto cleanup, orphans removed * cre-1601: log overrides, log plugin config * cre-1601: better comment * cre-1601: deterministic time; f check for round; improved time median; improved workflows dedup; improved comments; * cre-1601: log improvement * cre-1601: shard count health refactor * cre-1601: hash ring pure function refactor for both storage and observation * cre-1601: Transmitter notifies Arbiter * cre-1601: store in two states, steady and transition; enque for allocation trigger post round; * cre-1601: comments improved * cre-1601: more comments improvements * cre-1601: test extension to validate workflows to shards eassignments * cre-1601: test improvement - distribution check by percents * cre-1601: using maps improvement * cre-1601: remove number of shards limits (ref. review) * cre-1601: initial state fix; intial transition state until OCR round; tests adjustments to simulate state confirmation; * cre-1601: more tests; boosting test coverage * cre-1601: better comments * cre-1601: better comments * cre-1601: improved test for plugin outcome * cre-1601: comments improved * cre-1601: comment improvement * cre-1601: more tests; more coverage * cre-1601: refactor of state; state verification tests; related changes; * cre-1601: bool wrapper ShardStatus to extend with weights later * cre-1601: ArbiterScaler.Status() wantShards being part of the observations in the consensus * cre-1601: improved comments; improved naming; simplified logic; observations validation; ring performance improvements * cre-1601: on 1st round make prior outcome wantShards equal to the current data from the arbiter * cre-1601: removed unnecessary check (code review suggestions) * Fail Observation() on Arbiter error --------- Co-authored-by: Bolek Kulbabinski <1416262+bolekk@users.noreply.github.com> --- go.mod | 3 +- go.sum | 2 + pkg/types/plugin.go | 1 + pkg/workflows/ring/factory.go | 79 +++ pkg/workflows/ring/factory_test.go | 70 +++ pkg/workflows/ring/pb/arbiter.pb.go | 301 +++++++++ pkg/workflows/ring/pb/arbiter.proto | 35 ++ pkg/workflows/ring/pb/arbiter_grpc.pb.go | 262 ++++++++ pkg/workflows/ring/pb/consensus.pb.go | 442 +++++++++++++ pkg/workflows/ring/pb/consensus.proto | 39 ++ pkg/workflows/ring/pb/generate.go | 6 + .../ring/pb/shard_orchestrator.pb.go | 421 +++++++++++++ .../ring/pb/shard_orchestrator.proto | 41 ++ .../ring/pb/shard_orchestrator_grpc.pb.go | 160 +++++ pkg/workflows/ring/pb/shared.pb.go | 123 ++++ pkg/workflows/ring/pb/shared.proto | 9 + pkg/workflows/ring/plugin.go | 276 ++++++++ pkg/workflows/ring/plugin_test.go | 593 ++++++++++++++++++ pkg/workflows/ring/state.go | 76 +++ pkg/workflows/ring/state_test.go | 265 ++++++++ pkg/workflows/ring/store.go | 219 +++++++ pkg/workflows/ring/store_test.go | 313 +++++++++ pkg/workflows/ring/transmitter.go | 76 +++ pkg/workflows/ring/transmitter_test.go | 173 +++++ pkg/workflows/ring/utils.go | 66 ++ pkg/workflows/ring/utils_test.go | 12 + 26 files changed, 4062 insertions(+), 1 deletion(-) create mode 100644 pkg/workflows/ring/factory.go create mode 100644 pkg/workflows/ring/factory_test.go create mode 100644 pkg/workflows/ring/pb/arbiter.pb.go create mode 100644 pkg/workflows/ring/pb/arbiter.proto create mode 100644 pkg/workflows/ring/pb/arbiter_grpc.pb.go create mode 100644 pkg/workflows/ring/pb/consensus.pb.go create mode 100644 pkg/workflows/ring/pb/consensus.proto create mode 100644 pkg/workflows/ring/pb/generate.go create mode 100644 pkg/workflows/ring/pb/shard_orchestrator.pb.go create mode 100644 pkg/workflows/ring/pb/shard_orchestrator.proto create mode 100644 pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go create mode 100644 pkg/workflows/ring/pb/shared.pb.go create mode 100644 pkg/workflows/ring/pb/shared.proto create mode 100644 pkg/workflows/ring/plugin.go create mode 100644 pkg/workflows/ring/plugin_test.go create mode 100644 pkg/workflows/ring/state.go create mode 100644 pkg/workflows/ring/state_test.go create mode 100644 pkg/workflows/ring/store.go create mode 100644 pkg/workflows/ring/store_test.go create mode 100644 pkg/workflows/ring/transmitter.go create mode 100644 pkg/workflows/ring/transmitter_test.go create mode 100644 pkg/workflows/ring/utils.go create mode 100644 pkg/workflows/ring/utils_test.go diff --git a/go.mod b/go.mod index 2945de30d6..1f89bb1a8d 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,9 @@ require ( github.com/XSAM/otelsql v0.37.0 github.com/andybalholm/brotli v1.1.1 github.com/atombender/go-jsonschema v0.16.1-0.20240916205339-a74cd4e2851c + github.com/buraksezer/consistent v0.10.0 github.com/bytecodealliance/wasmtime-go/v28 v28.0.0 + github.com/cespare/xxhash/v2 v2.3.0 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc github.com/dominikbraun/graph v0.23.0 github.com/fxamacker/cbor/v2 v2.7.0 @@ -88,7 +90,6 @@ require ( github.com/buger/goterm v1.0.4 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v5 v5.0.2 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudevents/sdk-go/binding/format/protobuf/v2 v2.16.1 // indirect github.com/cloudevents/sdk-go/v2 v2.16.1 // indirect github.com/fatih/color v1.18.0 // indirect diff --git a/go.sum b/go.sum index 182fa64d83..8a39debc83 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/buger/goterm v1.0.4 h1:Z9YvGmOih81P0FbVtEYTFF6YsSgxSUKEhf/f9bTMXbY= github.com/buger/goterm v1.0.4/go.mod h1:HiFWV3xnkolgrBV3mY8m0X0Pumt4zg4QhbdOzQtB8tE= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/buraksezer/consistent v0.10.0 h1:hqBgz1PvNLC5rkWcEBVAL9dFMBWz6I0VgUCW25rrZlU= +github.com/buraksezer/consistent v0.10.0/go.mod h1:6BrVajWq7wbKZlTOUPs/XVfR8c0maujuPowduSpZqmw= github.com/bytecodealliance/wasmtime-go/v28 v28.0.0 h1:aBU8cexP2rPZ0Qz488kvn2NXvWZHL2aG1/+n7Iv+xGc= github.com/bytecodealliance/wasmtime-go/v28 v28.0.0/go.mod h1:4OCU0xAW9ycwtX4nMF4zxwgJBJ5/0eMfJiHB0wAmkV4= github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8= diff --git a/pkg/types/plugin.go b/pkg/types/plugin.go index 88ffeece73..aed29442ee 100644 --- a/pkg/types/plugin.go +++ b/pkg/types/plugin.go @@ -17,6 +17,7 @@ const ( OCR3Capability OCR2PluginType = "ocr3-capability" VaultPlugin OCR2PluginType = "vault-plugin" DonTimePlugin OCR2PluginType = "dontime" + RingPlugin OCR2PluginType = "ring" SecureMint OCR2PluginType = "securemint" CCIPCommit OCR2PluginType = "ccip-commit" diff --git a/pkg/workflows/ring/factory.go b/pkg/workflows/ring/factory.go new file mode 100644 index 0000000000..8b85d02c86 --- /dev/null +++ b/pkg/workflows/ring/factory.go @@ -0,0 +1,79 @@ +package ring + +import ( + "context" + "errors" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/types/core" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" +) + +const ( + defaultMaxPhaseOutputBytes = 1000000 // 1 MB + defaultMaxReportCount = 1 + defaultBatchSize = 100 +) + +var _ core.OCR3ReportingPluginFactory = &Factory{} + +type Factory struct { + store *Store + arbiterScaler pb.ArbiterScalerClient + config *ConsensusConfig + lggr logger.Logger + + services.StateMachine +} + +func NewFactory(s *Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logger, cfg *ConsensusConfig) (*Factory, error) { + if arbiterScaler == nil { + return nil, errors.New("arbiterScaler is required") + } + if cfg == nil { + cfg = &ConsensusConfig{ + BatchSize: defaultBatchSize, + } + } + return &Factory{ + store: s, + arbiterScaler: arbiterScaler, + config: cfg, + lggr: logger.Named(lggr, "RingPluginFactory"), + }, nil +} + +func (o *Factory) NewReportingPlugin(_ context.Context, config ocr3types.ReportingPluginConfig) (ocr3types.ReportingPlugin[[]byte], ocr3types.ReportingPluginInfo, error) { + plugin, err := NewPlugin(o.store, o.arbiterScaler, config, o.lggr, o.config) + pluginInfo := ocr3types.ReportingPluginInfo{ + Name: "RingPlugin", + Limits: ocr3types.ReportingPluginLimits{ + MaxQueryLength: defaultMaxPhaseOutputBytes, + MaxObservationLength: defaultMaxPhaseOutputBytes, + MaxOutcomeLength: defaultMaxPhaseOutputBytes, + MaxReportLength: defaultMaxPhaseOutputBytes, + MaxReportCount: defaultMaxReportCount, + }, + } + return plugin, pluginInfo, err +} + +func (o *Factory) Start(ctx context.Context) error { + return o.StartOnce("RingPlugin", func() error { + return nil + }) +} + +func (o *Factory) Close() error { + return o.StopOnce("RingPlugin", func() error { + return nil + }) +} + +func (o *Factory) Name() string { return o.lggr.Name() } + +func (o *Factory) HealthReport() map[string]error { + return map[string]error{o.Name(): o.Healthy()} +} diff --git a/pkg/workflows/ring/factory_test.go b/pkg/workflows/ring/factory_test.go new file mode 100644 index 0000000000..5a81c2a0fc --- /dev/null +++ b/pkg/workflows/ring/factory_test.go @@ -0,0 +1,70 @@ +package ring + +import ( + "context" + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/stretchr/testify/require" +) + +func TestFactory_NewFactory(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + arbiter := &mockArbiter{} + + t.Run("with_nil_config", func(t *testing.T) { + f, err := NewFactory(store, arbiter, lggr, nil) + require.NoError(t, err) + require.NotNil(t, f) + }) + + t.Run("with_custom_config", func(t *testing.T) { + cfg := &ConsensusConfig{BatchSize: 50} + f, err := NewFactory(store, arbiter, lggr, cfg) + require.NoError(t, err) + require.NotNil(t, f) + }) + + t.Run("nil_arbiter_returns_error", func(t *testing.T) { + _, err := NewFactory(store, nil, lggr, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "arbiterScaler is required") + }) +} + +func TestFactory_NewReportingPlugin(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + f, err := NewFactory(store, &mockArbiter{}, lggr, nil) + require.NoError(t, err) + + config := ocr3types.ReportingPluginConfig{N: 4, F: 1} + plugin, info, err := f.NewReportingPlugin(context.Background(), config) + require.NoError(t, err) + require.NotNil(t, plugin) + require.NotEmpty(t, info.Name) + require.Equal(t, "RingPlugin", info.Name) + require.Equal(t, defaultMaxReportCount, info.Limits.MaxReportCount) +} + +func TestFactory_Lifecycle(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + f, err := NewFactory(store, &mockArbiter{}, lggr, nil) + require.NoError(t, err) + + err = f.Start(context.Background()) + require.NoError(t, err) + + name := f.Name() + require.NotEmpty(t, name) + + report := f.HealthReport() + require.NotNil(t, report) + require.Contains(t, report, name) + + err = f.Close() + require.NoError(t, err) +} diff --git a/pkg/workflows/ring/pb/arbiter.pb.go b/pkg/workflows/ring/pb/arbiter.pb.go new file mode 100644 index 0000000000..bee0a949f3 --- /dev/null +++ b/pkg/workflows/ring/pb/arbiter.pb.go @@ -0,0 +1,301 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.8 +// protoc v5.29.3 +// source: arbiter.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ShardStatusRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Status map[uint32]*ShardStatus `protobuf:"bytes,1,rep,name=status,proto3" json:"status,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // shard_id -> status + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ShardStatusRequest) Reset() { + *x = ShardStatusRequest{} + mi := &file_arbiter_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ShardStatusRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ShardStatusRequest) ProtoMessage() {} + +func (x *ShardStatusRequest) ProtoReflect() protoreflect.Message { + mi := &file_arbiter_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ShardStatusRequest.ProtoReflect.Descriptor instead. +func (*ShardStatusRequest) Descriptor() ([]byte, []int) { + return file_arbiter_proto_rawDescGZIP(), []int{0} +} + +func (x *ShardStatusRequest) GetStatus() map[uint32]*ShardStatus { + if x != nil { + return x.Status + } + return nil +} + +type ArbiterResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + WantShards uint32 `protobuf:"varint,1,opt,name=want_shards,json=wantShards,proto3" json:"want_shards,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ArbiterResponse) Reset() { + *x = ArbiterResponse{} + mi := &file_arbiter_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ArbiterResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ArbiterResponse) ProtoMessage() {} + +func (x *ArbiterResponse) ProtoReflect() protoreflect.Message { + mi := &file_arbiter_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ArbiterResponse.ProtoReflect.Descriptor instead. +func (*ArbiterResponse) Descriptor() ([]byte, []int) { + return file_arbiter_proto_rawDescGZIP(), []int{1} +} + +func (x *ArbiterResponse) GetWantShards() uint32 { + if x != nil { + return x.WantShards + } + return 0 +} + +type ReplicaStatus struct { + state protoimpl.MessageState `protogen:"open.v1"` + WantShards uint32 `protobuf:"varint,1,opt,name=want_shards,json=wantShards,proto3" json:"want_shards,omitempty"` + Status map[uint32]*ShardStatus `protobuf:"bytes,2,rep,name=status,proto3" json:"status,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReplicaStatus) Reset() { + *x = ReplicaStatus{} + mi := &file_arbiter_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReplicaStatus) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReplicaStatus) ProtoMessage() {} + +func (x *ReplicaStatus) ProtoReflect() protoreflect.Message { + mi := &file_arbiter_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReplicaStatus.ProtoReflect.Descriptor instead. +func (*ReplicaStatus) Descriptor() ([]byte, []int) { + return file_arbiter_proto_rawDescGZIP(), []int{2} +} + +func (x *ReplicaStatus) GetWantShards() uint32 { + if x != nil { + return x.WantShards + } + return 0 +} + +func (x *ReplicaStatus) GetStatus() map[uint32]*ShardStatus { + if x != nil { + return x.Status + } + return nil +} + +type ConsensusWantShardsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + NShards uint32 `protobuf:"varint,1,opt,name=n_shards,json=nShards,proto3" json:"n_shards,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ConsensusWantShardsRequest) Reset() { + *x = ConsensusWantShardsRequest{} + mi := &file_arbiter_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ConsensusWantShardsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ConsensusWantShardsRequest) ProtoMessage() {} + +func (x *ConsensusWantShardsRequest) ProtoReflect() protoreflect.Message { + mi := &file_arbiter_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ConsensusWantShardsRequest.ProtoReflect.Descriptor instead. +func (*ConsensusWantShardsRequest) Descriptor() ([]byte, []int) { + return file_arbiter_proto_rawDescGZIP(), []int{3} +} + +func (x *ConsensusWantShardsRequest) GetNShards() uint32 { + if x != nil { + return x.NShards + } + return 0 +} + +var File_arbiter_proto protoreflect.FileDescriptor + +const file_arbiter_proto_rawDesc = "" + + "\n" + + "\rarbiter.proto\x12\x04ring\x1a\x1bgoogle/protobuf/empty.proto\x1a\fshared.proto\"\xa0\x01\n" + + "\x12ShardStatusRequest\x12<\n" + + "\x06status\x18\x01 \x03(\v2$.ring.ShardStatusRequest.StatusEntryR\x06status\x1aL\n" + + "\vStatusEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\rR\x03key\x12'\n" + + "\x05value\x18\x02 \x01(\v2\x11.ring.ShardStatusR\x05value:\x028\x01\"2\n" + + "\x0fArbiterResponse\x12\x1f\n" + + "\vwant_shards\x18\x01 \x01(\rR\n" + + "wantShards\"\xb7\x01\n" + + "\rReplicaStatus\x12\x1f\n" + + "\vwant_shards\x18\x01 \x01(\rR\n" + + "wantShards\x127\n" + + "\x06status\x18\x02 \x03(\v2\x1f.ring.ReplicaStatus.StatusEntryR\x06status\x1aL\n" + + "\vStatusEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\rR\x03key\x12'\n" + + "\x05value\x18\x02 \x01(\v2\x11.ring.ShardStatusR\x05value:\x028\x01\"7\n" + + "\x1aConsensusWantShardsRequest\x12\x19\n" + + "\bn_shards\x18\x01 \x01(\rR\anShards2P\n" + + "\aArbiter\x12E\n" + + "\x12GetDesiredReplicas\x12\x18.ring.ShardStatusRequest\x1a\x15.ring.ArbiterResponse2\x97\x01\n" + + "\rArbiterScaler\x125\n" + + "\x06Status\x12\x16.google.protobuf.Empty\x1a\x13.ring.ReplicaStatus\x12O\n" + + "\x13ConsensusWantShards\x12 .ring.ConsensusWantShardsRequest\x1a\x16.google.protobuf.EmptyBDZBgithub.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pbb\x06proto3" + +var ( + file_arbiter_proto_rawDescOnce sync.Once + file_arbiter_proto_rawDescData []byte +) + +func file_arbiter_proto_rawDescGZIP() []byte { + file_arbiter_proto_rawDescOnce.Do(func() { + file_arbiter_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_arbiter_proto_rawDesc), len(file_arbiter_proto_rawDesc))) + }) + return file_arbiter_proto_rawDescData +} + +var file_arbiter_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_arbiter_proto_goTypes = []any{ + (*ShardStatusRequest)(nil), // 0: ring.ShardStatusRequest + (*ArbiterResponse)(nil), // 1: ring.ArbiterResponse + (*ReplicaStatus)(nil), // 2: ring.ReplicaStatus + (*ConsensusWantShardsRequest)(nil), // 3: ring.ConsensusWantShardsRequest + nil, // 4: ring.ShardStatusRequest.StatusEntry + nil, // 5: ring.ReplicaStatus.StatusEntry + (*ShardStatus)(nil), // 6: ring.ShardStatus + (*emptypb.Empty)(nil), // 7: google.protobuf.Empty +} +var file_arbiter_proto_depIdxs = []int32{ + 4, // 0: ring.ShardStatusRequest.status:type_name -> ring.ShardStatusRequest.StatusEntry + 5, // 1: ring.ReplicaStatus.status:type_name -> ring.ReplicaStatus.StatusEntry + 6, // 2: ring.ShardStatusRequest.StatusEntry.value:type_name -> ring.ShardStatus + 6, // 3: ring.ReplicaStatus.StatusEntry.value:type_name -> ring.ShardStatus + 0, // 4: ring.Arbiter.GetDesiredReplicas:input_type -> ring.ShardStatusRequest + 7, // 5: ring.ArbiterScaler.Status:input_type -> google.protobuf.Empty + 3, // 6: ring.ArbiterScaler.ConsensusWantShards:input_type -> ring.ConsensusWantShardsRequest + 1, // 7: ring.Arbiter.GetDesiredReplicas:output_type -> ring.ArbiterResponse + 2, // 8: ring.ArbiterScaler.Status:output_type -> ring.ReplicaStatus + 7, // 9: ring.ArbiterScaler.ConsensusWantShards:output_type -> google.protobuf.Empty + 7, // [7:10] is the sub-list for method output_type + 4, // [4:7] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name +} + +func init() { file_arbiter_proto_init() } +func file_arbiter_proto_init() { + if File_arbiter_proto != nil { + return + } + file_shared_proto_init() + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_arbiter_proto_rawDesc), len(file_arbiter_proto_rawDesc)), + NumEnums: 0, + NumMessages: 6, + NumExtensions: 0, + NumServices: 2, + }, + GoTypes: file_arbiter_proto_goTypes, + DependencyIndexes: file_arbiter_proto_depIdxs, + MessageInfos: file_arbiter_proto_msgTypes, + }.Build() + File_arbiter_proto = out.File + file_arbiter_proto_goTypes = nil + file_arbiter_proto_depIdxs = nil +} diff --git a/pkg/workflows/ring/pb/arbiter.proto b/pkg/workflows/ring/pb/arbiter.proto new file mode 100644 index 0000000000..4b86ab256a --- /dev/null +++ b/pkg/workflows/ring/pb/arbiter.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; + +package ring; + +import "google/protobuf/empty.proto"; +import "shared.proto"; + +option go_package = "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"; + +message ShardStatusRequest { + map status = 1; // shard_id -> status +} + +message ArbiterResponse { + uint32 want_shards = 1; +} + +service Arbiter { + rpc GetDesiredReplicas(ShardStatusRequest) returns (ArbiterResponse); // called periodically by Scaler +} + +message ReplicaStatus { + uint32 want_shards = 1; + map status = 2; +} + +message ConsensusWantShardsRequest { + uint32 n_shards = 1; +} + +service ArbiterScaler { + rpc Status(google.protobuf.Empty) returns (ReplicaStatus); // called to collect current status of shards by Ring plugin + rpc ConsensusWantShards(ConsensusWantShardsRequest) returns (google.protobuf.Empty); // called at the end of the round with consensus shard count by Ring plugin +} + diff --git a/pkg/workflows/ring/pb/arbiter_grpc.pb.go b/pkg/workflows/ring/pb/arbiter_grpc.pb.go new file mode 100644 index 0000000000..34a66a8a7e --- /dev/null +++ b/pkg/workflows/ring/pb/arbiter_grpc.pb.go @@ -0,0 +1,262 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v5.29.3 +// source: arbiter.proto + +package pb + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + emptypb "google.golang.org/protobuf/types/known/emptypb" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + Arbiter_GetDesiredReplicas_FullMethodName = "/ring.Arbiter/GetDesiredReplicas" +) + +// ArbiterClient is the client API for Arbiter service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ArbiterClient interface { + GetDesiredReplicas(ctx context.Context, in *ShardStatusRequest, opts ...grpc.CallOption) (*ArbiterResponse, error) +} + +type arbiterClient struct { + cc grpc.ClientConnInterface +} + +func NewArbiterClient(cc grpc.ClientConnInterface) ArbiterClient { + return &arbiterClient{cc} +} + +func (c *arbiterClient) GetDesiredReplicas(ctx context.Context, in *ShardStatusRequest, opts ...grpc.CallOption) (*ArbiterResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ArbiterResponse) + err := c.cc.Invoke(ctx, Arbiter_GetDesiredReplicas_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ArbiterServer is the server API for Arbiter service. +// All implementations must embed UnimplementedArbiterServer +// for forward compatibility. +type ArbiterServer interface { + GetDesiredReplicas(context.Context, *ShardStatusRequest) (*ArbiterResponse, error) + mustEmbedUnimplementedArbiterServer() +} + +// UnimplementedArbiterServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedArbiterServer struct{} + +func (UnimplementedArbiterServer) GetDesiredReplicas(context.Context, *ShardStatusRequest) (*ArbiterResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetDesiredReplicas not implemented") +} +func (UnimplementedArbiterServer) mustEmbedUnimplementedArbiterServer() {} +func (UnimplementedArbiterServer) testEmbeddedByValue() {} + +// UnsafeArbiterServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ArbiterServer will +// result in compilation errors. +type UnsafeArbiterServer interface { + mustEmbedUnimplementedArbiterServer() +} + +func RegisterArbiterServer(s grpc.ServiceRegistrar, srv ArbiterServer) { + // If the following call pancis, it indicates UnimplementedArbiterServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&Arbiter_ServiceDesc, srv) +} + +func _Arbiter_GetDesiredReplicas_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ShardStatusRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ArbiterServer).GetDesiredReplicas(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: Arbiter_GetDesiredReplicas_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ArbiterServer).GetDesiredReplicas(ctx, req.(*ShardStatusRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Arbiter_ServiceDesc is the grpc.ServiceDesc for Arbiter service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Arbiter_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "ring.Arbiter", + HandlerType: (*ArbiterServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetDesiredReplicas", + Handler: _Arbiter_GetDesiredReplicas_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "arbiter.proto", +} + +const ( + ArbiterScaler_Status_FullMethodName = "/ring.ArbiterScaler/Status" + ArbiterScaler_ConsensusWantShards_FullMethodName = "/ring.ArbiterScaler/ConsensusWantShards" +) + +// ArbiterScalerClient is the client API for ArbiterScaler service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ArbiterScalerClient interface { + Status(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*ReplicaStatus, error) + ConsensusWantShards(ctx context.Context, in *ConsensusWantShardsRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) +} + +type arbiterScalerClient struct { + cc grpc.ClientConnInterface +} + +func NewArbiterScalerClient(cc grpc.ClientConnInterface) ArbiterScalerClient { + return &arbiterScalerClient{cc} +} + +func (c *arbiterScalerClient) Status(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*ReplicaStatus, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ReplicaStatus) + err := c.cc.Invoke(ctx, ArbiterScaler_Status_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *arbiterScalerClient) ConsensusWantShards(ctx context.Context, in *ConsensusWantShardsRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(emptypb.Empty) + err := c.cc.Invoke(ctx, ArbiterScaler_ConsensusWantShards_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ArbiterScalerServer is the server API for ArbiterScaler service. +// All implementations must embed UnimplementedArbiterScalerServer +// for forward compatibility. +type ArbiterScalerServer interface { + Status(context.Context, *emptypb.Empty) (*ReplicaStatus, error) + ConsensusWantShards(context.Context, *ConsensusWantShardsRequest) (*emptypb.Empty, error) + mustEmbedUnimplementedArbiterScalerServer() +} + +// UnimplementedArbiterScalerServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedArbiterScalerServer struct{} + +func (UnimplementedArbiterScalerServer) Status(context.Context, *emptypb.Empty) (*ReplicaStatus, error) { + return nil, status.Errorf(codes.Unimplemented, "method Status not implemented") +} +func (UnimplementedArbiterScalerServer) ConsensusWantShards(context.Context, *ConsensusWantShardsRequest) (*emptypb.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method ConsensusWantShards not implemented") +} +func (UnimplementedArbiterScalerServer) mustEmbedUnimplementedArbiterScalerServer() {} +func (UnimplementedArbiterScalerServer) testEmbeddedByValue() {} + +// UnsafeArbiterScalerServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ArbiterScalerServer will +// result in compilation errors. +type UnsafeArbiterScalerServer interface { + mustEmbedUnimplementedArbiterScalerServer() +} + +func RegisterArbiterScalerServer(s grpc.ServiceRegistrar, srv ArbiterScalerServer) { + // If the following call pancis, it indicates UnimplementedArbiterScalerServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ArbiterScaler_ServiceDesc, srv) +} + +func _ArbiterScaler_Status_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(emptypb.Empty) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ArbiterScalerServer).Status(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ArbiterScaler_Status_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ArbiterScalerServer).Status(ctx, req.(*emptypb.Empty)) + } + return interceptor(ctx, in, info, handler) +} + +func _ArbiterScaler_ConsensusWantShards_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ConsensusWantShardsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ArbiterScalerServer).ConsensusWantShards(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ArbiterScaler_ConsensusWantShards_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ArbiterScalerServer).ConsensusWantShards(ctx, req.(*ConsensusWantShardsRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// ArbiterScaler_ServiceDesc is the grpc.ServiceDesc for ArbiterScaler service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ArbiterScaler_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "ring.ArbiterScaler", + HandlerType: (*ArbiterScalerServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Status", + Handler: _ArbiterScaler_Status_Handler, + }, + { + MethodName: "ConsensusWantShards", + Handler: _ArbiterScaler_ConsensusWantShards_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "arbiter.proto", +} diff --git a/pkg/workflows/ring/pb/consensus.pb.go b/pkg/workflows/ring/pb/consensus.pb.go new file mode 100644 index 0000000000..109e05910f --- /dev/null +++ b/pkg/workflows/ring/pb/consensus.pb.go @@ -0,0 +1,442 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.8 +// protoc v5.29.3 +// source: consensus.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Observation struct { + state protoimpl.MessageState `protogen:"open.v1"` + ShardStatus map[uint32]*ShardStatus `protobuf:"bytes,1,rep,name=shard_status,json=shardStatus,proto3" json:"shard_status,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // shard_id -> status + WorkflowIds []string `protobuf:"bytes,2,rep,name=workflow_ids,json=workflowIds,proto3" json:"workflow_ids,omitempty"` + Now *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=now,proto3" json:"now,omitempty"` + WantShards uint32 `protobuf:"varint,4,opt,name=want_shards,json=wantShards,proto3" json:"want_shards,omitempty"` // from ArbiterScaler.Status() + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Observation) Reset() { + *x = Observation{} + mi := &file_consensus_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Observation) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Observation) ProtoMessage() {} + +func (x *Observation) ProtoReflect() protoreflect.Message { + mi := &file_consensus_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Observation.ProtoReflect.Descriptor instead. +func (*Observation) Descriptor() ([]byte, []int) { + return file_consensus_proto_rawDescGZIP(), []int{0} +} + +func (x *Observation) GetShardStatus() map[uint32]*ShardStatus { + if x != nil { + return x.ShardStatus + } + return nil +} + +func (x *Observation) GetWorkflowIds() []string { + if x != nil { + return x.WorkflowIds + } + return nil +} + +func (x *Observation) GetNow() *timestamppb.Timestamp { + if x != nil { + return x.Now + } + return nil +} + +func (x *Observation) GetWantShards() uint32 { + if x != nil { + return x.WantShards + } + return 0 +} + +type WorkflowRoute struct { + state protoimpl.MessageState `protogen:"open.v1"` + Shard uint32 `protobuf:"varint,1,opt,name=shard,proto3" json:"shard,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WorkflowRoute) Reset() { + *x = WorkflowRoute{} + mi := &file_consensus_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WorkflowRoute) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WorkflowRoute) ProtoMessage() {} + +func (x *WorkflowRoute) ProtoReflect() protoreflect.Message { + mi := &file_consensus_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WorkflowRoute.ProtoReflect.Descriptor instead. +func (*WorkflowRoute) Descriptor() ([]byte, []int) { + return file_consensus_proto_rawDescGZIP(), []int{1} +} + +func (x *WorkflowRoute) GetShard() uint32 { + if x != nil { + return x.Shard + } + return 0 +} + +type Transition struct { + state protoimpl.MessageState `protogen:"open.v1"` + WantShards uint32 `protobuf:"varint,1,opt,name=want_shards,json=wantShards,proto3" json:"want_shards,omitempty"` + LastStableCount uint32 `protobuf:"varint,2,opt,name=last_stable_count,json=lastStableCount,proto3" json:"last_stable_count,omitempty"` + ChangesSafeAfter *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=changes_safe_after,json=changesSafeAfter,proto3" json:"changes_safe_after,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Transition) Reset() { + *x = Transition{} + mi := &file_consensus_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Transition) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Transition) ProtoMessage() {} + +func (x *Transition) ProtoReflect() protoreflect.Message { + mi := &file_consensus_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Transition.ProtoReflect.Descriptor instead. +func (*Transition) Descriptor() ([]byte, []int) { + return file_consensus_proto_rawDescGZIP(), []int{2} +} + +func (x *Transition) GetWantShards() uint32 { + if x != nil { + return x.WantShards + } + return 0 +} + +func (x *Transition) GetLastStableCount() uint32 { + if x != nil { + return x.LastStableCount + } + return 0 +} + +func (x *Transition) GetChangesSafeAfter() *timestamppb.Timestamp { + if x != nil { + return x.ChangesSafeAfter + } + return nil +} + +type RoutingState struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + // Types that are valid to be assigned to State: + // + // *RoutingState_Transition + // *RoutingState_RoutableShards + State isRoutingState_State `protobuf_oneof:"state"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RoutingState) Reset() { + *x = RoutingState{} + mi := &file_consensus_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RoutingState) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RoutingState) ProtoMessage() {} + +func (x *RoutingState) ProtoReflect() protoreflect.Message { + mi := &file_consensus_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RoutingState.ProtoReflect.Descriptor instead. +func (*RoutingState) Descriptor() ([]byte, []int) { + return file_consensus_proto_rawDescGZIP(), []int{3} +} + +func (x *RoutingState) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *RoutingState) GetState() isRoutingState_State { + if x != nil { + return x.State + } + return nil +} + +func (x *RoutingState) GetTransition() *Transition { + if x != nil { + if x, ok := x.State.(*RoutingState_Transition); ok { + return x.Transition + } + } + return nil +} + +func (x *RoutingState) GetRoutableShards() uint32 { + if x != nil { + if x, ok := x.State.(*RoutingState_RoutableShards); ok { + return x.RoutableShards + } + } + return 0 +} + +type isRoutingState_State interface { + isRoutingState_State() +} + +type RoutingState_Transition struct { + Transition *Transition `protobuf:"bytes,2,opt,name=transition,proto3,oneof"` +} + +type RoutingState_RoutableShards struct { + RoutableShards uint32 `protobuf:"varint,3,opt,name=routable_shards,json=routableShards,proto3,oneof"` +} + +func (*RoutingState_Transition) isRoutingState_State() {} + +func (*RoutingState_RoutableShards) isRoutingState_State() {} + +type Outcome struct { + state protoimpl.MessageState `protogen:"open.v1"` + State *RoutingState `protobuf:"bytes,1,opt,name=state,proto3" json:"state,omitempty"` // used internally for ring plugin + Routes map[string]*WorkflowRoute `protobuf:"bytes,2,rep,name=routes,proto3" json:"routes,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` // used by consumers to route requests to the appropriate shard + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Outcome) Reset() { + *x = Outcome{} + mi := &file_consensus_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Outcome) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Outcome) ProtoMessage() {} + +func (x *Outcome) ProtoReflect() protoreflect.Message { + mi := &file_consensus_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Outcome.ProtoReflect.Descriptor instead. +func (*Outcome) Descriptor() ([]byte, []int) { + return file_consensus_proto_rawDescGZIP(), []int{4} +} + +func (x *Outcome) GetState() *RoutingState { + if x != nil { + return x.State + } + return nil +} + +func (x *Outcome) GetRoutes() map[string]*WorkflowRoute { + if x != nil { + return x.Routes + } + return nil +} + +var File_consensus_proto protoreflect.FileDescriptor + +const file_consensus_proto_rawDesc = "" + + "\n" + + "\x0fconsensus.proto\x12\x04ring\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\fshared.proto\"\x99\x02\n" + + "\vObservation\x12E\n" + + "\fshard_status\x18\x01 \x03(\v2\".ring.Observation.ShardStatusEntryR\vshardStatus\x12!\n" + + "\fworkflow_ids\x18\x02 \x03(\tR\vworkflowIds\x12,\n" + + "\x03now\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\x03now\x12\x1f\n" + + "\vwant_shards\x18\x04 \x01(\rR\n" + + "wantShards\x1aQ\n" + + "\x10ShardStatusEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\rR\x03key\x12'\n" + + "\x05value\x18\x02 \x01(\v2\x11.ring.ShardStatusR\x05value:\x028\x01\"%\n" + + "\rWorkflowRoute\x12\x14\n" + + "\x05shard\x18\x01 \x01(\rR\x05shard\"\xa3\x01\n" + + "\n" + + "Transition\x12\x1f\n" + + "\vwant_shards\x18\x01 \x01(\rR\n" + + "wantShards\x12*\n" + + "\x11last_stable_count\x18\x02 \x01(\rR\x0flastStableCount\x12H\n" + + "\x12changes_safe_after\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\x10changesSafeAfter\"\x86\x01\n" + + "\fRoutingState\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\x122\n" + + "\n" + + "transition\x18\x02 \x01(\v2\x10.ring.TransitionH\x00R\n" + + "transition\x12)\n" + + "\x0froutable_shards\x18\x03 \x01(\rH\x00R\x0eroutableShardsB\a\n" + + "\x05state\"\xb6\x01\n" + + "\aOutcome\x12(\n" + + "\x05state\x18\x01 \x01(\v2\x12.ring.RoutingStateR\x05state\x121\n" + + "\x06routes\x18\x02 \x03(\v2\x19.ring.Outcome.RoutesEntryR\x06routes\x1aN\n" + + "\vRoutesEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + + "\x05value\x18\x02 \x01(\v2\x13.ring.WorkflowRouteR\x05value:\x028\x01BDZBgithub.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pbb\x06proto3" + +var ( + file_consensus_proto_rawDescOnce sync.Once + file_consensus_proto_rawDescData []byte +) + +func file_consensus_proto_rawDescGZIP() []byte { + file_consensus_proto_rawDescOnce.Do(func() { + file_consensus_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_consensus_proto_rawDesc), len(file_consensus_proto_rawDesc))) + }) + return file_consensus_proto_rawDescData +} + +var file_consensus_proto_msgTypes = make([]protoimpl.MessageInfo, 7) +var file_consensus_proto_goTypes = []any{ + (*Observation)(nil), // 0: ring.Observation + (*WorkflowRoute)(nil), // 1: ring.WorkflowRoute + (*Transition)(nil), // 2: ring.Transition + (*RoutingState)(nil), // 3: ring.RoutingState + (*Outcome)(nil), // 4: ring.Outcome + nil, // 5: ring.Observation.ShardStatusEntry + nil, // 6: ring.Outcome.RoutesEntry + (*timestamppb.Timestamp)(nil), // 7: google.protobuf.Timestamp + (*ShardStatus)(nil), // 8: ring.ShardStatus +} +var file_consensus_proto_depIdxs = []int32{ + 5, // 0: ring.Observation.shard_status:type_name -> ring.Observation.ShardStatusEntry + 7, // 1: ring.Observation.now:type_name -> google.protobuf.Timestamp + 7, // 2: ring.Transition.changes_safe_after:type_name -> google.protobuf.Timestamp + 2, // 3: ring.RoutingState.transition:type_name -> ring.Transition + 3, // 4: ring.Outcome.state:type_name -> ring.RoutingState + 6, // 5: ring.Outcome.routes:type_name -> ring.Outcome.RoutesEntry + 8, // 6: ring.Observation.ShardStatusEntry.value:type_name -> ring.ShardStatus + 1, // 7: ring.Outcome.RoutesEntry.value:type_name -> ring.WorkflowRoute + 8, // [8:8] is the sub-list for method output_type + 8, // [8:8] is the sub-list for method input_type + 8, // [8:8] is the sub-list for extension type_name + 8, // [8:8] is the sub-list for extension extendee + 0, // [0:8] is the sub-list for field type_name +} + +func init() { file_consensus_proto_init() } +func file_consensus_proto_init() { + if File_consensus_proto != nil { + return + } + file_shared_proto_init() + file_consensus_proto_msgTypes[3].OneofWrappers = []any{ + (*RoutingState_Transition)(nil), + (*RoutingState_RoutableShards)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_consensus_proto_rawDesc), len(file_consensus_proto_rawDesc)), + NumEnums: 0, + NumMessages: 7, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_consensus_proto_goTypes, + DependencyIndexes: file_consensus_proto_depIdxs, + MessageInfos: file_consensus_proto_msgTypes, + }.Build() + File_consensus_proto = out.File + file_consensus_proto_goTypes = nil + file_consensus_proto_depIdxs = nil +} diff --git a/pkg/workflows/ring/pb/consensus.proto b/pkg/workflows/ring/pb/consensus.proto new file mode 100644 index 0000000000..2efaf4e1b0 --- /dev/null +++ b/pkg/workflows/ring/pb/consensus.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; + +package ring; + +import "google/protobuf/timestamp.proto"; +import "shared.proto"; + +option go_package = "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"; + +message Observation { + map shard_status = 1; // shard_id -> status + repeated string workflow_ids = 2; + google.protobuf.Timestamp now = 3; + uint32 want_shards = 4; // from ArbiterScaler.Status() +} + +message WorkflowRoute { + uint32 shard = 1; +} + +message Transition { + uint32 want_shards = 1; + uint32 last_stable_count = 2; + google.protobuf.Timestamp changes_safe_after = 3; +} + +message RoutingState { + uint64 id = 1; + oneof state { + Transition transition = 2; + uint32 routable_shards = 3; + } +} + +message Outcome { + RoutingState state = 1; // used internally for ring plugin + map routes = 2; // used by consumers to route requests to the appropriate shard +} + diff --git a/pkg/workflows/ring/pb/generate.go b/pkg/workflows/ring/pb/generate.go new file mode 100644 index 0000000000..850f3eeb44 --- /dev/null +++ b/pkg/workflows/ring/pb/generate.go @@ -0,0 +1,6 @@ +//go:generate protoc --go_out=. --go_opt=paths=source_relative shared.proto +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative arbiter.proto +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative shard_orchestrator.proto +//go:generate protoc --go_out=. --go_opt=paths=source_relative consensus.proto + +package pb diff --git a/pkg/workflows/ring/pb/shard_orchestrator.pb.go b/pkg/workflows/ring/pb/shard_orchestrator.pb.go new file mode 100644 index 0000000000..7a3f8491e8 --- /dev/null +++ b/pkg/workflows/ring/pb/shard_orchestrator.pb.go @@ -0,0 +1,421 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.8 +// protoc v5.29.3 +// source: shard_orchestrator.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type GetWorkflowShardMappingRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + WorkflowIds []string `protobuf:"bytes,1,rep,name=workflow_ids,json=workflowIds,proto3" json:"workflow_ids,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetWorkflowShardMappingRequest) Reset() { + *x = GetWorkflowShardMappingRequest{} + mi := &file_shard_orchestrator_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetWorkflowShardMappingRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetWorkflowShardMappingRequest) ProtoMessage() {} + +func (x *GetWorkflowShardMappingRequest) ProtoReflect() protoreflect.Message { + mi := &file_shard_orchestrator_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetWorkflowShardMappingRequest.ProtoReflect.Descriptor instead. +func (*GetWorkflowShardMappingRequest) Descriptor() ([]byte, []int) { + return file_shard_orchestrator_proto_rawDescGZIP(), []int{0} +} + +func (x *GetWorkflowShardMappingRequest) GetWorkflowIds() []string { + if x != nil { + return x.WorkflowIds + } + return nil +} + +type WorkflowMappingState struct { + state protoimpl.MessageState `protogen:"open.v1"` + OldShardId uint32 `protobuf:"varint,1,opt,name=old_shard_id,json=oldShardId,proto3" json:"old_shard_id,omitempty"` + NewShardId uint32 `protobuf:"varint,2,opt,name=new_shard_id,json=newShardId,proto3" json:"new_shard_id,omitempty"` + InTransition bool `protobuf:"varint,3,opt,name=in_transition,json=inTransition,proto3" json:"in_transition,omitempty"` + LastUpdated *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=last_updated,json=lastUpdated,proto3" json:"last_updated,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WorkflowMappingState) Reset() { + *x = WorkflowMappingState{} + mi := &file_shard_orchestrator_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WorkflowMappingState) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WorkflowMappingState) ProtoMessage() {} + +func (x *WorkflowMappingState) ProtoReflect() protoreflect.Message { + mi := &file_shard_orchestrator_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WorkflowMappingState.ProtoReflect.Descriptor instead. +func (*WorkflowMappingState) Descriptor() ([]byte, []int) { + return file_shard_orchestrator_proto_rawDescGZIP(), []int{1} +} + +func (x *WorkflowMappingState) GetOldShardId() uint32 { + if x != nil { + return x.OldShardId + } + return 0 +} + +func (x *WorkflowMappingState) GetNewShardId() uint32 { + if x != nil { + return x.NewShardId + } + return 0 +} + +func (x *WorkflowMappingState) GetInTransition() bool { + if x != nil { + return x.InTransition + } + return false +} + +func (x *WorkflowMappingState) GetLastUpdated() *timestamppb.Timestamp { + if x != nil { + return x.LastUpdated + } + return nil +} + +type GetWorkflowShardMappingResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Mappings map[string]uint32 `protobuf:"bytes,1,rep,name=mappings,proto3" json:"mappings,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"varint,2,opt,name=value"` + MappingStates map[string]*WorkflowMappingState `protobuf:"bytes,2,rep,name=mapping_states,json=mappingStates,proto3" json:"mapping_states,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + Timestamp *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=timestamp,proto3" json:"timestamp,omitempty"` + MappingVersion uint64 `protobuf:"varint,4,opt,name=mapping_version,json=mappingVersion,proto3" json:"mapping_version,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetWorkflowShardMappingResponse) Reset() { + *x = GetWorkflowShardMappingResponse{} + mi := &file_shard_orchestrator_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetWorkflowShardMappingResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetWorkflowShardMappingResponse) ProtoMessage() {} + +func (x *GetWorkflowShardMappingResponse) ProtoReflect() protoreflect.Message { + mi := &file_shard_orchestrator_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetWorkflowShardMappingResponse.ProtoReflect.Descriptor instead. +func (*GetWorkflowShardMappingResponse) Descriptor() ([]byte, []int) { + return file_shard_orchestrator_proto_rawDescGZIP(), []int{2} +} + +func (x *GetWorkflowShardMappingResponse) GetMappings() map[string]uint32 { + if x != nil { + return x.Mappings + } + return nil +} + +func (x *GetWorkflowShardMappingResponse) GetMappingStates() map[string]*WorkflowMappingState { + if x != nil { + return x.MappingStates + } + return nil +} + +func (x *GetWorkflowShardMappingResponse) GetTimestamp() *timestamppb.Timestamp { + if x != nil { + return x.Timestamp + } + return nil +} + +func (x *GetWorkflowShardMappingResponse) GetMappingVersion() uint64 { + if x != nil { + return x.MappingVersion + } + return 0 +} + +type ReportWorkflowTriggerRegistrationRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + SourceShardId uint32 `protobuf:"varint,1,opt,name=source_shard_id,json=sourceShardId,proto3" json:"source_shard_id,omitempty"` + RegisteredWorkflows map[string]uint32 `protobuf:"bytes,2,rep,name=registered_workflows,json=registeredWorkflows,proto3" json:"registered_workflows,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"varint,2,opt,name=value"` + ReportTimestamp *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=report_timestamp,json=reportTimestamp,proto3" json:"report_timestamp,omitempty"` + TotalActiveWorkflows uint32 `protobuf:"varint,4,opt,name=total_active_workflows,json=totalActiveWorkflows,proto3" json:"total_active_workflows,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReportWorkflowTriggerRegistrationRequest) Reset() { + *x = ReportWorkflowTriggerRegistrationRequest{} + mi := &file_shard_orchestrator_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReportWorkflowTriggerRegistrationRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReportWorkflowTriggerRegistrationRequest) ProtoMessage() {} + +func (x *ReportWorkflowTriggerRegistrationRequest) ProtoReflect() protoreflect.Message { + mi := &file_shard_orchestrator_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReportWorkflowTriggerRegistrationRequest.ProtoReflect.Descriptor instead. +func (*ReportWorkflowTriggerRegistrationRequest) Descriptor() ([]byte, []int) { + return file_shard_orchestrator_proto_rawDescGZIP(), []int{3} +} + +func (x *ReportWorkflowTriggerRegistrationRequest) GetSourceShardId() uint32 { + if x != nil { + return x.SourceShardId + } + return 0 +} + +func (x *ReportWorkflowTriggerRegistrationRequest) GetRegisteredWorkflows() map[string]uint32 { + if x != nil { + return x.RegisteredWorkflows + } + return nil +} + +func (x *ReportWorkflowTriggerRegistrationRequest) GetReportTimestamp() *timestamppb.Timestamp { + if x != nil { + return x.ReportTimestamp + } + return nil +} + +func (x *ReportWorkflowTriggerRegistrationRequest) GetTotalActiveWorkflows() uint32 { + if x != nil { + return x.TotalActiveWorkflows + } + return 0 +} + +type ReportWorkflowTriggerRegistrationResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReportWorkflowTriggerRegistrationResponse) Reset() { + *x = ReportWorkflowTriggerRegistrationResponse{} + mi := &file_shard_orchestrator_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReportWorkflowTriggerRegistrationResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReportWorkflowTriggerRegistrationResponse) ProtoMessage() {} + +func (x *ReportWorkflowTriggerRegistrationResponse) ProtoReflect() protoreflect.Message { + mi := &file_shard_orchestrator_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReportWorkflowTriggerRegistrationResponse.ProtoReflect.Descriptor instead. +func (*ReportWorkflowTriggerRegistrationResponse) Descriptor() ([]byte, []int) { + return file_shard_orchestrator_proto_rawDescGZIP(), []int{4} +} + +func (x *ReportWorkflowTriggerRegistrationResponse) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +var File_shard_orchestrator_proto protoreflect.FileDescriptor + +const file_shard_orchestrator_proto_rawDesc = "" + + "\n" + + "\x18shard_orchestrator.proto\x12\x04ring\x1a\x1fgoogle/protobuf/timestamp.proto\"C\n" + + "\x1eGetWorkflowShardMappingRequest\x12!\n" + + "\fworkflow_ids\x18\x01 \x03(\tR\vworkflowIds\"\xbe\x01\n" + + "\x14WorkflowMappingState\x12 \n" + + "\fold_shard_id\x18\x01 \x01(\rR\n" + + "oldShardId\x12 \n" + + "\fnew_shard_id\x18\x02 \x01(\rR\n" + + "newShardId\x12#\n" + + "\rin_transition\x18\x03 \x01(\bR\finTransition\x12=\n" + + "\flast_updated\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\vlastUpdated\"\xd1\x03\n" + + "\x1fGetWorkflowShardMappingResponse\x12O\n" + + "\bmappings\x18\x01 \x03(\v23.ring.GetWorkflowShardMappingResponse.MappingsEntryR\bmappings\x12_\n" + + "\x0emapping_states\x18\x02 \x03(\v28.ring.GetWorkflowShardMappingResponse.MappingStatesEntryR\rmappingStates\x128\n" + + "\ttimestamp\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12'\n" + + "\x0fmapping_version\x18\x04 \x01(\x04R\x0emappingVersion\x1a;\n" + + "\rMappingsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\rR\x05value:\x028\x01\x1a\\\n" + + "\x12MappingStatesEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x120\n" + + "\x05value\x18\x02 \x01(\v2\x1a.ring.WorkflowMappingStateR\x05value:\x028\x01\"\x93\x03\n" + + "(ReportWorkflowTriggerRegistrationRequest\x12&\n" + + "\x0fsource_shard_id\x18\x01 \x01(\rR\rsourceShardId\x12z\n" + + "\x14registered_workflows\x18\x02 \x03(\v2G.ring.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntryR\x13registeredWorkflows\x12E\n" + + "\x10report_timestamp\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\x0freportTimestamp\x124\n" + + "\x16total_active_workflows\x18\x04 \x01(\rR\x14totalActiveWorkflows\x1aF\n" + + "\x18RegisteredWorkflowsEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\rR\x05value:\x028\x01\"E\n" + + ")ReportWorkflowTriggerRegistrationResponse\x12\x18\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess2\x89\x02\n" + + "\x18ShardOrchestratorService\x12f\n" + + "\x17GetWorkflowShardMapping\x12$.ring.GetWorkflowShardMappingRequest\x1a%.ring.GetWorkflowShardMappingResponse\x12\x84\x01\n" + + "!ReportWorkflowTriggerRegistration\x12..ring.ReportWorkflowTriggerRegistrationRequest\x1a/.ring.ReportWorkflowTriggerRegistrationResponseBDZBgithub.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pbb\x06proto3" + +var ( + file_shard_orchestrator_proto_rawDescOnce sync.Once + file_shard_orchestrator_proto_rawDescData []byte +) + +func file_shard_orchestrator_proto_rawDescGZIP() []byte { + file_shard_orchestrator_proto_rawDescOnce.Do(func() { + file_shard_orchestrator_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_shard_orchestrator_proto_rawDesc), len(file_shard_orchestrator_proto_rawDesc))) + }) + return file_shard_orchestrator_proto_rawDescData +} + +var file_shard_orchestrator_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_shard_orchestrator_proto_goTypes = []any{ + (*GetWorkflowShardMappingRequest)(nil), // 0: ring.GetWorkflowShardMappingRequest + (*WorkflowMappingState)(nil), // 1: ring.WorkflowMappingState + (*GetWorkflowShardMappingResponse)(nil), // 2: ring.GetWorkflowShardMappingResponse + (*ReportWorkflowTriggerRegistrationRequest)(nil), // 3: ring.ReportWorkflowTriggerRegistrationRequest + (*ReportWorkflowTriggerRegistrationResponse)(nil), // 4: ring.ReportWorkflowTriggerRegistrationResponse + nil, // 5: ring.GetWorkflowShardMappingResponse.MappingsEntry + nil, // 6: ring.GetWorkflowShardMappingResponse.MappingStatesEntry + nil, // 7: ring.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry + (*timestamppb.Timestamp)(nil), // 8: google.protobuf.Timestamp +} +var file_shard_orchestrator_proto_depIdxs = []int32{ + 8, // 0: ring.WorkflowMappingState.last_updated:type_name -> google.protobuf.Timestamp + 5, // 1: ring.GetWorkflowShardMappingResponse.mappings:type_name -> ring.GetWorkflowShardMappingResponse.MappingsEntry + 6, // 2: ring.GetWorkflowShardMappingResponse.mapping_states:type_name -> ring.GetWorkflowShardMappingResponse.MappingStatesEntry + 8, // 3: ring.GetWorkflowShardMappingResponse.timestamp:type_name -> google.protobuf.Timestamp + 7, // 4: ring.ReportWorkflowTriggerRegistrationRequest.registered_workflows:type_name -> ring.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry + 8, // 5: ring.ReportWorkflowTriggerRegistrationRequest.report_timestamp:type_name -> google.protobuf.Timestamp + 1, // 6: ring.GetWorkflowShardMappingResponse.MappingStatesEntry.value:type_name -> ring.WorkflowMappingState + 0, // 7: ring.ShardOrchestratorService.GetWorkflowShardMapping:input_type -> ring.GetWorkflowShardMappingRequest + 3, // 8: ring.ShardOrchestratorService.ReportWorkflowTriggerRegistration:input_type -> ring.ReportWorkflowTriggerRegistrationRequest + 2, // 9: ring.ShardOrchestratorService.GetWorkflowShardMapping:output_type -> ring.GetWorkflowShardMappingResponse + 4, // 10: ring.ShardOrchestratorService.ReportWorkflowTriggerRegistration:output_type -> ring.ReportWorkflowTriggerRegistrationResponse + 9, // [9:11] is the sub-list for method output_type + 7, // [7:9] is the sub-list for method input_type + 7, // [7:7] is the sub-list for extension type_name + 7, // [7:7] is the sub-list for extension extendee + 0, // [0:7] is the sub-list for field type_name +} + +func init() { file_shard_orchestrator_proto_init() } +func file_shard_orchestrator_proto_init() { + if File_shard_orchestrator_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_shard_orchestrator_proto_rawDesc), len(file_shard_orchestrator_proto_rawDesc)), + NumEnums: 0, + NumMessages: 8, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_shard_orchestrator_proto_goTypes, + DependencyIndexes: file_shard_orchestrator_proto_depIdxs, + MessageInfos: file_shard_orchestrator_proto_msgTypes, + }.Build() + File_shard_orchestrator_proto = out.File + file_shard_orchestrator_proto_goTypes = nil + file_shard_orchestrator_proto_depIdxs = nil +} diff --git a/pkg/workflows/ring/pb/shard_orchestrator.proto b/pkg/workflows/ring/pb/shard_orchestrator.proto new file mode 100644 index 0000000000..c7e3c1668e --- /dev/null +++ b/pkg/workflows/ring/pb/shard_orchestrator.proto @@ -0,0 +1,41 @@ +syntax = "proto3"; + +package ring; + +import "google/protobuf/timestamp.proto"; + +option go_package = "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"; + +message GetWorkflowShardMappingRequest { + repeated string workflow_ids = 1; +} + +message WorkflowMappingState { + uint32 old_shard_id = 1; + uint32 new_shard_id = 2; + bool in_transition = 3; + google.protobuf.Timestamp last_updated = 4; +} + +message GetWorkflowShardMappingResponse { + map mappings = 1; + map mapping_states = 2; + google.protobuf.Timestamp timestamp = 3; + uint64 mapping_version = 4; +} + +message ReportWorkflowTriggerRegistrationRequest { + uint32 source_shard_id = 1; + map registered_workflows = 2; + google.protobuf.Timestamp report_timestamp = 3; + uint32 total_active_workflows = 4; +} + +message ReportWorkflowTriggerRegistrationResponse { + bool success = 1; +} + +service ShardOrchestratorService { + rpc GetWorkflowShardMapping(GetWorkflowShardMappingRequest) returns (GetWorkflowShardMappingResponse); // returns shard assignments for the specified workflows + rpc ReportWorkflowTriggerRegistration(ReportWorkflowTriggerRegistrationRequest) returns (ReportWorkflowTriggerRegistrationResponse); // called by shards to report their workflow trigger registrations +} diff --git a/pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go b/pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go new file mode 100644 index 0000000000..d2ab234c3a --- /dev/null +++ b/pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go @@ -0,0 +1,160 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc v5.29.3 +// source: shard_orchestrator.proto + +package pb + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + ShardOrchestratorService_GetWorkflowShardMapping_FullMethodName = "/ring.ShardOrchestratorService/GetWorkflowShardMapping" + ShardOrchestratorService_ReportWorkflowTriggerRegistration_FullMethodName = "/ring.ShardOrchestratorService/ReportWorkflowTriggerRegistration" +) + +// ShardOrchestratorServiceClient is the client API for ShardOrchestratorService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type ShardOrchestratorServiceClient interface { + GetWorkflowShardMapping(ctx context.Context, in *GetWorkflowShardMappingRequest, opts ...grpc.CallOption) (*GetWorkflowShardMappingResponse, error) + ReportWorkflowTriggerRegistration(ctx context.Context, in *ReportWorkflowTriggerRegistrationRequest, opts ...grpc.CallOption) (*ReportWorkflowTriggerRegistrationResponse, error) +} + +type shardOrchestratorServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewShardOrchestratorServiceClient(cc grpc.ClientConnInterface) ShardOrchestratorServiceClient { + return &shardOrchestratorServiceClient{cc} +} + +func (c *shardOrchestratorServiceClient) GetWorkflowShardMapping(ctx context.Context, in *GetWorkflowShardMappingRequest, opts ...grpc.CallOption) (*GetWorkflowShardMappingResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetWorkflowShardMappingResponse) + err := c.cc.Invoke(ctx, ShardOrchestratorService_GetWorkflowShardMapping_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *shardOrchestratorServiceClient) ReportWorkflowTriggerRegistration(ctx context.Context, in *ReportWorkflowTriggerRegistrationRequest, opts ...grpc.CallOption) (*ReportWorkflowTriggerRegistrationResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ReportWorkflowTriggerRegistrationResponse) + err := c.cc.Invoke(ctx, ShardOrchestratorService_ReportWorkflowTriggerRegistration_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ShardOrchestratorServiceServer is the server API for ShardOrchestratorService service. +// All implementations must embed UnimplementedShardOrchestratorServiceServer +// for forward compatibility. +type ShardOrchestratorServiceServer interface { + GetWorkflowShardMapping(context.Context, *GetWorkflowShardMappingRequest) (*GetWorkflowShardMappingResponse, error) + ReportWorkflowTriggerRegistration(context.Context, *ReportWorkflowTriggerRegistrationRequest) (*ReportWorkflowTriggerRegistrationResponse, error) + mustEmbedUnimplementedShardOrchestratorServiceServer() +} + +// UnimplementedShardOrchestratorServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedShardOrchestratorServiceServer struct{} + +func (UnimplementedShardOrchestratorServiceServer) GetWorkflowShardMapping(context.Context, *GetWorkflowShardMappingRequest) (*GetWorkflowShardMappingResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetWorkflowShardMapping not implemented") +} +func (UnimplementedShardOrchestratorServiceServer) ReportWorkflowTriggerRegistration(context.Context, *ReportWorkflowTriggerRegistrationRequest) (*ReportWorkflowTriggerRegistrationResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ReportWorkflowTriggerRegistration not implemented") +} +func (UnimplementedShardOrchestratorServiceServer) mustEmbedUnimplementedShardOrchestratorServiceServer() { +} +func (UnimplementedShardOrchestratorServiceServer) testEmbeddedByValue() {} + +// UnsafeShardOrchestratorServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ShardOrchestratorServiceServer will +// result in compilation errors. +type UnsafeShardOrchestratorServiceServer interface { + mustEmbedUnimplementedShardOrchestratorServiceServer() +} + +func RegisterShardOrchestratorServiceServer(s grpc.ServiceRegistrar, srv ShardOrchestratorServiceServer) { + // If the following call pancis, it indicates UnimplementedShardOrchestratorServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ShardOrchestratorService_ServiceDesc, srv) +} + +func _ShardOrchestratorService_GetWorkflowShardMapping_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetWorkflowShardMappingRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ShardOrchestratorServiceServer).GetWorkflowShardMapping(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ShardOrchestratorService_GetWorkflowShardMapping_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ShardOrchestratorServiceServer).GetWorkflowShardMapping(ctx, req.(*GetWorkflowShardMappingRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _ShardOrchestratorService_ReportWorkflowTriggerRegistration_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ReportWorkflowTriggerRegistrationRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ShardOrchestratorServiceServer).ReportWorkflowTriggerRegistration(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ShardOrchestratorService_ReportWorkflowTriggerRegistration_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ShardOrchestratorServiceServer).ReportWorkflowTriggerRegistration(ctx, req.(*ReportWorkflowTriggerRegistrationRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// ShardOrchestratorService_ServiceDesc is the grpc.ServiceDesc for ShardOrchestratorService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ShardOrchestratorService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "ring.ShardOrchestratorService", + HandlerType: (*ShardOrchestratorServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetWorkflowShardMapping", + Handler: _ShardOrchestratorService_GetWorkflowShardMapping_Handler, + }, + { + MethodName: "ReportWorkflowTriggerRegistration", + Handler: _ShardOrchestratorService_ReportWorkflowTriggerRegistration_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "shard_orchestrator.proto", +} diff --git a/pkg/workflows/ring/pb/shared.pb.go b/pkg/workflows/ring/pb/shared.pb.go new file mode 100644 index 0000000000..d138993056 --- /dev/null +++ b/pkg/workflows/ring/pb/shared.pb.go @@ -0,0 +1,123 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.8 +// protoc v5.29.3 +// source: shared.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ShardStatus struct { + state protoimpl.MessageState `protogen:"open.v1"` + IsHealthy bool `protobuf:"varint,1,opt,name=is_healthy,json=isHealthy,proto3" json:"is_healthy,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ShardStatus) Reset() { + *x = ShardStatus{} + mi := &file_shared_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ShardStatus) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ShardStatus) ProtoMessage() {} + +func (x *ShardStatus) ProtoReflect() protoreflect.Message { + mi := &file_shared_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ShardStatus.ProtoReflect.Descriptor instead. +func (*ShardStatus) Descriptor() ([]byte, []int) { + return file_shared_proto_rawDescGZIP(), []int{0} +} + +func (x *ShardStatus) GetIsHealthy() bool { + if x != nil { + return x.IsHealthy + } + return false +} + +var File_shared_proto protoreflect.FileDescriptor + +const file_shared_proto_rawDesc = "" + + "\n" + + "\fshared.proto\x12\x04ring\",\n" + + "\vShardStatus\x12\x1d\n" + + "\n" + + "is_healthy\x18\x01 \x01(\bR\tisHealthyBDZBgithub.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pbb\x06proto3" + +var ( + file_shared_proto_rawDescOnce sync.Once + file_shared_proto_rawDescData []byte +) + +func file_shared_proto_rawDescGZIP() []byte { + file_shared_proto_rawDescOnce.Do(func() { + file_shared_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_shared_proto_rawDesc), len(file_shared_proto_rawDesc))) + }) + return file_shared_proto_rawDescData +} + +var file_shared_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_shared_proto_goTypes = []any{ + (*ShardStatus)(nil), // 0: ring.ShardStatus +} +var file_shared_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_shared_proto_init() } +func file_shared_proto_init() { + if File_shared_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_shared_proto_rawDesc), len(file_shared_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_shared_proto_goTypes, + DependencyIndexes: file_shared_proto_depIdxs, + MessageInfos: file_shared_proto_msgTypes, + }.Build() + File_shared_proto = out.File + file_shared_proto_goTypes = nil + file_shared_proto_depIdxs = nil +} diff --git a/pkg/workflows/ring/pb/shared.proto b/pkg/workflows/ring/pb/shared.proto new file mode 100644 index 0000000000..b0adafcc2f --- /dev/null +++ b/pkg/workflows/ring/pb/shared.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package ring; + +option go_package = "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"; + +message ShardStatus { + bool is_healthy = 1; +} \ No newline at end of file diff --git a/pkg/workflows/ring/plugin.go b/pkg/workflows/ring/plugin.go new file mode 100644 index 0000000000..740009ef47 --- /dev/null +++ b/pkg/workflows/ring/plugin.go @@ -0,0 +1,276 @@ +package ring + +import ( + "context" + "errors" + "slices" + "sync" + "time" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/smartcontractkit/libocr/commontypes" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/libocr/quorumhelper" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" +) + +type Plugin struct { + mu sync.RWMutex + + store *Store + arbiterScaler pb.ArbiterScalerClient + config ocr3types.ReportingPluginConfig + lggr logger.Logger + + batchSize int + timeToSync time.Duration +} + +var _ ocr3types.ReportingPlugin[[]byte] = (*Plugin)(nil) + +type ConsensusConfig struct { + BatchSize int + TimeToSync time.Duration +} + +const ( + DefaultBatchSize = 100 + DefaultTimeToSync = 5 * time.Minute +) + +func NewPlugin(store *Store, arbiterScaler pb.ArbiterScalerClient, config ocr3types.ReportingPluginConfig, lggr logger.Logger, cfg *ConsensusConfig) (*Plugin, error) { + if arbiterScaler == nil { + return nil, errors.New("RingOCR arbiterScaler is required") + } + if cfg == nil { + cfg = &ConsensusConfig{ + BatchSize: DefaultBatchSize, + TimeToSync: DefaultTimeToSync, + } + } + + if cfg.BatchSize <= 0 { + lggr.Infow("using default batchSize", "default", DefaultBatchSize) + cfg.BatchSize = DefaultBatchSize + } + if cfg.TimeToSync <= 0 { + lggr.Infow("using default timeToSync", "default", DefaultTimeToSync) + cfg.TimeToSync = DefaultTimeToSync + } + + lggr.Infow("RingPlugin config", + "batchSize", cfg.BatchSize, + "timeToSync", cfg.TimeToSync, + ) + + return &Plugin{ + store: store, + arbiterScaler: arbiterScaler, + config: config, + lggr: logger.Named(lggr, "RingPlugin"), + batchSize: cfg.BatchSize, + timeToSync: cfg.TimeToSync, + }, nil +} + +//coverage:ignore +func (p *Plugin) Query(_ context.Context, _ ocr3types.OutcomeContext) (types.Query, error) { + return nil, nil +} + +func (p *Plugin) Observation(ctx context.Context, _ ocr3types.OutcomeContext, _ types.Query) (types.Observation, error) { + var wantShards uint32 + var shardStatus map[uint32]*pb.ShardStatus + + status, err := p.arbiterScaler.Status(ctx, &emptypb.Empty{}) + if err != nil { + // NOTE: consider a fallback data source if Arbiter is not available + p.lggr.Errorw("RingOCR failed to get arbiter scaler status", "error", err) + return nil, err + } + wantShards = status.WantShards + shardStatus = status.Status + + allWorkflowIDs := make([]string, 0) + for wfID := range p.store.GetAllRoutingState() { + allWorkflowIDs = append(allWorkflowIDs, wfID) + } + + pendingAllocs := p.store.GetPendingAllocations() + p.lggr.Infow("RingOCR Observation pending allocations", "pendingAllocs", pendingAllocs) + + allWorkflowIDs = append(allWorkflowIDs, pendingAllocs...) + allWorkflowIDs = uniqueSorted(allWorkflowIDs) + p.lggr.Infow("RingOCR Observation all workflow IDs unique", "allWorkflowIDs", allWorkflowIDs, "wantShards", wantShards) + + observation := &pb.Observation{ + ShardStatus: shardStatus, + WorkflowIds: allWorkflowIDs, + Now: timestamppb.Now(), + WantShards: wantShards, + } + + return proto.MarshalOptions{Deterministic: true}.Marshal(observation) +} + +func (p *Plugin) ValidateObservation(_ context.Context, _ ocr3types.OutcomeContext, _ types.Query, ao types.AttributedObservation) error { + observation := &pb.Observation{} + if err := proto.Unmarshal(ao.Observation, observation); err != nil { + return err + } + if observation.Now == nil { + return errors.New("observation missing timestamp") + } + if observation.WantShards == 0 { + return errors.New("observation missing WantShards") + } + return nil +} + +func (p *Plugin) ObservationQuorum(_ context.Context, _ ocr3types.OutcomeContext, _ types.Query, aos []types.AttributedObservation) (quorumReached bool, err error) { + return quorumhelper.ObservationCountReachesObservationQuorum(quorumhelper.QuorumTwoFPlusOne, p.config.N, p.config.F, aos), nil +} + +func (p *Plugin) collectShardInfo(aos []types.AttributedObservation) (shardHealth map[uint32]int, workflows []string, timestamps []time.Time, wantShardVotes map[commontypes.OracleID]uint32) { + shardHealth = make(map[uint32]int) + wantShardVotes = make(map[commontypes.OracleID]uint32) + for _, ao := range aos { + observation := &pb.Observation{} + _ = proto.Unmarshal(ao.Observation, observation) // validated in ValidateObservation + + for shardID, status := range observation.ShardStatus { + if status != nil && status.IsHealthy { + shardHealth[shardID]++ + } + } + + workflows = append(workflows, observation.WorkflowIds...) + timestamps = append(timestamps, observation.Now.AsTime()) + + wantShardVotes[ao.Observer] = observation.WantShards + } + return shardHealth, workflows, timestamps, wantShardVotes +} + +func (p *Plugin) getHealthyShards(shardHealth map[uint32]int) []uint32 { + var healthyShards []uint32 + for shardID, votes := range shardHealth { + if votes > p.config.F { + healthyShards = append(healthyShards, shardID) + p.store.SetShardHealth(shardID, true) + } + } + slices.Sort(healthyShards) + + return healthyShards +} + +func (p *Plugin) Outcome(_ context.Context, outctx ocr3types.OutcomeContext, _ types.Query, aos []types.AttributedObservation) (ocr3types.Outcome, error) { + currentShardHealth, allWorkflows, nows, wantShardVotes := p.collectShardInfo(aos) + p.lggr.Infow("RingOCR Outcome collect shard info", "currentShardHealth", currentShardHealth, "wantShardVotes", wantShardVotes) + + // Use the median timestamp to determine the current time + slices.SortFunc(nows, time.Time.Compare) + now := nows[len(nows)/2] + + // Use median for wantShards consensus (all validated observations have WantShards > 0) + votes := make([]uint32, 0, len(wantShardVotes)) + for _, v := range wantShardVotes { + votes = append(votes, v) + } + slices.Sort(votes) + wantShards := votes[len(votes)/2] + + // Bootstrap from Arbiter's current shard count on 1st round; subsequent rounds build on prior outcome + prior := &pb.Outcome{} + if outctx.PreviousOutcome == nil { + prior.Routes = map[string]*pb.WorkflowRoute{} + prior.State = &pb.RoutingState{Id: outctx.SeqNr, State: &pb.RoutingState_RoutableShards{RoutableShards: wantShards}} + } else if err := proto.Unmarshal(outctx.PreviousOutcome, prior); err != nil { + return nil, err + } + + allWorkflows = uniqueSorted(allWorkflows) + + healthyShards := p.getHealthyShards(currentShardHealth) + + nextState, err := NextState(prior.State, wantShards, now, p.timeToSync) + if err != nil { + return nil, err + } + + // Deterministic hashing ensures all nodes agree on workflow-to-shard assignments + // without coordination, preventing protocol failures from inconsistent routing + ring := newShardRing(healthyShards) + routes := make(map[string]*pb.WorkflowRoute) + for _, wfID := range allWorkflows { + shard, err := locateShard(ring, wfID) + if err != nil { + p.lggr.Warnw("RingOCR failed to locate shard for workflow", "workflowID", wfID, "error", err) + shard = 0 // fallback to shard 0 when no healthy shards + } + routes[wfID] = &pb.WorkflowRoute{Shard: shard} + } + + outcome := &pb.Outcome{ + State: nextState, + Routes: routes, + } + + p.lggr.Infow("RingOCR Outcome", "healthyShards", len(healthyShards), "totalObservations", len(aos), "workflowCount", len(routes)) + + return proto.MarshalOptions{Deterministic: true}.Marshal(outcome) +} + +func (p *Plugin) Reports(_ context.Context, _ uint64, outcome ocr3types.Outcome) ([]ocr3types.ReportPlus[[]byte], error) { + allOraclesTransmitNow := &ocr3types.TransmissionSchedule{ + Transmitters: make([]commontypes.OracleID, p.config.N), + TransmissionDelays: make([]time.Duration, p.config.N), + } + + for i := 0; i < p.config.N; i++ { + allOraclesTransmitNow.Transmitters[i] = commontypes.OracleID(i) + } + + info, err := structpb.NewStruct(map[string]any{ + "keyBundleName": "evm", + }) + if err != nil { + return nil, err + } + infoBytes, err := proto.MarshalOptions{Deterministic: true}.Marshal(info) + if err != nil { + return nil, err + } + + return []ocr3types.ReportPlus[[]byte]{ + { + ReportWithInfo: ocr3types.ReportWithInfo[[]byte]{ + Report: types.Report(outcome), + Info: infoBytes, + }, + }, + }, nil +} + +//coverage:ignore +func (p *Plugin) ShouldAcceptAttestedReport(_ context.Context, _ uint64, _ ocr3types.ReportWithInfo[[]byte]) (bool, error) { + return true, nil +} + +//coverage:ignore +func (p *Plugin) ShouldTransmitAcceptedReport(_ context.Context, _ uint64, _ ocr3types.ReportWithInfo[[]byte]) (bool, error) { + return true, nil +} + +//coverage:ignore +func (p *Plugin) Close() error { + return nil +} diff --git a/pkg/workflows/ring/plugin_test.go b/pkg/workflows/ring/plugin_test.go new file mode 100644 index 0000000000..99f70d1b9d --- /dev/null +++ b/pkg/workflows/ring/plugin_test.go @@ -0,0 +1,593 @@ +package ring + +import ( + "context" + "testing" + "time" + + "github.com/smartcontractkit/libocr/commontypes" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/libocr/offchainreporting2/types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" +) + +type mockArbiter struct { + status *pb.ReplicaStatus +} + +func (m *mockArbiter) Status(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*pb.ReplicaStatus, error) { + if m.status != nil { + return m.status, nil + } + return &pb.ReplicaStatus{}, nil +} + +func (m *mockArbiter) ConsensusWantShards(ctx context.Context, req *pb.ConsensusWantShardsRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { + return &emptypb.Empty{}, nil +} + +var twoHealthyShards = []map[uint32]*pb.ShardStatus{ + {0: {IsHealthy: true}, 1: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}}, +} + +func toShardStatus(m map[uint32]bool) map[uint32]*pb.ShardStatus { + result := make(map[uint32]*pb.ShardStatus, len(m)) + for k, v := range m { + result[k] = &pb.ShardStatus{IsHealthy: v} + } + return result +} + +func TestPlugin_Outcome(t *testing.T) { + t.Run("WithMultiNodeObservations", func(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + store.SetAllShardHealth(map[uint32]bool{0: true, 1: true, 2: true}) + + config := ocr3types.ReportingPluginConfig{ + N: 4, F: 1, + OffchainConfig: []byte{}, + MaxDurationObservation: 0, + MaxDurationShouldAcceptAttestedReport: 0, + MaxDurationShouldTransmitAcceptedReport: 0, + } + + plugin, err := NewPlugin(store, &mockArbiter{}, config, lggr, nil) + require.NoError(t, err) + + ctx := t.Context() + intialSeqNr := uint64(42) + outcomeCtx := ocr3types.OutcomeContext{SeqNr: intialSeqNr} + + // Observations from 4 NOPs reporting health, workflows, and wantShards=3 + observations := []struct { + name string + shardStatus map[uint32]*pb.ShardStatus + workflows []string + wantShards uint32 + }{ + { + name: "NOP 0", + shardStatus: toShardStatus(map[uint32]bool{0: true, 1: true, 2: true}), + workflows: []string{"wf-A", "wf-B", "wf-C"}, + wantShards: 3, + }, + { + name: "NOP 1", + shardStatus: toShardStatus(map[uint32]bool{0: true, 1: true, 2: true}), + workflows: []string{"wf-B", "wf-C", "wf-D"}, + wantShards: 3, + }, + { + name: "NOP 2", + shardStatus: toShardStatus(map[uint32]bool{0: true, 1: true, 2: false}), // shard 2 unhealthy + workflows: []string{"wf-A", "wf-C"}, + wantShards: 3, + }, + { + name: "NOP 3", + shardStatus: toShardStatus(map[uint32]bool{0: true, 1: true, 2: true}), + workflows: []string{"wf-A", "wf-B", "wf-D"}, + wantShards: 3, + }, + } + + // Build attributed observations + aos := make([]types.AttributedObservation, 0) + for _, obs := range observations { + pbObs := &pb.Observation{ + ShardStatus: obs.shardStatus, + WorkflowIds: obs.workflows, + Now: timestamppb.Now(), + WantShards: obs.wantShards, + } + rawObs, err := proto.Marshal(pbObs) + require.NoError(t, err) + + aos = append(aos, types.AttributedObservation{ + Observation: rawObs, + Observer: commontypes.OracleID(len(aos)), + }) + } + + // Execute Outcome phase + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + require.NotNil(t, outcome) + + // Verify outcome + outcomeProto := &pb.Outcome{} + err = proto.Unmarshal(outcome, outcomeProto) + require.NoError(t, err) + + // Check consensus results + require.NotNil(t, outcomeProto.State) + // When bootstrapping without PreviousOutcome, we use wantShards from observations (3) + // Since consensus wantShards (3) equals bootstrap shards, no transition needed - ID stays the same + require.Equal(t, intialSeqNr, outcomeProto.State.Id, "ID should match SeqNr (no transition needed)") + t.Logf("Outcome - ID: %d, HealthyShards: %v", outcomeProto.State.Id, outcomeProto.State.GetRoutableShards()) + t.Logf("Workflows assigned: %d", len(outcomeProto.Routes)) + + // Verify all workflows are assigned + expectedWorkflows := map[string]bool{"wf-A": true, "wf-B": true, "wf-C": true, "wf-D": true} + require.Equal(t, len(expectedWorkflows), len(outcomeProto.Routes)) + for wf := range expectedWorkflows { + route, exists := outcomeProto.Routes[wf] + require.True(t, exists, "workflow %s should be assigned", wf) + require.True(t, route.Shard <= 2, "shard should be healthy (0-2)") + t.Logf(" %s → shard %d", wf, route.Shard) + } + + // Verify determinism: run again, should get same assignments + outcome2, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + + outcomeProto2 := &pb.Outcome{} + err = proto.Unmarshal(outcome2, outcomeProto2) + require.NoError(t, err) + + // Same workflows → same shards + for wf, route1 := range outcomeProto.Routes { + route2, exists := outcomeProto2.Routes[wf] + require.True(t, exists) + require.Equal(t, route1.Shard, route2.Shard, "workflow %s should assign to same shard", wf) + } + }) +} + +func TestPlugin_StateTransitions(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + + config := ocr3types.ReportingPluginConfig{ + N: 4, F: 1, + } + + // Use short time to sync for testing + plugin, err := NewPlugin(store, &mockArbiter{}, config, lggr, &ConsensusConfig{ + BatchSize: 100, + TimeToSync: 1 * time.Second, + }) + require.NoError(t, err) + + ctx := t.Context() + now := time.Now() + + // Test 1: Initial state with no previous outcome + t.Run("initial_state", func(t *testing.T) { + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 1, + PreviousOutcome: nil, + } + + // Only 1 healthy shard in observations with wantShards=1 + aos := makeObservationsWithWantShards(t, []map[uint32]*pb.ShardStatus{ + {0: {IsHealthy: true}}, + {0: {IsHealthy: true}}, + {0: {IsHealthy: true}}, + }, []string{"wf-1"}, now, 1) + + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + + outcomeProto := &pb.Outcome{} + err = proto.Unmarshal(outcome, outcomeProto) + require.NoError(t, err) + + // Should be in stable state with min shard count + require.NotNil(t, outcomeProto.State.GetRoutableShards()) + require.Equal(t, uint32(1), outcomeProto.State.GetRoutableShards()) + t.Logf("Initial state: %d routable shards", outcomeProto.State.GetRoutableShards()) + }) + + // Test 2: Transition triggered when wantShards changes + t.Run("transition_triggered", func(t *testing.T) { + // Start with 1 shard in stable state + priorOutcome := &pb.Outcome{ + State: &pb.RoutingState{ + Id: 1, + State: &pb.RoutingState_RoutableShards{ + RoutableShards: 1, + }, + }, + Routes: map[string]*pb.WorkflowRoute{}, + } + priorBytes, err := proto.Marshal(priorOutcome) + require.NoError(t, err) + + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 2, + PreviousOutcome: priorBytes, + } + + // Observations show 2 healthy shards and wantShards=2 + aos := makeObservationsWithWantShards(t, twoHealthyShards, []string{"wf-1"}, now, 2) + + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + + outcomeProto := &pb.Outcome{} + err = proto.Unmarshal(outcome, outcomeProto) + require.NoError(t, err) + + // Should transition to Transition state + transition := outcomeProto.State.GetTransition() + require.NotNil(t, transition, "should be in transition state") + require.Equal(t, uint32(2), transition.WantShards, "want 2 shards") + require.Equal(t, uint32(1), transition.LastStableCount, "was at 1 shard") + require.True(t, transition.ChangesSafeAfter.AsTime().After(now), "safety period should be in future") + t.Logf("Transition: %d → %d, safe after %v", transition.LastStableCount, transition.WantShards, transition.ChangesSafeAfter.AsTime()) + }) + + // Test 3: Stay in transition during safety period + t.Run("stay_in_transition", func(t *testing.T) { + safeAfter := now.Add(1 * time.Hour) + priorOutcome := &pb.Outcome{ + State: &pb.RoutingState{ + Id: 2, + State: &pb.RoutingState_Transition{ + Transition: &pb.Transition{ + WantShards: 2, + LastStableCount: 1, + ChangesSafeAfter: timestamppb.New(safeAfter), + }, + }, + }, + Routes: map[string]*pb.WorkflowRoute{}, + } + priorBytes, err := proto.Marshal(priorOutcome) + require.NoError(t, err) + + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 3, + PreviousOutcome: priorBytes, + } + + // Still showing 2 healthy shards with wantShards=2, but safety period not elapsed + aos := makeObservationsWithWantShards(t, twoHealthyShards, []string{"wf-1"}, now, 2) + + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + + outcomeProto := &pb.Outcome{} + err = proto.Unmarshal(outcome, outcomeProto) + require.NoError(t, err) + + // Should still be in transition state + transition := outcomeProto.State.GetTransition() + require.NotNil(t, transition, "should still be in transition") + require.Equal(t, uint32(2), transition.WantShards) + t.Logf("Still in transition, waiting for safety period") + }) + + // Test 4: Complete transition after safety period + t.Run("complete_transition", func(t *testing.T) { + safeAfter := now.Add(-1 * time.Second) // Safety period already passed + priorOutcome := &pb.Outcome{ + State: &pb.RoutingState{ + Id: 2, + State: &pb.RoutingState_Transition{ + Transition: &pb.Transition{ + WantShards: 2, + LastStableCount: 1, + ChangesSafeAfter: timestamppb.New(safeAfter), + }, + }, + }, + Routes: map[string]*pb.WorkflowRoute{}, + } + priorBytes, err := proto.Marshal(priorOutcome) + require.NoError(t, err) + + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 3, + PreviousOutcome: priorBytes, + } + + aos := makeObservationsWithWantShards(t, twoHealthyShards, []string{"wf-1"}, now, 2) + + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + + outcomeProto := &pb.Outcome{} + err = proto.Unmarshal(outcome, outcomeProto) + require.NoError(t, err) + + // Should now be in stable state with 2 shards + require.NotNil(t, outcomeProto.State.GetRoutableShards(), "should be in stable state") + require.Equal(t, uint32(2), outcomeProto.State.GetRoutableShards()) + require.Equal(t, uint64(3), outcomeProto.State.Id, "state ID should increment") + t.Logf("Transition complete: now at %d routable shards", outcomeProto.State.GetRoutableShards()) + }) + + // Test 5: Stay stable when wantShards matches current + t.Run("stay_stable", func(t *testing.T) { + priorOutcome := &pb.Outcome{ + State: &pb.RoutingState{ + Id: 3, + State: &pb.RoutingState_RoutableShards{ + RoutableShards: 2, + }, + }, + Routes: map[string]*pb.WorkflowRoute{}, + } + priorBytes, err := proto.Marshal(priorOutcome) + require.NoError(t, err) + + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 4, + PreviousOutcome: priorBytes, + } + + // Same 2 healthy shards with wantShards=2 + aos := makeObservationsWithWantShards(t, twoHealthyShards, []string{"wf-1"}, now, 2) + + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + + outcomeProto := &pb.Outcome{} + err = proto.Unmarshal(outcome, outcomeProto) + require.NoError(t, err) + + // Should stay in stable state, ID unchanged + require.NotNil(t, outcomeProto.State.GetRoutableShards()) + require.Equal(t, uint32(2), outcomeProto.State.GetRoutableShards()) + require.Equal(t, uint64(3), outcomeProto.State.Id, "state ID should not change when stable") + t.Logf("Staying stable at %d routable shards", outcomeProto.State.GetRoutableShards()) + }) +} + +func makeObservations(t *testing.T, shardStatuses []map[uint32]*pb.ShardStatus, workflows []string, now time.Time) []types.AttributedObservation { + return makeObservationsWithWantShards(t, shardStatuses, workflows, now, 0) +} + +func makeObservationsWithWantShards(t *testing.T, shardStatuses []map[uint32]*pb.ShardStatus, workflows []string, now time.Time, wantShards uint32) []types.AttributedObservation { + aos := make([]types.AttributedObservation, 0, len(shardStatuses)) + for i, status := range shardStatuses { + pbObs := &pb.Observation{ + ShardStatus: status, + WorkflowIds: workflows, + Now: timestamppb.New(now), + WantShards: wantShards, + } + rawObs, err := proto.Marshal(pbObs) + require.NoError(t, err) + + aos = append(aos, types.AttributedObservation{ + Observation: rawObs, + Observer: commontypes.OracleID(i), + }) + } + return aos +} + +func TestPlugin_NewPlugin_NilArbiter(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + config := ocr3types.ReportingPluginConfig{N: 4, F: 1} + + _, err := NewPlugin(store, nil, config, lggr, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "RingOCR arbiterScaler is required") +} + +func TestPlugin_getHealthyShards(t *testing.T) { + tests := []struct { + name string + votes map[uint32]int // shardID -> vote count + f int + want int + }{ + {"all healthy", map[uint32]int{0: 2, 1: 2, 2: 2}, 1, 3}, + {"some unhealthy", map[uint32]int{0: 2, 1: 1, 2: 2}, 1, 2}, + {"none healthy", map[uint32]int{0: 1, 1: 1}, 1, 0}, + {"higher F threshold", map[uint32]int{0: 3, 1: 2, 2: 3}, 2, 2}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + plugin := &Plugin{ + store: NewStore(), + config: ocr3types.ReportingPluginConfig{F: tc.f}, + } + got := plugin.getHealthyShards(tc.votes) + require.Equal(t, tc.want, len(got)) + }) + } +} + +func TestPlugin_NoHealthyShardsFallbackToShardZero(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + + // Set all shards unhealthy - store starts in transition state + store.SetAllShardHealth(map[uint32]bool{0: false, 1: false, 2: false}) + + config := ocr3types.ReportingPluginConfig{ + N: 4, F: 1, + } + + arbiter := &mockArbiter{} + plugin, err := NewPlugin(store, arbiter, config, lggr, &ConsensusConfig{ + BatchSize: 100, + TimeToSync: 1 * time.Second, + }) + require.NoError(t, err) + + transmitter := NewTransmitter(lggr, store, arbiter, "test-account") + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + + // Start a goroutine that requests allocation (will block waiting for OCR) + resultCh := make(chan uint32) + errCh := make(chan error, 1) + go func() { + shard, err := store.GetShardForWorkflow(ctx, "workflow-123") + if err != nil { + errCh <- err + return + } + resultCh <- shard + }() + + // Give goroutine time to enqueue request + time.Sleep(10 * time.Millisecond) + + // Verify request is pending for OCR consensus + pending := store.GetPendingAllocations() + require.Contains(t, pending, "workflow-123") + + // Simulate OCR round with observations showing no healthy shards + // The pending allocation "workflow-123" should be included in observation + now := time.Now() + aos := make([]types.AttributedObservation, 3) + for i := 0; i < 3; i++ { + pbObs := &pb.Observation{ + ShardStatus: toShardStatus(map[uint32]bool{0: false, 1: false, 2: false}), + WorkflowIds: []string{"workflow-123"}, + Now: timestamppb.New(now), + } + rawObs, err := proto.Marshal(pbObs) + require.NoError(t, err) + aos[i] = types.AttributedObservation{ + Observation: rawObs, + Observer: commontypes.OracleID(i), + } + } + + // Use a previous outcome in steady state so we can test the fallback + priorOutcome := &pb.Outcome{ + State: &pb.RoutingState{ + Id: 1, + State: &pb.RoutingState_RoutableShards{RoutableShards: 3}, + }, + Routes: map[string]*pb.WorkflowRoute{}, + } + priorBytes, err := proto.Marshal(priorOutcome) + require.NoError(t, err) + + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 2, + PreviousOutcome: priorBytes, + } + + // Run plugin Outcome phase + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + + // Transmit the outcome (applies routes to store) + reports, err := plugin.Reports(ctx, 2, outcome) + require.NoError(t, err) + require.Len(t, reports, 1) + + err = transmitter.Transmit(ctx, types.ConfigDigest{}, 2, reports[0].ReportWithInfo, nil) + require.NoError(t, err) + + // Blocked goroutine should now receive result from OCR - should be shard 0 (fallback) + select { + case shard := <-resultCh: + require.Equal(t, uint32(0), shard, "should fallback to shard 0 when no healthy shards") + case err := <-errCh: + t.Fatalf("unexpected error: %v", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("allocation was not fulfilled by OCR") + } + + // Verify the outcome assigned workflow-123 to shard 0 + outcomeProto := &pb.Outcome{} + err = proto.Unmarshal(outcome, outcomeProto) + require.NoError(t, err) + + route, exists := outcomeProto.Routes["workflow-123"] + require.True(t, exists, "workflow-123 should be in routes") + require.Equal(t, uint32(0), route.Shard, "workflow-123 should be assigned to shard 0 (fallback)") +} + +func TestPlugin_ObservationQuorum(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + config := ocr3types.ReportingPluginConfig{N: 4, F: 1} + plugin, err := NewPlugin(store, &mockArbiter{}, config, lggr, nil) + require.NoError(t, err) + + ctx := context.Background() + outctx := ocr3types.OutcomeContext{} + + t.Run("quorum_reached", func(t *testing.T) { + // Need 2F+1 = 3 observations for quorum with N=4, F=1 + aos := make([]types.AttributedObservation, 3) + for i := range aos { + aos[i] = types.AttributedObservation{Observer: commontypes.OracleID(i)} + } + + quorum, err := plugin.ObservationQuorum(ctx, outctx, nil, aos) + require.NoError(t, err) + require.True(t, quorum) + }) + + t.Run("quorum_not_reached", func(t *testing.T) { + // Only 2 observations - not enough for quorum + aos := make([]types.AttributedObservation, 2) + for i := range aos { + aos[i] = types.AttributedObservation{Observer: commontypes.OracleID(i)} + } + + quorum, err := plugin.ObservationQuorum(ctx, outctx, nil, aos) + require.NoError(t, err) + require.False(t, quorum) + }) + + t.Run("exact_quorum", func(t *testing.T) { + // Exactly 2F+1 = 3 observations + aos := make([]types.AttributedObservation, 3) + for i := range aos { + aos[i] = types.AttributedObservation{Observer: commontypes.OracleID(i)} + } + + quorum, err := plugin.ObservationQuorum(ctx, outctx, nil, aos) + require.NoError(t, err) + require.True(t, quorum) + }) + + t.Run("all_observations", func(t *testing.T) { + // All N=4 observations + aos := make([]types.AttributedObservation, 4) + for i := range aos { + aos[i] = types.AttributedObservation{Observer: commontypes.OracleID(i)} + } + + quorum, err := plugin.ObservationQuorum(ctx, outctx, nil, aos) + require.NoError(t, err) + require.True(t, quorum) + }) +} diff --git a/pkg/workflows/ring/state.go b/pkg/workflows/ring/state.go new file mode 100644 index 0000000000..62c26a5b18 --- /dev/null +++ b/pkg/workflows/ring/state.go @@ -0,0 +1,76 @@ +package ring + +import ( + "errors" + "time" + + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" +) + +func IsInSteadyState(state *pb.RoutingState) bool { + if state == nil { + return false + } + _, ok := state.State.(*pb.RoutingState_RoutableShards) + return ok +} + +func NextStateFromSteady(currentID uint64, currentShards, wantShards uint32, now time.Time, timeToSync time.Duration) *pb.RoutingState { + if currentShards == wantShards { + return &pb.RoutingState{ + Id: currentID, + State: &pb.RoutingState_RoutableShards{RoutableShards: currentShards}, + } + } + + return &pb.RoutingState{ + Id: currentID + 1, + State: &pb.RoutingState_Transition{ + Transition: &pb.Transition{ + WantShards: wantShards, + LastStableCount: currentShards, + ChangesSafeAfter: timestamppb.New(now.Add(timeToSync)), + }, + }, + } +} + +func NextStateFromTransition(currentID uint64, transition *pb.Transition, now time.Time) *pb.RoutingState { + safeAfter := transition.ChangesSafeAfter.AsTime() + + if now.Before(safeAfter) { + return &pb.RoutingState{ + Id: currentID, + State: &pb.RoutingState_Transition{ + Transition: transition, + }, + } + } + + return &pb.RoutingState{ + Id: currentID + 1, + State: &pb.RoutingState_RoutableShards{ + RoutableShards: transition.WantShards, + }, + } +} + +func NextState(current *pb.RoutingState, wantShards uint32, now time.Time, timeToSync time.Duration) (*pb.RoutingState, error) { + if current == nil { + return nil, errors.New("current state is nil") + } + + switch s := current.State.(type) { + case *pb.RoutingState_RoutableShards: + return NextStateFromSteady(current.Id, s.RoutableShards, wantShards, now, timeToSync), nil + + case *pb.RoutingState_Transition: + return NextStateFromTransition(current.Id, s.Transition, now), nil + + // coverage:ignore + default: + return nil, errors.New("unknown state type") + } +} diff --git a/pkg/workflows/ring/state_test.go b/pkg/workflows/ring/state_test.go new file mode 100644 index 0000000000..af6bd3f731 --- /dev/null +++ b/pkg/workflows/ring/state_test.go @@ -0,0 +1,265 @@ +package ring + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" +) + +func TestStateTransitionDeterminism(t *testing.T) { + now := time.Unix(0, 0) + timeToSync := 5 * time.Minute + + current := &pb.RoutingState{ + Id: 1, + State: &pb.RoutingState_RoutableShards{RoutableShards: 2}, + } + + // Same inputs should produce identical outputs + result1, err := NextState(current, 4, now, timeToSync) + require.NoError(t, err) + + result2, err := NextState(current, 4, now, timeToSync) + require.NoError(t, err) + + require.Equal(t, result1.Id, result2.Id) + require.Equal(t, result1.GetTransition().WantShards, result2.GetTransition().WantShards) + require.Equal(t, result1.GetTransition().LastStableCount, result2.GetTransition().LastStableCount) + require.Equal(t, result1.GetTransition().ChangesSafeAfter.AsTime(), result2.GetTransition().ChangesSafeAfter.AsTime()) +} + +// ∀ state, inputs: NextState(state, inputs).Id >= state.Id +func TestFV_StateIDMonotonicity(t *testing.T) { + timeToSync := 5 * time.Minute + baseTime := time.Unix(0, 0) + + testCases := []struct { + name string + state *pb.RoutingState + now time.Time + }{ + // Steady state cases + {"steady_same_shards", steadyState(10, 3), baseTime}, + {"steady_more_shards", steadyState(10, 3), baseTime}, + {"steady_fewer_shards", steadyState(10, 3), baseTime}, + // Transition state cases + {"transition_before_safe", transitionState(10, 3, 5, baseTime.Add(1*time.Hour)), baseTime}, + {"transition_at_safe", transitionState(10, 3, 5, baseTime), baseTime}, + {"transition_after_safe", transitionState(10, 3, 5, baseTime.Add(-1*time.Second)), baseTime}, + } + + shardCounts := []uint32{1, 2, 3, 5, 10} + + for _, tc := range testCases { + for _, wantShards := range shardCounts { + t.Run(tc.name, func(t *testing.T) { + result, err := NextState(tc.state, wantShards, tc.now, timeToSync) + require.NoError(t, err) + + // INVARIANT: ID never decreases + require.GreaterOrEqual(t, result.Id, tc.state.Id, + "state ID must be monotonically non-decreasing") + }) + } + } +} + +// The state machine only produces valid transitions: +// - Steady → Steady (when shards unchanged) +// - Steady → Transition (when shards change) +// - Transition → Transition (before safety period) +// - Transition → Steady (after safety period) +func TestFV_ValidStateTransitions(t *testing.T) { + timeToSync := 5 * time.Minute + baseTime := time.Unix(0, 0) + + t.Run("steady_to_steady_when_unchanged", func(t *testing.T) { + for _, shards := range []uint32{1, 2, 3, 5, 10} { + state := steadyState(1, shards) + result, err := NextState(state, shards, baseTime, timeToSync) + require.NoError(t, err) + + // Must remain steady with same shard count + require.True(t, IsInSteadyState(result)) + require.Equal(t, shards, result.GetRoutableShards()) + require.Equal(t, state.Id, result.Id, "ID unchanged when no transition") + } + }) + + t.Run("steady_to_transition_when_changed", func(t *testing.T) { + transitions := [][2]uint32{{1, 2}, {2, 1}, {3, 5}, {5, 3}, {1, 10}} + for _, tr := range transitions { + current, want := tr[0], tr[1] + state := steadyState(1, current) + result, err := NextState(state, want, baseTime, timeToSync) + require.NoError(t, err) + + // Must enter transition + require.False(t, IsInSteadyState(result)) + require.NotNil(t, result.GetTransition()) + require.Equal(t, want, result.GetTransition().WantShards) + require.Equal(t, current, result.GetTransition().LastStableCount) + require.Equal(t, state.Id+1, result.Id) + } + }) + + t.Run("transition_stays_before_safe_time", func(t *testing.T) { + safeAfter := baseTime.Add(1 * time.Hour) + for _, wantShards := range []uint32{1, 2, 5} { + state := transitionState(5, 2, wantShards, safeAfter) + result, err := NextState(state, wantShards, baseTime, timeToSync) + require.NoError(t, err) + + // Must remain in transition + require.False(t, IsInSteadyState(result)) + require.Equal(t, state.Id, result.Id, "ID unchanged while waiting") + } + }) + + t.Run("transition_completes_after_safe_time", func(t *testing.T) { + safeAfter := baseTime.Add(-1 * time.Second) + for _, wantShards := range []uint32{1, 2, 5} { + state := transitionState(5, 2, wantShards, safeAfter) + result, err := NextState(state, wantShards, baseTime, timeToSync) + require.NoError(t, err) + + // Must complete to steady + require.True(t, IsInSteadyState(result)) + require.Equal(t, wantShards, result.GetRoutableShards()) + require.Equal(t, state.Id+1, result.Id) + } + }) +} + +// ∀ transition: completion occurs iff now >= safeAfter +func TestFV_SafetyPeriodEnforcement(t *testing.T) { + timeToSync := 5 * time.Minute + baseTime := time.Unix(0, 0) + + // Test various time offsets relative to safeAfter + offsets := []time.Duration{ + -1 * time.Hour, + -1 * time.Minute, + -1 * time.Second, + -1 * time.Nanosecond, + 0, + 1 * time.Nanosecond, + 1 * time.Second, + 1 * time.Minute, + 1 * time.Hour, + } + + for _, offset := range offsets { + safeAfter := baseTime + now := baseTime.Add(offset) + state := transitionState(1, 2, 5, safeAfter) + + result, err := NextState(state, 5, now, timeToSync) + require.NoError(t, err) + + shouldComplete := !now.Before(safeAfter) + didComplete := IsInSteadyState(result) + + require.Equal(t, shouldComplete, didComplete, + "offset=%v: safety period enforcement failed", offset) + } +} + +// When entering transition, WantShards equals the requested shard count +// When completing transition, final shard count equals WantShards +func TestFV_TransitionPreservesTarget(t *testing.T) { + timeToSync := 5 * time.Minute + baseTime := time.Unix(0, 0) + + for _, currentShards := range []uint32{1, 2, 3, 5} { + for _, wantShards := range []uint32{1, 2, 3, 5} { + if currentShards == wantShards { + continue // No transition occurs + } + + // Step 1: Enter transition + state := steadyState(0, currentShards) + afterEnter, err := NextState(state, wantShards, baseTime, timeToSync) + require.NoError(t, err) + require.Equal(t, wantShards, afterEnter.GetTransition().WantShards, + "transition must preserve target shard count") + + // Step 2: Complete transition (after safety period) + afterComplete, err := NextState(afterEnter, wantShards, baseTime.Add(timeToSync+time.Second), timeToSync) + require.NoError(t, err) + require.Equal(t, wantShards, afterComplete.GetRoutableShards(), + "completed state must have target shard count") + } + } +} + +// ∀ transition: ∃ time t where transition completes (no infinite loops) +func TestFV_EventualCompletion(t *testing.T) { + timeToSync := 5 * time.Minute + baseTime := time.Unix(0, 0) + + state := steadyState(0, 2) + + // Enter transition + state, err := NextState(state, 5, baseTime, timeToSync) + require.NoError(t, err) + require.False(t, IsInSteadyState(state)) + + // Simulate time progression - must complete within safety period + completionTime := baseTime.Add(timeToSync) + state, err = NextState(state, 5, completionTime, timeToSync) + require.NoError(t, err) + + require.True(t, IsInSteadyState(state), "transition must eventually complete") +} + +// ∀ state: exactly one of (IsInSteadyState, IsInTransition) is true +func TestFV_StateTypeExclusivity(t *testing.T) { + states := []*pb.RoutingState{ + steadyState(0, 1), + steadyState(5, 3), + transitionState(0, 1, 2, time.Now()), + transitionState(5, 3, 5, time.Now().Add(time.Hour)), + } + + for i, state := range states { + isSteady := IsInSteadyState(state) + _, isTransition := state.State.(*pb.RoutingState_Transition) + + require.NotEqual(t, isSteady, isTransition, + "state %d: exactly one state type must be true", i) + } +} + +// IsInSteadyState(nil) = false (safe handling of nil) +// NextState(nil, ...) returns error (explicit failure) +func TestFV_NilStateSafety(t *testing.T) { + require.False(t, IsInSteadyState(nil), "nil state must not be steady") + + _, err := NextState(nil, 1, time.Now(), time.Minute) + require.Error(t, err, "NextState must reject nil input") +} + +func steadyState(id uint64, shards uint32) *pb.RoutingState { + return &pb.RoutingState{ + Id: id, + State: &pb.RoutingState_RoutableShards{RoutableShards: shards}, + } +} + +func transitionState(id uint64, lastStable, wantShards uint32, safeAfter time.Time) *pb.RoutingState { + return &pb.RoutingState{ + Id: id, + State: &pb.RoutingState_Transition{ + Transition: &pb.Transition{ + WantShards: wantShards, + LastStableCount: lastStable, + ChangesSafeAfter: timestamppb.New(safeAfter), + }, + }, + } +} diff --git a/pkg/workflows/ring/store.go b/pkg/workflows/ring/store.go new file mode 100644 index 0000000000..e7cf31d8d5 --- /dev/null +++ b/pkg/workflows/ring/store.go @@ -0,0 +1,219 @@ +package ring + +import ( + "context" + "maps" + "slices" + "sync" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" +) + +// AllocationRequest represents a pending workflow allocation request during transition +type AllocationRequest struct { + WorkflowID string + Result chan uint32 +} + +// Store manages shard routing state and workflow mappings. +// It serves as a shared data layer across three components: +// - RingOCR plugin: produces consensus-driven routing updates +// - Arbiter: provides shard health and scaling decisions +// - ShardOrchestrator: consumes routing state to direct workflow execution +type Store struct { + routingState map[string]uint32 // workflow_id -> shard_id (cache of allocated workflows) + shardHealth map[uint32]bool // shard_id -> is_healthy + healthyShards []uint32 // Sorted list of healthy shards + currentState *pb.RoutingState // Current routing state (steady or transition) + + pendingAllocs map[string][]chan uint32 // workflow_id -> waiting channels + allocRequests chan AllocationRequest // Channel for new allocation requests + + mu sync.Mutex +} + +const AllocationRequestChannelCapacity = 1000 + +func NewStore() *Store { + return &Store{ + routingState: make(map[string]uint32), + shardHealth: make(map[uint32]bool), + healthyShards: make([]uint32, 0), + pendingAllocs: make(map[string][]chan uint32), + allocRequests: make(chan AllocationRequest, AllocationRequestChannelCapacity), + mu: sync.Mutex{}, + } +} + +func (s *Store) updateHealthyShards() { + s.mu.Lock() + defer s.mu.Unlock() + + s.healthyShards = make([]uint32, 0) + + for shardID, healthy := range s.shardHealth { + if healthy { + s.healthyShards = append(s.healthyShards, shardID) + } + } + + // Sort for determinism + slices.Sort(s.healthyShards) + + // If no healthy shards, add shard 0 as fallback + if len(s.healthyShards) == 0 { + s.healthyShards = []uint32{0} + } +} + +// GetShardForWorkflow called by Workflow Registry Syncers of all shards via ShardOrchestratorService. +func (s *Store) GetShardForWorkflow(ctx context.Context, workflowID string) (uint32, error) { + s.mu.Lock() + + // Only trust the cache in steady state; during transition OCR may have invalidated it + if IsInSteadyState(s.currentState) { + // Check if already allocated in cache + if shard, ok := s.routingState[workflowID]; ok { + s.mu.Unlock() + return shard, nil + } + ring := newShardRing(s.healthyShards) + s.mu.Unlock() + return locateShard(ring, workflowID) + } + + // During transition, defer to OCR consensus for consistent shard assignment across nodes + resultCh := make(chan uint32, 1) + s.pendingAllocs[workflowID] = append(s.pendingAllocs[workflowID], resultCh) + s.mu.Unlock() + + select { + case s.allocRequests <- AllocationRequest{WorkflowID: workflowID, Result: resultCh}: + case <-ctx.Done(): + return 0, ctx.Err() + } + + select { + case shard := <-resultCh: + return shard, nil + case <-ctx.Done(): + return 0, ctx.Err() + } +} + +// SetShardForWorkflow is called by the RingOCR plugin whenever it finishes a round with allocations for a given workflow ID. +func (s *Store) SetShardForWorkflow(workflowID string, shardID uint32) { + s.mu.Lock() + defer s.mu.Unlock() + + s.routingState[workflowID] = shardID + + // Signal any waiting allocation requests + if waiters, ok := s.pendingAllocs[workflowID]; ok { + for _, ch := range waiters { + select { + case ch <- shardID: + default: + } + } + delete(s.pendingAllocs, workflowID) + } +} + +// SetRoutingState is called by the RingOCR plugin whenever a state transition happens. +func (s *Store) SetRoutingState(state *pb.RoutingState) { + s.mu.Lock() + defer s.mu.Unlock() + s.currentState = state +} + +func (s *Store) GetRoutingState() *pb.RoutingState { + s.mu.Lock() + defer s.mu.Unlock() + return s.currentState +} + +// GetPendingAllocations called by the RingOCR plugin in the observation phase +// to collect all allocation requests (only applicable to the TRANSITION phase). +func (s *Store) GetPendingAllocations() []string { + var pending []string + for { + select { + case req := <-s.allocRequests: + pending = append(pending, req.WorkflowID) + default: + return pending + } + } +} + +func (s *Store) IsInTransition() bool { + s.mu.Lock() + defer s.mu.Unlock() + return !IsInSteadyState(s.currentState) +} + +func (s *Store) GetShardHealth() map[uint32]bool { + s.mu.Lock() + defer s.mu.Unlock() + return maps.Clone(s.shardHealth) +} + +func (s *Store) SetShardHealth(shardID uint32, healthy bool) { + s.mu.Lock() + s.shardHealth[shardID] = healthy + s.mu.Unlock() + s.updateHealthyShards() +} + +func (s *Store) SetAllShardHealth(health map[uint32]bool) { + s.mu.Lock() + s.shardHealth = make(map[uint32]bool) + for k, v := range health { + s.shardHealth[k] = v + } + + // Uninitialized store must wait for OCR consensus before serving requests + if s.currentState == nil { + numHealthy := uint32(0) + for _, healthy := range health { + if healthy { + numHealthy++ + } + } + s.currentState = &pb.RoutingState{ + State: &pb.RoutingState_Transition{ + Transition: &pb.Transition{ + WantShards: numHealthy, + }, + }, + } + } + s.mu.Unlock() + + s.updateHealthyShards() +} + +func (s *Store) GetAllRoutingState() map[string]uint32 { + s.mu.Lock() + defer s.mu.Unlock() + return maps.Clone(s.routingState) +} + +func (s *Store) DeleteWorkflow(workflowID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.routingState, workflowID) +} + +func (s *Store) GetHealthyShardCount() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.healthyShards) +} + +func (s *Store) GetHealthyShards() []uint32 { + s.mu.Lock() + defer s.mu.Unlock() + return slices.Clone(s.healthyShards) +} diff --git a/pkg/workflows/ring/store_test.go b/pkg/workflows/ring/store_test.go new file mode 100644 index 0000000000..c986ea9f0d --- /dev/null +++ b/pkg/workflows/ring/store_test.go @@ -0,0 +1,313 @@ +package ring + +import ( + "context" + "testing" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/stretchr/testify/require" +) + +func TestStore_DeterministicHashing(t *testing.T) { + store := NewStore() + + // Set up healthy shards + store.SetAllShardHealth(map[uint32]bool{ + 0: true, + 1: true, + 2: true, + }) + // Simulate OCR having moved to steady state + store.SetRoutingState(&pb.RoutingState{ + State: &pb.RoutingState_RoutableShards{RoutableShards: 3}, + }) + + ctx := context.Background() + + // Test determinism: same workflow always gets same shard + shard1, err := store.GetShardForWorkflow(ctx, "workflow-123") + require.NoError(t, err) + shard2, err := store.GetShardForWorkflow(ctx, "workflow-123") + require.NoError(t, err) + shard3, err := store.GetShardForWorkflow(ctx, "workflow-123") + require.NoError(t, err) + + require.Equal(t, shard1, shard2, "Same workflow should get same shard (call 2)") + require.Equal(t, shard2, shard3, "Same workflow should get same shard (call 3)") + require.True(t, shard1 >= 0 && shard1 <= 2, "Shard should be in healthy set") +} + +func TestStore_ConsistentRingConsistency(t *testing.T) { + store1 := NewStore() + store2 := NewStore() + store3 := NewStore() + + // All stores with same healthy shards + healthyShards := map[uint32]bool{0: true, 1: true, 2: true} + steadyState := &pb.RoutingState{ + State: &pb.RoutingState_RoutableShards{RoutableShards: 3}, + } + store1.SetAllShardHealth(healthyShards) + store1.SetRoutingState(steadyState) + store2.SetAllShardHealth(healthyShards) + store2.SetRoutingState(steadyState) + store3.SetAllShardHealth(healthyShards) + store3.SetRoutingState(steadyState) + + ctx := context.Background() + + // All compute same assignments + workflows := []string{"workflow-A", "workflow-B", "workflow-C", "workflow-D"} + for _, wf := range workflows { + s1, err := store1.GetShardForWorkflow(ctx, wf) + require.NoError(t, err) + s2, err := store2.GetShardForWorkflow(ctx, wf) + require.NoError(t, err) + s3, err := store3.GetShardForWorkflow(ctx, wf) + require.NoError(t, err) + + require.Equal(t, s1, s2, "All nodes should agree on %s assignment", wf) + require.Equal(t, s2, s3, "All nodes should agree on %s assignment", wf) + } +} + +func TestStore_Rebalancing(t *testing.T) { + store := NewStore() + ctx := context.Background() + + // Start with 3 healthy shards + store.SetAllShardHealth(map[uint32]bool{0: true, 1: true, 2: true}) + store.SetRoutingState(&pb.RoutingState{ + State: &pb.RoutingState_RoutableShards{RoutableShards: 3}, + }) + assignments1 := make(map[string]uint32) + for i := 1; i <= 10; i++ { + wfID := "workflow-" + string(rune(i)) + shard, err := store.GetShardForWorkflow(ctx, wfID) + require.NoError(t, err) + assignments1[wfID] = shard + } + + // Shard 1 fails + store.SetShardHealth(1, false) + assignments2 := make(map[string]uint32) + for i := 1; i <= 10; i++ { + wfID := "workflow-" + string(rune(i)) + shard, err := store.GetShardForWorkflow(ctx, wfID) + require.NoError(t, err) + assignments2[wfID] = shard + } + + // Check that rebalancing occurred (some workflows moved) + healthyShards := store.GetHealthyShards() + require.Equal(t, 2, len(healthyShards), "Should have 2 healthy shards") + require.NotContains(t, healthyShards, uint32(1), "Shard 1 should not be healthy") + + // Verify that workflows on healthy shards did not move + for wfID, originalShard := range assignments1 { + if originalShard == 0 || originalShard == 2 { + require.Equal(t, originalShard, assignments2[wfID], + "Workflow %s on healthy shard %d should not have moved", wfID, originalShard) + } + } +} + +func TestStore_GetHealthyShards(t *testing.T) { + store := NewStore() + + store.SetAllShardHealth(map[uint32]bool{ + 3: true, + 1: true, + 2: true, + }) + + healthyShards := store.GetHealthyShards() + require.Len(t, healthyShards, 3) + // Should be sorted + require.Equal(t, []uint32{1, 2, 3}, healthyShards) +} + +func TestStore_DistributionAcrossShards(t *testing.T) { + store := NewStore() + ctx := context.Background() + + store.SetAllShardHealth(map[uint32]bool{ + 0: true, + 1: true, + 2: true, + }) + store.SetRoutingState(&pb.RoutingState{ + State: &pb.RoutingState_RoutableShards{RoutableShards: 3}, + }) + + // Generate many workflows and check distribution + totalWorkflows := 100 + distribution := make(map[uint32]int) + for i := 0; i < totalWorkflows; i++ { + wfID := "workflow-" + string(rune(i)) + shard, err := store.GetShardForWorkflow(ctx, wfID) + require.NoError(t, err) + distribution[shard]++ + } + + require.Equal(t, totalWorkflows, sum(distribution), "Should have 100 workflows") + + // Each shard should have roughly 33% of workflows (±5%) + for shard, count := range distribution { + pct := float64(count) / 100.0 * 100 + require.GreaterOrEqual(t, pct, 28.0, "Shard %d has too few workflows: %d%%", shard, int(pct)) + require.LessOrEqual(t, pct, 38.0, "Shard %d has too many workflows: %d%%", shard, int(pct)) + } +} + +func sum(distribution map[uint32]int) int { + total := 0 + for _, count := range distribution { + total += count + } + return total +} + +func TestStore_GetShardForWorkflow_CacheHit(t *testing.T) { + store := NewStore() + ctx := context.Background() + + // Set up steady state + store.SetAllShardHealth(map[uint32]bool{0: true, 1: true, 2: true}) + store.SetRoutingState(&pb.RoutingState{ + State: &pb.RoutingState_RoutableShards{RoutableShards: 3}, + }) + + // Pre-populate cache with a specific shard assignment + store.SetShardForWorkflow("cached-workflow", 2) + + // Should return cached value, not recompute + shard, err := store.GetShardForWorkflow(ctx, "cached-workflow") + require.NoError(t, err) + require.Equal(t, uint32(2), shard) +} + +func TestStore_GetShardForWorkflow_ContextCancelledDuringSend(t *testing.T) { + store := NewStore() + + // Put store in transition state + store.SetAllShardHealth(map[uint32]bool{0: true}) + store.SetRoutingState(&pb.RoutingState{ + State: &pb.RoutingState_Transition{ + Transition: &pb.Transition{WantShards: 2}, + }, + }) + + // Fill up the allocRequests channel + for i := 0; i < AllocationRequestChannelCapacity; i++ { + store.allocRequests <- AllocationRequest{WorkflowID: "filler"} + } + + // Context that's already cancelled + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // Should fail: channel is full and context is cancelled + _, err := store.GetShardForWorkflow(ctx, "workflow-123") + require.ErrorIs(t, err, context.Canceled) +} + +func TestStore_PendingAllocsDuringTransition(t *testing.T) { + store := NewStore() + store.SetAllShardHealth(map[uint32]bool{0: true, 1: true}) + + // Put store in transition state + store.SetRoutingState(&pb.RoutingState{ + State: &pb.RoutingState_Transition{ + Transition: &pb.Transition{WantShards: 3}, + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Start a goroutine that requests allocation (will block) + resultCh := make(chan uint32) + go func() { + shard, _ := store.GetShardForWorkflow(ctx, "workflow-X") + resultCh <- shard + }() + + // Give goroutine time to enqueue request + time.Sleep(10 * time.Millisecond) + + // Verify request is pending + pending := store.GetPendingAllocations() + require.Contains(t, pending, "workflow-X") + + // Fulfill the allocation (simulates transmitter receiving OCR outcome) + store.SetShardForWorkflow("workflow-X", 2) + + // Blocked goroutine should now receive result + select { + case shard := <-resultCh: + require.Equal(t, uint32(2), shard) + case <-time.After(50 * time.Millisecond): + t.Fatal("allocation was not fulfilled") + } +} + +func TestStore_AccessorMethods(t *testing.T) { + store := NewStore() + + store.SetAllShardHealth(map[uint32]bool{0: true, 1: true, 2: false}) + store.SetRoutingState(&pb.RoutingState{ + State: &pb.RoutingState_RoutableShards{RoutableShards: 2}, + }) + store.SetShardForWorkflow("wf-1", 0) + store.SetShardForWorkflow("wf-2", 1) + + t.Run("GetRoutingState", func(t *testing.T) { + state := store.GetRoutingState() + require.NotNil(t, state) + require.Equal(t, uint32(2), state.GetRoutableShards()) + }) + + t.Run("IsInTransition_steady_state", func(t *testing.T) { + require.False(t, store.IsInTransition()) + }) + + t.Run("GetShardHealth", func(t *testing.T) { + health := store.GetShardHealth() + require.Len(t, health, 3) + require.True(t, health[0]) + require.True(t, health[1]) + require.False(t, health[2]) + }) + + t.Run("GetAllRoutingState", func(t *testing.T) { + routes := store.GetAllRoutingState() + require.Len(t, routes, 2) + require.Equal(t, uint32(0), routes["wf-1"]) + require.Equal(t, uint32(1), routes["wf-2"]) + }) + + t.Run("GetHealthyShardCount", func(t *testing.T) { + require.Equal(t, 2, store.GetHealthyShardCount()) + }) + + t.Run("DeleteWorkflow", func(t *testing.T) { + store.DeleteWorkflow("wf-1") + routes := store.GetAllRoutingState() + require.Len(t, routes, 1) + require.NotContains(t, routes, "wf-1") + }) + + t.Run("IsInTransition_transition_state", func(t *testing.T) { + store.SetRoutingState(&pb.RoutingState{ + State: &pb.RoutingState_Transition{Transition: &pb.Transition{WantShards: 3}}, + }) + require.True(t, store.IsInTransition()) + }) + + t.Run("IsInTransition_nil_state", func(t *testing.T) { + store.SetRoutingState(nil) + require.True(t, store.IsInTransition()) + }) +} diff --git a/pkg/workflows/ring/transmitter.go b/pkg/workflows/ring/transmitter.go new file mode 100644 index 0000000000..524be65be1 --- /dev/null +++ b/pkg/workflows/ring/transmitter.go @@ -0,0 +1,76 @@ +package ring + +import ( + "context" + + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" +) + +var _ ocr3types.ContractTransmitter[[]byte] = (*Transmitter)(nil) + +// Transmitter handles transmission of shard orchestration outcomes +type Transmitter struct { + lggr logger.Logger + store *Store + arbiterScaler pb.ArbiterScalerClient + fromAccount types.Account +} + +func NewTransmitter(lggr logger.Logger, store *Store, arbiterScaler pb.ArbiterScalerClient, fromAccount types.Account) *Transmitter { + return &Transmitter{lggr: lggr, store: store, arbiterScaler: arbiterScaler, fromAccount: fromAccount} +} + +func (t *Transmitter) Transmit(ctx context.Context, _ types.ConfigDigest, _ uint64, r ocr3types.ReportWithInfo[[]byte], _ []types.AttributedOnchainSignature) error { + outcome := &pb.Outcome{} + if err := proto.Unmarshal(r.Report, outcome); err != nil { + t.lggr.Errorf("failed to unmarshal report") + return err + } + + if err := t.notifyArbiter(ctx, outcome.State); err != nil { + t.lggr.Errorf("failed to notify arbiter", "err", err) + return err + } + + t.store.SetRoutingState(outcome.State) + + for workflowID, route := range outcome.Routes { + t.store.SetShardForWorkflow(workflowID, route.Shard) + t.lggr.Debugw("Updated workflow shard mapping", "workflowID", workflowID, "shard", route.Shard) + } + + return nil +} + +func (t *Transmitter) notifyArbiter(ctx context.Context, state *pb.RoutingState) error { + if state == nil { + return nil + } + + var nShards uint32 + switch s := state.State.(type) { + case *pb.RoutingState_RoutableShards: + nShards = s.RoutableShards + t.lggr.Infow("Transmitting shard routing", "routableShards", nShards) + case *pb.RoutingState_Transition: + nShards = s.Transition.WantShards + t.lggr.Infow("Transmitting shard routing (in transition)", "wantShards", nShards) + } + + if t.arbiterScaler != nil && nShards > 0 { + if _, err := t.arbiterScaler.ConsensusWantShards(ctx, &pb.ConsensusWantShardsRequest{NShards: nShards}); err != nil { + return err + } + } + + return nil +} + +func (t *Transmitter) FromAccount(ctx context.Context) (types.Account, error) { + return t.fromAccount, nil +} diff --git a/pkg/workflows/ring/transmitter_test.go b/pkg/workflows/ring/transmitter_test.go new file mode 100644 index 0000000000..9fde000152 --- /dev/null +++ b/pkg/workflows/ring/transmitter_test.go @@ -0,0 +1,173 @@ +package ring + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/libocr/offchainreporting2plus/types" +) + +type mockArbiterScaler struct { + called bool + nShards uint32 + err error +} + +func (m *mockArbiterScaler) Status(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*pb.ReplicaStatus, error) { + return &pb.ReplicaStatus{}, nil +} + +func (m *mockArbiterScaler) ConsensusWantShards(ctx context.Context, req *pb.ConsensusWantShardsRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) { + m.called = true + m.nShards = req.NShards + if m.err != nil { + return nil, m.err + } + return &emptypb.Empty{}, nil +} + +func TestTransmitter_NewTransmitter(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + tx := NewTransmitter(lggr, store, nil, "test-account") + require.NotNil(t, tx) +} + +func TestTransmitter_FromAccount(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + tx := NewTransmitter(lggr, store, nil, "my-account") + + account, err := tx.FromAccount(context.Background()) + require.NoError(t, err) + require.Equal(t, types.Account("my-account"), account) +} + +func TestTransmitter_Transmit(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + mock := &mockArbiterScaler{} + tx := NewTransmitter(lggr, store, mock, "test-account") + + outcome := &pb.Outcome{ + State: &pb.RoutingState{ + Id: 1, + State: &pb.RoutingState_RoutableShards{RoutableShards: 3}, + }, + Routes: map[string]*pb.WorkflowRoute{ + "wf-1": {Shard: 0}, + "wf-2": {Shard: 1}, + }, + } + outcomeBytes, err := proto.Marshal(outcome) + require.NoError(t, err) + + report := ocr3types.ReportWithInfo[[]byte]{Report: outcomeBytes} + err = tx.Transmit(context.Background(), types.ConfigDigest{}, 0, report, nil) + require.NoError(t, err) + + // Verify arbiter was notified + require.True(t, mock.called) + require.Equal(t, uint32(3), mock.nShards) + + // Verify store was updated + require.Equal(t, uint32(3), store.GetRoutingState().GetRoutableShards()) + routes := store.GetAllRoutingState() + require.Equal(t, uint32(0), routes["wf-1"]) + require.Equal(t, uint32(1), routes["wf-2"]) +} + +func TestTransmitter_Transmit_NilArbiter(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + tx := NewTransmitter(lggr, store, nil, "test-account") + + outcome := &pb.Outcome{ + State: &pb.RoutingState{ + Id: 1, + State: &pb.RoutingState_RoutableShards{RoutableShards: 2}, + }, + Routes: map[string]*pb.WorkflowRoute{"wf-1": {Shard: 0}}, + } + outcomeBytes, _ := proto.Marshal(outcome) + + err := tx.Transmit(context.Background(), types.ConfigDigest{}, 0, ocr3types.ReportWithInfo[[]byte]{Report: outcomeBytes}, nil) + require.NoError(t, err) +} + +func TestTransmitter_Transmit_TransitionState(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + mock := &mockArbiterScaler{} + tx := NewTransmitter(lggr, store, mock, "test-account") + + outcome := &pb.Outcome{ + State: &pb.RoutingState{ + Id: 1, + State: &pb.RoutingState_Transition{ + Transition: &pb.Transition{WantShards: 5}, + }, + }, + } + outcomeBytes, _ := proto.Marshal(outcome) + + err := tx.Transmit(context.Background(), types.ConfigDigest{}, 0, ocr3types.ReportWithInfo[[]byte]{Report: outcomeBytes}, nil) + require.NoError(t, err) + require.Equal(t, uint32(5), mock.nShards) +} + +func TestTransmitter_Transmit_InvalidReport(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + tx := NewTransmitter(lggr, store, nil, "test-account") + + // Send invalid protobuf data + report := ocr3types.ReportWithInfo[[]byte]{Report: []byte("invalid protobuf")} + err := tx.Transmit(context.Background(), types.ConfigDigest{}, 0, report, nil) + require.Error(t, err) +} + +func TestTransmitter_Transmit_ArbiterError(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + mock := &mockArbiterScaler{err: context.DeadlineExceeded} + tx := NewTransmitter(lggr, store, mock, "test-account") + + outcome := &pb.Outcome{ + State: &pb.RoutingState{ + Id: 1, + State: &pb.RoutingState_RoutableShards{RoutableShards: 3}, + }, + } + outcomeBytes, _ := proto.Marshal(outcome) + + err := tx.Transmit(context.Background(), types.ConfigDigest{}, 0, ocr3types.ReportWithInfo[[]byte]{Report: outcomeBytes}, nil) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestTransmitter_Transmit_NilState(t *testing.T) { + lggr := logger.Test(t) + store := NewStore() + tx := NewTransmitter(lggr, store, nil, "test-account") + + outcome := &pb.Outcome{ + State: nil, + Routes: map[string]*pb.WorkflowRoute{"wf-1": {Shard: 0}}, + } + outcomeBytes, _ := proto.Marshal(outcome) + + err := tx.Transmit(context.Background(), types.ConfigDigest{}, 0, ocr3types.ReportWithInfo[[]byte]{Report: outcomeBytes}, nil) + require.NoError(t, err) + + // Routes should still be applied + routes := store.GetAllRoutingState() + require.Equal(t, uint32(0), routes["wf-1"]) +} diff --git a/pkg/workflows/ring/utils.go b/pkg/workflows/ring/utils.go new file mode 100644 index 0000000000..734495279b --- /dev/null +++ b/pkg/workflows/ring/utils.go @@ -0,0 +1,66 @@ +package ring + +import ( + "errors" + "slices" + "strconv" + + "github.com/buraksezer/consistent" + "github.com/cespare/xxhash/v2" +) + +var errInvalidRing = errors.New("RingOCR invalid ring for consistent hashing") +var errInvalidMember = errors.New("RingOCR invalid member for consistent hashing") + +func uniqueSorted(s []string) []string { + result := slices.Clone(s) + slices.Sort(result) + return slices.Compact(result) +} + +type xxhashHasher struct{} + +func (h xxhashHasher) Sum64(data []byte) uint64 { + return xxhash.Sum64(data) +} + +type ShardMember string + +func (m ShardMember) String() string { + return string(m) +} + +func consistentHashConfig() consistent.Config { + return consistent.Config{ + PartitionCount: 997, // Prime number for better distribution + ReplicationFactor: 50, // Number of replicas per node + Load: 1.1, // Load factor for bounded loads + Hasher: xxhashHasher{}, + } +} + +func newShardRing(healthyShards []uint32) *consistent.Consistent { + if len(healthyShards) == 0 { + return nil + } + members := make([]consistent.Member, len(healthyShards)) + for i, shardID := range healthyShards { + members[i] = ShardMember(strconv.FormatUint(uint64(shardID), 10)) + } + return consistent.New(members, consistentHashConfig()) +} + +func locateShard(ring *consistent.Consistent, workflowID string) (uint32, error) { + if ring == nil { + return 0, errInvalidRing + } + member := ring.LocateKey([]byte(workflowID)) + if member == nil { + return 0, errInvalidMember + } + shardID, err := strconv.ParseUint(member.String(), 10, 32) + if err != nil { + return 0, err + } + return uint32(shardID), nil +} diff --git a/pkg/workflows/ring/utils_test.go b/pkg/workflows/ring/utils_test.go new file mode 100644 index 0000000000..65f1c61a62 --- /dev/null +++ b/pkg/workflows/ring/utils_test.go @@ -0,0 +1,12 @@ +package ring + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUniqueSorted(t *testing.T) { + got := uniqueSorted([]string{"c", "a", "b", "a", "c"}) + require.Equal(t, []string{"a", "b", "c"}, got) +} From c4dad9e02e2ac40b13030e2a8029e0473fbc0ccc Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Mon, 5 Jan 2026 08:06:10 -0800 Subject: [PATCH 27/42] [CRE] Log more details of observations in DONTime plugin (#1750) --- pkg/workflows/dontime/plugin.go | 36 +++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/pkg/workflows/dontime/plugin.go b/pkg/workflows/dontime/plugin.go index ece80d7ddd..ffae1fdfe8 100644 --- a/pkg/workflows/dontime/plugin.go +++ b/pkg/workflows/dontime/plugin.go @@ -11,12 +11,13 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/dontime/pb" "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/libocr/quorumhelper" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/dontime/pb" ) type Plugin struct { @@ -121,7 +122,12 @@ func (p *Plugin) ObservationQuorum(_ context.Context, _ ocr3types.OutcomeContext func (p *Plugin) Outcome(_ context.Context, outctx ocr3types.OutcomeContext, _ types.Query, aos []types.AttributedObservation) (ocr3types.Outcome, error) { observationCounts := map[string]int64{} // counts how many nodes reported where a new DON timestamp might be needed - var times []int64 + type timestampNodePair struct { + Timestamp int64 + NodeID int + OffsetFromMedian int64 + } + var timestampNodePairs []timestampNodePair prevOutcome := &pb.Outcome{} if err := proto.Unmarshal(outctx.PreviousOutcome, prevOutcome); err != nil { @@ -131,7 +137,7 @@ func (p *Plugin) Outcome(_ context.Context, outctx ocr3types.OutcomeContext, _ t prevOutcome.ObservedDonTimes = make(map[string]*pb.ObservedDonTimes) } - for _, ao := range aos { + for idx, ao := range aos { observation := &pb.Observation{} if err := proto.Unmarshal(ao.Observation, observation); err != nil { p.lggr.Errorf("failed to unmarshal observation in Outcome phase") @@ -153,12 +159,26 @@ func (p *Plugin) Outcome(_ context.Context, outctx ocr3types.OutcomeContext, _ t } } - times = append(times, observation.Timestamp) + timestampNodePairs = append(timestampNodePairs, timestampNodePair{Timestamp: observation.Timestamp, NodeID: idx}) + } + if len(timestampNodePairs) == 0 { + return nil, errors.New("no observation contains a valid timestamp") } - p.lggr.Debugw("Observed Node Timestamps", "timestamps", times) - slices.Sort(times) - donTime := times[len(times)/2] + slices.SortFunc(timestampNodePairs, func(a, b timestampNodePair) int { + return int(a.Timestamp - b.Timestamp) + }) + donTime := timestampNodePairs[len(timestampNodePairs)/2].Timestamp + for i := range timestampNodePairs { + timestampNodePairs[i].OffsetFromMedian = timestampNodePairs[i].Timestamp - donTime + } + p.lggr.Debugw("Observed Node Timestamps", + "timestampNodePairs", timestampNodePairs, + "median", donTime, + "collectedDataPoints", len(timestampNodePairs), + "minOffsetFromMedian", timestampNodePairs[0].OffsetFromMedian, + "maxOffsetFromMedian", timestampNodePairs[len(timestampNodePairs)-1].OffsetFromMedian, + ) outcome := prevOutcome From fe69589137e482d474a27e40d431dcbe7e86cab7 Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Mon, 5 Jan 2026 20:39:31 -0800 Subject: [PATCH 28/42] FIFO order in ResourcePoolLimiter to support ordered concurrency limits (#1754) * Add FIFOResourcePoolLimiter to support ordered concurrency limits Almost entirely generated by Claude 4.5 * Merge into the existing resourcePoolLimiter instead of creating a new object --- pkg/settings/limits/resource.go | 118 +++++++++++++----- pkg/settings/limits/resource_test.go | 174 +++++++++++++++++++++++++++ 2 files changed, 265 insertions(+), 27 deletions(-) diff --git a/pkg/settings/limits/resource.go b/pkg/settings/limits/resource.go index 75e716e80f..02b2d27305 100644 --- a/pkg/settings/limits/resource.go +++ b/pkg/settings/limits/resource.go @@ -142,13 +142,23 @@ func (l *resourcePoolLimiter[N]) getLimit(ctx context.Context) (limit N) { return limit } +// waiter represents a goroutine waiting for resources in the FIFO queue. +type waiter[N Number] struct { + amount N + ready chan struct{} // closed when resources are granted +} + type resourcePoolUsage[N Number] struct { *resourcePoolLimiter[N] scope settings.Scope // optional tenant string // optional mu sync.Mutex - cond sync.Cond used N + // queue holds waiters in FIFO order; head of slice is first to be serviced + queue []*waiter[N] + // onEnqueue is an optional callback invoked (under lock) when a waiter is added to the queue. + // Used for testing to synchronize without sleeps. + onEnqueue func() recordUsage func(context.Context, N) recordLimit func(context.Context, N) @@ -165,6 +175,7 @@ type resourcePoolUsage[N Number] struct { func (l *resourcePoolLimiter[N]) newLimitUsage(opts ...metric.RecordOption) *resourcePoolUsage[N] { u := resourcePoolUsage[N]{ resourcePoolLimiter: l, + queue: make([]*waiter[N], 0), stopCh: make(services.StopChan), done: make(chan struct{}), recordUsage: func(ctx context.Context, n N) { @@ -193,7 +204,6 @@ func (l *resourcePoolLimiter[N]) newLimitUsage(opts ...metric.RecordOption) *res } }, } - u.cond.L = &u.mu return &u } @@ -207,9 +217,25 @@ func (u *resourcePoolUsage[N]) free(amount N) { defer cancel() u.recordUsage(ctx, u.used) - u.cond.Broadcast() // notify others blocked on cond.Wait + u.tryWakeWaiters() +} - return +// tryWakeWaiters attempts to wake waiters at the head of the queue +// whose resource requests can now be satisfied. +// Must be called with u.mu held. +func (u *resourcePoolUsage[N]) tryWakeWaiters() { + for len(u.queue) > 0 { + head := u.queue[0] + limit := u.getLimit(context.Background()) + if u.used+head.amount > limit { + // Not enough resources for the head waiter; stop here to preserve FIFO + break + } + // Grant resources to head waiter + u.used += head.amount + close(head.ready) + u.queue = u.queue[1:] + } } func (u *resourcePoolUsage[N]) newErrorLimitReached(limit, amount N) ErrorResourceLimited[N] { @@ -241,7 +267,6 @@ func (u *resourcePoolUsage[N]) available(ctx context.Context) (N, error) { return limit - u.used, nil } -// opt: queue instead of racing for the [sync.Mutex] & [sync.Cond] func (u *resourcePoolUsage[N]) use(ctx context.Context, amount N, block bool) error { limit, err := u.get(ctx) if err != nil { @@ -250,34 +275,73 @@ func (u *resourcePoolUsage[N]) use(ctx context.Context, amount N, block bool) er start := time.Now() u.mu.Lock() - defer u.mu.Unlock() - if u.used+amount > limit { - if !block { + // Fast path: resources available immediately and no one else waiting + if len(u.queue) == 0 && u.used+amount <= limit { + u.used += amount + u.recordUsage(ctx, u.used) + u.recordAmount(ctx, amount) + u.recordBlockTime(ctx, time.Since(start).Seconds()) + u.mu.Unlock() + return nil + } + + // Not enough resources + if !block { + u.recordDenied(ctx, amount) + err := u.newErrorLimitReached(limit, amount) + u.mu.Unlock() + return err + } + + // Slow path: need to queue up and wait (FIFO ordering) + w := &waiter[N]{ + amount: amount, + ready: make(chan struct{}), + } + u.queue = append(u.queue, w) + if u.onEnqueue != nil { + u.onEnqueue() + } + u.mu.Unlock() + + // Wait for our turn or context cancellation + select { + case <-w.ready: + // Resources have been granted to us + u.mu.Lock() + u.recordUsage(ctx, u.used) + u.recordAmount(ctx, amount) + u.recordBlockTime(ctx, time.Since(start).Seconds()) + u.mu.Unlock() + return nil + case <-ctx.Done(): + // Context cancelled - remove ourselves from queue + u.mu.Lock() + defer u.mu.Unlock() + + // Check if we were already granted resources while acquiring the lock + select { + case <-w.ready: + // We got resources just as we were cancelling; return them + u.used -= amount + u.tryWakeWaiters() u.recordDenied(ctx, amount) - return u.newErrorLimitReached(limit, amount) + return fmt.Errorf("context error (%w) after waiting %s for limit: %w", ctx.Err(), time.Since(start), u.newErrorLimitReached(limit, amount)) + default: } - // Ensure cond.Wait() yields to context expiration - stop := context.AfterFunc(ctx, func() { - u.mu.Lock() - defer u.mu.Unlock() - u.cond.Broadcast() - }) - defer stop() - start := time.Now() - for u.used+amount > limit { - u.cond.Wait() // wait until some resources are freed, or context expiration - if err := ctx.Err(); err != nil { - u.recordDenied(ctx, amount) - return fmt.Errorf("context error (%w) after waiting %s for limit: %w", err, time.Since(start), u.newErrorLimitReached(limit, amount)) + + // Remove from queue. Only needed when the context was cancelled before the element got to the head of the queue. + // Otherwise it is already removed by tryWakeWaiters(). + for i, waiter := range u.queue { + if waiter == w { + u.queue = append(u.queue[:i], u.queue[i+1:]...) + break } } + u.recordDenied(ctx, amount) + return fmt.Errorf("context error (%w) after waiting %s for limit: %w", ctx.Err(), time.Since(start), u.newErrorLimitReached(limit, amount)) } - u.used += amount - u.recordUsage(ctx, u.used) - u.recordAmount(ctx, amount) - u.recordBlockTime(ctx, time.Since(start).Seconds()) - return nil } func (u *resourcePoolUsage[N]) wait(ctx context.Context, amount N) (free func(), err error) { diff --git a/pkg/settings/limits/resource_test.go b/pkg/settings/limits/resource_test.go index b587c6a984..f5aed65444 100644 --- a/pkg/settings/limits/resource_test.go +++ b/pkg/settings/limits/resource_test.go @@ -286,3 +286,177 @@ func Test_newScopedResourcePoolLimiterFromFactory(t *testing.T) { }) require.Error(t, l.Use(ctx2, 1)) } + +// TestResourcePoolLimiter_WaitOrderPreserved confirms that ResourcePoolLimiter +// preserves FIFO ordering when multiple goroutines are waiting. +func TestResourcePoolLimiter_WaitOrderPreserved(t *testing.T) { + const numWaiters = 10 + + ctx := context.Background() + limiter := newUnscopedResourcePoolLimiter(1) + + // Channel to signal when each waiter has been enqueued + enqueued := make(chan struct{}, numWaiters) + limiter.resourcePoolUsage.setOnEnqueue(func() { + enqueued <- struct{}{} + }) + + // Acquire the single resource first + free, err := limiter.Wait(ctx, 1) + require.NoError(t, err) + + // Track the order in which waiters acquired resources + acquiredOrder := make(chan int, numWaiters) + + // Start multiple waiters sequentially, waiting for each to be enqueued before starting the next + for i := range numWaiters { + waiterID := i + go func() { + f, err := limiter.Wait(ctx, 1) + if err != nil { + return + } + acquiredOrder <- waiterID + f() + }() + // Wait for this waiter to be enqueued before starting the next + <-enqueued + } + + // All waiters are now in the queue in order 0, 1, 2, ... numWaiters-1 + // Release the resource - this should wake up waiters in FIFO order + free() + + // Collect the order in which waiters acquired the resource + acquired := make([]int, 0, numWaiters) + for range numWaiters { + select { + case id := <-acquiredOrder: + acquired = append(acquired, id) + case <-time.After(time.Second): + t.Fatalf("timed out waiting for waiter to acquire resource") + } + } + + // Verify FIFO order is preserved + for i, id := range acquired { + require.Equalf(t, i, id, "expected waiter %d at position %d, got %d (acquired order: %v)", i, i, id, acquired) + } +} + +// TestResourcePoolLimiter_ContextCancellation tests that context cancellation +// properly removes waiters from the queue without breaking FIFO order. +func TestResourcePoolLimiter_ContextCancellation(t *testing.T) { + ctx := context.Background() + limiter := newUnscopedResourcePoolLimiter(1) + + // Channel to signal when each waiter has been enqueued + enqueued := make(chan struct{}, 5) + limiter.resourcePoolUsage.setOnEnqueue(func() { + enqueued <- struct{}{} + }) + + // Acquire the single resource + free, err := limiter.Wait(ctx, 1) + require.NoError(t, err) + + // Start 5 waiters, but cancel the middle one + results := make(chan struct { + id int + err error + }, 5) + + var ctxs []context.Context + var cancels []context.CancelFunc + for range 5 { + c, cancel := context.WithCancel(ctx) + ctxs = append(ctxs, c) + cancels = append(cancels, cancel) + } + + // Start waiters sequentially, waiting for each to be enqueued + for i := range 5 { + waiterID := i + go func() { + f, err := limiter.Wait(ctxs[waiterID], 1) + if err != nil { + results <- struct { + id int + err error + }{waiterID, err} + return + } + results <- struct { + id int + err error + }{waiterID, nil} + f() + }() + // Wait for this waiter to be enqueued + <-enqueued + } + + // All 5 waiters are now in the queue in order 0, 1, 2, 3, 4 + // Cancel waiter 2 (middle of the queue) and wait for the cancellation result + cancels[2]() + cancelResult := <-results + require.Equal(t, 2, cancelResult.id) + require.Error(t, cancelResult.err) + + // Release the resource - remaining waiters should acquire in FIFO order + free() + + // Collect remaining results + var acquiredIDs []int + for range 4 { + select { + case r := <-results: + require.NoError(t, r.err) + acquiredIDs = append(acquiredIDs, r.id) + case <-time.After(time.Second): + t.Fatal("timed out waiting for results") + } + } + + // Remaining waiters should acquire in order: 0, 1, 3, 4 + assert.Equal(t, []int{0, 1, 3, 4}, acquiredIDs, "waiters should acquire in FIFO order, skipping cancelled") +} + +// TestResourcePoolLimiter_BasicUsage tests basic Use/Free functionality. +func TestResourcePoolLimiter_BasicUsage(t *testing.T) { + ctx := context.Background() + limiter := GlobalResourcePoolLimiter(5) + + // Use should work + require.NoError(t, limiter.Use(ctx, 3)) + + // Available should report 2 + avail, err := limiter.Available(ctx) + require.NoError(t, err) + assert.Equal(t, 2, avail) + + // Using more than available should fail + err = limiter.Use(ctx, 3) + require.Error(t, err) + var limitErr ErrorResourceLimited[int] + require.ErrorAs(t, err, &limitErr) + assert.Equal(t, 3, limitErr.Used) + assert.Equal(t, 5, limitErr.Limit) + assert.Equal(t, 3, limitErr.Amount) + + // Free should work + require.NoError(t, limiter.Free(ctx, 3)) + + // Now should have 5 available + avail, err = limiter.Available(ctx) + require.NoError(t, err) + assert.Equal(t, 5, avail) +} + +// setOnEnqueue sets a callback that is invoked each time a waiter is added to the queue. +// The callback is called with the mutex held. Used for testing to synchronize without sleeps. +func (u *resourcePoolUsage[N]) setOnEnqueue(fn func()) { + u.mu.Lock() + defer u.mu.Unlock() + u.onEnqueue = fn +} From 62bc87e18c82e1f5a6653a504d1598d2ec6c2b18 Mon Sep 17 00:00:00 2001 From: Patrick Date: Tue, 6 Jan 2026 08:56:07 -0500 Subject: [PATCH 29/42] adding grpc workflow metadata source client and related types (#1749) * adding grpc workflow metadata source client and related types * downgrading dependencies to keep diff clean * updating chainlink-protos * bumping protos * bumping protos + PR feedback --- go.mod | 8 +- go.sum | 16 +-- pkg/billing/workflow_client.go | 6 +- pkg/nodeauth/grpc/token_extractor.go | 43 ++++++++ pkg/services/orgresolver/linking.go | 3 +- pkg/workflows/grpcsource/client.go | 138 +++++++++++++++++++++++++ pkg/workflows/privateregistry/types.go | 35 +++++++ 7 files changed, 233 insertions(+), 16 deletions(-) create mode 100644 pkg/nodeauth/grpc/token_extractor.go create mode 100644 pkg/workflows/grpcsource/client.go create mode 100644 pkg/workflows/privateregistry/types.go diff --git a/go.mod b/go.mod index 1f89bb1a8d..6bd5db4560 100644 --- a/go.mod +++ b/go.mod @@ -46,7 +46,7 @@ require ( github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20251124151448-0448aefdaab9 github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 - github.com/smartcontractkit/chainlink-protos/workflows/go v0.0.0-20251020004840-4638e4262066 + github.com/smartcontractkit/chainlink-protos/workflows/go v0.0.0-20260106052706-6dd937cb5ec6 github.com/smartcontractkit/freeport v0.1.3-0.20250716200817-cb5dfd0e369e github.com/smartcontractkit/grpc-proxy v0.0.0-20240830132753-a7e17fec5ab7 github.com/smartcontractkit/libocr v0.0.0-20250912173940-f3ab0246e23d @@ -136,14 +136,14 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.65.0 // indirect github.com/prometheus/procfs v0.16.1 // indirect - github.com/rogpeppe/go-internal v1.13.1 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/ryanuber/go-glob v1.0.0 // indirect github.com/sanity-io/litter v1.5.5 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 // indirect go.opentelemetry.io/proto/otlp v1.6.0 // indirect go.uber.org/multierr v1.11.0 // indirect @@ -154,7 +154,7 @@ require ( golang.org/x/term v0.37.0 // indirect golang.org/x/text v0.31.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index 8a39debc83..58493b2eec 100644 --- a/go.sum +++ b/go.sum @@ -308,8 +308,8 @@ github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzM github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= @@ -338,8 +338,8 @@ github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-202510021 github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 h1:B7itmjy+CMJ26elVw/cAJqqhBQ3Xa/mBYWK0/rQ5MuI= github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0/go.mod h1:h6kqaGajbNRrezm56zhx03p0mVmmA2xxj7E/M4ytLUA= -github.com/smartcontractkit/chainlink-protos/workflows/go v0.0.0-20251020004840-4638e4262066 h1:Lrc0+uegqasIFgsGXHy4tzdENT+zH2AbkTV4F7e3otU= -github.com/smartcontractkit/chainlink-protos/workflows/go v0.0.0-20251020004840-4638e4262066/go.mod h1:HIpGvF6nKCdtZ30xhdkKWGM9+4Z4CVqJH8ZBL1FTEiY= +github.com/smartcontractkit/chainlink-protos/workflows/go v0.0.0-20260106052706-6dd937cb5ec6 h1:BXMylId1EoFxuAy++JRifxUF+P/I7v5BEBh0wECtrEM= +github.com/smartcontractkit/chainlink-protos/workflows/go v0.0.0-20260106052706-6dd937cb5ec6/go.mod h1:GTpDgyK0OObf7jpch6p8N281KxN92wbB8serZhU9yRc= github.com/smartcontractkit/freeport v0.1.3-0.20250716200817-cb5dfd0e369e h1:Hv9Mww35LrufCdM9wtS9yVi/rEWGI1UnjHbcKKU0nVY= github.com/smartcontractkit/freeport v0.1.3-0.20250716200817-cb5dfd0e369e/go.mod h1:T4zH9R8R8lVWKfU7tUvYz2o2jMv1OpGCdpY2j2QZXzU= github.com/smartcontractkit/grpc-proxy v0.0.0-20240830132753-a7e17fec5ab7 h1:12ijqMM9tvYVEm+nR826WsrNi6zCKpwBhuApq127wHs= @@ -380,8 +380,8 @@ github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= -go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= @@ -588,8 +588,8 @@ google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoA google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20210401141331-865547bb08e2/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= -google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 h1:FiusG7LWj+4byqhbvmB+Q93B/mOxJLN2DTozDuZm4EU= -google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:kXqgZtrWaf6qS3jZOCnCH7WYfrvFjkC51bM8fz3RsCA= +google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY= +google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE= google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE= google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= diff --git a/pkg/billing/workflow_client.go b/pkg/billing/workflow_client.go index ecae475ba5..c307cd10b9 100644 --- a/pkg/billing/workflow_client.go +++ b/pkg/billing/workflow_client.go @@ -11,10 +11,10 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/known/emptypb" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + nodeauthgrpc "github.com/smartcontractkit/chainlink-common/pkg/nodeauth/grpc" auth "github.com/smartcontractkit/chainlink-common/pkg/nodeauth/jwt" pb "github.com/smartcontractkit/chainlink-protos/billing/go" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" ) // WorkflowClient is a specialized interface for the Workflow node use-case. @@ -139,7 +139,7 @@ func (wc *workflowClient) addJWTAuth(ctx context.Context, req any) (context.Cont } // Add JWT to Authorization header - return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+jwtToken), nil + return metadata.AppendToOutgoingContext(ctx, nodeauthgrpc.AuthorizationHeader, nodeauthgrpc.BearerPrefix+jwtToken), nil } func (wc *workflowClient) GetOrganizationCreditsByWorkflow(ctx context.Context, req *pb.GetOrganizationCreditsByWorkflowRequest) (*pb.GetOrganizationCreditsByWorkflowResponse, error) { diff --git a/pkg/nodeauth/grpc/token_extractor.go b/pkg/nodeauth/grpc/token_extractor.go new file mode 100644 index 0000000000..4d8c942fcd --- /dev/null +++ b/pkg/nodeauth/grpc/token_extractor.go @@ -0,0 +1,43 @@ +package grpc + +import ( + "context" + "errors" + "strings" + + "google.golang.org/grpc/metadata" +) + +const ( + // AuthorizationHeader is the lowercase header key for authorization + AuthorizationHeader = "authorization" + // BearerPrefix is the prefix for Bearer tokens + BearerPrefix = "Bearer " +) + +var ( + ErrMissingMetadata = errors.New("missing metadata") + ErrMissingAuthHeader = errors.New("missing authorization header") + ErrInvalidAuthFormat = errors.New("invalid authorization header format") +) + +// ExtractBearerToken extracts a Bearer token from gRPC incoming context metadata. +// Used by servers requiring the JWT authentication this package provides. +func ExtractBearerToken(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", ErrMissingMetadata + } + + authHeaders := md.Get(AuthorizationHeader) + if len(authHeaders) == 0 { + return "", ErrMissingAuthHeader + } + + authHeader := authHeaders[0] + if !strings.HasPrefix(authHeader, BearerPrefix) { + return "", ErrInvalidAuthFormat + } + + return strings.TrimPrefix(authHeader, BearerPrefix), nil +} diff --git a/pkg/services/orgresolver/linking.go b/pkg/services/orgresolver/linking.go index 40be90ce41..24e069b060 100644 --- a/pkg/services/orgresolver/linking.go +++ b/pkg/services/orgresolver/linking.go @@ -11,6 +11,7 @@ import ( "google.golang.org/grpc/metadata" log "github.com/smartcontractkit/chainlink-common/pkg/logger" + nodeauthgrpc "github.com/smartcontractkit/chainlink-common/pkg/nodeauth/grpc" "github.com/smartcontractkit/chainlink-common/pkg/services" linkingclient "github.com/smartcontractkit/chainlink-protos/linking-service/go/v1" ) @@ -100,7 +101,7 @@ func (o *orgResolver) addJWTAuth(ctx context.Context, req any) (context.Context, } // Add JWT to Authorization header - return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+jwtToken), nil + return metadata.AppendToOutgoingContext(ctx, nodeauthgrpc.AuthorizationHeader, nodeauthgrpc.BearerPrefix+jwtToken), nil } func (o *orgResolver) Get(ctx context.Context, owner string) (string, error) { diff --git a/pkg/workflows/grpcsource/client.go b/pkg/workflows/grpcsource/client.go new file mode 100644 index 0000000000..dc86f4e7cd --- /dev/null +++ b/pkg/workflows/grpcsource/client.go @@ -0,0 +1,138 @@ +package grpcsource + +import ( + "context" + "crypto/tls" + "fmt" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + + nodeauthgrpc "github.com/smartcontractkit/chainlink-common/pkg/nodeauth/grpc" + auth "github.com/smartcontractkit/chainlink-common/pkg/nodeauth/jwt" + pb "github.com/smartcontractkit/chainlink-protos/workflows/go/sources" +) + +// Client is a GRPC client for the WorkflowMetadataSourceService. +type Client struct { + conn *grpc.ClientConn + client pb.WorkflowMetadataSourceServiceClient + name string + jwtGenerator auth.JWTGenerator +} + +// clientConfig holds configuration for the client +type clientConfig struct { + tlsEnabled bool + jwtGenerator auth.JWTGenerator +} + +// ClientOption configures the Client +type ClientOption func(*clientConfig) + +func WithTLS(enabled bool) ClientOption { + return func(c *clientConfig) { + c.tlsEnabled = enabled + } +} + +func WithJWTGenerator(generator auth.JWTGenerator) ClientOption { + return func(c *clientConfig) { + c.jwtGenerator = generator + } +} + +// NewClient creates a new GRPC client for the WorkflowMetadataSourceService. +// addr is the GRPC endpoint address (e.g., "localhost:50051"). +// name is a human-readable identifier for logging. +// opts are optional configuration options. +func NewClient(addr string, name string, opts ...ClientOption) (*Client, error) { + cfg := &clientConfig{} + for _, opt := range opts { + opt(cfg) + } + + var dialOpts []grpc.DialOption + if cfg.tlsEnabled { + dialOpts = append(dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{}))) + } else { + dialOpts = append(dialOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + conn, err := grpc.NewClient(addr, dialOpts...) + if err != nil { + return nil, err + } + + return &Client{ + conn: conn, + client: pb.NewWorkflowMetadataSourceServiceClient(conn), + name: name, + jwtGenerator: cfg.jwtGenerator, + }, nil +} + +// NewClientWithOptions creates a new GRPC client with custom dial options. +// This is useful for testing or when custom options are needed. +func NewClientWithOptions(addr string, name string, opts ...grpc.DialOption) (*Client, error) { + conn, err := grpc.NewClient(addr, opts...) + if err != nil { + return nil, err + } + + return &Client{ + conn: conn, + client: pb.NewWorkflowMetadataSourceServiceClient(conn), + name: name, + }, nil +} + +func (c *Client) addJWTAuth(ctx context.Context, req any) (context.Context, error) { + if c.jwtGenerator == nil { + return ctx, nil // Skip if no generator configured + } + + jwtToken, err := c.jwtGenerator.CreateJWTForRequest(req) + if err != nil { + return nil, fmt.Errorf("failed to create JWT: %w", err) + } + + return metadata.AppendToOutgoingContext(ctx, nodeauthgrpc.AuthorizationHeader, nodeauthgrpc.BearerPrefix+jwtToken), nil +} + +// ListWorkflowMetadata fetches workflow metadata from the GRPC source. +// families is the list of DON families to filter workflows by. +// start is the pagination offset (0-indexed). +// limit is the maximum number of workflows to return per page. +// Returns workflows, hasMore flag indicating if more pages exist, and error. +func (c *Client) ListWorkflowMetadata(ctx context.Context, families []string, start, limit int64) ([]*pb.WorkflowMetadata, bool, error) { + req := &pb.ListWorkflowMetadataRequest{ + DonFamilies: families, + Start: start, + Limit: limit, + } + + // Inject JWT auth + ctx, err := c.addJWTAuth(ctx, req) + if err != nil { + return nil, false, err + } + + resp, err := c.client.ListWorkflowMetadata(ctx, req) + if err != nil { + return nil, false, err + } + return resp.Workflows, resp.HasMore, nil +} + +// Close closes the underlying GRPC connection. +func (c *Client) Close() error { + return c.conn.Close() +} + +// Name returns the human-readable name of this client. +func (c *Client) Name() string { + return c.name +} diff --git a/pkg/workflows/privateregistry/types.go b/pkg/workflows/privateregistry/types.go new file mode 100644 index 0000000000..fb3809b5cc --- /dev/null +++ b/pkg/workflows/privateregistry/types.go @@ -0,0 +1,35 @@ +package privateregistry + +import "context" + +// WorkflowDeploymentAction defines operations for managing workflows in a workflow source. +// This interface is implemented by both the mock server (for testing) and the actual +// private workflow registry server (when built). +type WorkflowDeploymentAction interface { + // AddWorkflow registers a new workflow with the source + AddWorkflow(ctx context.Context, workflow *WorkflowRegistration) error + + // UpdateWorkflow updates the workflow's status configuration + UpdateWorkflow(ctx context.Context, workflowID [32]byte, config *WorkflowStatusConfig) error + + // DeleteWorkflow removes the workflow from the source + DeleteWorkflow(ctx context.Context, workflowID [32]byte) error +} + +// WorkflowStatusConfig contains the desired state for a workflow's status +type WorkflowStatusConfig struct { + // Paused indicates whether the workflow should be paused (true) or active (false) + Paused bool +} + +// WorkflowRegistration contains the data needed to register a workflow +type WorkflowRegistration struct { + WorkflowID [32]byte + Owner []byte + WorkflowName string + BinaryURL string + ConfigURL string + DonFamily string + Tag string + Attributes []byte +} From f0414b873a1da61600464cbab21421ec907841d7 Mon Sep 17 00:00:00 2001 From: Matthew Pendrey Date: Tue, 6 Jan 2026 15:14:37 +0000 Subject: [PATCH 30/42] trigger wrapper generator changes to support capability errors (#1755) * trigger wrapper generator changes to support capability errors * update other trigger capability API errors * update utils to use caperrors.Error * update utils unit test --- pkg/capabilities/utils.go | 3 ++- pkg/capabilities/utils_test.go | 17 +++++++++++------ .../evm/server/client_server_gen.go | 4 ++-- .../solana/server/client_server_gen.go | 4 ++-- .../v2/protoc/pkg/templates/server.go.tmpl | 4 ++-- .../server/action_and_trigger_server_gen.go | 4 ++-- .../server/basic_trigger_server_gen.go | 5 +++-- .../triggers/cron/server/trigger_server_gen.go | 9 +++++---- .../triggers/http/server/trigger_server_gen.go | 5 +++-- 9 files changed, 32 insertions(+), 23 deletions(-) diff --git a/pkg/capabilities/utils.go b/pkg/capabilities/utils.go index 039b5cf593..143dde7d0c 100644 --- a/pkg/capabilities/utils.go +++ b/pkg/capabilities/utils.go @@ -8,6 +8,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" "github.com/smartcontractkit/chainlink-protos/cre/go/values" ) @@ -124,7 +125,7 @@ func RegisterTrigger[I, O proto.Message]( triggerType string, request TriggerRegistrationRequest, message I, - fn func(context.Context, string, RequestMetadata, I) (<-chan TriggerAndId[O], error), + fn func(context.Context, string, RequestMetadata, I) (<-chan TriggerAndId[O], caperrors.Error), ) (<-chan TriggerResponse, error) { migrated, err := FromValueOrAny(request.Config, request.Payload, message) if err != nil { diff --git a/pkg/capabilities/utils_test.go b/pkg/capabilities/utils_test.go index 02ae69fdf0..80ac39ea2c 100644 --- a/pkg/capabilities/utils_test.go +++ b/pkg/capabilities/utils_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb" "github.com/smartcontractkit/chainlink-protos/cre/go/values" ) @@ -260,7 +261,7 @@ func TestRegisterTrigger(t *testing.T) { "type", req, &pb.TriggerEvent{}, - func(_ context.Context, triggerID string, m capabilities.RequestMetadata, r *pb.TriggerEvent) (<-chan capabilities.TriggerAndId[*pb.TriggerEvent], error) { + func(_ context.Context, triggerID string, m capabilities.RequestMetadata, r *pb.TriggerEvent) (<-chan capabilities.TriggerAndId[*pb.TriggerEvent], caperrors.Error) { assert.Equal(t, "workflow-id", m.WorkflowID) assert.Equal(t, "reg", r.Id) return eventCh, nil @@ -322,7 +323,7 @@ func TestRegisterTrigger(t *testing.T) { "type", req, &pb.TriggerEvent{}, - func(_ context.Context, triggerID string, m capabilities.RequestMetadata, r *pb.TriggerEvent) (<-chan capabilities.TriggerAndId[*pb.TriggerEvent], error) { + func(_ context.Context, triggerID string, m capabilities.RequestMetadata, r *pb.TriggerEvent) (<-chan capabilities.TriggerAndId[*pb.TriggerEvent], caperrors.Error) { assert.Equal(t, "workflow-id", m.WorkflowID) assert.Equal(t, "reg", r.Id) return eventCh, nil @@ -367,8 +368,8 @@ func TestRegisterTrigger(t *testing.T) { "type", req, &pb.TriggerEvent{}, - func(ctx context.Context, triggerID string, m capabilities.RequestMetadata, r *pb.TriggerEvent) (<-chan capabilities.TriggerAndId[*pb.TriggerEvent], error) { - return nil, ctx.Err() + func(ctx context.Context, triggerID string, m capabilities.RequestMetadata, r *pb.TriggerEvent) (<-chan capabilities.TriggerAndId[*pb.TriggerEvent], caperrors.Error) { + return nil, caperrors.NewPublicSystemError(ctx.Err(), caperrors.Internal) }, ) require.Error(t, err) @@ -412,10 +413,14 @@ func TestRegisterTrigger(t *testing.T) { "type", req, &pb.TriggerEvent{}, - func(ctx context.Context, triggerID string, m capabilities.RequestMetadata, r *pb.TriggerEvent) (<-chan capabilities.TriggerAndId[*pb.TriggerEvent], error) { + func(ctx context.Context, triggerID string, m capabilities.RequestMetadata, r *pb.TriggerEvent) (<-chan capabilities.TriggerAndId[*pb.TriggerEvent], caperrors.Error) { assert.Equal(t, "workflow-id", m.WorkflowID) assert.Equal(t, "reg", r.Id) - return eventCh, ctx.Err() + if ctx.Err() != nil { + return nil, caperrors.NewPublicSystemError(ctx.Err(), caperrors.Internal) + } else { + return eventCh, nil + } }, ) require.NoError(t, err) diff --git a/pkg/capabilities/v2/chain-capabilities/evm/server/client_server_gen.go b/pkg/capabilities/v2/chain-capabilities/evm/server/client_server_gen.go index c44e1c707b..ecd7b2944e 100644 --- a/pkg/capabilities/v2/chain-capabilities/evm/server/client_server_gen.go +++ b/pkg/capabilities/v2/chain-capabilities/evm/server/client_server_gen.go @@ -34,8 +34,8 @@ type ClientCapability interface { HeaderByNumber(ctx context.Context, metadata capabilities.RequestMetadata, input *evm.HeaderByNumberRequest) (*capabilities.ResponseAndMetadata[*evm.HeaderByNumberReply], caperrors.Error) - RegisterLogTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *evm.FilterLogTriggerRequest) (<-chan capabilities.TriggerAndId[*evm.Log], error) - UnregisterLogTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *evm.FilterLogTriggerRequest) error + RegisterLogTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *evm.FilterLogTriggerRequest) (<-chan capabilities.TriggerAndId[*evm.Log], caperrors.Error) + UnregisterLogTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *evm.FilterLogTriggerRequest) caperrors.Error WriteReport(ctx context.Context, metadata capabilities.RequestMetadata, input *evm.WriteReportRequest) (*capabilities.ResponseAndMetadata[*evm.WriteReportReply], caperrors.Error) diff --git a/pkg/capabilities/v2/chain-capabilities/solana/server/client_server_gen.go b/pkg/capabilities/v2/chain-capabilities/solana/server/client_server_gen.go index 086ef5058e..b2ba3a5d5c 100644 --- a/pkg/capabilities/v2/chain-capabilities/solana/server/client_server_gen.go +++ b/pkg/capabilities/v2/chain-capabilities/solana/server/client_server_gen.go @@ -36,8 +36,8 @@ type ClientCapability interface { GetTransaction(ctx context.Context, metadata capabilities.RequestMetadata, input *solana.GetTransactionRequest) (*capabilities.ResponseAndMetadata[*solana.GetTransactionReply], caperrors.Error) - RegisterLogTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *solana.FilterLogTriggerRequest) (<-chan capabilities.TriggerAndId[*solana.Log], error) - UnregisterLogTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *solana.FilterLogTriggerRequest) error + RegisterLogTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *solana.FilterLogTriggerRequest) (<-chan capabilities.TriggerAndId[*solana.Log], caperrors.Error) + UnregisterLogTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *solana.FilterLogTriggerRequest) caperrors.Error WriteReport(ctx context.Context, metadata capabilities.RequestMetadata, input *solana.WriteReportRequest) (*capabilities.ResponseAndMetadata[*solana.WriteReportReply], caperrors.Error) diff --git a/pkg/capabilities/v2/protoc/pkg/templates/server.go.tmpl b/pkg/capabilities/v2/protoc/pkg/templates/server.go.tmpl index 6b1c6ba9f1..55f1ee30c1 100644 --- a/pkg/capabilities/v2/protoc/pkg/templates/server.go.tmpl +++ b/pkg/capabilities/v2/protoc/pkg/templates/server.go.tmpl @@ -36,8 +36,8 @@ type {{.GoName}}Capability interface { {{- range .Methods}} {{- if isTrigger . }} {{ $hasTriggers = true }} - Register{{.GoName}}(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *{{ImportAlias .Input.GoIdent.GoImportPath}}.{{.Input.GoIdent.GoName}}) (<- chan capabilities.TriggerAndId[*{{ImportAlias .Output.GoIdent.GoImportPath}}.{{.Output.GoIdent.GoName}}], error) - Unregister{{.GoName}}(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *{{ImportAlias .Input.GoIdent.GoImportPath}}.{{.Input.GoIdent.GoName}}) error + Register{{.GoName}}(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *{{ImportAlias .Input.GoIdent.GoImportPath}}.{{.Input.GoIdent.GoName}}) (<- chan capabilities.TriggerAndId[*{{ImportAlias .Output.GoIdent.GoImportPath}}.{{.Output.GoIdent.GoName}}], caperrors.Error) + Unregister{{.GoName}}(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *{{ImportAlias .Input.GoIdent.GoImportPath}}.{{.Input.GoIdent.GoName}}) caperrors.Error {{- else }} {{ $hasActions = true }} {{.GoName}}(ctx context.Context, metadata capabilities.RequestMetadata, input *{{ImportAlias .Input.GoIdent.GoImportPath}}.{{.Input.GoIdent.GoName}} {{if ne "emptypb.Empty" (ConfigType $service)}}, {{(ConfigType $service)}}{{ end }}) (*capabilities.ResponseAndMetadata[*{{ImportAlias .Output.GoIdent.GoImportPath}}.{{.Output.GoIdent.GoName}}], caperrors.Error) diff --git a/pkg/capabilities/v2/protoc/pkg/test_capabilities/actionandtrigger/server/action_and_trigger_server_gen.go b/pkg/capabilities/v2/protoc/pkg/test_capabilities/actionandtrigger/server/action_and_trigger_server_gen.go index 7ea875038c..7931b021da 100644 --- a/pkg/capabilities/v2/protoc/pkg/test_capabilities/actionandtrigger/server/action_and_trigger_server_gen.go +++ b/pkg/capabilities/v2/protoc/pkg/test_capabilities/actionandtrigger/server/action_and_trigger_server_gen.go @@ -21,8 +21,8 @@ var _ = emptypb.Empty{} type BasicCapability interface { Action(ctx context.Context, metadata capabilities.RequestMetadata, input *actionandtrigger.Input) (*capabilities.ResponseAndMetadata[*actionandtrigger.Output], caperrors.Error) - RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *actionandtrigger.Config) (<-chan capabilities.TriggerAndId[*actionandtrigger.TriggerEvent], error) - UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *actionandtrigger.Config) error + RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *actionandtrigger.Config) (<-chan capabilities.TriggerAndId[*actionandtrigger.TriggerEvent], caperrors.Error) + UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *actionandtrigger.Config) caperrors.Error Start(ctx context.Context) error Close() error diff --git a/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger/server/basic_trigger_server_gen.go b/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger/server/basic_trigger_server_gen.go index 74545198f5..6247e51c5c 100644 --- a/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger/server/basic_trigger_server_gen.go +++ b/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger/server/basic_trigger_server_gen.go @@ -11,6 +11,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" "github.com/smartcontractkit/chainlink-common/pkg/types/core" ) @@ -18,8 +19,8 @@ import ( var _ = emptypb.Empty{} type BasicCapability interface { - RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *basictrigger.Config) (<-chan capabilities.TriggerAndId[*basictrigger.Outputs], error) - UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *basictrigger.Config) error + RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *basictrigger.Config) (<-chan capabilities.TriggerAndId[*basictrigger.Outputs], caperrors.Error) + UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *basictrigger.Config) caperrors.Error Start(ctx context.Context) error Close() error diff --git a/pkg/capabilities/v2/triggers/cron/server/trigger_server_gen.go b/pkg/capabilities/v2/triggers/cron/server/trigger_server_gen.go index a74d49b8c2..83ec64de96 100644 --- a/pkg/capabilities/v2/triggers/cron/server/trigger_server_gen.go +++ b/pkg/capabilities/v2/triggers/cron/server/trigger_server_gen.go @@ -11,6 +11,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" "github.com/smartcontractkit/chainlink-common/pkg/types/core" ) @@ -18,11 +19,11 @@ import ( var _ = emptypb.Empty{} type CronCapability interface { - RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *cron.Config) (<-chan capabilities.TriggerAndId[*cron.Payload], error) - UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *cron.Config) error + RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *cron.Config) (<-chan capabilities.TriggerAndId[*cron.Payload], caperrors.Error) + UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *cron.Config) caperrors.Error - RegisterLegacyTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *cron.Config) (<-chan capabilities.TriggerAndId[*cron.LegacyPayload], error) - UnregisterLegacyTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *cron.Config) error + RegisterLegacyTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *cron.Config) (<-chan capabilities.TriggerAndId[*cron.LegacyPayload], caperrors.Error) + UnregisterLegacyTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *cron.Config) caperrors.Error Start(ctx context.Context) error Close() error diff --git a/pkg/capabilities/v2/triggers/http/server/trigger_server_gen.go b/pkg/capabilities/v2/triggers/http/server/trigger_server_gen.go index 0d6fd2dd40..9e58d62666 100644 --- a/pkg/capabilities/v2/triggers/http/server/trigger_server_gen.go +++ b/pkg/capabilities/v2/triggers/http/server/trigger_server_gen.go @@ -11,6 +11,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" "github.com/smartcontractkit/chainlink-common/pkg/types/core" ) @@ -18,8 +19,8 @@ import ( var _ = emptypb.Empty{} type HTTPCapability interface { - RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *http.Config) (<-chan capabilities.TriggerAndId[*http.Payload], error) - UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *http.Config) error + RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *http.Config) (<-chan capabilities.TriggerAndId[*http.Payload], caperrors.Error) + UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *http.Config) caperrors.Error Start(ctx context.Context) error Close() error From b31c6a3e39216c2f0122a4994a2dce6cf9391da8 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Tue, 6 Jan 2026 11:01:38 -0500 Subject: [PATCH 31/42] pkg/settings: add per-chain overridable EVM.GasLimit (#1757) --- pkg/settings/cresettings/defaults.json | 9 ++++++++- pkg/settings/cresettings/defaults.toml | 7 +++++++ pkg/settings/cresettings/settings.go | 9 ++++++++- pkg/settings/limits/bound.go | 19 ++++++++++--------- pkg/settings/limits/factory.go | 10 ++++------ pkg/settings/limits/gate.go | 16 ++++++++-------- pkg/settings/map.go | 18 ++++++++++++++++++ pkg/settings/settings.go | 23 +++++++++++++++++++++++ 8 files changed, 86 insertions(+), 25 deletions(-) diff --git a/pkg/settings/cresettings/defaults.json b/pkg/settings/cresettings/defaults.json index 62bbfc34c8..654069ee99 100644 --- a/pkg/settings/cresettings/defaults.json +++ b/pkg/settings/cresettings/defaults.json @@ -57,7 +57,14 @@ "TargetsLimit": "10", "ReportSizeLimit": "5kb", "EVM": { - "TransactionGasLimit": "5000000" + "TransactionGasLimit": "5000000", + "GasLimit": { + "Default": "5000000", + "Values": { + "12922642891491394802": "50000000", + "3379446385462418246": "10000000" + } + } } }, "ChainRead": { diff --git a/pkg/settings/cresettings/defaults.toml b/pkg/settings/cresettings/defaults.toml index 013475840d..687bd74d8b 100644 --- a/pkg/settings/cresettings/defaults.toml +++ b/pkg/settings/cresettings/defaults.toml @@ -61,6 +61,13 @@ ReportSizeLimit = '5kb' [PerWorkflow.ChainWrite.EVM] TransactionGasLimit = '5000000' +[PerWorkflow.ChainWrite.EVM.GasLimit] +Default = '5000000' + +[PerWorkflow.ChainWrite.EVM.GasLimit.Values] +12922642891491394802 = '50000000' +3379446385462418246 = '10000000' + [PerWorkflow.ChainRead] CallLimit = '10' LogQueryBlockLimit = '100' diff --git a/pkg/settings/cresettings/settings.go b/pkg/settings/cresettings/settings.go index 2fd0964197..4cb2fe69d8 100644 --- a/pkg/settings/cresettings/settings.go +++ b/pkg/settings/cresettings/settings.go @@ -123,6 +123,12 @@ var Default = Schema{ ReportSizeLimit: Size(5 * config.KByte), EVM: evmChainWrite{ TransactionGasLimit: Uint64(5_000_000), + GasLimit: PerChainSelector(Uint64(5_000_000), map[string]uint64{ + // geth-testnet + "3379446385462418246": 10_000_000, + // geth-devnet2 + "12922642891491394802": 50_000_000, + }), }, }, ChainRead: chainRead{ @@ -225,7 +231,8 @@ type chainWrite struct { EVM evmChainWrite } type evmChainWrite struct { - TransactionGasLimit Setting[uint64] `unit:"{gas}"` + TransactionGasLimit Setting[uint64] `unit:"{gas}"` // Deprecated + GasLimit SettingMap[uint64] `unit:"{gas}"` } type chainRead struct { CallLimit Setting[int] `unit:"{call}"` diff --git a/pkg/settings/limits/bound.go b/pkg/settings/limits/bound.go index 815bd722c3..e3f34e099e 100644 --- a/pkg/settings/limits/bound.go +++ b/pkg/settings/limits/bound.go @@ -49,13 +49,13 @@ func (s *simpleBoundLimiter[N]) Limit(ctx context.Context) (N, error) { return s.bound, nil } -func newBoundLimiter[N Number](f Factory, bound settings.Setting[N]) (BoundLimiter[N], error) { +func newBoundLimiter[N Number](f Factory, bound settings.SettingSpec[N]) (BoundLimiter[N], error) { b := &boundLimiter[N]{ updater: newUpdater[N](nil, func(ctx context.Context) (N, error) { return bound.GetOrDefault(ctx, f.Settings) }, nil), - key: bound.Key, - scope: bound.Scope, + key: bound.GetKey(), + scope: bound.GetScope(), } b.updater.recordLimit = func(ctx context.Context, n N) { b.recordBound(ctx, n) } @@ -63,23 +63,24 @@ func newBoundLimiter[N Number](f Factory, bound settings.Setting[N]) (BoundLimit if b.key == "" { return nil, errors.New("metrics require Key to be set") } - newGauge, newHist := metricConstructors[N](f.Meter, bound.Unit) + newGauge, newHist := metricConstructors[N](f.Meter, bound.GetUnit()) - limitGauge, err := newGauge("bound." + bound.Key + ".limit") + key := bound.GetKey() + limitGauge, err := newGauge("bound." + key + ".limit") if err != nil { return nil, err } b.recordBound = func(ctx context.Context, value N, options ...metric.RecordOption) { limitGauge.Record(ctx, value, options...) } - usageHist, err := newHist("bound." + bound.Key + ".usage") + usageHist, err := newHist("bound." + key + ".usage") if err != nil { return nil, err } b.recordUsage = func(ctx context.Context, value N, options ...metric.RecordOption) { usageHist.Record(ctx, value, options...) } - deniedHist, err := newHist("bound." + bound.Key + ".denied") + deniedHist, err := newHist("bound." + key + ".denied") if err != nil { return nil, err } @@ -93,7 +94,7 @@ func newBoundLimiter[N Number](f Factory, bound settings.Setting[N]) (BoundLimit } if f.Logger != nil { - b.lggr = logger.Sugared(f.Logger).Named("BoundLimiter").With("key", bound.Key) + b.lggr = logger.Sugared(f.Logger).Named("BoundLimiter").With("key", bound.GetKey()) } if f.Settings != nil { @@ -104,7 +105,7 @@ func newBoundLimiter[N Number](f Factory, bound settings.Setting[N]) (BoundLimit } } - if bound.Scope == settings.ScopeGlobal { + if bound.GetScope() == settings.ScopeGlobal { b.updateCRE(contexts.CRE{}) go b.updateLoop(contexts.CRE{}) } diff --git a/pkg/settings/limits/factory.go b/pkg/settings/limits/factory.go index 97c876d03b..849ec11dad 100644 --- a/pkg/settings/limits/factory.go +++ b/pkg/settings/limits/factory.go @@ -82,8 +82,8 @@ func MakeResourcePoolLimiter[N Number](f Factory, limit settings.Setting[N]) (Re // - bound.*.limit - gauge // - bound.*.usage - histogram // - bound.*.denied - histogram -func MakeBoundLimiter[N Number](f Factory, bound settings.Setting[N]) (BoundLimiter[N], error) { - return newBoundLimiter(f, bound) +func MakeBoundLimiter[N Number](f Factory, bound settings.IsSetting[N]) (BoundLimiter[N], error) { + return newBoundLimiter(f, bound.GetSpec()) } // MakeQueueLimiter returns a QueueLimiter for the given limit and configured by the Factory. @@ -103,8 +103,6 @@ func MakeQueueLimiter[T any](f Factory, limit settings.Setting[int]) (QueueLimit // - gate.*.limit - int gauge // - gate.*.usage - int counter // - gate.*.denied - int counter -// -// OPT: accept an interface for limit -func MakeGateLimiter(f Factory, limit settings.SettingMap[bool]) (GateLimiter, error) { - return newGateLimiter(f, limit) +func MakeGateLimiter(f Factory, limit settings.IsSetting[bool]) (GateLimiter, error) { + return newGateLimiter(f, limit.GetSpec()) } diff --git a/pkg/settings/limits/gate.go b/pkg/settings/limits/gate.go index 20f9d0ceaf..ceed5f0a71 100644 --- a/pkg/settings/limits/gate.go +++ b/pkg/settings/limits/gate.go @@ -44,14 +44,13 @@ func (s *simpleGateLimiter) AllowErr(ctx context.Context) error { return nil } -// OPT: interface satisfied by Setting[bool] & SettingMap[bool] -func newGateLimiter(f Factory, limit settings.SettingMap[bool]) (GateLimiter, error) { +func newGateLimiter(f Factory, limit settings.SettingSpec[bool]) (GateLimiter, error) { g := &gateLimiter{ updater: newUpdater[bool](nil, func(ctx context.Context) (bool, error) { return limit.GetOrDefault(ctx, f.Settings) }, nil), - key: limit.Default.Key, - scope: limit.Default.Scope, + key: limit.GetKey(), + scope: limit.GetScope(), } g.updater.recordLimit = func(ctx context.Context, b bool) { g.recordStatus(ctx, b) } @@ -59,7 +58,8 @@ func newGateLimiter(f Factory, limit settings.SettingMap[bool]) (GateLimiter, er if g.key == "" { return nil, errors.New("metrics require Key to be set") } - limitGauge, err := f.Meter.Int64Gauge("gate."+g.key+".limit", metric.WithUnit(limit.Default.Unit)) + unit := limit.GetUnit() + limitGauge, err := f.Meter.Int64Gauge("gate."+g.key+".limit", metric.WithUnit(unit)) if err != nil { return nil, err } @@ -70,14 +70,14 @@ func newGateLimiter(f Factory, limit settings.SettingMap[bool]) (GateLimiter, er } limitGauge.Record(ctx, val, options...) } - usageCounter, err := f.Meter.Int64Counter("gate."+g.key+".usage", metric.WithUnit(limit.Default.Unit)) + usageCounter, err := f.Meter.Int64Counter("gate."+g.key+".usage", metric.WithUnit(unit)) if err != nil { return nil, err } g.recordUsage = func(ctx context.Context, options ...metric.AddOption) { usageCounter.Add(ctx, 1, options...) } - deniedCounter, err := f.Meter.Int64Counter("gate."+g.key+".denied", metric.WithUnit(limit.Default.Unit)) + deniedCounter, err := f.Meter.Int64Counter("gate."+g.key+".denied", metric.WithUnit(unit)) if err != nil { return nil, err } @@ -91,7 +91,7 @@ func newGateLimiter(f Factory, limit settings.SettingMap[bool]) (GateLimiter, er } if f.Logger != nil { - g.lggr = logger.Sugared(f.Logger).Named("GateLimiter").With("key", limit.Default.Key) + g.lggr = logger.Sugared(f.Logger).Named("GateLimiter").With("key", limit.GetKey()) } // OPT: support settings.Registry subscriptions diff --git a/pkg/settings/map.go b/pkg/settings/map.go index c758d9c169..6340820363 100644 --- a/pkg/settings/map.go +++ b/pkg/settings/map.go @@ -23,12 +23,22 @@ func PerChainSelector[T any](defaultValue Setting[T], vals map[string]T) Setting } } +var _ IsSetting[int] = SettingMap[int]{} + type SettingMap[T any] struct { Default Setting[T] Values map[string]string // unparsed KeyFromCtx func(context.Context) (uint64, error) `json:"-" toml:"-"` } +func (s SettingMap[T]) GetSpec() SettingSpec[T] { return &s } + +func (s *SettingMap[T]) GetKey() string { return s.Default.Key } + +func (s *SettingMap[T]) GetScope() Scope { return s.Default.Scope } + +func (s *SettingMap[T]) GetUnit() string { return s.Default.Unit } + func (s *SettingMap[T]) initSetting(key string, scope Scope, unit *string) error { if s.KeyFromCtx == nil { return errors.New("missing KeyFromCtx func") @@ -86,3 +96,11 @@ func (s *SettingMap[T]) GetOrDefault(ctx context.Context, g Getter) (value T, er } return } + +func (s *SettingMap[T]) Subscribe(ctx context.Context, registry Registry) (<-chan Update[T], func()) { + //TODO subscribe to Values & Default + + // no-op + ch := make(chan Update[T]) + return ch, func() { close(ch) } +} diff --git a/pkg/settings/settings.go b/pkg/settings/settings.go index e8b955bbbb..688d18d2fd 100644 --- a/pkg/settings/settings.go +++ b/pkg/settings/settings.go @@ -30,6 +30,21 @@ type Registry interface { SubscribeScoped(ctx context.Context, scope Scope, key string) (updates <-chan Update[string], stop func()) } +//TODO use this everywhere +type IsSetting[T any] interface { + GetSpec() SettingSpec[T] +} + +type SettingSpec[T any] interface { + GetKey() string + GetScope() Scope + GetUnit() string + GetOrDefault(context.Context, Getter) (T, error) + Subscribe(context.Context, Registry) (<-chan Update[T], func()) +} + +var _ IsSetting[int] = Setting[int]{} + // Setting holds a key, default value, and parsing function for a particular setting. // Use Setting.GetOrDefault with a Getter to look up settings. // Use Setting.Subscribe with a Registry to have updates pushed over a channel. @@ -41,6 +56,14 @@ type Setting[T any] struct { Unit string } +func (s Setting[T]) GetSpec() SettingSpec[T] { return &s } + +func (s *Setting[T]) GetKey() string { return s.Key } + +func (s *Setting[T]) GetScope() Scope { return s.Scope } + +func (s *Setting[T]) GetUnit() string { return s.Unit } + func (s Setting[T]) MarshalText() ([]byte, error) { return fmt.Appendf(nil, "%v", s.DefaultValue), nil } From fb3aa9e16323bb02ceb33451f092cba3abfc1aa2 Mon Sep 17 00:00:00 2001 From: mchain0 Date: Wed, 7 Jan 2026 05:57:04 +0100 Subject: [PATCH 32/42] cre-1626: minor rest refactor (#1751) --- pkg/workflows/ring/factory_test.go | 57 +++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/pkg/workflows/ring/factory_test.go b/pkg/workflows/ring/factory_test.go index 5a81c2a0fc..557f872b5b 100644 --- a/pkg/workflows/ring/factory_test.go +++ b/pkg/workflows/ring/factory_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/stretchr/testify/require" ) @@ -14,24 +15,46 @@ func TestFactory_NewFactory(t *testing.T) { store := NewStore() arbiter := &mockArbiter{} - t.Run("with_nil_config", func(t *testing.T) { - f, err := NewFactory(store, arbiter, lggr, nil) - require.NoError(t, err) - require.NotNil(t, f) - }) + tests := []struct { + name string + arbiter pb.ArbiterScalerClient + config *ConsensusConfig + wantErr bool + errSubstr string + }{ + { + name: "with_nil_config", + arbiter: arbiter, + config: nil, + wantErr: false, + }, + { + name: "with_custom_config", + arbiter: arbiter, + config: &ConsensusConfig{BatchSize: 50}, + wantErr: false, + }, + { + name: "nil_arbiter_returns_error", + arbiter: nil, + config: nil, + wantErr: true, + errSubstr: "arbiterScaler is required", + }, + } - t.Run("with_custom_config", func(t *testing.T) { - cfg := &ConsensusConfig{BatchSize: 50} - f, err := NewFactory(store, arbiter, lggr, cfg) - require.NoError(t, err) - require.NotNil(t, f) - }) - - t.Run("nil_arbiter_returns_error", func(t *testing.T) { - _, err := NewFactory(store, nil, lggr, nil) - require.Error(t, err) - require.Contains(t, err.Error(), "arbiterScaler is required") - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f, err := NewFactory(store, tt.arbiter, lggr, tt.config) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errSubstr) + } else { + require.NoError(t, err) + require.NotNil(t, f) + } + }) + } } func TestFactory_NewReportingPlugin(t *testing.T) { From f0a00aff9f390b19b8837621a2afb11acec202be Mon Sep 17 00:00:00 2001 From: pavel-raykov <165708424+pavel-raykov@users.noreply.github.com> Date: Wed, 7 Jan 2026 10:56:48 +0100 Subject: [PATCH 33/42] [CRE-491] Move chainaccessor event ccip types to chainlink-common. (#1724) * Minor. * Minor. --- pkg/types/ccipocr3/chainaccessor_event.go | 61 +++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 pkg/types/ccipocr3/chainaccessor_event.go diff --git a/pkg/types/ccipocr3/chainaccessor_event.go b/pkg/types/ccipocr3/chainaccessor_event.go new file mode 100644 index 0000000000..178ed1947b --- /dev/null +++ b/pkg/types/ccipocr3/chainaccessor_event.go @@ -0,0 +1,61 @@ +package ccipocr3 + +import ( + "math/big" +) + +// --------------------------------------------------- +// The following types match the structs defined in the EVM contracts are used to decode these +// on-chain events. + +// SendRequestedEvent represents the contents of the event emitted by the CCIP OnRamp when a +// message is sent. +type SendRequestedEvent struct { + DestChainSelector ChainSelector + SequenceNumber SeqNum + Message Message +} + +// CommitReportAcceptedEvent represents the contents of the event emitted by the CCIP OffRamp when a +// commit report is accepted. +type CommitReportAcceptedEvent struct { + BlessedMerkleRoots []MerkleRoot + UnblessedMerkleRoots []MerkleRoot + PriceUpdates AccessorPriceUpdates +} + +// ExecutionStateChangedEvent represents the contents of the event emitted by the CCIP OffRamp +type ExecutionStateChangedEvent struct { + SourceChainSelector ChainSelector + SequenceNumber SeqNum + MessageID Bytes32 + MessageHash Bytes32 + State uint8 + ReturnData Bytes + GasUsed big.Int +} + +type MerkleRoot struct { + SourceChainSelector uint64 + OnRampAddress UnknownAddress + MinSeqNr uint64 + MaxSeqNr uint64 + MerkleRoot Bytes32 +} + +type TokenPriceUpdate struct { + SourceToken UnknownAddress + UsdPerToken *big.Int +} + +type GasPriceUpdate struct { + // DestChainSelector is the chain that the gas price is for (some plugin source chain). + // Not the chain that the gas price is stored on. + DestChainSelector uint64 + UsdPerUnitGas *big.Int +} + +type AccessorPriceUpdates struct { + TokenPriceUpdates []TokenPriceUpdate + GasPriceUpdates []GasPriceUpdate +} From e5e4627009dd3f595bc2c0278f32b92bd79ce204 Mon Sep 17 00:00:00 2001 From: pavel-raykov <165708424+pavel-raykov@users.noreply.github.com> Date: Wed, 7 Jan 2026 18:53:02 +0100 Subject: [PATCH 34/42] [ARCH-327] Address security comments (#1758) * Minor. * Minor. --- keystore/admin.go | 22 ++++++++++++++++++---- keystore/admin_test.go | 31 ++++++++++++++++++++++++++++++- keystore/internal/raw.go | 6 ++++++ keystore/keystore.go | 9 +++++++-- keystore/reader.go | 7 +------ 5 files changed, 62 insertions(+), 13 deletions(-) diff --git a/keystore/admin.go b/keystore/admin.go index f6754ba053..4686cc0214 100644 --- a/keystore/admin.go +++ b/keystore/admin.go @@ -9,6 +9,7 @@ import ( "encoding/json" "fmt" "maps" + "slices" "time" gethkeystore "github.com/ethereum/go-ethereum/accounts/keystore" @@ -21,6 +22,11 @@ import ( "github.com/smartcontractkit/chainlink-common/keystore/serialization" ) +const ( + MaxKeyNameLength = 1000 + MaxMetadataLength = 1024 * 1024 // 1mb +) + var ( ErrKeyAlreadyExists = fmt.Errorf("key already exists") ErrInvalidKeyName = fmt.Errorf("invalid key name") @@ -165,8 +171,8 @@ func ValidKeyName(name string) error { return fmt.Errorf("key name cannot be empty") } // Just a sanity bound. - if len(name) > 1_000 { - return fmt.Errorf("key name cannot be longer than 1000 characters") + if len(name) > MaxKeyNameLength { + return fmt.Errorf("key name cannot be longer than %d characters", MaxKeyNameLength) } return nil } @@ -202,9 +208,8 @@ func (ks *keystore) CreateKeys(ctx context.Context, req CreateKeysRequest) (Crea if err != nil { return CreateKeysResponse{}, fmt.Errorf("failed to generate ECDSA_S256 key: %w", err) } - // Must copy the private key into 32 byte slice because leading zeros are stripped. privateKeyBytes := make([]byte, 32) - copy(privateKeyBytes, privateKey.D.Bytes()) + privateKey.D.FillBytes(privateKeyBytes) publicKey, err := publicKeyFromPrivateKey(internal.NewRaw(privateKeyBytes), keyReq.KeyType) if err != nil { return CreateKeysResponse{}, fmt.Errorf("failed to get public key from private key: %w", err) @@ -291,6 +296,9 @@ func (ks *keystore) ImportKeys(ctx context.Context, req ImportKeysRequest) (Impo } pkRaw := internal.NewRaw(keypb.PrivateKey) keyType := KeyType(keypb.KeyType) + if !slices.Contains(AllKeyTypes, keyType) { + return ImportKeysResponse{}, fmt.Errorf("%w: %s, available key types: %s", ErrUnsupportedKeyType, keyType, AllKeyTypes.String()) + } publicKey, err := publicKeyFromPrivateKey(pkRaw, keyType) if err != nil { return ImportKeysResponse{}, fmt.Errorf("key num = %d, failed to get public key from private key: %w", i, err) @@ -301,6 +309,9 @@ func (ks *keystore) ImportKeys(ctx context.Context, req ImportKeysRequest) (Impo if metadata == nil { metadata = []byte{} } + if len(metadata) > MaxMetadataLength { + return ImportKeysResponse{}, fmt.Errorf("key num = %d, metadata of length %d exceeds maximum length of %d bytes", i, len(metadata), MaxMetadataLength) + } keyName := keyReq.NewKeyName if keyName == "" { @@ -366,6 +377,9 @@ func (ks *keystore) SetMetadata(ctx context.Context, req SetMetadataRequest) (Se ksCopy := maps.Clone(ks.keystore) for _, metReq := range req.Updates { + if len(metReq.Metadata) > MaxMetadataLength { + return SetMetadataResponse{}, fmt.Errorf("metadata for key %s exceeds maximum length of %d bytes", metReq.KeyName, MaxMetadataLength) + } key, ok := ksCopy[metReq.KeyName] if !ok { return SetMetadataResponse{}, fmt.Errorf("%w: %s", ErrKeyNotFound, metReq.KeyName) diff --git a/keystore/admin_test.go b/keystore/admin_test.go index 6a72d261db..de8f51abad 100644 --- a/keystore/admin_test.go +++ b/keystore/admin_test.go @@ -3,6 +3,7 @@ package keystore_test import ( "context" "fmt" + "math/big" "sort" "sync" "testing" @@ -10,6 +11,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + gethcrypto "github.com/ethereum/go-ethereum/crypto" + "github.com/smartcontractkit/chainlink-common/keystore" ) @@ -258,7 +261,16 @@ func TestKeystore_ExportImport(t *testing.T) { key1ks1, err := ks1.GetKeys(t.Context(), keystore.GetKeysRequest{KeyNames: []string{"key1"}}) require.NoError(t, err) key1ks2, err := ks2.GetKeys(t.Context(), keystore.GetKeysRequest{KeyNames: []string{"key1"}}) - require.Equal(t, key1ks1, key1ks2) + require.NoError(t, err) + // Test equality of the keys except of the CreatedAt field. + require.Len(t, key1ks1.Keys, 1) + require.Len(t, key1ks2.Keys, 1) + key1ks1Info := key1ks1.Keys[0].KeyInfo + key1ks2Info := key1ks2.Keys[0].KeyInfo + require.Equal(t, key1ks1Info.Name, key1ks2Info.Name) + require.Equal(t, key1ks1Info.PublicKey, key1ks2Info.PublicKey) + require.Equal(t, key1ks1Info.KeyType, key1ks2Info.KeyType) + require.Equal(t, key1ks1Info.Metadata, key1ks2Info.Metadata) testData := []byte("hello world") signature, err := ks2.Sign(t.Context(), keystore.SignRequest{ @@ -411,3 +423,20 @@ func TestKeystore_RenameKey(t *testing.T) { require.EqualError(t, err, "key not found: key1") }) } + +func TestECDSA_Serialization_WithPadding(t *testing.T) { + // This test ensures that ECDSA private keys that serialize to less than 32 bytes + // are correctly padded with leading zeros during serialization and deserialization. + // This is important for compatibility with Ethereum's crypto library which expects + // 32-byte private keys. + + // The example key has been found randomly such that it has 2 leading zero bytes when serialized. + key, ok := big.NewInt(0).SetString("57269542458293433845411819226400606954116463824740942170224417652371448", 10) + require.True(t, ok) + privateKeyBytes := make([]byte, 32) + key.FillBytes(privateKeyBytes) + require.Equal(t, []byte{0, 0, 8, 76, 62, 209, 247, 104, 97, 108, 141, 217, 255, 150, 114, 196, 223, 66, 254, 101, 209, 14, 233, 174, 149, 89, 207, 141, 2, 188, 111, 248}, privateKeyBytes) + deserializedKey, err := gethcrypto.ToECDSA(privateKeyBytes) + require.NoError(t, err) + require.Equal(t, key, deserializedKey.D) +} diff --git a/keystore/internal/raw.go b/keystore/internal/raw.go index ae97519851..bde7aad8e1 100644 --- a/keystore/internal/raw.go +++ b/keystore/internal/raw.go @@ -2,6 +2,8 @@ // only available for use in the keystore sub-tree. package internal +import "fmt" + // Raw is a wrapper type that holds private key bytes // and is designed to prevent accidental logging. // The only way to access the internal bytes (without reflection) is to use Bytes, @@ -22,6 +24,10 @@ func (raw Raw) GoString() string { return raw.String() } +func (raw Raw) Format(state fmt.State, _ rune) { + _, _ = fmt.Fprint(state, raw.String()) +} + // Bytes is a func for accessing the internal bytes field of Raw. // It is not declared as a method, because that would allow access from callers which cannot otherwise access this internal package. func Bytes(raw Raw) []byte { return raw.bytes } diff --git a/keystore/keystore.go b/keystore/keystore.go index 206ee4685a..15306cda22 100644 --- a/keystore/keystore.go +++ b/keystore/keystore.go @@ -8,14 +8,13 @@ import ( "errors" "fmt" "io" + "log/slog" "slices" "strings" "sync" "testing" "time" - "log/slog" - "golang.org/x/crypto/curve25519" gethkeystore "github.com/ethereum/go-ethereum/accounts/keystore" @@ -201,6 +200,9 @@ type EncryptionParams struct { func publicKeyFromPrivateKey(privateKeyBytes internal.Raw, keyType KeyType) ([]byte, error) { switch keyType { case Ed25519: + if len(internal.Bytes(privateKeyBytes)) != ed25519.PrivateKeySize { + return nil, fmt.Errorf("invalid Ed25519 private key size: %d", len(internal.Bytes(privateKeyBytes))) + } privateKey := ed25519.PrivateKey(internal.Bytes(privateKeyBytes)) publicKey := privateKey.Public().(ed25519.PublicKey) return publicKey, nil @@ -216,6 +218,9 @@ func publicKeyFromPrivateKey(privateKeyBytes internal.Raw, keyType KeyType) ([]b pubKey := gethcrypto.FromECDSAPub(&privateKey.PublicKey) return pubKey, nil case X25519: + if len(internal.Bytes(privateKeyBytes)) != curve25519.ScalarSize { + return nil, fmt.Errorf("invalid X25519 private key size: %d", len(internal.Bytes(privateKeyBytes))) + } pubKey, err := curve25519.X25519(internal.Bytes(privateKeyBytes)[:], curve25519.Basepoint) if err != nil { return nil, fmt.Errorf("failed to derive shared secret: %w", err) diff --git a/keystore/reader.go b/keystore/reader.go index f05c1509ca..a35023ac55 100644 --- a/keystore/reader.go +++ b/keystore/reader.go @@ -63,12 +63,7 @@ func (k *keystore) GetKeys(ctx context.Context, req GetKeysRequest) (GetKeysResp } seen[name] = true responses = append(responses, GetKeyResponse{ - KeyInfo: KeyInfo{ - Name: name, - KeyType: key.keyType, - PublicKey: key.publicKey, - Metadata: key.metadata, - }, + KeyInfo: newKeyInfo(name, key.keyType, key.createdAt, key.publicKey, key.metadata), }) } sort.Slice(responses, func(i, j int) bool { return responses[i].KeyInfo.Name < responses[j].KeyInfo.Name }) From 8ff43d65e56f52f7304c5c5b5fa7b54530700ce6 Mon Sep 17 00:00:00 2001 From: Gabriel Paradiso Date: Thu, 8 Jan 2026 11:07:21 +0100 Subject: [PATCH 35/42] [CRE-1601] shard-orchestrator implementation (#1747) * feat: implement shard-orchestrator logic * chore: use state constants instead of strings * test: add pluging <-> shardorchestrator integration tests * tests: add server tests * feat: add client to communicate with shard 0 * chore: remove unused timestamps * fix: remove reduntant check --------- Co-authored-by: mchain0 --- pkg/workflows/ring/factory.go | 26 +- pkg/workflows/ring/factory_test.go | 13 +- pkg/workflows/ring/pb/generate.go | 1 - pkg/workflows/ring/plugin_test.go | 163 +++++++++++- pkg/workflows/ring/state.go | 17 ++ pkg/workflows/ring/transmitter.go | 82 +++++- pkg/workflows/ring/transmitter_test.go | 21 +- pkg/workflows/shardorchestrator/client.go | 78 ++++++ .../shardorchestrator/client_test.go | 212 ++++++++++++++++ .../shardorchestrator/pb/generate.go | 3 + .../pb/shard_orchestrator.pb.go | 112 +++------ .../pb/shard_orchestrator.proto | 13 +- .../pb/shard_orchestrator_grpc.pb.go | 6 +- pkg/workflows/shardorchestrator/service.go | 108 ++++++++ .../shardorchestrator/service_test.go | 118 +++++++++ pkg/workflows/shardorchestrator/store.go | 235 ++++++++++++++++++ pkg/workflows/shardorchestrator/store_test.go | 198 +++++++++++++++ 17 files changed, 1282 insertions(+), 124 deletions(-) create mode 100644 pkg/workflows/shardorchestrator/client.go create mode 100644 pkg/workflows/shardorchestrator/client_test.go create mode 100644 pkg/workflows/shardorchestrator/pb/generate.go rename pkg/workflows/{ring => shardorchestrator}/pb/shard_orchestrator.pb.go (68%) rename pkg/workflows/{ring => shardorchestrator}/pb/shard_orchestrator.proto (78%) rename pkg/workflows/{ring => shardorchestrator}/pb/shard_orchestrator_grpc.pb.go (96%) create mode 100644 pkg/workflows/shardorchestrator/service.go create mode 100644 pkg/workflows/shardorchestrator/service_test.go create mode 100644 pkg/workflows/shardorchestrator/store.go create mode 100644 pkg/workflows/shardorchestrator/store_test.go diff --git a/pkg/workflows/ring/factory.go b/pkg/workflows/ring/factory.go index 8b85d02c86..5e8710f9dc 100644 --- a/pkg/workflows/ring/factory.go +++ b/pkg/workflows/ring/factory.go @@ -4,11 +4,13 @@ import ( "context" "errors" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/types/core" "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" - "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) const ( @@ -20,15 +22,16 @@ const ( var _ core.OCR3ReportingPluginFactory = &Factory{} type Factory struct { - store *Store - arbiterScaler pb.ArbiterScalerClient - config *ConsensusConfig - lggr logger.Logger + ringStore *Store + shardOrchestratorStore *shardorchestrator.Store + arbiterScaler pb.ArbiterScalerClient + config *ConsensusConfig + lggr logger.Logger services.StateMachine } -func NewFactory(s *Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logger, cfg *ConsensusConfig) (*Factory, error) { +func NewFactory(s *Store, shardOrchestratorStore *shardorchestrator.Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logger, cfg *ConsensusConfig) (*Factory, error) { if arbiterScaler == nil { return nil, errors.New("arbiterScaler is required") } @@ -38,15 +41,16 @@ func NewFactory(s *Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logg } } return &Factory{ - store: s, - arbiterScaler: arbiterScaler, - config: cfg, - lggr: logger.Named(lggr, "RingPluginFactory"), + ringStore: s, + shardOrchestratorStore: shardOrchestratorStore, + arbiterScaler: arbiterScaler, + config: cfg, + lggr: logger.Named(lggr, "RingPluginFactory"), }, nil } func (o *Factory) NewReportingPlugin(_ context.Context, config ocr3types.ReportingPluginConfig) (ocr3types.ReportingPlugin[[]byte], ocr3types.ReportingPluginInfo, error) { - plugin, err := NewPlugin(o.store, o.arbiterScaler, config, o.lggr, o.config) + plugin, err := NewPlugin(o.ringStore, o.arbiterScaler, config, o.lggr, o.config) pluginInfo := ocr3types.ReportingPluginInfo{ Name: "RingPlugin", Limits: ocr3types.ReportingPluginLimits{ diff --git a/pkg/workflows/ring/factory_test.go b/pkg/workflows/ring/factory_test.go index 557f872b5b..f2a1993f43 100644 --- a/pkg/workflows/ring/factory_test.go +++ b/pkg/workflows/ring/factory_test.go @@ -4,15 +4,18 @@ import ( "context" "testing" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) func TestFactory_NewFactory(t *testing.T) { lggr := logger.Test(t) store := NewStore() + shardOrchestratorStore := shardorchestrator.NewStore(lggr) arbiter := &mockArbiter{} tests := []struct { @@ -45,7 +48,7 @@ func TestFactory_NewFactory(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - f, err := NewFactory(store, tt.arbiter, lggr, tt.config) + f, err := NewFactory(store, shardOrchestratorStore, tt.arbiter, lggr, tt.config) if tt.wantErr { require.Error(t, err) require.Contains(t, err.Error(), tt.errSubstr) @@ -60,7 +63,7 @@ func TestFactory_NewFactory(t *testing.T) { func TestFactory_NewReportingPlugin(t *testing.T) { lggr := logger.Test(t) store := NewStore() - f, err := NewFactory(store, &mockArbiter{}, lggr, nil) + f, err := NewFactory(store, nil, &mockArbiter{}, lggr, nil) require.NoError(t, err) config := ocr3types.ReportingPluginConfig{N: 4, F: 1} @@ -75,7 +78,7 @@ func TestFactory_NewReportingPlugin(t *testing.T) { func TestFactory_Lifecycle(t *testing.T) { lggr := logger.Test(t) store := NewStore() - f, err := NewFactory(store, &mockArbiter{}, lggr, nil) + f, err := NewFactory(store, nil, &mockArbiter{}, lggr, nil) require.NoError(t, err) err = f.Start(context.Background()) diff --git a/pkg/workflows/ring/pb/generate.go b/pkg/workflows/ring/pb/generate.go index 850f3eeb44..bff63fddde 100644 --- a/pkg/workflows/ring/pb/generate.go +++ b/pkg/workflows/ring/pb/generate.go @@ -1,6 +1,5 @@ //go:generate protoc --go_out=. --go_opt=paths=source_relative shared.proto //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative arbiter.proto -//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative shard_orchestrator.proto //go:generate protoc --go_out=. --go_opt=paths=source_relative consensus.proto package pb diff --git a/pkg/workflows/ring/plugin_test.go b/pkg/workflows/ring/plugin_test.go index 99f70d1b9d..094e32c93a 100644 --- a/pkg/workflows/ring/plugin_test.go +++ b/pkg/workflows/ring/plugin_test.go @@ -12,10 +12,12 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" "github.com/smartcontractkit/libocr/offchainreporting2/types" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) type mockArbiter struct { @@ -443,7 +445,7 @@ func TestPlugin_NoHealthyShardsFallbackToShardZero(t *testing.T) { }) require.NoError(t, err) - transmitter := NewTransmitter(lggr, store, arbiter, "test-account") + transmitter := NewTransmitter(lggr, store, nil, arbiter, "test-account") ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() @@ -591,3 +593,158 @@ func TestPlugin_ObservationQuorum(t *testing.T) { require.True(t, quorum) }) } + +func TestPlugin_ShardOrchestratorIntegration(t *testing.T) { + lggr := logger.Test(t) + + // Create both stores + ringStore := NewStore() + orchestratorStore := shardorchestrator.NewStore(lggr) + + // Initialize ring store with healthy shards + ringStore.SetAllShardHealth(map[uint32]bool{0: true, 1: true, 2: true}) + + config := ocr3types.ReportingPluginConfig{ + N: 4, F: 1, + } + + arbiter := &mockArbiter{} + plugin, err := NewPlugin(ringStore, arbiter, config, lggr, &ConsensusConfig{ + BatchSize: 100, + TimeToSync: 1 * time.Second, + }) + require.NoError(t, err) + + // Create transmitter with both stores + transmitter := NewTransmitter(lggr, ringStore, orchestratorStore, arbiter, "test-account") + + ctx := context.Background() + now := time.Now() + + t.Run("initial_workflow_assignments", func(t *testing.T) { + // Create observations with workflows + workflows := []string{"wf-A", "wf-B", "wf-C"} + aos := makeObservationsWithWantShards(t, []map[uint32]*pb.ShardStatus{ + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + }, workflows, now, 3) + + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 1, + PreviousOutcome: nil, + } + + // Generate outcome + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + + // Generate report and transmit + reports, err := plugin.Reports(ctx, 1, outcome) + require.NoError(t, err) + require.Len(t, reports, 1) + + err = transmitter.Transmit(ctx, types.ConfigDigest{}, 1, reports[0].ReportWithInfo, nil) + require.NoError(t, err) + + // Verify ring store was updated + for _, wf := range workflows { + shard, err := ringStore.GetShardForWorkflow(ctx, wf) + require.NoError(t, err) + require.LessOrEqual(t, shard, uint32(2), "workflow should be assigned to valid shard") + t.Logf("Ring store: %s → shard %d", wf, shard) + } + + // Verify orchestrator store was updated with correct state + for _, wf := range workflows { + mapping, err := orchestratorStore.GetWorkflowMapping(ctx, wf) + require.NoError(t, err) + require.Equal(t, wf, mapping.WorkflowID) + require.LessOrEqual(t, mapping.NewShardID, uint32(2)) + require.Equal(t, uint32(0), mapping.OldShardID, "initial assignment should have oldShardID=0") + require.Equal(t, shardorchestrator.StateSteady, mapping.TransitionState, "initial assignment should be steady") + t.Logf("Orchestrator store: %s → shard %d (state: %s)", wf, mapping.NewShardID, mapping.TransitionState.String()) + } + + // Verify version tracking + version := orchestratorStore.GetMappingVersion() + require.Equal(t, uint64(1), version, "version should increment after first update") + }) + + t.Run("workflow_transition_detected", func(t *testing.T) { + // First, establish a baseline with workflows distributed across 3 shards + // Use wantShards=3 to ensure workflows actually get assigned to shard 2 + baselineAos := makeObservationsWithWantShards(t, []map[uint32]*pb.ShardStatus{ + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + }, []string{"wf-A", "wf-B", "wf-C", "wf-D", "wf-E"}, now, 3) + + baselineOutcome, err := plugin.Outcome(ctx, ocr3types.OutcomeContext{SeqNr: 2}, nil, baselineAos) + require.NoError(t, err) + + baselineReports, err := plugin.Reports(ctx, 2, baselineOutcome) + require.NoError(t, err) + + err = transmitter.Transmit(ctx, types.ConfigDigest{}, 2, baselineReports[0].ReportWithInfo, nil) + require.NoError(t, err) + + // Parse baseline to see which workflows were on shard 2 + baselineProto := &pb.Outcome{} + err = proto.Unmarshal(baselineOutcome, baselineProto) + require.NoError(t, err) + + workflowsOnShard2 := []string{} + for wfID, route := range baselineProto.Routes { + if route.Shard == 2 { + workflowsOnShard2 = append(workflowsOnShard2, wfID) + } + t.Logf("Baseline: %s on shard %d", wfID, route.Shard) + } + require.NotEmpty(t, workflowsOnShard2, "at least one workflow should be on shard 2 for this test") + + // Now scale down to 2 shards - workflows on shard 2 MUST move + transitionAos := makeObservationsWithWantShards(t, []map[uint32]*pb.ShardStatus{ + {0: {IsHealthy: true}, 1: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}}, + }, []string{"wf-A", "wf-B", "wf-C", "wf-D", "wf-E"}, now, 2) + + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 3, + PreviousOutcome: baselineOutcome, + } + + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, transitionAos) + require.NoError(t, err) + + reports, err := plugin.Reports(ctx, 3, outcome) + require.NoError(t, err) + + err = transmitter.Transmit(ctx, types.ConfigDigest{}, 3, reports[0].ReportWithInfo, nil) + require.NoError(t, err) + + // Verify orchestrator store shows transition state for workflows that moved from shard 2 + outcomeProto := &pb.Outcome{} + err = proto.Unmarshal(outcome, outcomeProto) + require.NoError(t, err) + + // Workflows that were on shard 2 must have moved and should show TransitionState + for _, wfID := range workflowsOnShard2 { + mapping, err := orchestratorStore.GetWorkflowMapping(ctx, wfID) + require.NoError(t, err) + + newRoute := outcomeProto.Routes[wfID] + require.NotEqual(t, uint32(2), newRoute.Shard, "workflow should have moved from shard 2") + require.Equal(t, shardorchestrator.StateTransitioning, mapping.TransitionState, + "workflow %s moved from shard 2 to shard %d, should be transitioning", wfID, newRoute.Shard) + require.Equal(t, uint32(2), mapping.OldShardID, "should track old shard") + require.Equal(t, newRoute.Shard, mapping.NewShardID, "should track new shard") + t.Logf("Workflow %s transitioned: shard 2 → %d", wfID, newRoute.Shard) + } + + // Verify version incremented + version := orchestratorStore.GetMappingVersion() + require.Equal(t, uint64(3), version, "version should increment after update") + }) +} diff --git a/pkg/workflows/ring/state.go b/pkg/workflows/ring/state.go index 62c26a5b18..f1751bf809 100644 --- a/pkg/workflows/ring/state.go +++ b/pkg/workflows/ring/state.go @@ -7,8 +7,25 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) +// TransitionStateFromBool converts a proto bool (in_transition) to TransitionState +func TransitionStateFromBool(inTransition bool) shardorchestrator.TransitionState { + if inTransition { + return shardorchestrator.StateTransitioning + } + return shardorchestrator.StateSteady +} + +// TransitionStateFromRoutingState returns the TransitionState based on RoutingState +func TransitionStateFromRoutingState(state *pb.RoutingState) shardorchestrator.TransitionState { + if IsInSteadyState(state) { + return shardorchestrator.StateSteady + } + return shardorchestrator.StateTransitioning +} + func IsInSteadyState(state *pb.RoutingState) bool { if state == nil { return false diff --git a/pkg/workflows/ring/transmitter.go b/pkg/workflows/ring/transmitter.go index 524be65be1..5e5f72708b 100644 --- a/pkg/workflows/ring/transmitter.go +++ b/pkg/workflows/ring/transmitter.go @@ -5,24 +5,33 @@ import ( "google.golang.org/protobuf/proto" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) var _ ocr3types.ContractTransmitter[[]byte] = (*Transmitter)(nil) // Transmitter handles transmission of shard orchestration outcomes type Transmitter struct { - lggr logger.Logger - store *Store - arbiterScaler pb.ArbiterScalerClient - fromAccount types.Account + lggr logger.Logger + ringStore *Store + shardOrchestratorStore *shardorchestrator.Store + arbiterScaler pb.ArbiterScalerClient + fromAccount types.Account } -func NewTransmitter(lggr logger.Logger, store *Store, arbiterScaler pb.ArbiterScalerClient, fromAccount types.Account) *Transmitter { - return &Transmitter{lggr: lggr, store: store, arbiterScaler: arbiterScaler, fromAccount: fromAccount} +func NewTransmitter(lggr logger.Logger, ringStore *Store, shardOrchestratorStore *shardorchestrator.Store, arbiterScaler pb.ArbiterScalerClient, fromAccount types.Account) *Transmitter { + return &Transmitter{ + lggr: lggr, + ringStore: ringStore, + shardOrchestratorStore: shardOrchestratorStore, + arbiterScaler: arbiterScaler, + fromAccount: fromAccount, + } } func (t *Transmitter) Transmit(ctx context.Context, _ types.ConfigDigest, _ uint64, r ocr3types.ReportWithInfo[[]byte], _ []types.AttributedOnchainSignature) error { @@ -37,10 +46,63 @@ func (t *Transmitter) Transmit(ctx context.Context, _ types.ConfigDigest, _ uint return err } - t.store.SetRoutingState(outcome.State) + // Update Ring Store + t.ringStore.SetRoutingState(outcome.State) + + // Determine if system is in transition state + systemInTransition := false + if outcome.State != nil { + if _, ok := outcome.State.State.(*pb.RoutingState_Transition); ok { + systemInTransition = true + } + } + + // Update ShardOrchestrator store if available + if t.shardOrchestratorStore != nil { + mappings := make([]*shardorchestrator.WorkflowMappingState, 0, len(outcome.Routes)) + for workflowID, route := range outcome.Routes { + // Get the current shard assignment for this workflow to detect changes + var oldShardID uint32 + var transitionState shardorchestrator.TransitionState + + existingMapping, err := t.shardOrchestratorStore.GetWorkflowMapping(ctx, workflowID) + if err != nil { + // New workflow - no previous assignment + oldShardID = 0 + transitionState = shardorchestrator.StateSteady + } else if existingMapping.NewShardID != route.Shard { + // Workflow is moving to a different shard + oldShardID = existingMapping.NewShardID + transitionState = shardorchestrator.StateTransitioning + } else { + // Same shard - but might be in system transition + oldShardID = existingMapping.NewShardID + if systemInTransition { + transitionState = shardorchestrator.StateTransitioning + } else { + transitionState = shardorchestrator.StateSteady + } + } + + mappings = append(mappings, &shardorchestrator.WorkflowMappingState{ + WorkflowID: workflowID, + OldShardID: oldShardID, + NewShardID: route.Shard, + TransitionState: transitionState, + }) + } + + if err := t.shardOrchestratorStore.BatchUpdateWorkflowMappings(ctx, mappings); err != nil { + t.lggr.Errorw("failed to update ShardOrchestrator store", "err", err, "workflowCount", len(mappings)) + // Don't fail the entire transmission if ShardOrchestrator update fails + } else { + t.lggr.Debugw("Updated ShardOrchestrator store", "workflowCount", len(mappings)) + } + } + // Update Ring Store workflow mappings for workflowID, route := range outcome.Routes { - t.store.SetShardForWorkflow(workflowID, route.Shard) + t.ringStore.SetShardForWorkflow(workflowID, route.Shard) t.lggr.Debugw("Updated workflow shard mapping", "workflowID", workflowID, "shard", route.Shard) } diff --git a/pkg/workflows/ring/transmitter_test.go b/pkg/workflows/ring/transmitter_test.go index 9fde000152..087f4aaf52 100644 --- a/pkg/workflows/ring/transmitter_test.go +++ b/pkg/workflows/ring/transmitter_test.go @@ -9,10 +9,11 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/emptypb" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" ) type mockArbiterScaler struct { @@ -37,14 +38,14 @@ func (m *mockArbiterScaler) ConsensusWantShards(ctx context.Context, req *pb.Con func TestTransmitter_NewTransmitter(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "test-account") + tx := NewTransmitter(lggr, store, nil, nil, "test-account") require.NotNil(t, tx) } func TestTransmitter_FromAccount(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "my-account") + tx := NewTransmitter(lggr, store, nil, nil, "my-account") account, err := tx.FromAccount(context.Background()) require.NoError(t, err) @@ -55,7 +56,7 @@ func TestTransmitter_Transmit(t *testing.T) { lggr := logger.Test(t) store := NewStore() mock := &mockArbiterScaler{} - tx := NewTransmitter(lggr, store, mock, "test-account") + tx := NewTransmitter(lggr, store, nil, mock, "test-account") outcome := &pb.Outcome{ State: &pb.RoutingState{ @@ -88,7 +89,7 @@ func TestTransmitter_Transmit(t *testing.T) { func TestTransmitter_Transmit_NilArbiter(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "test-account") + tx := NewTransmitter(lggr, store, nil, nil, "test-account") outcome := &pb.Outcome{ State: &pb.RoutingState{ @@ -107,7 +108,7 @@ func TestTransmitter_Transmit_TransitionState(t *testing.T) { lggr := logger.Test(t) store := NewStore() mock := &mockArbiterScaler{} - tx := NewTransmitter(lggr, store, mock, "test-account") + tx := NewTransmitter(lggr, store, nil, mock, "test-account") outcome := &pb.Outcome{ State: &pb.RoutingState{ @@ -127,7 +128,7 @@ func TestTransmitter_Transmit_TransitionState(t *testing.T) { func TestTransmitter_Transmit_InvalidReport(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "test-account") + tx := NewTransmitter(lggr, store, nil, nil, "test-account") // Send invalid protobuf data report := ocr3types.ReportWithInfo[[]byte]{Report: []byte("invalid protobuf")} @@ -139,7 +140,7 @@ func TestTransmitter_Transmit_ArbiterError(t *testing.T) { lggr := logger.Test(t) store := NewStore() mock := &mockArbiterScaler{err: context.DeadlineExceeded} - tx := NewTransmitter(lggr, store, mock, "test-account") + tx := NewTransmitter(lggr, store, nil, mock, "test-account") outcome := &pb.Outcome{ State: &pb.RoutingState{ @@ -156,7 +157,7 @@ func TestTransmitter_Transmit_ArbiterError(t *testing.T) { func TestTransmitter_Transmit_NilState(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "test-account") + tx := NewTransmitter(lggr, store, nil, nil, "test-account") outcome := &pb.Outcome{ State: nil, diff --git a/pkg/workflows/shardorchestrator/client.go b/pkg/workflows/shardorchestrator/client.go new file mode 100644 index 0000000000..c7c9ed0f85 --- /dev/null +++ b/pkg/workflows/shardorchestrator/client.go @@ -0,0 +1,78 @@ +package shardorchestrator + +import ( + "context" + "fmt" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb" +) + +// Client wraps gRPC client for communicating with shard 0's orchestrator service +type Client struct { + conn *grpc.ClientConn + client pb.ShardOrchestratorServiceClient + logger logger.Logger +} + +// NewClient creates a new gRPC client to communicate with the shard orchestrator on shard 0 +func NewClient(ctx context.Context, address string, lggr logger.Logger) (*Client, error) { + conn, err := grpc.NewClient(address, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create shard orchestrator client for %s: %w", address, err) + } + + return &Client{ + conn: conn, + client: pb.NewShardOrchestratorServiceClient(conn), + logger: logger.Named(lggr, "ShardOrchestratorClient"), + }, nil +} + +// GetWorkflowShardMapping queries shard 0 for workflow-to-shard mappings +func (c *Client) GetWorkflowShardMapping(ctx context.Context, workflowIDs []string) (*pb.GetWorkflowShardMappingResponse, error) { + c.logger.Debugw("Calling GetWorkflowShardMapping", "workflowCount", len(workflowIDs)) + + req := &pb.GetWorkflowShardMappingRequest{ + WorkflowIds: workflowIDs, + } + + resp, err := c.client.GetWorkflowShardMapping(ctx, req) + if err != nil { + return nil, fmt.Errorf("gRPC GetWorkflowShardMapping failed: %w", err) + } + + c.logger.Debugw("GetWorkflowShardMapping response received", + "mappingCount", len(resp.Mappings), + "version", resp.MappingVersion) + + return resp, nil +} + +// ReportWorkflowTriggerRegistration reports workflow trigger registration to shard 0 +func (c *Client) ReportWorkflowTriggerRegistration(ctx context.Context, req *pb.ReportWorkflowTriggerRegistrationRequest) (*pb.ReportWorkflowTriggerRegistrationResponse, error) { + c.logger.Debugw("Calling ReportWorkflowTriggerRegistration", + "shardID", req.SourceShardId, + "workflowCount", len(req.RegisteredWorkflows)) + + resp, err := c.client.ReportWorkflowTriggerRegistration(ctx, req) + if err != nil { + return nil, fmt.Errorf("gRPC ReportWorkflowTriggerRegistration failed: %w", err) + } + + c.logger.Debugw("ReportWorkflowTriggerRegistration response received", + "success", resp.Success) + + return resp, nil +} + +// Close closes the gRPC connection +func (c *Client) Close() error { + c.logger.Info("Closing ShardOrchestrator gRPC client") + return c.conn.Close() +} diff --git a/pkg/workflows/shardorchestrator/client_test.go b/pkg/workflows/shardorchestrator/client_test.go new file mode 100644 index 0000000000..a090c531f1 --- /dev/null +++ b/pkg/workflows/shardorchestrator/client_test.go @@ -0,0 +1,212 @@ +package shardorchestrator + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb" +) + +const bufSize = 1024 * 1024 + +// mockShardOrchestratorServer implements the gRPC server for testing +type mockShardOrchestratorServer struct { + pb.UnimplementedShardOrchestratorServiceServer + mappings map[string]uint32 + registrationCalled bool +} + +func (m *mockShardOrchestratorServer) GetWorkflowShardMapping(ctx context.Context, req *pb.GetWorkflowShardMappingRequest) (*pb.GetWorkflowShardMappingResponse, error) { + mappings := make(map[string]uint32) + mappingStates := make(map[string]*pb.WorkflowMappingState) + + for _, wfID := range req.WorkflowIds { + if shardID, ok := m.mappings[wfID]; ok { + mappings[wfID] = shardID + mappingStates[wfID] = &pb.WorkflowMappingState{ + OldShardId: 0, + NewShardId: shardID, + InTransition: false, + } + } + } + + return &pb.GetWorkflowShardMappingResponse{ + Mappings: mappings, + MappingStates: mappingStates, + MappingVersion: 1, + }, nil +} + +func (m *mockShardOrchestratorServer) ReportWorkflowTriggerRegistration(ctx context.Context, req *pb.ReportWorkflowTriggerRegistrationRequest) (*pb.ReportWorkflowTriggerRegistrationResponse, error) { + m.registrationCalled = true + return &pb.ReportWorkflowTriggerRegistrationResponse{ + Success: true, + }, nil +} + +// setupTestServer creates a test gRPC server using bufconn +func setupTestServer(t *testing.T, mock *mockShardOrchestratorServer) (*grpc.Server, *bufconn.Listener) { + lis := bufconn.Listen(bufSize) + s := grpc.NewServer() + pb.RegisterShardOrchestratorServiceServer(s, mock) + + go func() { + if err := s.Serve(lis); err != nil { + t.Logf("Server exited with error: %v", err) + } + }() + + return s, lis +} + +// createTestClient creates a client connected to the test server +func createTestClient(t *testing.T, lis *bufconn.Listener) *Client { + conn, err := grpc.NewClient("passthrough://bufnet", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return lis.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + + lggr := logger.Test(t) + return &Client{ + conn: conn, + client: pb.NewShardOrchestratorServiceClient(conn), + logger: logger.Named(lggr, "TestClient"), + } +} + +func TestClient_GetWorkflowShardMapping(t *testing.T) { + ctx := context.Background() + + mock := &mockShardOrchestratorServer{ + mappings: map[string]uint32{ + "workflow-1": 0, + "workflow-2": 1, + "workflow-3": 2, + }, + } + + grpcServer, lis := setupTestServer(t, mock) + defer grpcServer.Stop() + + client := createTestClient(t, lis) + defer client.Close() + + t.Run("successful mapping query", func(t *testing.T) { + workflowIDs := []string{"workflow-1", "workflow-2", "workflow-3"} + resp, err := client.GetWorkflowShardMapping(ctx, workflowIDs) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Len(t, resp.Mappings, 3) + assert.Equal(t, uint32(0), resp.Mappings["workflow-1"]) + assert.Equal(t, uint32(1), resp.Mappings["workflow-2"]) + assert.Equal(t, uint32(2), resp.Mappings["workflow-3"]) + + assert.Len(t, resp.MappingStates, 3) + assert.Equal(t, uint64(1), resp.MappingVersion) + }) + + t.Run("partial workflow query", func(t *testing.T) { + workflowIDs := []string{"workflow-1", "workflow-unknown"} + resp, err := client.GetWorkflowShardMapping(ctx, workflowIDs) + require.NoError(t, err) + require.NotNil(t, resp) + + // Should only return mappings for known workflows + assert.Len(t, resp.Mappings, 1) + assert.Equal(t, uint32(0), resp.Mappings["workflow-1"]) + _, exists := resp.Mappings["workflow-unknown"] + assert.False(t, exists) + }) + + t.Run("empty workflow list", func(t *testing.T) { + resp, err := client.GetWorkflowShardMapping(ctx, []string{}) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Empty(t, resp.Mappings) + }) +} + +func TestClient_ReportWorkflowTriggerRegistration(t *testing.T) { + ctx := context.Background() + + mock := &mockShardOrchestratorServer{ + mappings: map[string]uint32{}, + } + + grpcServer, lis := setupTestServer(t, mock) + defer grpcServer.Stop() + + client := createTestClient(t, lis) + defer client.Close() + + t.Run("successful registration report", func(t *testing.T) { + req := &pb.ReportWorkflowTriggerRegistrationRequest{ + SourceShardId: 1, + RegisteredWorkflows: map[string]uint32{ + "workflow-1": 1, + "workflow-2": 1, + }, + TotalActiveWorkflows: 2, + } + + resp, err := client.ReportWorkflowTriggerRegistration(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.True(t, resp.Success) + assert.True(t, mock.registrationCalled) + }) +} + +func TestClient_Close(t *testing.T) { + mock := &mockShardOrchestratorServer{ + mappings: map[string]uint32{}, + } + + grpcServer, lis := setupTestServer(t, mock) + defer grpcServer.Stop() + + client := createTestClient(t, lis) + + err := client.Close() + assert.NoError(t, err) + + // Verify connection is closed by attempting to use it + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = client.GetWorkflowShardMapping(ctx, []string{"test"}) + assert.Error(t, err, "should fail after client is closed") +} + +func TestNewClient(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + + t.Run("creates client successfully", func(t *testing.T) { + // Note: This creates a client but doesn't connect immediately with grpc.NewClient + client, err := NewClient(ctx, "localhost:50051", lggr) + require.NoError(t, err) + require.NotNil(t, client) + defer client.Close() + + assert.NotNil(t, client.conn) + assert.NotNil(t, client.client) + assert.NotNil(t, client.logger) + }) +} diff --git a/pkg/workflows/shardorchestrator/pb/generate.go b/pkg/workflows/shardorchestrator/pb/generate.go new file mode 100644 index 0000000000..be40262383 --- /dev/null +++ b/pkg/workflows/shardorchestrator/pb/generate.go @@ -0,0 +1,3 @@ +package pb + +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative shard_orchestrator.proto diff --git a/pkg/workflows/ring/pb/shard_orchestrator.pb.go b/pkg/workflows/shardorchestrator/pb/shard_orchestrator.pb.go similarity index 68% rename from pkg/workflows/ring/pb/shard_orchestrator.pb.go rename to pkg/workflows/shardorchestrator/pb/shard_orchestrator.pb.go index 7a3f8491e8..ba90ab901c 100644 --- a/pkg/workflows/ring/pb/shard_orchestrator.pb.go +++ b/pkg/workflows/shardorchestrator/pb/shard_orchestrator.pb.go @@ -9,7 +9,6 @@ package pb import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" - timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" unsafe "unsafe" @@ -71,7 +70,6 @@ type WorkflowMappingState struct { OldShardId uint32 `protobuf:"varint,1,opt,name=old_shard_id,json=oldShardId,proto3" json:"old_shard_id,omitempty"` NewShardId uint32 `protobuf:"varint,2,opt,name=new_shard_id,json=newShardId,proto3" json:"new_shard_id,omitempty"` InTransition bool `protobuf:"varint,3,opt,name=in_transition,json=inTransition,proto3" json:"in_transition,omitempty"` - LastUpdated *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=last_updated,json=lastUpdated,proto3" json:"last_updated,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -127,19 +125,11 @@ func (x *WorkflowMappingState) GetInTransition() bool { return false } -func (x *WorkflowMappingState) GetLastUpdated() *timestamppb.Timestamp { - if x != nil { - return x.LastUpdated - } - return nil -} - type GetWorkflowShardMappingResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Mappings map[string]uint32 `protobuf:"bytes,1,rep,name=mappings,proto3" json:"mappings,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"varint,2,opt,name=value"` MappingStates map[string]*WorkflowMappingState `protobuf:"bytes,2,rep,name=mapping_states,json=mappingStates,proto3" json:"mapping_states,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` - Timestamp *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=timestamp,proto3" json:"timestamp,omitempty"` - MappingVersion uint64 `protobuf:"varint,4,opt,name=mapping_version,json=mappingVersion,proto3" json:"mapping_version,omitempty"` + MappingVersion uint64 `protobuf:"varint,3,opt,name=mapping_version,json=mappingVersion,proto3" json:"mapping_version,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -188,13 +178,6 @@ func (x *GetWorkflowShardMappingResponse) GetMappingStates() map[string]*Workflo return nil } -func (x *GetWorkflowShardMappingResponse) GetTimestamp() *timestamppb.Timestamp { - if x != nil { - return x.Timestamp - } - return nil -} - func (x *GetWorkflowShardMappingResponse) GetMappingVersion() uint64 { if x != nil { return x.MappingVersion @@ -206,8 +189,7 @@ type ReportWorkflowTriggerRegistrationRequest struct { state protoimpl.MessageState `protogen:"open.v1"` SourceShardId uint32 `protobuf:"varint,1,opt,name=source_shard_id,json=sourceShardId,proto3" json:"source_shard_id,omitempty"` RegisteredWorkflows map[string]uint32 `protobuf:"bytes,2,rep,name=registered_workflows,json=registeredWorkflows,proto3" json:"registered_workflows,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"varint,2,opt,name=value"` - ReportTimestamp *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=report_timestamp,json=reportTimestamp,proto3" json:"report_timestamp,omitempty"` - TotalActiveWorkflows uint32 `protobuf:"varint,4,opt,name=total_active_workflows,json=totalActiveWorkflows,proto3" json:"total_active_workflows,omitempty"` + TotalActiveWorkflows uint32 `protobuf:"varint,3,opt,name=total_active_workflows,json=totalActiveWorkflows,proto3" json:"total_active_workflows,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -256,13 +238,6 @@ func (x *ReportWorkflowTriggerRegistrationRequest) GetRegisteredWorkflows() map[ return nil } -func (x *ReportWorkflowTriggerRegistrationRequest) GetReportTimestamp() *timestamppb.Timestamp { - if x != nil { - return x.ReportTimestamp - } - return nil -} - func (x *ReportWorkflowTriggerRegistrationRequest) GetTotalActiveWorkflows() uint32 { if x != nil { return x.TotalActiveWorkflows @@ -318,40 +293,37 @@ var File_shard_orchestrator_proto protoreflect.FileDescriptor const file_shard_orchestrator_proto_rawDesc = "" + "\n" + - "\x18shard_orchestrator.proto\x12\x04ring\x1a\x1fgoogle/protobuf/timestamp.proto\"C\n" + + "\x18shard_orchestrator.proto\x12\x11shardorchestrator\"C\n" + "\x1eGetWorkflowShardMappingRequest\x12!\n" + - "\fworkflow_ids\x18\x01 \x03(\tR\vworkflowIds\"\xbe\x01\n" + + "\fworkflow_ids\x18\x01 \x03(\tR\vworkflowIds\"\x7f\n" + "\x14WorkflowMappingState\x12 \n" + "\fold_shard_id\x18\x01 \x01(\rR\n" + "oldShardId\x12 \n" + "\fnew_shard_id\x18\x02 \x01(\rR\n" + "newShardId\x12#\n" + - "\rin_transition\x18\x03 \x01(\bR\finTransition\x12=\n" + - "\flast_updated\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\vlastUpdated\"\xd1\x03\n" + - "\x1fGetWorkflowShardMappingResponse\x12O\n" + - "\bmappings\x18\x01 \x03(\v23.ring.GetWorkflowShardMappingResponse.MappingsEntryR\bmappings\x12_\n" + - "\x0emapping_states\x18\x02 \x03(\v28.ring.GetWorkflowShardMappingResponse.MappingStatesEntryR\rmappingStates\x128\n" + - "\ttimestamp\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12'\n" + - "\x0fmapping_version\x18\x04 \x01(\x04R\x0emappingVersion\x1a;\n" + + "\rin_transition\x18\x03 \x01(\bR\finTransition\"\xbe\x03\n" + + "\x1fGetWorkflowShardMappingResponse\x12\\\n" + + "\bmappings\x18\x01 \x03(\v2@.shardorchestrator.GetWorkflowShardMappingResponse.MappingsEntryR\bmappings\x12l\n" + + "\x0emapping_states\x18\x02 \x03(\v2E.shardorchestrator.GetWorkflowShardMappingResponse.MappingStatesEntryR\rmappingStates\x12'\n" + + "\x0fmapping_version\x18\x03 \x01(\x04R\x0emappingVersion\x1a;\n" + "\rMappingsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + - "\x05value\x18\x02 \x01(\rR\x05value:\x028\x01\x1a\\\n" + + "\x05value\x18\x02 \x01(\rR\x05value:\x028\x01\x1ai\n" + "\x12MappingStatesEntry\x12\x10\n" + - "\x03key\x18\x01 \x01(\tR\x03key\x120\n" + - "\x05value\x18\x02 \x01(\v2\x1a.ring.WorkflowMappingStateR\x05value:\x028\x01\"\x93\x03\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12=\n" + + "\x05value\x18\x02 \x01(\v2'.shardorchestrator.WorkflowMappingStateR\x05value:\x028\x01\"\xda\x02\n" + "(ReportWorkflowTriggerRegistrationRequest\x12&\n" + - "\x0fsource_shard_id\x18\x01 \x01(\rR\rsourceShardId\x12z\n" + - "\x14registered_workflows\x18\x02 \x03(\v2G.ring.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntryR\x13registeredWorkflows\x12E\n" + - "\x10report_timestamp\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\x0freportTimestamp\x124\n" + - "\x16total_active_workflows\x18\x04 \x01(\rR\x14totalActiveWorkflows\x1aF\n" + + "\x0fsource_shard_id\x18\x01 \x01(\rR\rsourceShardId\x12\x87\x01\n" + + "\x14registered_workflows\x18\x02 \x03(\v2T.shardorchestrator.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntryR\x13registeredWorkflows\x124\n" + + "\x16total_active_workflows\x18\x03 \x01(\rR\x14totalActiveWorkflows\x1aF\n" + "\x18RegisteredWorkflowsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + "\x05value\x18\x02 \x01(\rR\x05value:\x028\x01\"E\n" + ")ReportWorkflowTriggerRegistrationResponse\x12\x18\n" + - "\asuccess\x18\x01 \x01(\bR\asuccess2\x89\x02\n" + - "\x18ShardOrchestratorService\x12f\n" + - "\x17GetWorkflowShardMapping\x12$.ring.GetWorkflowShardMappingRequest\x1a%.ring.GetWorkflowShardMappingResponse\x12\x84\x01\n" + - "!ReportWorkflowTriggerRegistration\x12..ring.ReportWorkflowTriggerRegistrationRequest\x1a/.ring.ReportWorkflowTriggerRegistrationResponseBDZBgithub.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pbb\x06proto3" + "\asuccess\x18\x01 \x01(\bR\asuccess2\xbe\x02\n" + + "\x18ShardOrchestratorService\x12\x80\x01\n" + + "\x17GetWorkflowShardMapping\x121.shardorchestrator.GetWorkflowShardMappingRequest\x1a2.shardorchestrator.GetWorkflowShardMappingResponse\x12\x9e\x01\n" + + "!ReportWorkflowTriggerRegistration\x12;.shardorchestrator.ReportWorkflowTriggerRegistrationRequest\x1a<.shardorchestrator.ReportWorkflowTriggerRegistrationResponseBQZOgithub.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pbb\x06proto3" var ( file_shard_orchestrator_proto_rawDescOnce sync.Once @@ -367,33 +339,29 @@ func file_shard_orchestrator_proto_rawDescGZIP() []byte { var file_shard_orchestrator_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_shard_orchestrator_proto_goTypes = []any{ - (*GetWorkflowShardMappingRequest)(nil), // 0: ring.GetWorkflowShardMappingRequest - (*WorkflowMappingState)(nil), // 1: ring.WorkflowMappingState - (*GetWorkflowShardMappingResponse)(nil), // 2: ring.GetWorkflowShardMappingResponse - (*ReportWorkflowTriggerRegistrationRequest)(nil), // 3: ring.ReportWorkflowTriggerRegistrationRequest - (*ReportWorkflowTriggerRegistrationResponse)(nil), // 4: ring.ReportWorkflowTriggerRegistrationResponse - nil, // 5: ring.GetWorkflowShardMappingResponse.MappingsEntry - nil, // 6: ring.GetWorkflowShardMappingResponse.MappingStatesEntry - nil, // 7: ring.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry - (*timestamppb.Timestamp)(nil), // 8: google.protobuf.Timestamp + (*GetWorkflowShardMappingRequest)(nil), // 0: shardorchestrator.GetWorkflowShardMappingRequest + (*WorkflowMappingState)(nil), // 1: shardorchestrator.WorkflowMappingState + (*GetWorkflowShardMappingResponse)(nil), // 2: shardorchestrator.GetWorkflowShardMappingResponse + (*ReportWorkflowTriggerRegistrationRequest)(nil), // 3: shardorchestrator.ReportWorkflowTriggerRegistrationRequest + (*ReportWorkflowTriggerRegistrationResponse)(nil), // 4: shardorchestrator.ReportWorkflowTriggerRegistrationResponse + nil, // 5: shardorchestrator.GetWorkflowShardMappingResponse.MappingsEntry + nil, // 6: shardorchestrator.GetWorkflowShardMappingResponse.MappingStatesEntry + nil, // 7: shardorchestrator.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry } var file_shard_orchestrator_proto_depIdxs = []int32{ - 8, // 0: ring.WorkflowMappingState.last_updated:type_name -> google.protobuf.Timestamp - 5, // 1: ring.GetWorkflowShardMappingResponse.mappings:type_name -> ring.GetWorkflowShardMappingResponse.MappingsEntry - 6, // 2: ring.GetWorkflowShardMappingResponse.mapping_states:type_name -> ring.GetWorkflowShardMappingResponse.MappingStatesEntry - 8, // 3: ring.GetWorkflowShardMappingResponse.timestamp:type_name -> google.protobuf.Timestamp - 7, // 4: ring.ReportWorkflowTriggerRegistrationRequest.registered_workflows:type_name -> ring.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry - 8, // 5: ring.ReportWorkflowTriggerRegistrationRequest.report_timestamp:type_name -> google.protobuf.Timestamp - 1, // 6: ring.GetWorkflowShardMappingResponse.MappingStatesEntry.value:type_name -> ring.WorkflowMappingState - 0, // 7: ring.ShardOrchestratorService.GetWorkflowShardMapping:input_type -> ring.GetWorkflowShardMappingRequest - 3, // 8: ring.ShardOrchestratorService.ReportWorkflowTriggerRegistration:input_type -> ring.ReportWorkflowTriggerRegistrationRequest - 2, // 9: ring.ShardOrchestratorService.GetWorkflowShardMapping:output_type -> ring.GetWorkflowShardMappingResponse - 4, // 10: ring.ShardOrchestratorService.ReportWorkflowTriggerRegistration:output_type -> ring.ReportWorkflowTriggerRegistrationResponse - 9, // [9:11] is the sub-list for method output_type - 7, // [7:9] is the sub-list for method input_type - 7, // [7:7] is the sub-list for extension type_name - 7, // [7:7] is the sub-list for extension extendee - 0, // [0:7] is the sub-list for field type_name + 5, // 0: shardorchestrator.GetWorkflowShardMappingResponse.mappings:type_name -> shardorchestrator.GetWorkflowShardMappingResponse.MappingsEntry + 6, // 1: shardorchestrator.GetWorkflowShardMappingResponse.mapping_states:type_name -> shardorchestrator.GetWorkflowShardMappingResponse.MappingStatesEntry + 7, // 2: shardorchestrator.ReportWorkflowTriggerRegistrationRequest.registered_workflows:type_name -> shardorchestrator.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry + 1, // 3: shardorchestrator.GetWorkflowShardMappingResponse.MappingStatesEntry.value:type_name -> shardorchestrator.WorkflowMappingState + 0, // 4: shardorchestrator.ShardOrchestratorService.GetWorkflowShardMapping:input_type -> shardorchestrator.GetWorkflowShardMappingRequest + 3, // 5: shardorchestrator.ShardOrchestratorService.ReportWorkflowTriggerRegistration:input_type -> shardorchestrator.ReportWorkflowTriggerRegistrationRequest + 2, // 6: shardorchestrator.ShardOrchestratorService.GetWorkflowShardMapping:output_type -> shardorchestrator.GetWorkflowShardMappingResponse + 4, // 7: shardorchestrator.ShardOrchestratorService.ReportWorkflowTriggerRegistration:output_type -> shardorchestrator.ReportWorkflowTriggerRegistrationResponse + 6, // [6:8] is the sub-list for method output_type + 4, // [4:6] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name } func init() { file_shard_orchestrator_proto_init() } diff --git a/pkg/workflows/ring/pb/shard_orchestrator.proto b/pkg/workflows/shardorchestrator/pb/shard_orchestrator.proto similarity index 78% rename from pkg/workflows/ring/pb/shard_orchestrator.proto rename to pkg/workflows/shardorchestrator/pb/shard_orchestrator.proto index c7e3c1668e..1d9fe6a6cf 100644 --- a/pkg/workflows/ring/pb/shard_orchestrator.proto +++ b/pkg/workflows/shardorchestrator/pb/shard_orchestrator.proto @@ -1,10 +1,8 @@ syntax = "proto3"; -package ring; +package shardorchestrator; -import "google/protobuf/timestamp.proto"; - -option go_package = "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"; +option go_package = "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb"; message GetWorkflowShardMappingRequest { repeated string workflow_ids = 1; @@ -14,21 +12,18 @@ message WorkflowMappingState { uint32 old_shard_id = 1; uint32 new_shard_id = 2; bool in_transition = 3; - google.protobuf.Timestamp last_updated = 4; } message GetWorkflowShardMappingResponse { map mappings = 1; map mapping_states = 2; - google.protobuf.Timestamp timestamp = 3; - uint64 mapping_version = 4; + uint64 mapping_version = 3; } message ReportWorkflowTriggerRegistrationRequest { uint32 source_shard_id = 1; map registered_workflows = 2; - google.protobuf.Timestamp report_timestamp = 3; - uint32 total_active_workflows = 4; + uint32 total_active_workflows = 3; } message ReportWorkflowTriggerRegistrationResponse { diff --git a/pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go b/pkg/workflows/shardorchestrator/pb/shard_orchestrator_grpc.pb.go similarity index 96% rename from pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go rename to pkg/workflows/shardorchestrator/pb/shard_orchestrator_grpc.pb.go index d2ab234c3a..d099fd0aca 100644 --- a/pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go +++ b/pkg/workflows/shardorchestrator/pb/shard_orchestrator_grpc.pb.go @@ -19,8 +19,8 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - ShardOrchestratorService_GetWorkflowShardMapping_FullMethodName = "/ring.ShardOrchestratorService/GetWorkflowShardMapping" - ShardOrchestratorService_ReportWorkflowTriggerRegistration_FullMethodName = "/ring.ShardOrchestratorService/ReportWorkflowTriggerRegistration" + ShardOrchestratorService_GetWorkflowShardMapping_FullMethodName = "/shardorchestrator.ShardOrchestratorService/GetWorkflowShardMapping" + ShardOrchestratorService_ReportWorkflowTriggerRegistration_FullMethodName = "/shardorchestrator.ShardOrchestratorService/ReportWorkflowTriggerRegistration" ) // ShardOrchestratorServiceClient is the client API for ShardOrchestratorService service. @@ -143,7 +143,7 @@ func _ShardOrchestratorService_ReportWorkflowTriggerRegistration_Handler(srv int // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var ShardOrchestratorService_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "ring.ShardOrchestratorService", + ServiceName: "shardorchestrator.ShardOrchestratorService", HandlerType: (*ShardOrchestratorServiceServer)(nil), Methods: []grpc.MethodDesc{ { diff --git a/pkg/workflows/shardorchestrator/service.go b/pkg/workflows/shardorchestrator/service.go new file mode 100644 index 0000000000..4d0b8cb62b --- /dev/null +++ b/pkg/workflows/shardorchestrator/service.go @@ -0,0 +1,108 @@ +package shardorchestrator + +import ( + "context" + "fmt" + + "google.golang.org/grpc" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb" +) + +// Server implements the gRPC ShardOrchestratorService +// This runs on shard zero and serves requests from other shards +type Server struct { + pb.UnimplementedShardOrchestratorServiceServer + store *Store + logger logger.Logger +} + +func NewServer(store *Store, lggr logger.Logger) *Server { + return &Server{ + store: store, + logger: logger.Named(lggr, "ShardOrchestratorServer"), + } +} + +// RegisterWithGRPCServer registers this service with a gRPC server +func (s *Server) RegisterWithGRPCServer(grpcServer *grpc.Server) { + pb.RegisterShardOrchestratorServiceServer(grpcServer, s) + s.logger.Info("Registered ShardOrchestrator gRPC service") +} + +// GetWorkflowShardMapping handles batch requests for workflow-to-shard mappings +// This is called by other shards to determine where to route workflow executions +func (s *Server) GetWorkflowShardMapping(ctx context.Context, req *pb.GetWorkflowShardMappingRequest) (*pb.GetWorkflowShardMappingResponse, error) { + s.logger.Debugw("GetWorkflowShardMapping called", "workflowCount", len(req.WorkflowIds)) + + if len(req.WorkflowIds) == 0 { + return nil, fmt.Errorf("workflow_ids is required and must not be empty") + } + + // Retrieve batch from store + mappings, version, err := s.store.GetWorkflowMappingsBatch(ctx, req.WorkflowIds) + if err != nil { + s.logger.Errorw("Failed to get workflow mappings", "error", err) + return nil, fmt.Errorf("failed to get workflow mappings: %w", err) + } + + // Build simple mappings map (workflow_id -> shard_id) + simpleMappings := make(map[string]uint32, len(mappings)) + // Build detailed mapping states + mappingStates := make(map[string]*pb.WorkflowMappingState, len(mappings)) + + for workflowID, mapping := range mappings { + // Simple mapping: just the current shard + simpleMappings[workflowID] = mapping.NewShardID + + // Detailed state: includes transition information + mappingStates[workflowID] = &pb.WorkflowMappingState{ + OldShardId: mapping.OldShardID, + NewShardId: mapping.NewShardID, + InTransition: mapping.TransitionState.InTransition(), + } + } + + return &pb.GetWorkflowShardMappingResponse{ + Mappings: simpleMappings, + MappingStates: mappingStates, + MappingVersion: version, + }, nil +} + +// ReportWorkflowTriggerRegistration handles shard registration reports +// Shards call this to inform shard zero about which workflows they have loaded +func (s *Server) ReportWorkflowTriggerRegistration(ctx context.Context, req *pb.ReportWorkflowTriggerRegistrationRequest) (*pb.ReportWorkflowTriggerRegistrationResponse, error) { + s.logger.Debugw("ReportWorkflowTriggerRegistration called", + "shardID", req.SourceShardId, + "workflowCount", len(req.RegisteredWorkflows), + "totalActive", req.TotalActiveWorkflows, + ) + + // Extract workflow IDs from the map + workflowIDs := make([]string, 0, len(req.RegisteredWorkflows)) + for workflowID := range req.RegisteredWorkflows { + workflowIDs = append(workflowIDs, workflowID) + } + + err := s.store.ReportShardRegistration(ctx, req.SourceShardId, workflowIDs) + if err != nil { + s.logger.Errorw("Failed to update shard registrations", + "shardID", req.SourceShardId, + "error", err, + ) + return &pb.ReportWorkflowTriggerRegistrationResponse{ + Success: false, + }, nil + } + + s.logger.Infow("Successfully registered workflows", + "shardID", req.SourceShardId, + "workflowCount", len(workflowIDs), + ) + + return &pb.ReportWorkflowTriggerRegistrationResponse{ + Success: true, + }, nil +} diff --git a/pkg/workflows/shardorchestrator/service_test.go b/pkg/workflows/shardorchestrator/service_test.go new file mode 100644 index 0000000000..bbe8766cc7 --- /dev/null +++ b/pkg/workflows/shardorchestrator/service_test.go @@ -0,0 +1,118 @@ +package shardorchestrator_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb" +) + +func TestServer_GetWorkflowShardMapping(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + + t.Run("returns_mappings_for_multiple_workflows", func(t *testing.T) { + store := shardorchestrator.NewStore(lggr) + server := shardorchestrator.NewServer(store, lggr) + + // Set up some workflow mappings + mappings := []*shardorchestrator.WorkflowMappingState{ + { + WorkflowID: "wf-alpha", + OldShardID: 0, + NewShardID: 1, + TransitionState: shardorchestrator.StateSteady, + }, + { + WorkflowID: "wf-beta", + OldShardID: 0, + NewShardID: 2, + TransitionState: shardorchestrator.StateSteady, + }, + { + WorkflowID: "wf-gamma", + OldShardID: 1, + NewShardID: 0, + TransitionState: shardorchestrator.StateTransitioning, + }, + } + err := store.BatchUpdateWorkflowMappings(ctx, mappings) + require.NoError(t, err) + + // Request all three workflows + req := &pb.GetWorkflowShardMappingRequest{ + WorkflowIds: []string{"wf-alpha", "wf-beta", "wf-gamma"}, + } + + resp, err := server.GetWorkflowShardMapping(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify simple mappings + require.Len(t, resp.Mappings, 3) + require.Equal(t, uint32(1), resp.Mappings["wf-alpha"]) + require.Equal(t, uint32(2), resp.Mappings["wf-beta"]) + require.Equal(t, uint32(0), resp.Mappings["wf-gamma"]) + + // Verify detailed mapping states + require.Len(t, resp.MappingStates, 3) + + // wf-alpha: steady state + alphaState := resp.MappingStates["wf-alpha"] + require.Equal(t, uint32(0), alphaState.OldShardId) + require.Equal(t, uint32(1), alphaState.NewShardId) + require.False(t, alphaState.InTransition, "steady state should not be in transition") + + // wf-gamma: transitioning state + gammaState := resp.MappingStates["wf-gamma"] + require.Equal(t, uint32(1), gammaState.OldShardId) + require.Equal(t, uint32(0), gammaState.NewShardId) + require.True(t, gammaState.InTransition, "transitioning state should be in transition") + + // Verify version + require.Equal(t, uint64(1), resp.MappingVersion) + }) + + t.Run("rejects_empty_workflow_ids", func(t *testing.T) { + store := shardorchestrator.NewStore(lggr) + server := shardorchestrator.NewServer(store, lggr) + + req := &pb.GetWorkflowShardMappingRequest{ + WorkflowIds: []string{}, + } + + resp, err := server.GetWorkflowShardMapping(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "required") + }) + + t.Run("handles_partial_results_for_nonexistent_workflows", func(t *testing.T) { + store := shardorchestrator.NewStore(lggr) + server := shardorchestrator.NewServer(store, lggr) + + // Add one workflow + err := store.BatchUpdateWorkflowMappings(ctx, []*shardorchestrator.WorkflowMappingState{ + {WorkflowID: "exists", NewShardID: 1, TransitionState: shardorchestrator.StateSteady}, + }) + require.NoError(t, err) + + // Request one that exists and one that doesn't - batch query silently skips missing workflows + req := &pb.GetWorkflowShardMappingRequest{ + WorkflowIds: []string{"exists", "does-not-exist"}, + } + + resp, err := server.GetWorkflowShardMapping(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Only the existing workflow is returned + require.Len(t, resp.Mappings, 1) + require.Equal(t, uint32(1), resp.Mappings["exists"]) + require.NotContains(t, resp.Mappings, "does-not-exist") + }) +} diff --git a/pkg/workflows/shardorchestrator/store.go b/pkg/workflows/shardorchestrator/store.go new file mode 100644 index 0000000000..4da391babc --- /dev/null +++ b/pkg/workflows/shardorchestrator/store.go @@ -0,0 +1,235 @@ +package shardorchestrator + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// TransitionState represents the state of a workflow's shard assignment +type TransitionState uint8 + +const ( + StateSteady TransitionState = iota + StateTransitioning +) + +// String returns the string representation of the TransitionState +func (s TransitionState) String() string { + switch s { + case StateSteady: + return "steady" + case StateTransitioning: + return "transitioning" + default: + return "unknown" + } +} + +// InTransition returns true if the state is transitioning +func (s TransitionState) InTransition() bool { + return s == StateTransitioning +} + +// WorkflowMappingState represents the state of a workflow assignment +type WorkflowMappingState struct { + WorkflowID string + OldShardID uint32 + NewShardID uint32 + TransitionState TransitionState + UpdatedAt time.Time +} + +// Store manages workflow-to-shard mappings that will be exposed via gRPC +// RingOCR plugin updates this store, and the gRPC service reads from it +type Store struct { + // workflowMappings tracks the current shard assignment for each workflow + workflowMappings map[string]*WorkflowMappingState // workflow_id -> mapping state + + // shardRegistrations tracks what workflows each shard has registered + // This is populated by ReportWorkflowTriggerRegistration calls from shards + shardRegistrations map[uint32]map[string]bool // shard_id -> set of workflow_ids + + // mappingVersion increments on any change to workflowMappings + // Used by clients for cache invalidation + mappingVersion uint64 + + // lastUpdateTime tracks when mappings were last modified + lastUpdateTime time.Time + + mu sync.RWMutex + logger logger.Logger +} + +func NewStore(lggr logger.Logger) *Store { + return &Store{ + workflowMappings: make(map[string]*WorkflowMappingState), + shardRegistrations: make(map[uint32]map[string]bool), + mappingVersion: 0, + lastUpdateTime: time.Now(), + logger: logger.Named(lggr, "ShardOrchestratorStore"), + } +} + +// UpdateWorkflowMapping is called by RingOCR to update workflow assignments +// This is the primary data source for shard orchestration +func (s *Store) UpdateWorkflowMapping(ctx context.Context, workflowID string, oldShardID, newShardID uint32, state TransitionState) error { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + s.workflowMappings[workflowID] = &WorkflowMappingState{ + WorkflowID: workflowID, + OldShardID: oldShardID, + NewShardID: newShardID, + TransitionState: state, + UpdatedAt: now, + } + + s.mappingVersion++ + s.lastUpdateTime = now + + s.logger.Debugw("Updated workflow mapping", + "workflowID", workflowID, + "oldShardID", oldShardID, + "newShardID", newShardID, + "state", state.String(), + "version", s.mappingVersion, + ) + + return nil +} + +// BatchUpdateWorkflowMappings allows RingOCR to update multiple mappings atomically +func (s *Store) BatchUpdateWorkflowMappings(ctx context.Context, mappings []*WorkflowMappingState) error { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + for _, mapping := range mappings { + s.workflowMappings[mapping.WorkflowID] = &WorkflowMappingState{ + WorkflowID: mapping.WorkflowID, + OldShardID: mapping.OldShardID, + NewShardID: mapping.NewShardID, + TransitionState: mapping.TransitionState, + UpdatedAt: now, + } + } + + s.mappingVersion++ + s.lastUpdateTime = now + + s.logger.Debugw("Batch updated workflow mappings", "count", len(mappings), "version", s.mappingVersion) + return nil +} + +// GetWorkflowMapping retrieves the shard assignment for a specific workflow +// This is called by the gRPC service to respond to GetWorkflowShardMapping requests +func (s *Store) GetWorkflowMapping(ctx context.Context, workflowID string) (*WorkflowMappingState, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + mapping, ok := s.workflowMappings[workflowID] + if !ok { + return nil, fmt.Errorf("workflow %s not found in shard mappings", workflowID) + } + + // Return a copy to avoid external mutations + return &WorkflowMappingState{ + WorkflowID: mapping.WorkflowID, + OldShardID: mapping.OldShardID, + NewShardID: mapping.NewShardID, + TransitionState: mapping.TransitionState, + UpdatedAt: mapping.UpdatedAt, + }, nil +} + +// GetAllWorkflowMappings returns all current workflow-to-shard assignments +func (s *Store) GetAllWorkflowMappings(ctx context.Context) ([]*WorkflowMappingState, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + mappings := make([]*WorkflowMappingState, 0, len(s.workflowMappings)) + for _, mapping := range s.workflowMappings { + mappings = append(mappings, &WorkflowMappingState{ + WorkflowID: mapping.WorkflowID, + OldShardID: mapping.OldShardID, + NewShardID: mapping.NewShardID, + TransitionState: mapping.TransitionState, + UpdatedAt: mapping.UpdatedAt, + }) + } + + return mappings, nil +} + +// ReportShardRegistration is called when a shard reports its registered workflows +// This helps track which workflows each shard has successfully loaded +func (s *Store) ReportShardRegistration(ctx context.Context, shardID uint32, workflowIDs []string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Clear and update + s.shardRegistrations[shardID] = make(map[string]bool) + for _, wfID := range workflowIDs { + s.shardRegistrations[shardID][wfID] = true + } + + s.logger.Debugw("Updated shard registrations", + "shardID", shardID, + "workflowCount", len(workflowIDs), + ) + + return nil +} + +// GetShardRegistrations returns the workflows registered on a specific shard +func (s *Store) GetShardRegistrations(ctx context.Context, shardID uint32) ([]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + workflows, ok := s.shardRegistrations[shardID] + if !ok { + return []string{}, nil + } + + result := make([]string, 0, len(workflows)) + for wfID := range workflows { + result = append(result, wfID) + } + + return result, nil +} + +// GetWorkflowMappingsBatch retrieves mappings for multiple workflows +func (s *Store) GetWorkflowMappingsBatch(ctx context.Context, workflowIDs []string) (map[string]*WorkflowMappingState, uint64, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + result := make(map[string]*WorkflowMappingState, len(workflowIDs)) + + for _, workflowID := range workflowIDs { + if mapping, ok := s.workflowMappings[workflowID]; ok { + // Return a copy to avoid external mutations + result[workflowID] = &WorkflowMappingState{ + WorkflowID: mapping.WorkflowID, + OldShardID: mapping.OldShardID, + NewShardID: mapping.NewShardID, + TransitionState: mapping.TransitionState, + UpdatedAt: mapping.UpdatedAt, + } + } + } + + return result, s.mappingVersion, nil +} + +// GetMappingVersion returns the current version of the mapping set +func (s *Store) GetMappingVersion() uint64 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.mappingVersion +} diff --git a/pkg/workflows/shardorchestrator/store_test.go b/pkg/workflows/shardorchestrator/store_test.go new file mode 100644 index 0000000000..6ad2af5a04 --- /dev/null +++ b/pkg/workflows/shardorchestrator/store_test.go @@ -0,0 +1,198 @@ +package shardorchestrator_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" +) + +func TestStore_BatchUpdateAndQuery(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Create and insert multiple workflow mappings + mappings := []*shardorchestrator.WorkflowMappingState{ + { + WorkflowID: "workflow-1", + OldShardID: 0, + NewShardID: 1, + TransitionState: shardorchestrator.StateSteady, + }, + { + WorkflowID: "workflow-2", + OldShardID: 0, + NewShardID: 2, + TransitionState: shardorchestrator.StateSteady, + }, + { + WorkflowID: "workflow-3", + OldShardID: 0, + NewShardID: 1, + TransitionState: shardorchestrator.StateSteady, + }, + } + + err := store.BatchUpdateWorkflowMappings(ctx, mappings) + require.NoError(t, err) + + // Query individual workflow + mapping1, err := store.GetWorkflowMapping(ctx, "workflow-1") + require.NoError(t, err) + assert.Equal(t, uint32(1), mapping1.NewShardID) + assert.Equal(t, shardorchestrator.StateSteady, mapping1.TransitionState) + + // Query all workflows + allMappings, err := store.GetAllWorkflowMappings(ctx) + require.NoError(t, err) + assert.Len(t, allMappings, 3) + + // Query batch + batchMappings, version, err := store.GetWorkflowMappingsBatch(ctx, []string{"workflow-1", "workflow-2"}) + require.NoError(t, err) + assert.Len(t, batchMappings, 2) + assert.Equal(t, uint64(1), version) // First update +} + +func TestStore_WorkflowTransition(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Initial assignment + err := store.UpdateWorkflowMapping(ctx, "workflow-123", 0, 1, shardorchestrator.StateSteady) + require.NoError(t, err) + + mapping, err := store.GetWorkflowMapping(ctx, "workflow-123") + require.NoError(t, err) + assert.Equal(t, uint32(1), mapping.NewShardID) + assert.Equal(t, shardorchestrator.StateSteady, mapping.TransitionState) + + // Move to different shard (transitioning) + err = store.UpdateWorkflowMapping(ctx, "workflow-123", 1, 3, shardorchestrator.StateTransitioning) + require.NoError(t, err) + + mapping, err = store.GetWorkflowMapping(ctx, "workflow-123") + require.NoError(t, err) + assert.Equal(t, uint32(1), mapping.OldShardID) + assert.Equal(t, uint32(3), mapping.NewShardID) + assert.Equal(t, shardorchestrator.StateTransitioning, mapping.TransitionState) + + // Complete transition + err = store.UpdateWorkflowMapping(ctx, "workflow-123", 1, 3, shardorchestrator.StateSteady) + require.NoError(t, err) + + mapping, err = store.GetWorkflowMapping(ctx, "workflow-123") + require.NoError(t, err) + assert.Equal(t, uint32(3), mapping.NewShardID) + assert.Equal(t, shardorchestrator.StateSteady, mapping.TransitionState) +} + +func TestStore_VersionTracking(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Initial version should be 0 + assert.Equal(t, uint64(0), store.GetMappingVersion()) + + // First update increments version + err := store.UpdateWorkflowMapping(ctx, "wf-1", 0, 1, shardorchestrator.StateSteady) + require.NoError(t, err) + assert.Equal(t, uint64(1), store.GetMappingVersion()) + + // Batch update increments version + err = store.BatchUpdateWorkflowMappings(ctx, []*shardorchestrator.WorkflowMappingState{ + {WorkflowID: "wf-2", NewShardID: 2, TransitionState: shardorchestrator.StateSteady}, + }) + require.NoError(t, err) + assert.Equal(t, uint64(2), store.GetMappingVersion()) + + // Version is included in batch query response + _, version, err := store.GetWorkflowMappingsBatch(ctx, []string{"wf-1", "wf-2"}) + require.NoError(t, err) + assert.Equal(t, uint64(2), version) +} + +func TestStore_ShardRegistrations(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Shard 1 reports its workflows + err := store.ReportShardRegistration(ctx, 1, []string{"workflow-1", "workflow-3"}) + require.NoError(t, err) + + // Shard 2 reports its workflows + err = store.ReportShardRegistration(ctx, 2, []string{"workflow-2"}) + require.NoError(t, err) + + // Query shard registrations + shard1Workflows, err := store.GetShardRegistrations(ctx, 1) + require.NoError(t, err) + assert.Len(t, shard1Workflows, 2) + assert.Contains(t, shard1Workflows, "workflow-1") + assert.Contains(t, shard1Workflows, "workflow-3") + + shard2Workflows, err := store.GetShardRegistrations(ctx, 2) + require.NoError(t, err) + assert.Len(t, shard2Workflows, 1) + assert.Contains(t, shard2Workflows, "workflow-2") + + // Query non-existent shard returns empty + shard3Workflows, err := store.GetShardRegistrations(ctx, 3) + require.NoError(t, err) + assert.Empty(t, shard3Workflows) + + // Re-reporting replaces previous registrations + err = store.ReportShardRegistration(ctx, 1, []string{"workflow-1"}) + require.NoError(t, err) + + shard1Workflows, err = store.GetShardRegistrations(ctx, 1) + require.NoError(t, err) + assert.Len(t, shard1Workflows, 1) + assert.Contains(t, shard1Workflows, "workflow-1") + assert.NotContains(t, shard1Workflows, "workflow-3") // Removed +} + +func TestStore_NotFoundError(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Query non-existent workflow + _, err := store.GetWorkflowMapping(ctx, "non-existent") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestStore_BatchQueryPartialResults(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Insert only some workflows + err := store.UpdateWorkflowMapping(ctx, "exists-1", 0, 1, shardorchestrator.StateSteady) + require.NoError(t, err) + err = store.UpdateWorkflowMapping(ctx, "exists-2", 0, 2, shardorchestrator.StateSteady) + require.NoError(t, err) + + // Query mix of existing and non-existing workflows + results, _, err := store.GetWorkflowMappingsBatch(ctx, []string{ + "exists-1", + "non-existent", + "exists-2", + }) + require.NoError(t, err) + + // Should only return existing ones + assert.Len(t, results, 2) + assert.Contains(t, results, "exists-1") + assert.Contains(t, results, "exists-2") + assert.NotContains(t, results, "non-existent") +} From 4af655f669474c7f5086e83492a57a034f633cec Mon Sep 17 00:00:00 2001 From: Connor Stein Date: Thu, 8 Jan 2026 14:03:37 -0500 Subject: [PATCH 36/42] Fix keystore CLI embedding (#1761) * Fix CLI embedding * nits --- keystore/cli/cli.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/keystore/cli/cli.go b/keystore/cli/cli.go index 912301eaff..3d4c5b4c43 100644 --- a/keystore/cli/cli.go +++ b/keystore/cli/cli.go @@ -25,7 +25,7 @@ const ( func NewRootCmd() *cobra.Command { cmd := &cobra.Command{ - Use: "./keystore ", + Use: "keys", Long: ` CLI for managing keystore keys. Must specify KEYSTORE_FILE_PATH or KEYSTORE_DB_URL and KEYSTORE_PASSWORD in order to load the keystore. @@ -48,9 +48,9 @@ KEYSTORE_PASSWORD is the password used to encrypt the key material before storag Short: "CLI for managing keystore keys", SilenceUsage: true, } - cmd.PersistentFlags().String("keystore-file-path", "", "Overrides KEYSTORE_FILE_PATH environment variable") - cmd.PersistentFlags().String("keystore-db-url", "", "Overrides KEYSTORE_DB_URL environment variable") - cmd.PersistentFlags().String("keystore-password", "", "Overrides KEYSTORE_PASSWORD environment variable. Not recommended as will leave shell traces.") + cmd.PersistentFlags().String("file-path", "", "Overrides KEYSTORE_FILE_PATH environment variable") + cmd.PersistentFlags().String("db-url", "", "Overrides KEYSTORE_DB_URL environment variable") + cmd.PersistentFlags().String("password", "", "Overrides KEYSTORE_PASSWORD environment variable. Not recommended as will leave shell traces.") cmd.AddCommand(NewListCmd(), NewGetCmd(), NewCreateCmd(), NewDeleteCmd(), NewExportCmd(), NewImportCmd(), NewSetMetadataCmd(), NewSignCmd(), NewVerifyCmd(), NewEncryptCmd(), NewDecryptCmd()) return cmd @@ -336,16 +336,18 @@ func NewDecryptCmd() *cobra.Command { } func loadKeystore(ctx context.Context, cmd *cobra.Command) (ks.Keystore, error) { - root := cmd.Root() - filePath, err := root.Flags().GetString("keystore-file-path") + // Use parent command which has the persistent flags. + // This works whether keystore CLI is standalone or embedded as a subcommand. + parent := cmd.Parent() + filePath, err := parent.Flags().GetString("file-path") if err != nil { return nil, err } - dbURL, err := root.Flags().GetString("keystore-db-url") + dbURL, err := parent.Flags().GetString("db-url") if err != nil { return nil, err } - password, err := cmd.Flags().GetString("keystore-password") + password, err := parent.Flags().GetString("password") if err != nil { return nil, err } From 8b4dfe94eb9850fefa43f0d75fb918ca370e3802 Mon Sep 17 00:00:00 2001 From: Bolek <1416262+bolekk@users.noreply.github.com> Date: Fri, 9 Jan 2026 08:00:15 -0800 Subject: [PATCH 37/42] [CRE][Limits] Handle limit flip to zero and back correctly (#1762) --- pkg/settings/limits/resource.go | 24 +++++++++++++ pkg/settings/limits/resource_test.go | 53 ++++++++++++++++++++++++++++ pkg/settings/limits/updater.go | 21 +++++++---- 3 files changed, 92 insertions(+), 6 deletions(-) diff --git a/pkg/settings/limits/resource.go b/pkg/settings/limits/resource.go index 02b2d27305..04f68ea512 100644 --- a/pkg/settings/limits/resource.go +++ b/pkg/settings/limits/resource.go @@ -90,6 +90,10 @@ type resourcePoolLimiter[N Number] struct { recordDenied func(context.Context, N, ...metric.RecordOption) // optional } +func (l *resourcePoolLimiter[N]) setOnLimitUpdate(fn func(ctx context.Context)) { + l.updater.onLimitUpdate = fn +} + func (l *resourcePoolLimiter[N]) createGauges(meter metric.Meter, unit string) error { if l.key == "" { return errors.New("metrics require Key to be set") @@ -172,6 +176,14 @@ type resourcePoolUsage[N Number] struct { cancelSub func() // optional } +// onLimitUpdate is invoked when the configured limit changes. It attempts to +// wake queued waiters using the new limit. +func (u *resourcePoolUsage[N]) onLimitUpdate() { + u.mu.Lock() + defer u.mu.Unlock() + u.tryWakeWaiters() +} + func (l *resourcePoolLimiter[N]) newLimitUsage(opts ...metric.RecordOption) *resourcePoolUsage[N] { u := resourcePoolUsage[N]{ resourcePoolLimiter: l, @@ -367,6 +379,9 @@ func newUnscopedResourcePoolLimiter[N Number](defaultLimit N) *unscopedResourceP }, } l.resourcePoolUsage = l.newLimitUsage() + l.setOnLimitUpdate(func(context.Context) { + l.resourcePoolUsage.onLimitUpdate() + }) return l } @@ -489,6 +504,15 @@ func newScopedResourcePoolLimiter[N Number](scope settings.Scope, key string, de }, scope: scope, } + l.setOnLimitUpdate(func(ctx context.Context) { + tenant := l.scope.Value(ctx) + if tenant == "" { + return + } + if usage, ok := l.used.Load(tenant); ok { + usage.(*resourcePoolUsage[N]).onLimitUpdate() + } + }) return l } diff --git a/pkg/settings/limits/resource_test.go b/pkg/settings/limits/resource_test.go index f5aed65444..ac568aa4ec 100644 --- a/pkg/settings/limits/resource_test.go +++ b/pkg/settings/limits/resource_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "sync/atomic" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/attribute" @@ -453,6 +455,57 @@ func TestResourcePoolLimiter_BasicUsage(t *testing.T) { assert.Equal(t, 5, avail) } +// TestResourcePoolLimiter_LimitFlapToZeroDoesNotDeadlock verifies that a waiter +// is woken up when the limit is reduced to zero and then increased again. +func TestResourcePoolLimiter_LimitFlapToZeroDoesNotDeadlock(t *testing.T) { + t.Parallel() + + var limit atomic.Int64 + limit.Store(1) + + limiter := newUnscopedResourcePoolLimiter(1) + limiter.getLimitFn = func(context.Context) (int, error) { + return int(limit.Load()), nil + } + go limiter.updateLoop(contexts.CRE{}) + t.Cleanup(func() { assert.NoError(t, limiter.Close()) }) + + ctx := t.Context() + + // Consume the single available resource to force the next waiter to enqueue. + freeFirst, err := limiter.Wait(ctx, 1) + require.NoError(t, err) + + enqueued := make(chan struct{}, 1) + limiter.resourcePoolUsage.setOnEnqueue(func() { enqueued <- struct{}{} }) + + waitErr := make(chan error, 1) + go func() { + _, err := limiter.Wait(t.Context(), 1) + waitErr <- err + }() + + // Ensure the waiter is queued before mutating the limit. + <-enqueued + + // Drop the limit to zero, then free the first resource. The queued waiter + // remains blocked because tryWakeWaiters sees a zero limit. + limit.Store(0) + freeFirst() + + // Raise the limit again; the queued waiter should be woken by the update. + limit.Store(1) + + select { + case err := <-waitErr: + require.NoError(t, err) + // release to avoid affecting subsequent waits + _ = limiter.Free(ctx, 1) + case <-time.After(pollPeriod * 3): + t.Fatal("waiter did not return after limit flap") + } +} + // setOnEnqueue sets a callback that is invoked each time a waiter is added to the queue. // The callback is called with the mutex held. Used for testing to synchronize without sleeps. func (u *resourcePoolUsage[N]) setOnEnqueue(fn func()) { diff --git a/pkg/settings/limits/updater.go b/pkg/settings/limits/updater.go index 3d2b9971a7..b1ac35343b 100644 --- a/pkg/settings/limits/updater.go +++ b/pkg/settings/limits/updater.go @@ -15,10 +15,11 @@ import ( // updater monitors limit updates via subscriptions or polling and reports them via recordLimit. // If an updateLoop goroutine is spawned, then Close must be called. type updater[N any] struct { - lggr logger.Logger - getLimitFn func(context.Context) (N, error) - subFn func(ctx context.Context) (<-chan settings.Update[N], func()) // optional - recordLimit func(context.Context, N) + lggr logger.Logger + getLimitFn func(context.Context) (N, error) + subFn func(ctx context.Context) (<-chan settings.Update[N], func()) // optional + recordLimit func(context.Context, N) + onLimitUpdate func(context.Context) creCh chan struct{} cre atomic.Value @@ -99,13 +100,21 @@ func (u *updater[N]) updateLoop(cre contexts.CRE) { if err != nil { u.lggr.Errorw("Failed to get limit. Using default value", "default", limit, "err", err) } - u.recordLimit(contexts.WithCRE(ctx, cre), limit) + rcCtx := contexts.WithCRE(ctx, cre) + u.recordLimit(rcCtx, limit) + if u.onLimitUpdate != nil { + u.onLimitUpdate(rcCtx) + } case update := <-updates: if update.Err != nil { u.lggr.Errorw("Failed to update limit. Using default value", "default", update.Value, "err", update.Err) } - u.recordLimit(contexts.WithCRE(ctx, cre), update.Value) + rcCtx := contexts.WithCRE(ctx, cre) + u.recordLimit(rcCtx, update.Value) + if u.onLimitUpdate != nil { + u.onLimitUpdate(rcCtx) + } case <-u.creCh: cre = u.cre.Load().(contexts.CRE) From d01ca26cd06fb8ec6a1beb5aaf000419f9a58133 Mon Sep 17 00:00:00 2001 From: pavel-raykov <165708424+pavel-raykov@users.noreply.github.com> Date: Fri, 9 Jan 2026 21:32:10 +0100 Subject: [PATCH 38/42] [ARCH-327] Address security comments. 2 (#1760) * Minor. * Minor. * Minor. * Minor. * Minor. * Minor. * Minor. --------- Co-authored-by: Connor Stein --- keystore/file.go | 4 +- keystore/go.mod | 1 - keystore/go.sum | 2 - keystore/internal/atomicfile/write.go | 71 ++++++++++++++++++++++ keystore/internal/atomicfile/write_test.go | 24 ++++++++ keystore/internal/raw_test.go | 2 + 6 files changed, 99 insertions(+), 5 deletions(-) create mode 100644 keystore/internal/atomicfile/write.go create mode 100644 keystore/internal/atomicfile/write_test.go diff --git a/keystore/file.go b/keystore/file.go index 3d8cdf82b2..9af0e99ff3 100644 --- a/keystore/file.go +++ b/keystore/file.go @@ -5,7 +5,7 @@ import ( "context" "os" - "github.com/natefinch/atomic" + "github.com/smartcontractkit/chainlink-common/keystore/internal/atomicfile" ) var _ Storage = &FileStorage{} @@ -26,5 +26,5 @@ func (f *FileStorage) GetEncryptedKeystore(ctx context.Context) ([]byte, error) } func (f *FileStorage) PutEncryptedKeystore(ctx context.Context, encryptedKeystore []byte) error { - return atomic.WriteFile(f.name, bytes.NewReader(encryptedKeystore)) + return atomicfile.WriteFile(f.name, bytes.NewReader(encryptedKeystore), 0600) } diff --git a/keystore/go.mod b/keystore/go.mod index 584165782b..e8cf4168bb 100644 --- a/keystore/go.mod +++ b/keystore/go.mod @@ -6,7 +6,6 @@ require ( github.com/ethereum/go-ethereum v1.16.2 github.com/jmoiron/sqlx v1.4.0 github.com/lib/pq v1.10.9 - github.com/natefinch/atomic v1.0.1 github.com/smartcontractkit/chainlink-common v0.9.6-0.20251107154219-ec6d8370ebbf github.com/smartcontractkit/libocr v0.0.0-20250912173940-f3ab0246e23d github.com/spf13/cobra v1.8.1 diff --git a/keystore/go.sum b/keystore/go.sum index a191c9fab8..157ad9061c 100644 --- a/keystore/go.sum +++ b/keystore/go.sum @@ -235,8 +235,6 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= -github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= diff --git a/keystore/internal/atomicfile/write.go b/keystore/internal/atomicfile/write.go new file mode 100644 index 0000000000..944623c0b2 --- /dev/null +++ b/keystore/internal/atomicfile/write.go @@ -0,0 +1,71 @@ +package atomicfile + +import ( + "fmt" + "io" + "os" + "path/filepath" +) + +// WriteFile atomically writes the contents of r to the specified filepath with the given mode. +// This is a copy of https://github.com/natefinch/atomic/blob/master/atomic.go with minor modifications allowing +// to set mode of the written file. If the file already exists, its mode is preserved. +func WriteFile(filename string, r io.Reader, mode os.FileMode) (err error) { + // write to a temp file first, then we'll atomically replace the target file + // with the temp file. + dir, file := filepath.Split(filename) + if dir == "" { + dir = "." + } + + f, err := os.CreateTemp(dir, file) + if err != nil { + return fmt.Errorf("cannot create temp file: %w", err) + } + defer func() { + if err != nil { + // Don't leave the temp file lying around on error. + _ = os.Remove(f.Name()) // yes, ignore the error, not much we can do about it. + } + }() + // ensure we always close f. Note that this does not conflict with the close below, as close is idempotent while + // it returns an error for repeating close operations. + defer f.Close() //nolint:errcheck + name := f.Name() + if _, err = io.Copy(f, r); err != nil { + return fmt.Errorf("cannot write data to tempfile %q: %w", name, err) + } + // fsync is important, otherwise os.Rename could rename a zero-length file + if err = f.Sync(); err != nil { + return fmt.Errorf("can't flush tempfile %q: %w", name, err) + } + if err = f.Close(); err != nil { + return fmt.Errorf("can't close tempfile %q: %w", name, err) + } + + // get the file mode from the original file and use that for the replacement file, too. + destInfo, err := os.Stat(filename) + if os.IsNotExist(err) { + // no original file + if err = os.Chmod(name, mode); err != nil { + return fmt.Errorf("can't set filemode on tempfile %q: %w", name, err) + } + } else if err != nil { + return err + } else { + sourceInfo, err := os.Stat(name) + if err != nil { + return err + } + + if sourceInfo.Mode() != destInfo.Mode() { + if err = os.Chmod(name, destInfo.Mode()); err != nil { + return fmt.Errorf("can't set filemode on tempfile %q: %w", name, err) + } + } + } + if err := os.Rename(name, filename); err != nil { + return fmt.Errorf("cannot replace %q with tempfile %q: %w", filename, name, err) + } + return nil +} diff --git a/keystore/internal/atomicfile/write_test.go b/keystore/internal/atomicfile/write_test.go new file mode 100644 index 0000000000..d4677676c3 --- /dev/null +++ b/keystore/internal/atomicfile/write_test.go @@ -0,0 +1,24 @@ +package atomicfile + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWriteFile_WriteAndRead(t *testing.T) { + mode := os.FileMode(0600) + path := filepath.Join(t.TempDir(), "out.txt") + data := []byte("test") + err := WriteFile(path, bytes.NewReader(data), mode) + require.NoError(t, err) + readData, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, readData, data) + info, err := os.Stat(path) + require.NoError(t, err) + require.Equal(t, mode, info.Mode()) +} diff --git a/keystore/internal/raw_test.go b/keystore/internal/raw_test.go index 127ffcc533..f930ec4358 100644 --- a/keystore/internal/raw_test.go +++ b/keystore/internal/raw_test.go @@ -23,6 +23,8 @@ func TestRaw_nonprintable(t *testing.T) { assert.Equal(t, exp, fmt.Sprintf("%#v", r)) + assert.Equal(t, exp, fmt.Sprintf("%x", r)) + assert.Equal(t, exp, fmt.Sprintf("%s", r)) //nolint:gosimple // S1025 deliberately testing formatting verbs got, err := json.Marshal(r) //nolint:staticcheck // SA9005 deliberately testing marshalling From 2194556035f86403a7f26befcaae878daa505896 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Tue, 13 Jan 2026 03:55:20 -0500 Subject: [PATCH 39/42] pkg/contexts: don't change case of org IDs (#1766) --- pkg/contexts/contexts.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/contexts/contexts.go b/pkg/contexts/contexts.go index b414a92e9b..8898dcf5a9 100644 --- a/pkg/contexts/contexts.go +++ b/pkg/contexts/contexts.go @@ -39,8 +39,7 @@ type CRE struct { // Normalized returns a possibly modified CRE with normalized values. func (c CRE) Normalized() CRE { c.Org = strings.TrimPrefix(c.Org, "org_") - c.Org = strings.TrimPrefix(c.Org, "0x") - c.Org = strings.ToLower(c.Org) + // not hex like the others, so don't look for 0x or change case c.Owner = strings.TrimPrefix(c.Owner, "owner_") c.Owner = strings.TrimPrefix(c.Owner, "0x") From 5539000a4c1a853ea29d91a418a609c45e0147ae Mon Sep 17 00:00:00 2001 From: cawthorne Date: Tue, 13 Jan 2026 12:46:10 +0000 Subject: [PATCH 40/42] feat: add authorization capability tests --- .../v2/triggers/streams/server/authorized_capability_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go b/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go index 437566b6f0..a54ed4d733 100644 --- a/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go +++ b/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go @@ -260,4 +260,3 @@ func (m *mockStreamsCapability) Ready() error { func (m *mockStreamsCapability) Initialise(ctx context.Context, deps core.StandardCapabilitiesDependencies) error { return nil } - From 3a71f495b907eebc65bb71c2f2107f812c72296e Mon Sep 17 00:00:00 2001 From: cawthorne Date: Tue, 13 Jan 2026 12:51:40 +0000 Subject: [PATCH 41/42] fix: update to caperrors.Error return type after rebase, remove authorization (handled by CRE per-capability limits) --- .../v2/triggers/streams/authorization.go | 181 -------- .../v2/triggers/streams/authorization_test.go | 405 ------------------ .../server/authorized_capability_test.go | 262 ----------- .../streams/server/authorized_server.go | 104 ----- .../streams/server/trigger_server_gen.go | 3 +- .../v2/triggers/streams/streams_test.go | 7 +- 6 files changed, 6 insertions(+), 956 deletions(-) delete mode 100644 pkg/capabilities/v2/triggers/streams/authorization.go delete mode 100644 pkg/capabilities/v2/triggers/streams/authorization_test.go delete mode 100644 pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go delete mode 100644 pkg/capabilities/v2/triggers/streams/server/authorized_server.go diff --git a/pkg/capabilities/v2/triggers/streams/authorization.go b/pkg/capabilities/v2/triggers/streams/authorization.go deleted file mode 100644 index 2d992c6203..0000000000 --- a/pkg/capabilities/v2/triggers/streams/authorization.go +++ /dev/null @@ -1,181 +0,0 @@ -package streams - -import ( - "fmt" - "regexp" - - "github.com/smartcontractkit/chainlink-common/pkg/capabilities" -) - -// Authorizer handles workflow authorization for streams trigger -// Ensures only authorized workflows (e.g., Data Feeds) can use the trigger -type Authorizer struct { - allowedWorkflowIDs map[string]bool - allowedWorkflowPattern *regexp.Regexp - allowedWorkflowOwners map[string]bool - allowedWorkflowNamePattern *regexp.Regexp - enabled bool -} - -// AuthConfig configures authorization rules for the streams trigger -type AuthConfig struct { - // Enable authorization checks (set to false to disable authorization) - Enabled bool - - // AllowedWorkflowIDs is an explicit allowlist of workflow IDs - AllowedWorkflowIDs []string - - // AllowedWorkflowPattern is a regex pattern for allowed workflow IDs - // Example: "^df-.*" allows all workflows starting with "df-" - AllowedWorkflowPattern string - - // AllowedWorkflowOwners is an explicit allowlist of workflow owner addresses - // Example: ["0xDFOwner1", "0xDFOwner2"] - AllowedWorkflowOwners []string - - // AllowedWorkflowNamePattern is a regex pattern for allowed workflow names - // Example: "^data-feed-.*" for workflow names starting with "data-feed-" - AllowedWorkflowNamePattern string -} - -// NewAuthorizer creates a new authorizer with the given configuration -func NewAuthorizer(config AuthConfig) (*Authorizer, error) { - auth := &Authorizer{ - enabled: config.Enabled, - allowedWorkflowIDs: make(map[string]bool), - allowedWorkflowOwners: make(map[string]bool), - } - - // If authorization is disabled, return early - if !config.Enabled { - return auth, nil - } - - // Build workflow ID allowlist map for O(1) lookups - for _, id := range config.AllowedWorkflowIDs { - auth.allowedWorkflowIDs[id] = true - } - - // Build workflow owner allowlist map for O(1) lookups - for _, owner := range config.AllowedWorkflowOwners { - auth.allowedWorkflowOwners[owner] = true - } - - // Compile workflow ID pattern if provided - if config.AllowedWorkflowPattern != "" { - pattern, err := regexp.Compile(config.AllowedWorkflowPattern) - if err != nil { - return nil, fmt.Errorf("invalid workflow ID pattern '%s': %w", config.AllowedWorkflowPattern, err) - } - auth.allowedWorkflowPattern = pattern - } - - // Compile workflow name pattern if provided - if config.AllowedWorkflowNamePattern != "" { - pattern, err := regexp.Compile(config.AllowedWorkflowNamePattern) - if err != nil { - return nil, fmt.Errorf("invalid workflow name pattern '%s': %w", config.AllowedWorkflowNamePattern, err) - } - auth.allowedWorkflowNamePattern = pattern - } - - return auth, nil -} - -// NewDefaultDataFeedsAuthorizer creates an authorizer for Data Feeds workflows -// This is a convenience function for the common case -// Allows workflows with IDs starting with "df-" or names containing "data-feed" -func NewDefaultDataFeedsAuthorizer() (*Authorizer, error) { - return NewAuthorizer(AuthConfig{ - Enabled: true, - AllowedWorkflowPattern: "^df-.*", // Allow workflow IDs starting with "df-" - AllowedWorkflowNamePattern: "data-feed", // Allow workflow names containing "data-feed" - }) -} - -// IsAuthorized checks if a workflow is authorized to use the streams trigger -// Returns nil if authorized, error otherwise -// Authorization checks (in order): -// 1. Explicit workflow ID allowlist -// 2. Workflow ID pattern matching -// 3. Workflow owner address allowlist -// 4. Workflow name pattern matching -// If ANY check passes, the workflow is authorized -func (a *Authorizer) IsAuthorized(metadata capabilities.RequestMetadata) error { - // If authorization is disabled, allow all - if !a.enabled { - return nil - } - - workflowID := metadata.WorkflowID - workflowName := metadata.WorkflowName - if workflowName == "" { - workflowName = metadata.DecodedWorkflowName - } - workflowOwner := metadata.WorkflowOwner - - // If no checks configured, deny by default - if len(a.allowedWorkflowIDs) == 0 && a.allowedWorkflowPattern == nil && - len(a.allowedWorkflowOwners) == 0 && a.allowedWorkflowNamePattern == nil { - return fmt.Errorf("workflow %s: no authorization checks configured, denying by default", workflowID) - } - - // Check 1: Explicit workflow ID allowlist - if len(a.allowedWorkflowIDs) > 0 { - if a.allowedWorkflowIDs[workflowID] { - return nil // Authorized - } - } - - // Check 2: Workflow ID pattern matching - if a.allowedWorkflowPattern != nil { - if a.allowedWorkflowPattern.MatchString(workflowID) { - return nil // Authorized - } - } - - // Check 3: Workflow owner allowlist - if len(a.allowedWorkflowOwners) > 0 && workflowOwner != "" { - if a.allowedWorkflowOwners[workflowOwner] { - return nil // Authorized - } - } - - // Check 4: Workflow name pattern matching - if a.allowedWorkflowNamePattern != nil && workflowName != "" { - if a.allowedWorkflowNamePattern.MatchString(workflowName) { - return nil // Authorized - } - } - - // None of the checks passed - return fmt.Errorf("workflow %s (name: %s, owner: %s) not authorized", workflowID, workflowName, workflowOwner) -} - -// String returns a human-readable description of the authorization rules -func (a *Authorizer) String() string { - if !a.enabled { - return "Authorization: Disabled (all workflows allowed)" - } - - desc := "Authorization: Enabled\n" - - if len(a.allowedWorkflowIDs) > 0 { - desc += fmt.Sprintf(" - Workflow ID allowlist: %d entries\n", len(a.allowedWorkflowIDs)) - } - - if a.allowedWorkflowPattern != nil { - desc += fmt.Sprintf(" - Workflow ID pattern: %s\n", a.allowedWorkflowPattern.String()) - } - - if len(a.allowedWorkflowOwners) > 0 { - desc += fmt.Sprintf(" - Workflow owner allowlist: %d entries\n", len(a.allowedWorkflowOwners)) - } - - if a.allowedWorkflowNamePattern != nil { - desc += fmt.Sprintf(" - Workflow name pattern: %s\n", a.allowedWorkflowNamePattern.String()) - } - - return desc -} - diff --git a/pkg/capabilities/v2/triggers/streams/authorization_test.go b/pkg/capabilities/v2/triggers/streams/authorization_test.go deleted file mode 100644 index 9c5592efcb..0000000000 --- a/pkg/capabilities/v2/triggers/streams/authorization_test.go +++ /dev/null @@ -1,405 +0,0 @@ -package streams_test - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/smartcontractkit/chainlink-common/pkg/capabilities" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams" -) - -func TestAuthorizerDisabled(t *testing.T) { - config := streams.AuthConfig{ - Enabled: false, - } - - auth, err := streams.NewAuthorizer(config) - require.NoError(t, err) - - // Any workflow should be allowed when disabled - metadata := capabilities.RequestMetadata{ - WorkflowID: "any-workflow", - WorkflowName: "anything", - WorkflowOwner: "0xAnyOwner", - } - - err = auth.IsAuthorized(metadata) - assert.NoError(t, err, "Should allow all workflows when authorization is disabled") -} - -func TestAuthorizerWorkflowIDAllowlist(t *testing.T) { - config := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowIDs: []string{ - "df-prod-1", - "df-prod-2", - "df-staging-1", - }, - } - - auth, err := streams.NewAuthorizer(config) - require.NoError(t, err) - - tests := []struct { - name string - workflowID string - expectError bool - }{ - {"workflow in allowlist 1", "df-prod-1", false}, - {"workflow in allowlist 2", "df-prod-2", false}, - {"workflow in allowlist 3", "df-staging-1", false}, - {"workflow not in allowlist", "other-workflow", true}, - {"workflow similar but not exact", "df-prod-10", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - metadata := capabilities.RequestMetadata{ - WorkflowID: tt.workflowID, - } - err := auth.IsAuthorized(metadata) - if tt.expectError { - assert.Error(t, err) - assert.Contains(t, err.Error(), "not authorized") - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestAuthorizerWorkflowIDPattern(t *testing.T) { - config := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowPattern: "^df-.*-mainnet$", - } - - auth, err := streams.NewAuthorizer(config) - require.NoError(t, err) - - tests := []struct { - name string - workflowID string - expectError bool - }{ - {"matches pattern 1", "df-btc-mainnet", false}, - {"matches pattern 2", "df-eth-mainnet", false}, - {"matches pattern 3", "df-link-usd-mainnet", false}, - {"doesn't match - no prefix", "other-mainnet", true}, - {"doesn't match - no suffix", "df-btc", true}, - {"doesn't match - wrong suffix", "df-btc-testnet", true}, - {"doesn't match - completely different", "malicious-workflow", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - metadata := capabilities.RequestMetadata{ - WorkflowID: tt.workflowID, - } - err := auth.IsAuthorized(metadata) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestAuthorizerWorkflowOwner(t *testing.T) { - config := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowOwners: []string{ - "0xDFOwner1", - "0xDFOwner2", - }, - } - - auth, err := streams.NewAuthorizer(config) - require.NoError(t, err) - - tests := []struct { - name string - workflowOwner string - expectError bool - }{ - {"owner in allowlist 1", "0xDFOwner1", false}, - {"owner in allowlist 2", "0xDFOwner2", false}, - {"owner not in allowlist", "0xOtherOwner", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - metadata := capabilities.RequestMetadata{ - WorkflowID: "some-workflow", - WorkflowOwner: tt.workflowOwner, - } - err := auth.IsAuthorized(metadata) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestAuthorizerWorkflowNamePattern(t *testing.T) { - config := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowNamePattern: "data-feed", - } - - auth, err := streams.NewAuthorizer(config) - require.NoError(t, err) - - tests := []struct { - name string - workflowName string - expectError bool - }{ - {"matches pattern 1", "data-feed-btc-usd", false}, - {"matches pattern 2", "mainnet-data-feed", false}, - {"matches pattern 3", "data-feed", false}, - {"doesn't match", "other-workflow", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - metadata := capabilities.RequestMetadata{ - WorkflowID: "some-id", - WorkflowName: tt.workflowName, - } - err := auth.IsAuthorized(metadata) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestAuthorizerInvalidPattern(t *testing.T) { - config := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowPattern: "[invalid(regex", - } - - _, err := streams.NewAuthorizer(config) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid workflow ID pattern") -} - -func TestAuthorizerInvalidNamePattern(t *testing.T) { - config := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowNamePattern: "[invalid(regex", - } - - _, err := streams.NewAuthorizer(config) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid workflow name pattern") -} - -func TestAuthorizerCombinedChecksAnyMatch(t *testing.T) { - // If ANY check passes, workflow is authorized - config := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowPattern: "^df-.*", - AllowedWorkflowOwners: []string{"0xDFOwner"}, - } - - auth, err := streams.NewAuthorizer(config) - require.NoError(t, err) - - tests := []struct { - name string - workflowID string - workflowOwner string - expectError bool - }{ - { - name: "matches ID pattern", - workflowID: "df-prod-1", - workflowOwner: "0xOther", - expectError: false, // Passes ID pattern check - }, - { - name: "matches owner", - workflowID: "other-workflow", - workflowOwner: "0xDFOwner", - expectError: false, // Passes owner check - }, - { - name: "matches both", - workflowID: "df-prod-1", - workflowOwner: "0xDFOwner", - expectError: false, // Passes both checks - }, - { - name: "matches neither", - workflowID: "other-workflow", - workflowOwner: "0xOther", - expectError: true, // Fails both checks - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - metadata := capabilities.RequestMetadata{ - WorkflowID: tt.workflowID, - WorkflowOwner: tt.workflowOwner, - } - err := auth.IsAuthorized(metadata) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestAuthorizerNoChecksConfigured(t *testing.T) { - // If no checks configured, deny by default - config := streams.AuthConfig{ - Enabled: true, - // No checks configured - } - - auth, err := streams.NewAuthorizer(config) - require.NoError(t, err) - - metadata := capabilities.RequestMetadata{ - WorkflowID: "any-workflow", - } - - err = auth.IsAuthorized(metadata) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no authorization checks configured") -} - -func TestDefaultDataFeedsAuthorizer(t *testing.T) { - auth, err := streams.NewDefaultDataFeedsAuthorizer() - require.NoError(t, err) - - tests := []struct { - name string - workflowID string - workflowName string - expectError bool - }{ - { - name: "DF workflow ID", - workflowID: "df-btc-usd", - workflowName: "", - expectError: false, - }, - { - name: "DF workflow name", - workflowID: "other-id", - workflowName: "data-feed-eth-usd", - expectError: false, - }, - { - name: "Neither matches", - workflowID: "other-id", - workflowName: "other-name", - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - metadata := capabilities.RequestMetadata{ - WorkflowID: tt.workflowID, - WorkflowName: tt.workflowName, - } - err := auth.IsAuthorized(metadata) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestAuthorizerString(t *testing.T) { - // Test disabled - auth1, _ := streams.NewAuthorizer(streams.AuthConfig{Enabled: false}) - str := auth1.String() - assert.Contains(t, str, "Disabled") - - // Test with ID allowlist - auth2, _ := streams.NewAuthorizer(streams.AuthConfig{ - Enabled: true, - AllowedWorkflowIDs: []string{"id1", "id2"}, - }) - str = auth2.String() - assert.Contains(t, str, "Enabled") - assert.Contains(t, str, "Workflow ID allowlist") - - // Test with pattern - auth3, _ := streams.NewAuthorizer(streams.AuthConfig{ - Enabled: true, - AllowedWorkflowPattern: "^df-.*", - }) - str = auth3.String() - assert.Contains(t, str, "pattern") - - // Test with owner allowlist - auth4, _ := streams.NewAuthorizer(streams.AuthConfig{ - Enabled: true, - AllowedWorkflowOwners: []string{"0xOwner1"}, - }) - str = auth4.String() - assert.Contains(t, str, "owner") -} - -// BenchmarkAuthorizerCheck benchmarks the authorization check -func BenchmarkAuthorizerCheck(b *testing.B) { - config := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowPattern: "^df-.*", - } - - auth, _ := streams.NewAuthorizer(config) - - metadata := capabilities.RequestMetadata{ - WorkflowID: "df-prod-1", - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = auth.IsAuthorized(metadata) - } -} - -func BenchmarkAuthorizerCheckAllowlist(b *testing.B) { - // Create large allowlist - allowlist := make([]string, 1000) - for i := 0; i < 1000; i++ { - allowlist[i] = fmt.Sprintf("df-workflow-%d", i) - } - - config := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowIDs: allowlist, - } - - auth, _ := streams.NewAuthorizer(config) - - metadata := capabilities.RequestMetadata{ - WorkflowID: "df-workflow-500", // Middle of allowlist - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = auth.IsAuthorized(metadata) - } -} diff --git a/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go b/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go deleted file mode 100644 index a54ed4d733..0000000000 --- a/pkg/capabilities/v2/triggers/streams/server/authorized_capability_test.go +++ /dev/null @@ -1,262 +0,0 @@ -package server_test - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/smartcontractkit/chainlink-common/pkg/capabilities" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams/server" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/types/core" -) - -func TestAuthorizedCapabilityBlocksUnauthorizedWorkflows(t *testing.T) { - lggr, _ := logger.New() - mockCap := &mockStreamsCapability{} - - // Create authorized capability with DF authorization - authCap, err := server.NewDefaultDataFeedsCapability(mockCap, lggr) - require.NoError(t, err) - - // Test 1: Authorized DF workflow (by ID pattern) - should succeed - ch, err := authCap.RegisterTrigger( - context.Background(), - "trigger-1", - capabilities.RequestMetadata{ - WorkflowID: "df-btc-usd", - WorkflowOwner: "0xDF001", - WorkflowName: "Bitcoin Data Feed", - }, - &streams.Config{FeedIds: []string{"0x001"}}, - ) - assert.NoError(t, err) - assert.NotNil(t, ch) - - // Test 2: Authorized DF workflow (by name pattern) - should succeed - ch, err = authCap.RegisterTrigger( - context.Background(), - "trigger-2", - capabilities.RequestMetadata{ - WorkflowID: "workflow-123", - WorkflowOwner: "0xDF002", - WorkflowName: "mainnet-data-feed-eth", - }, - &streams.Config{FeedIds: []string{"0x002"}}, - ) - assert.NoError(t, err) - assert.NotNil(t, ch) - - // Test 3: Unauthorized workflow (doesn't match ID or name) - should fail with auth error - ch, err = authCap.RegisterTrigger( - context.Background(), - "trigger-3", - capabilities.RequestMetadata{ - WorkflowID: "other-workflow", - WorkflowOwner: "0xOTHER", - WorkflowName: "Other Workflow", - }, - &streams.Config{FeedIds: []string{"0x003"}}, - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "authorization failed") - assert.Nil(t, ch) -} - -func TestAuthorizedCapabilityWithCustomConfig(t *testing.T) { - lggr, _ := logger.New() - mockCap := &mockStreamsCapability{} - - // Custom authorization: specific allowlist - authConfig := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowIDs: []string{ - "df-prod-btc-usd", - "df-prod-eth-usd", - }, - } - - authCap, err := server.NewAuthorizedStreamsCapability(mockCap, authConfig, lggr) - require.NoError(t, err) - - // Test allowed workflow - ch, err := authCap.RegisterTrigger( - context.Background(), - "trigger-1", - capabilities.RequestMetadata{ - WorkflowID: "df-prod-btc-usd", - }, - &streams.Config{}, - ) - assert.NoError(t, err) - assert.NotNil(t, ch) - - // Test non-allowed workflow - ch, err = authCap.RegisterTrigger( - context.Background(), - "trigger-2", - capabilities.RequestMetadata{ - WorkflowID: "df-prod-link-usd", // Not in allowlist - }, - &streams.Config{}, - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "authorization failed") - assert.Nil(t, ch) -} - -func TestAuthorizedCapabilityWorkflowOwnerAllowlist(t *testing.T) { - lggr, _ := logger.New() - mockCap := &mockStreamsCapability{} - - // Authorization by owner address - authConfig := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowOwners: []string{ - "0xDFOwner1", - "0xDFOwner2", - }, - } - - authCap, err := server.NewAuthorizedStreamsCapability(mockCap, authConfig, lggr) - require.NoError(t, err) - - // Test allowed owner - ch, err := authCap.RegisterTrigger( - context.Background(), - "trigger-1", - capabilities.RequestMetadata{ - WorkflowID: "any-workflow-id", - WorkflowOwner: "0xDFOwner1", - }, - &streams.Config{}, - ) - assert.NoError(t, err) - assert.NotNil(t, ch) - - // Test non-allowed owner - ch, err = authCap.RegisterTrigger( - context.Background(), - "trigger-2", - capabilities.RequestMetadata{ - WorkflowID: "any-workflow-id", - WorkflowOwner: "0xOtherOwner", - }, - &streams.Config{}, - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "authorization failed") - assert.Nil(t, ch) -} - -func TestAuthorizedCapabilityDisabled(t *testing.T) { - lggr, _ := logger.New() - mockCap := &mockStreamsCapability{} - - // Disable authorization - authConfig := streams.AuthConfig{ - Enabled: false, - } - - authCap, err := server.NewAuthorizedStreamsCapability(mockCap, authConfig, lggr) - require.NoError(t, err) - - // Any workflow should be allowed - ch, err := authCap.RegisterTrigger( - context.Background(), - "trigger-1", - capabilities.RequestMetadata{ - WorkflowID: "any-workflow", - WorkflowOwner: "0xAnyone", - WorkflowName: "Anything", - }, - &streams.Config{}, - ) - assert.NoError(t, err) - assert.NotNil(t, ch) -} - -func TestAuthorizedCapabilityUnregisterAlsoChecksAuth(t *testing.T) { - lggr, _ := logger.New() - mockCap := &mockStreamsCapability{} - - // Authorization enabled - authConfig := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowPattern: "^df-.*", - } - - authCap, err := server.NewAuthorizedStreamsCapability(mockCap, authConfig, lggr) - require.NoError(t, err) - - // Test authorized unregister - err = authCap.UnregisterTrigger( - context.Background(), - "trigger-1", - capabilities.RequestMetadata{ - WorkflowID: "df-prod-btc", - }, - &streams.Config{}, - ) - assert.NoError(t, err) - - // Test unauthorized unregister - err = authCap.UnregisterTrigger( - context.Background(), - "trigger-2", - capabilities.RequestMetadata{ - WorkflowID: "other-workflow", - }, - &streams.Config{}, - ) - require.Error(t, err) - assert.Contains(t, err.Error(), "authorization failed") -} - -// Mock implementations for testing - -type mockStreamsCapability struct { - registerCalled bool -} - -func (m *mockStreamsCapability) RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], error) { - m.registerCalled = true - ch := make(chan capabilities.TriggerAndId[*streams.Feed]) - close(ch) - return ch, nil -} - -func (m *mockStreamsCapability) UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) error { - return nil -} - -func (m *mockStreamsCapability) Start(ctx context.Context) error { - return nil -} - -func (m *mockStreamsCapability) Close() error { - return nil -} - -func (m *mockStreamsCapability) HealthReport() map[string]error { - return map[string]error{} -} - -func (m *mockStreamsCapability) Name() string { - return "MockStreams" -} - -func (m *mockStreamsCapability) Description() string { - return "Mock" -} - -func (m *mockStreamsCapability) Ready() error { - return nil -} - -func (m *mockStreamsCapability) Initialise(ctx context.Context, deps core.StandardCapabilitiesDependencies) error { - return nil -} diff --git a/pkg/capabilities/v2/triggers/streams/server/authorized_server.go b/pkg/capabilities/v2/triggers/streams/server/authorized_server.go deleted file mode 100644 index 4d5ba3c4cc..0000000000 --- a/pkg/capabilities/v2/triggers/streams/server/authorized_server.go +++ /dev/null @@ -1,104 +0,0 @@ -package server - -import ( - "context" - "fmt" - - "github.com/smartcontractkit/chainlink-common/pkg/capabilities" - "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams" - "github.com/smartcontractkit/chainlink-common/pkg/logger" -) - -// AuthorizedStreamsCapability wraps StreamsCapability with authorization checks -type AuthorizedStreamsCapability struct { - StreamsCapability - authorizer *streams.Authorizer - lggr logger.Logger -} - -// NewAuthorizedStreamsCapability creates a new capability with authorization enabled -func NewAuthorizedStreamsCapability(capability StreamsCapability, authConfig streams.AuthConfig, lggr logger.Logger) (*AuthorizedStreamsCapability, error) { - authorizer, err := streams.NewAuthorizer(authConfig) - if err != nil { - return nil, fmt.Errorf("failed to create authorizer: %w", err) - } - - return &AuthorizedStreamsCapability{ - StreamsCapability: capability, - authorizer: authorizer, - lggr: logger.Named(lggr, "AuthorizedStreamsCapability"), - }, nil -} - -// NewDefaultDataFeedsCapability creates a capability with default Data Feeds authorization -// Only workflows with IDs starting with "df-" or names containing "data-feed" will be allowed -func NewDefaultDataFeedsCapability(capability StreamsCapability, lggr logger.Logger) (*AuthorizedStreamsCapability, error) { - authConfig := streams.AuthConfig{ - Enabled: true, - AllowedWorkflowPattern: "^df-.*", - AllowedWorkflowNamePattern: "data-feed", - } - - return NewAuthorizedStreamsCapability(capability, authConfig, lggr) -} - -// RegisterTrigger wraps the base RegisterTrigger with authorization check -func (a *AuthorizedStreamsCapability) RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], error) { - // Authorization check - if err := a.authorizer.IsAuthorized(metadata); err != nil { - a.lggr.Warnw("Unauthorized trigger registration attempt", - "workflowID", metadata.WorkflowID, - "workflowOwner", metadata.WorkflowOwner, - "error", err, - ) - return nil, fmt.Errorf("authorization failed: %w", err) - } - - a.lggr.Debugw("Authorized trigger registration", - "workflowID", metadata.WorkflowID, - "triggerID", triggerID, - ) - - // Call the underlying implementation - return a.StreamsCapability.RegisterTrigger(ctx, triggerID, metadata, input) -} - -// UnregisterTrigger wraps the base UnregisterTrigger with authorization check -func (a *AuthorizedStreamsCapability) UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) error { - // Authorization check - if err := a.authorizer.IsAuthorized(metadata); err != nil { - a.lggr.Warnw("Unauthorized trigger unregistration attempt", - "workflowID", metadata.WorkflowID, - "error", err, - ) - return fmt.Errorf("authorization failed: %w", err) - } - - a.lggr.Debugw("Authorized trigger unregistration", - "workflowID", metadata.WorkflowID, - "triggerID", triggerID, - ) - - // Call the underlying implementation - return a.StreamsCapability.UnregisterTrigger(ctx, triggerID, metadata, input) -} - -// NewAuthorizedStreamsServer creates a server wrapping an authorized capability -func NewAuthorizedStreamsServer(capability StreamsCapability, authConfig streams.AuthConfig, lggr logger.Logger) (*StreamsServer, error) { - authCap, err := NewAuthorizedStreamsCapability(capability, authConfig, lggr) - if err != nil { - return nil, err - } - - return NewStreamsServer(authCap), nil -} - -// NewDefaultDataFeedsServer creates a server with default Data Feeds authorization -func NewDefaultDataFeedsServer(capability StreamsCapability, lggr logger.Logger) (*StreamsServer, error) { - authCap, err := NewDefaultDataFeedsCapability(capability, lggr) - if err != nil { - return nil, err - } - - return NewStreamsServer(authCap), nil -} diff --git a/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go b/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go index a6d7490790..6b4e34d1ed 100644 --- a/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go +++ b/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go @@ -11,6 +11,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" "github.com/smartcontractkit/chainlink-common/pkg/types/core" ) @@ -18,7 +19,7 @@ import ( var _ = emptypb.Empty{} type StreamsCapability interface { - RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], error) + RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], caperrors.Error) UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) error Start(ctx context.Context) error diff --git a/pkg/capabilities/v2/triggers/streams/streams_test.go b/pkg/capabilities/v2/triggers/streams/streams_test.go index 718b4632d0..9aaafce06b 100644 --- a/pkg/capabilities/v2/triggers/streams/streams_test.go +++ b/pkg/capabilities/v2/triggers/streams/streams_test.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/libocr/ragep2p/types" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + caperrors "github.com/smartcontractkit/chainlink-common/pkg/capabilities/errors" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streams/server" "github.com/smartcontractkit/chainlink-common/pkg/types/core" @@ -93,7 +94,7 @@ type mockStreamsCapability struct { closeCalled bool } -func (m *mockStreamsCapability) RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], error) { +func (m *mockStreamsCapability) RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], caperrors.Error) { m.registerCalled = true ch := make(chan capabilities.TriggerAndId[*streams.Feed], 1) return ch, nil @@ -186,8 +187,8 @@ func TestTriggerRegistration(t *testing.T) { assert.True(t, mock.registerCalled) // Test unregister - err = mock.UnregisterTrigger(ctx, triggerID, metadata, config) - assert.NoError(t, err) + unregErr := mock.UnregisterTrigger(ctx, triggerID, metadata, config) + assert.NoError(t, unregErr) assert.True(t, mock.unregisterCalled) } From 4ca4ef9b5d0af84b67a3d876546f5e7473731f77 Mon Sep 17 00:00:00 2001 From: cawthorne Date: Tue, 13 Jan 2026 16:08:47 +0000 Subject: [PATCH 42/42] chore: regenerate streams trigger protos after simplification - Update capability ID to streams-trigger@2.0.0 - Regenerate after proto structure simplification - Remove authorization code (handled by CRE per-capability limits) - Reduces generated code size by ~195 lines --- .../streams/server/trigger_server_gen.go | 25 +- .../v2/triggers/streams/streams_test.go | 170 ++++---- .../v2/triggers/streams/trigger.pb.go | 402 +++++------------- 3 files changed, 201 insertions(+), 396 deletions(-) diff --git a/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go b/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go index 6b4e34d1ed..0d7ff2058d 100644 --- a/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go +++ b/pkg/capabilities/v2/triggers/streams/server/trigger_server_gen.go @@ -19,8 +19,8 @@ import ( var _ = emptypb.Empty{} type StreamsCapability interface { - RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], caperrors.Error) - UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) error + RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Report], caperrors.Error) + UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) caperrors.Error Start(ctx context.Context) error Close() error @@ -55,7 +55,7 @@ func (c *StreamsServer) Initialise(ctx context.Context, dependencies core.Standa if err := dependencies.CapabilityRegistry.Add(ctx, &streamsCapability{ StreamsCapability: c.StreamsCapability, }); err != nil { - return fmt.Errorf("error when adding %s to the registry: %w", "streams-trigger@1.0.0", err) + return fmt.Errorf("error when adding %s to the registry: %w", "streams-trigger@2.0.0", err) } return nil @@ -66,7 +66,7 @@ func (c *StreamsServer) Close() error { defer cancel() if c.capabilityRegistry != nil { - if err := c.capabilityRegistry.Remove(ctx, "streams-trigger@1.0.0"); err != nil { + if err := c.capabilityRegistry.Remove(ctx, "streams-trigger@2.0.0"); err != nil { return err } } @@ -93,21 +93,18 @@ type streamsCapability struct { func (c *streamsCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) { // Maybe we do need to split it out, even if the user doesn't see it - return capabilities.NewCapabilityInfo("streams-trigger@1.0.0", capabilities.CapabilityTypeCombined, c.StreamsCapability.Description()) + return capabilities.NewCapabilityInfo("streams-trigger@2.0.0", capabilities.CapabilityTypeCombined, c.StreamsCapability.Description()) } var _ capabilities.ExecutableAndTriggerCapability = (*streamsCapability)(nil) -const StreamsID = "streams-trigger@1.0.0" +const StreamsID = "streams-trigger@2.0.0" func (c *streamsCapability) RegisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) (<-chan capabilities.TriggerResponse, error) { switch request.Method { case "Trigger": input := &streams.Config{} - return capabilities.RegisterTrigger(ctx, c.stopCh, "streams-trigger@1.0.0", request, input, c.StreamsCapability.RegisterTrigger) - case "": - input := &streams.Config{} - return capabilities.RegisterTrigger(ctx, c.stopCh, "streams-trigger@1.0.0", request, input, c.StreamsCapability.RegisterTrigger) + return capabilities.RegisterTrigger(ctx, c.stopCh, "streams-trigger@2.0.0", request, input, c.StreamsCapability.RegisterTrigger) default: return nil, fmt.Errorf("trigger %s not found", request.Method) } @@ -122,13 +119,6 @@ func (c *streamsCapability) UnregisterTrigger(ctx context.Context, request capab return err } return c.StreamsCapability.UnregisterTrigger(ctx, request.TriggerID, request.Metadata, input) - case "": - input := &streams.Config{} - _, err := capabilities.FromValueOrAny(request.Config, request.Payload, input) - if err != nil { - return err - } - return c.StreamsCapability.UnregisterTrigger(ctx, request.TriggerID, request.Metadata, input) default: return fmt.Errorf("method %s not found", request.Method) } @@ -145,4 +135,3 @@ func (c *streamsCapability) UnregisterFromWorkflow(ctx context.Context, request func (c *streamsCapability) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { return capabilities.CapabilityResponse{}, fmt.Errorf("method %s not found", request.Method) } - diff --git a/pkg/capabilities/v2/triggers/streams/streams_test.go b/pkg/capabilities/v2/triggers/streams/streams_test.go index 9aaafce06b..787b19ff28 100644 --- a/pkg/capabilities/v2/triggers/streams/streams_test.go +++ b/pkg/capabilities/v2/triggers/streams/streams_test.go @@ -21,63 +21,66 @@ import ( func TestProtoTypesExist(t *testing.T) { // Config type config := &streams.Config{ - FeedIds: []string{"0x0001", "0x0002"}, + StreamIds: []uint32{1, 2, 3}, MaxFrequencyMs: 5000, } assert.NotNil(t, config) - assert.Len(t, config.FeedIds, 2) + assert.Len(t, config.StreamIds, 3) assert.Equal(t, uint64(5000), config.MaxFrequencyMs) - // Feed type - feed := &streams.Feed{ - Timestamp: 1234567890, - Metadata: &streams.SignersMetadata{ - Signers: []string{"signer1", "signer2"}, - MinRequiredSignatures: 2, - }, - Payload: []*streams.FeedReport{ + // Report type + report := &streams.Report{ + ConfigDigest: []byte{1, 2, 3, 4}, + SeqNr: 42, + Report: []byte("report-data"), + Sigs: []*streams.OCRSignature{ + { + Signer: 1, + Signature: []byte("sig1"), + }, { - FeedId: "0x0001", - FullReport: []byte("report-data"), - ReportContext: []byte("context"), - Signatures: [][]byte{[]byte("sig1")}, - BenchmarkPrice: []byte("price"), - ObservationTimestamp: 1234567890, + Signer: 2, + Signature: []byte("sig2"), }, }, } - assert.NotNil(t, feed) - assert.Equal(t, int64(1234567890), feed.Timestamp) - assert.Len(t, feed.Payload, 1) + assert.NotNil(t, report) + assert.Equal(t, []byte{1, 2, 3, 4}, report.ConfigDigest) + assert.Equal(t, uint64(42), report.SeqNr) + assert.Len(t, report.Sigs, 2) } // TestConfigGetters verifies getter methods work func TestConfigGetters(t *testing.T) { config := &streams.Config{ - FeedIds: []string{"0xfeed1", "0xfeed2", "0xfeed3"}, + StreamIds: []uint32{1, 2, 3}, MaxFrequencyMs: 10000, } - assert.Equal(t, []string{"0xfeed1", "0xfeed2", "0xfeed3"}, config.GetFeedIds()) + assert.Equal(t, []uint32{1, 2, 3}, config.GetStreamIds()) assert.Equal(t, uint64(10000), config.GetMaxFrequencyMs()) } -// TestFeedGetters verifies Feed getter methods -func TestFeedGetters(t *testing.T) { - metadata := &streams.SignersMetadata{ - Signers: []string{"signer1"}, - MinRequiredSignatures: 1, +// TestReportGetters verifies Report getter methods +func TestReportGetters(t *testing.T) { + sigs := []*streams.OCRSignature{ + { + Signer: 1, + Signature: []byte("sig1"), + }, } - feed := &streams.Feed{ - Timestamp: 9999999999, - Metadata: metadata, - Payload: []*streams.FeedReport{}, + report := &streams.Report{ + ConfigDigest: []byte{1, 2, 3, 4}, + SeqNr: 99, + Report: []byte("test-report"), + Sigs: sigs, } - assert.Equal(t, int64(9999999999), feed.GetTimestamp()) - assert.Equal(t, metadata, feed.GetMetadata()) - assert.NotNil(t, feed.GetPayload()) + assert.Equal(t, []byte{1, 2, 3, 4}, report.GetConfigDigest()) + assert.Equal(t, uint64(99), report.GetSeqNr()) + assert.Equal(t, []byte("test-report"), report.GetReport()) + assert.Equal(t, sigs, report.GetSigs()) } // TestStreamsCapabilityInterface verifies the server interface @@ -94,13 +97,13 @@ type mockStreamsCapability struct { closeCalled bool } -func (m *mockStreamsCapability) RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Feed], caperrors.Error) { +func (m *mockStreamsCapability) RegisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) (<-chan capabilities.TriggerAndId[*streams.Report], caperrors.Error) { m.registerCalled = true - ch := make(chan capabilities.TriggerAndId[*streams.Feed], 1) + ch := make(chan capabilities.TriggerAndId[*streams.Report], 1) return ch, nil } -func (m *mockStreamsCapability) UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) error { +func (m *mockStreamsCapability) UnregisterTrigger(ctx context.Context, triggerID string, metadata capabilities.RequestMetadata, input *streams.Config) caperrors.Error { m.unregisterCalled = true return nil } @@ -177,7 +180,7 @@ func TestTriggerRegistration(t *testing.T) { } config := &streams.Config{ - FeedIds: []string{"0x0001"}, + StreamIds: []uint32{1}, MaxFrequencyMs: 1000, } @@ -192,34 +195,35 @@ func TestTriggerRegistration(t *testing.T) { assert.True(t, mock.unregisterCalled) } -// TestFeedReportStructure tests the FeedReport structure -func TestFeedReportStructure(t *testing.T) { - report := &streams.FeedReport{ - FeedId: "0xfeedid12345", - FullReport: []byte("full-report-bytes"), - ReportContext: []byte("report-context"), - Signatures: [][]byte{[]byte("sig1"), []byte("sig2")}, - BenchmarkPrice: []byte("benchmark-price-bytes"), - ObservationTimestamp: 1700000000, +// TestReportStructure tests the Report structure +func TestReportStructure(t *testing.T) { + sigs := []*streams.OCRSignature{ + {Signer: 1, Signature: []byte("sig1")}, + {Signer: 2, Signature: []byte("sig2")}, + } + + report := &streams.Report{ + ConfigDigest: []byte{1, 2, 3, 4, 5}, + SeqNr: 123, + Report: []byte("full-report-bytes"), + Sigs: sigs, } - assert.Equal(t, "0xfeedid12345", report.GetFeedId()) - assert.Equal(t, []byte("full-report-bytes"), report.GetFullReport()) - assert.Equal(t, []byte("report-context"), report.GetReportContext()) - assert.Len(t, report.GetSignatures(), 2) - assert.Equal(t, []byte("benchmark-price-bytes"), report.GetBenchmarkPrice()) - assert.Equal(t, int64(1700000000), report.GetObservationTimestamp()) -} - -// TestSignersMetadata tests the SignersMetadata structure -func TestSignersMetadata(t *testing.T) { - metadata := &streams.SignersMetadata{ - Signers: []string{"0xsigner1", "0xsigner2", "0xsigner3"}, - MinRequiredSignatures: 2, + assert.Equal(t, []byte{1, 2, 3, 4, 5}, report.GetConfigDigest()) + assert.Equal(t, uint64(123), report.GetSeqNr()) + assert.Equal(t, []byte("full-report-bytes"), report.GetReport()) + assert.Len(t, report.GetSigs(), 2) +} + +// TestOCRSignature tests the OCRSignature structure +func TestOCRSignature(t *testing.T) { + sig := &streams.OCRSignature{ + Signer: 5, + Signature: []byte("signature-bytes"), } - assert.Len(t, metadata.GetSigners(), 3) - assert.Equal(t, int64(2), metadata.GetMinRequiredSignatures()) + assert.Equal(t, uint32(5), sig.GetSigner()) + assert.Equal(t, []byte("signature-bytes"), sig.GetSignature()) } // TestConfigValidation tests configuration validation scenarios @@ -230,17 +234,17 @@ func TestConfigValidation(t *testing.T) { expectValid bool }{ { - name: "valid config with single feed", + name: "valid config with single stream", config: &streams.Config{ - FeedIds: []string{"0x0001"}, + StreamIds: []uint32{1}, MaxFrequencyMs: 1000, }, expectValid: true, }, { - name: "valid config with multiple feeds", + name: "valid config with multiple streams", config: &streams.Config{ - FeedIds: []string{"0x0001", "0x0002", "0x0003"}, + StreamIds: []uint32{1, 2, 3}, MaxFrequencyMs: 5000, }, expectValid: true, @@ -248,7 +252,7 @@ func TestConfigValidation(t *testing.T) { { name: "high frequency", config: &streams.Config{ - FeedIds: []string{"0x0001"}, + StreamIds: []uint32{1}, MaxFrequencyMs: 100, }, expectValid: true, @@ -259,7 +263,7 @@ func TestConfigValidation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { // Basic validation - config should be creatable assert.NotNil(t, tt.config) - assert.NotEmpty(t, tt.config.FeedIds) + assert.NotEmpty(t, tt.config.StreamIds) assert.Greater(t, tt.config.MaxFrequencyMs, uint64(0)) }) } @@ -352,33 +356,31 @@ func TestServerLifecycle(t *testing.T) { infos, err := srv.Infos(ctx) require.NoError(t, err) require.Len(t, infos, 1) - assert.Equal(t, "streams-trigger@1.0.0", infos[0].ID) + assert.Equal(t, "streams-trigger@2.0.0", infos[0].ID) // Close err = srv.Close() require.NoError(t, err) assert.True(t, mock.closeCalled, "Close should be called") assert.Len(t, mockRegistry.removed, 1, "Capability should be unregistered") - assert.Equal(t, "streams-trigger@1.0.0", mockRegistry.removed[0]) + assert.Equal(t, "streams-trigger@2.0.0", mockRegistry.removed[0]) } -// BenchmarkFeedCreation benchmarks creating Feed objects -func BenchmarkFeedCreation(b *testing.B) { +// BenchmarkReportCreation benchmarks creating Report objects +func BenchmarkReportCreation(b *testing.B) { for i := 0; i < b.N; i++ { - _ = &streams.Feed{ - Timestamp: int64(i), - Metadata: &streams.SignersMetadata{ - Signers: []string{"signer1", "signer2"}, - MinRequiredSignatures: 2, - }, - Payload: []*streams.FeedReport{ + _ = &streams.Report{ + ConfigDigest: []byte{1, 2, 3, 4}, + SeqNr: uint64(i), + Report: []byte("report-data"), + Sigs: []*streams.OCRSignature{ + { + Signer: 1, + Signature: []byte("sig1"), + }, { - FeedId: "0x0001", - FullReport: []byte("report"), - ReportContext: []byte("context"), - Signatures: [][]byte{[]byte("sig")}, - BenchmarkPrice: []byte("price"), - ObservationTimestamp: int64(i), + Signer: 2, + Signature: []byte("sig2"), }, }, } diff --git a/pkg/capabilities/v2/triggers/streams/trigger.pb.go b/pkg/capabilities/v2/triggers/streams/trigger.pb.go index 7ae1104d07..995e82d567 100644 --- a/pkg/capabilities/v2/triggers/streams/trigger.pb.go +++ b/pkg/capabilities/v2/triggers/streams/trigger.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.34.2 +// protoc-gen-go v1.36.10 // protoc v5.27.3 // source: cre/capabilities/streams/v1/trigger.proto @@ -12,6 +12,7 @@ import ( protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -21,26 +22,25 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// Configuration for the Streams Trigger +// Configuration for the Streams LLO Trigger +// This matches the existing LLOTriggerConfig structure type Config struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // The IDs of the data feeds that will have their reports included in the trigger event. - // Feed IDs are hex-encoded strings (e.g., "0x000..."). - FeedIds []string `protobuf:"bytes,1,rep,name=feed_ids,json=feedIds,proto3" json:"feed_ids,omitempty"` - // The interval in milliseconds after which a new trigger event is generated. + state protoimpl.MessageState `protogen:"open.v1"` + // The IDs of the LLO data streams to subscribe to. + // Stream IDs are uint32 values that identify specific feeds. + StreamIds []uint32 `protobuf:"varint,1,rep,packed,name=stream_ids,json=streamIds,proto3" json:"stream_ids,omitempty"` + // The minimum interval in milliseconds between trigger events. + // Trigger will only emit events at most once per this interval. MaxFrequencyMs uint64 `protobuf:"varint,2,opt,name=max_frequency_ms,json=maxFrequencyMs,proto3" json:"max_frequency_ms,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *Config) Reset() { *x = Config{} - if protoimpl.UnsafeEnabled { - mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *Config) String() string { @@ -51,7 +51,7 @@ func (*Config) ProtoMessage() {} func (x *Config) ProtoReflect() protoreflect.Message { mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -66,9 +66,9 @@ func (*Config) Descriptor() ([]byte, []int) { return file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP(), []int{0} } -func (x *Config) GetFeedIds() []string { +func (x *Config) GetStreamIds() []uint32 { if x != nil { - return x.FeedIds + return x.StreamIds } return nil } @@ -80,36 +80,33 @@ func (x *Config) GetMaxFrequencyMs() uint64 { return 0 } -// Metadata about the signers that produced the reports -type SignersMetadata struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache +// An attributed onchain signature +type OCRSignature struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The signer index + Signer uint32 `protobuf:"varint,1,opt,name=signer,proto3" json:"signer,omitempty"` + // The signature bytes + Signature []byte `protobuf:"bytes,2,opt,name=signature,proto3" json:"signature,omitempty"` unknownFields protoimpl.UnknownFields - - // The IDs of the signers - Signers []string `protobuf:"bytes,1,rep,name=signers,proto3" json:"signers,omitempty"` - // The minimum number of signatures required to validate a report - MinRequiredSignatures int64 `protobuf:"varint,2,opt,name=min_required_signatures,json=minRequiredSignatures,proto3" json:"min_required_signatures,omitempty"` + sizeCache protoimpl.SizeCache } -func (x *SignersMetadata) Reset() { - *x = SignersMetadata{} - if protoimpl.UnsafeEnabled { - mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } +func (x *OCRSignature) Reset() { + *x = OCRSignature{} + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } -func (x *SignersMetadata) String() string { +func (x *OCRSignature) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SignersMetadata) ProtoMessage() {} +func (*OCRSignature) ProtoMessage() {} -func (x *SignersMetadata) ProtoReflect() protoreflect.Message { +func (x *OCRSignature) ProtoReflect() protoreflect.Message { mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -119,63 +116,57 @@ func (x *SignersMetadata) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SignersMetadata.ProtoReflect.Descriptor instead. -func (*SignersMetadata) Descriptor() ([]byte, []int) { +// Deprecated: Use OCRSignature.ProtoReflect.Descriptor instead. +func (*OCRSignature) Descriptor() ([]byte, []int) { return file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP(), []int{1} } -func (x *SignersMetadata) GetSigners() []string { +func (x *OCRSignature) GetSigner() uint32 { if x != nil { - return x.Signers + return x.Signer } - return nil + return 0 } -func (x *SignersMetadata) GetMinRequiredSignatures() int64 { +func (x *OCRSignature) GetSignature() []byte { if x != nil { - return x.MinRequiredSignatures + return x.Signature } - return 0 + return nil } -// A single feed report containing data and signatures -type FeedReport struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache +// OCR Trigger Event payload +// This matches the existing OCRTriggerEvent structure that the transmitter emits +type Report struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Configuration digest for the OCR round + ConfigDigest []byte `protobuf:"bytes,1,opt,name=config_digest,json=configDigest,proto3" json:"config_digest,omitempty"` + // Sequence number of the report + SeqNr uint64 `protobuf:"varint,2,opt,name=seq_nr,json=seqNr,proto3" json:"seq_nr,omitempty"` + // The report bytes (raw OCR report) + Report []byte `protobuf:"bytes,3,opt,name=report,proto3" json:"report,omitempty"` + // Attributed onchain signatures + Sigs []*OCRSignature `protobuf:"bytes,4,rep,name=sigs,proto3" json:"sigs,omitempty"` unknownFields protoimpl.UnknownFields - - // The ID of the data feed (hex-encoded) - FeedId string `protobuf:"bytes,1,opt,name=feed_id,json=feedId,proto3" json:"feed_id,omitempty"` - // The full report as raw bytes - FullReport []byte `protobuf:"bytes,2,opt,name=full_report,json=fullReport,proto3" json:"full_report,omitempty"` - // Report context required to validate signatures - ReportContext []byte `protobuf:"bytes,3,opt,name=report_context,json=reportContext,proto3" json:"report_context,omitempty"` - // Signatures over the full report and report context - Signatures [][]byte `protobuf:"bytes,4,rep,name=signatures,proto3" json:"signatures,omitempty"` - // The benchmark price extracted from the full report - BenchmarkPrice []byte `protobuf:"bytes,5,opt,name=benchmark_price,json=benchmarkPrice,proto3" json:"benchmark_price,omitempty"` - // Timestamp when the observation was made - ObservationTimestamp int64 `protobuf:"varint,6,opt,name=observation_timestamp,json=observationTimestamp,proto3" json:"observation_timestamp,omitempty"` + sizeCache protoimpl.SizeCache } -func (x *FeedReport) Reset() { - *x = FeedReport{} - if protoimpl.UnsafeEnabled { - mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } +func (x *Report) Reset() { + *x = Report{} + mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } -func (x *FeedReport) String() string { +func (x *Report) String() string { return protoimpl.X.MessageStringOf(x) } -func (*FeedReport) ProtoMessage() {} +func (*Report) ProtoMessage() {} -func (x *FeedReport) ProtoReflect() protoreflect.Message { +func (x *Report) ProtoReflect() protoreflect.Message { mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -185,212 +176,86 @@ func (x *FeedReport) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use FeedReport.ProtoReflect.Descriptor instead. -func (*FeedReport) Descriptor() ([]byte, []int) { +// Deprecated: Use Report.ProtoReflect.Descriptor instead. +func (*Report) Descriptor() ([]byte, []int) { return file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP(), []int{2} } -func (x *FeedReport) GetFeedId() string { - if x != nil { - return x.FeedId - } - return "" -} - -func (x *FeedReport) GetFullReport() []byte { - if x != nil { - return x.FullReport - } - return nil -} - -func (x *FeedReport) GetReportContext() []byte { - if x != nil { - return x.ReportContext - } - return nil -} - -func (x *FeedReport) GetSignatures() [][]byte { +func (x *Report) GetConfigDigest() []byte { if x != nil { - return x.Signatures + return x.ConfigDigest } return nil } -func (x *FeedReport) GetBenchmarkPrice() []byte { +func (x *Report) GetSeqNr() uint64 { if x != nil { - return x.BenchmarkPrice - } - return nil -} - -func (x *FeedReport) GetObservationTimestamp() int64 { - if x != nil { - return x.ObservationTimestamp + return x.SeqNr } return 0 } -// The payload emitted by the Streams Trigger containing feed data -type Feed struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Timestamp when the trigger event was generated - Timestamp int64 `protobuf:"varint,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"` - // Metadata about the signers - Metadata *SignersMetadata `protobuf:"bytes,2,opt,name=metadata,proto3" json:"metadata,omitempty"` - // Array of feed reports - Payload []*FeedReport `protobuf:"bytes,3,rep,name=payload,proto3" json:"payload,omitempty"` -} - -func (x *Feed) Reset() { - *x = Feed{} - if protoimpl.UnsafeEnabled { - mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *Feed) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Feed) ProtoMessage() {} - -func (x *Feed) ProtoReflect() protoreflect.Message { - mi := &file_cre_capabilities_streams_v1_trigger_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Feed.ProtoReflect.Descriptor instead. -func (*Feed) Descriptor() ([]byte, []int) { - return file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP(), []int{3} -} - -func (x *Feed) GetTimestamp() int64 { +func (x *Report) GetReport() []byte { if x != nil { - return x.Timestamp - } - return 0 -} - -func (x *Feed) GetMetadata() *SignersMetadata { - if x != nil { - return x.Metadata + return x.Report } return nil } -func (x *Feed) GetPayload() []*FeedReport { +func (x *Report) GetSigs() []*OCRSignature { if x != nil { - return x.Payload + return x.Sigs } return nil } var File_cre_capabilities_streams_v1_trigger_proto protoreflect.FileDescriptor -var file_cre_capabilities_streams_v1_trigger_proto_rawDesc = []byte{ - 0x0a, 0x29, 0x63, 0x72, 0x65, 0x2f, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, - 0x65, 0x73, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2f, 0x76, 0x31, 0x2f, 0x74, 0x72, - 0x69, 0x67, 0x67, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x17, 0x63, 0x61, 0x70, - 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, - 0x73, 0x2e, 0x76, 0x31, 0x1a, 0x2a, 0x74, 0x6f, 0x6f, 0x6c, 0x73, 0x2f, 0x67, 0x65, 0x6e, 0x65, - 0x72, 0x61, 0x74, 0x6f, 0x72, 0x2f, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x2f, 0x63, 0x72, - 0x65, 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x22, 0x4d, 0x0a, 0x06, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x19, 0x0a, 0x08, 0x66, 0x65, - 0x65, 0x64, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x66, 0x65, - 0x65, 0x64, 0x49, 0x64, 0x73, 0x12, 0x28, 0x0a, 0x10, 0x6d, 0x61, 0x78, 0x5f, 0x66, 0x72, 0x65, - 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x5f, 0x6d, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, - 0x0e, 0x6d, 0x61, 0x78, 0x46, 0x72, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x4d, 0x73, 0x22, - 0x63, 0x0a, 0x0f, 0x53, 0x69, 0x67, 0x6e, 0x65, 0x72, 0x73, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x69, 0x67, 0x6e, 0x65, 0x72, 0x73, 0x12, 0x36, 0x0a, 0x17, - 0x6d, 0x69, 0x6e, 0x5f, 0x72, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x5f, 0x73, 0x69, 0x67, - 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x15, 0x6d, - 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x74, - 0x75, 0x72, 0x65, 0x73, 0x22, 0xeb, 0x01, 0x0a, 0x0a, 0x46, 0x65, 0x65, 0x64, 0x52, 0x65, 0x70, - 0x6f, 0x72, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x65, 0x65, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x66, 0x65, 0x65, 0x64, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x0b, - 0x66, 0x75, 0x6c, 0x6c, 0x5f, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x25, 0x0a, - 0x0e, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0d, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x6f, 0x6e, - 0x74, 0x65, 0x78, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, - 0x65, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x0a, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, - 0x75, 0x72, 0x65, 0x73, 0x12, 0x27, 0x0a, 0x0f, 0x62, 0x65, 0x6e, 0x63, 0x68, 0x6d, 0x61, 0x72, - 0x6b, 0x5f, 0x70, 0x72, 0x69, 0x63, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0e, 0x62, - 0x65, 0x6e, 0x63, 0x68, 0x6d, 0x61, 0x72, 0x6b, 0x50, 0x72, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, - 0x15, 0x6f, 0x62, 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x74, 0x69, 0x6d, - 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x6f, 0x62, - 0x73, 0x65, 0x72, 0x76, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, - 0x6d, 0x70, 0x22, 0xa9, 0x01, 0x0a, 0x04, 0x46, 0x65, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x74, - 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, - 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x44, 0x0a, 0x08, 0x6d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x28, 0x2e, 0x63, 0x61, - 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, - 0x6d, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x65, 0x72, 0x73, 0x4d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, - 0x3d, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x23, 0x2e, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x2e, - 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x46, 0x65, 0x65, 0x64, 0x52, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x32, 0x75, - 0x0a, 0x07, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x12, 0x4b, 0x0a, 0x07, 0x54, 0x72, 0x69, - 0x67, 0x67, 0x65, 0x72, 0x12, 0x1f, 0x2e, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, - 0x69, 0x65, 0x73, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x1a, 0x1d, 0x2e, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, 0x6c, 0x69, - 0x74, 0x69, 0x65, 0x73, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x76, 0x31, 0x2e, - 0x46, 0x65, 0x65, 0x64, 0x30, 0x01, 0x1a, 0x1d, 0x82, 0xb5, 0x18, 0x19, 0x08, 0x01, 0x12, 0x15, - 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2d, 0x74, 0x72, 0x69, 0x67, 0x67, 0x65, 0x72, 0x40, - 0x31, 0x2e, 0x30, 0x2e, 0x30, 0x42, 0x53, 0x5a, 0x51, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x73, 0x6d, 0x61, 0x72, 0x74, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x61, 0x63, - 0x74, 0x6b, 0x69, 0x74, 0x2f, 0x63, 0x68, 0x61, 0x69, 0x6e, 0x6c, 0x69, 0x6e, 0x6b, 0x2d, 0x63, - 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x63, 0x61, 0x70, 0x61, 0x62, 0x69, - 0x6c, 0x69, 0x74, 0x69, 0x65, 0x73, 0x2f, 0x76, 0x32, 0x2f, 0x74, 0x72, 0x69, 0x67, 0x67, 0x65, - 0x72, 0x73, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, -} +const file_cre_capabilities_streams_v1_trigger_proto_rawDesc = "" + + "\n" + + ")cre/capabilities/streams/v1/trigger.proto\x12\x17capabilities.streams.v1\x1a.cre/tools/generator/v1alpha/cre_metadata.proto\"Q\n" + + "\x06Config\x12\x1d\n" + + "\n" + + "stream_ids\x18\x01 \x03(\rR\tstreamIds\x12(\n" + + "\x10max_frequency_ms\x18\x02 \x01(\x04R\x0emaxFrequencyMs\"D\n" + + "\fOCRSignature\x12\x16\n" + + "\x06signer\x18\x01 \x01(\rR\x06signer\x12\x1c\n" + + "\tsignature\x18\x02 \x01(\fR\tsignature\"\x97\x01\n" + + "\x06Report\x12#\n" + + "\rconfig_digest\x18\x01 \x01(\fR\fconfigDigest\x12\x15\n" + + "\x06seq_nr\x18\x02 \x01(\x04R\x05seqNr\x12\x16\n" + + "\x06report\x18\x03 \x01(\fR\x06report\x129\n" + + "\x04sigs\x18\x04 \x03(\v2%.capabilities.streams.v1.OCRSignatureR\x04sigs2w\n" + + "\aStreams\x12M\n" + + "\aTrigger\x12\x1f.capabilities.streams.v1.Config\x1a\x1f.capabilities.streams.v1.Report0\x01\x1a\x1d\x82\xb5\x18\x19\b\x01\x12\x15streams-trigger@2.0.0BSZQgithub.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/triggers/streamsb\x06proto3" var ( file_cre_capabilities_streams_v1_trigger_proto_rawDescOnce sync.Once - file_cre_capabilities_streams_v1_trigger_proto_rawDescData = file_cre_capabilities_streams_v1_trigger_proto_rawDesc + file_cre_capabilities_streams_v1_trigger_proto_rawDescData []byte ) func file_cre_capabilities_streams_v1_trigger_proto_rawDescGZIP() []byte { file_cre_capabilities_streams_v1_trigger_proto_rawDescOnce.Do(func() { - file_cre_capabilities_streams_v1_trigger_proto_rawDescData = protoimpl.X.CompressGZIP(file_cre_capabilities_streams_v1_trigger_proto_rawDescData) + file_cre_capabilities_streams_v1_trigger_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_cre_capabilities_streams_v1_trigger_proto_rawDesc), len(file_cre_capabilities_streams_v1_trigger_proto_rawDesc))) }) return file_cre_capabilities_streams_v1_trigger_proto_rawDescData } -var file_cre_capabilities_streams_v1_trigger_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_cre_capabilities_streams_v1_trigger_proto_msgTypes = make([]protoimpl.MessageInfo, 3) var file_cre_capabilities_streams_v1_trigger_proto_goTypes = []any{ - (*Config)(nil), // 0: capabilities.streams.v1.Config - (*SignersMetadata)(nil), // 1: capabilities.streams.v1.SignersMetadata - (*FeedReport)(nil), // 2: capabilities.streams.v1.FeedReport - (*Feed)(nil), // 3: capabilities.streams.v1.Feed + (*Config)(nil), // 0: capabilities.streams.v1.Config + (*OCRSignature)(nil), // 1: capabilities.streams.v1.OCRSignature + (*Report)(nil), // 2: capabilities.streams.v1.Report } var file_cre_capabilities_streams_v1_trigger_proto_depIdxs = []int32{ - 1, // 0: capabilities.streams.v1.Feed.metadata:type_name -> capabilities.streams.v1.SignersMetadata - 2, // 1: capabilities.streams.v1.Feed.payload:type_name -> capabilities.streams.v1.FeedReport - 0, // 2: capabilities.streams.v1.Streams.Trigger:input_type -> capabilities.streams.v1.Config - 3, // 3: capabilities.streams.v1.Streams.Trigger:output_type -> capabilities.streams.v1.Feed - 3, // [3:4] is the sub-list for method output_type - 2, // [2:3] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 1, // 0: capabilities.streams.v1.Report.sigs:type_name -> capabilities.streams.v1.OCRSignature + 0, // 1: capabilities.streams.v1.Streams.Trigger:input_type -> capabilities.streams.v1.Config + 2, // 2: capabilities.streams.v1.Streams.Trigger:output_type -> capabilities.streams.v1.Report + 2, // [2:3] is the sub-list for method output_type + 1, // [1:2] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name } func init() { file_cre_capabilities_streams_v1_trigger_proto_init() } @@ -398,63 +263,13 @@ func file_cre_capabilities_streams_v1_trigger_proto_init() { if File_cre_capabilities_streams_v1_trigger_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_cre_capabilities_streams_v1_trigger_proto_msgTypes[0].Exporter = func(v any, i int) any { - switch v := v.(*Config); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_cre_capabilities_streams_v1_trigger_proto_msgTypes[1].Exporter = func(v any, i int) any { - switch v := v.(*SignersMetadata); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_cre_capabilities_streams_v1_trigger_proto_msgTypes[2].Exporter = func(v any, i int) any { - switch v := v.(*FeedReport); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_cre_capabilities_streams_v1_trigger_proto_msgTypes[3].Exporter = func(v any, i int) any { - switch v := v.(*Feed); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_cre_capabilities_streams_v1_trigger_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_cre_capabilities_streams_v1_trigger_proto_rawDesc), len(file_cre_capabilities_streams_v1_trigger_proto_rawDesc)), NumEnums: 0, - NumMessages: 4, + NumMessages: 3, NumExtensions: 0, NumServices: 1, }, @@ -463,7 +278,6 @@ func file_cre_capabilities_streams_v1_trigger_proto_init() { MessageInfos: file_cre_capabilities_streams_v1_trigger_proto_msgTypes, }.Build() File_cre_capabilities_streams_v1_trigger_proto = out.File - file_cre_capabilities_streams_v1_trigger_proto_rawDesc = nil file_cre_capabilities_streams_v1_trigger_proto_goTypes = nil file_cre_capabilities_streams_v1_trigger_proto_depIdxs = nil }