diff --git a/pkg/capabilities/registry/base.go b/pkg/capabilities/registry/base.go index 8599ec646..39b864186 100644 --- a/pkg/capabilities/registry/base.go +++ b/pkg/capabilities/registry/base.go @@ -6,8 +6,11 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "github.com/Masterminds/semver/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -18,8 +21,22 @@ var ( ErrCapabilityAlreadyExists = errors.New("capability already exists") ) +// atomicBaseCapability extends [capabilities.BaseCapability] to support atomic updates and forward client state checks. +type atomicBaseCapability interface { + capabilities.BaseCapability + Update(capabilities.BaseCapability) error + StateGetter +} + +var _ StateGetter = (*grpc.ClientConn)(nil) + +// StateGetter is implemented by GRPC client connections. +type StateGetter interface { + GetState() connectivity.State +} + type baseRegistry struct { - m map[string]capabilities.BaseCapability + m map[string]atomicBaseCapability lggr logger.Logger mu sync.RWMutex } @@ -28,7 +45,7 @@ var _ core.CapabilitiesRegistryBase = (*baseRegistry)(nil) func NewBaseRegistry(lggr logger.Logger) core.CapabilitiesRegistryBase { return &baseRegistry{ - m: map[string]capabilities.BaseCapability{}, + m: map[string]atomicBaseCapability{}, lggr: logger.Named(lggr, "registries.basic"), } } @@ -142,33 +159,35 @@ func (r *baseRegistry) Add(ctx context.Context, c capabilities.BaseCapability) e return err } - switch info.CapabilityType { - case capabilities.CapabilityTypeTrigger: - _, ok := c.(capabilities.TriggerCapability) - if !ok { - return errors.New("trigger capability does not satisfy TriggerCapability interface") + id := info.ID + bc, ok := r.m[id] + if ok { + switch state := bc.GetState(); state { + case connectivity.Shutdown, connectivity.TransientFailure, connectivity.Idle: + // allow replace + default: + return fmt.Errorf("%w: id %s found in registry: state %s", ErrCapabilityAlreadyExists, id, state) } - case capabilities.CapabilityTypeAction, capabilities.CapabilityTypeConsensus, capabilities.CapabilityTypeTarget: - _, ok := c.(capabilities.ExecutableCapability) - if !ok { - return errors.New("action does not satisfy ExecutableCapability interface") + if err := bc.Update(c); err != nil { + return fmt.Errorf("failed to update capability %s: %w", id, err) } - case capabilities.CapabilityTypeCombined: - _, ok := c.(capabilities.ExecutableAndTriggerCapability) - if !ok { - return errors.New("target capability does not satisfy ExecutableAndTriggerCapability interface") + } else { + var ac atomicBaseCapability + switch info.CapabilityType { + case capabilities.CapabilityTypeTrigger: + ac = &atomicTriggerCapability{} + case capabilities.CapabilityTypeAction, capabilities.CapabilityTypeConsensus, capabilities.CapabilityTypeTarget: + ac = &atomicExecuteCapability{} + case capabilities.CapabilityTypeCombined: + ac = &atomicExecuteAndTriggerCapability{} + default: + return fmt.Errorf("unknown capability type: %s", info.CapabilityType) } - default: - return fmt.Errorf("unknown capability type: %s", info.CapabilityType) - } - - id := info.ID - _, ok := r.m[id] - if ok { - return fmt.Errorf("%w: id %s found in registry", ErrCapabilityAlreadyExists, id) + if err := ac.Update(c); err != nil { + return err + } + r.m[id] = ac } - - r.m[id] = c r.lggr.Infow("capability added", "id", id, "type", info.CapabilityType, "description", info.Description, "version", info.Version()) return nil } @@ -176,12 +195,207 @@ func (r *baseRegistry) Add(ctx context.Context, c capabilities.BaseCapability) e func (r *baseRegistry) Remove(_ context.Context, id string) error { r.mu.Lock() defer r.mu.Unlock() - _, ok := r.m[id] + ac, ok := r.m[id] if !ok { return fmt.Errorf("unable to remove, capability not found: %s", id) } - - delete(r.m, id) + if err := ac.Update(nil); err != nil { + return fmt.Errorf("failed to remove capability %s: %w", id, err) + } r.lggr.Infow("capability removed", "id", id) return nil } + +var _ capabilities.TriggerCapability = &atomicTriggerCapability{} + +type atomicTriggerCapability struct { + atomic.Pointer[capabilities.TriggerCapability] +} + +func (a *atomicTriggerCapability) Update(c capabilities.BaseCapability) error { + if c == nil { + a.Store(nil) + return nil + } + tc, ok := c.(capabilities.TriggerCapability) + if !ok { + return errors.New("trigger capability does not satisfy TriggerCapability interface") + } + a.Store(&tc) + return nil +} + +func (a *atomicTriggerCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) { + c := a.Load() + if c == nil { + return capabilities.CapabilityInfo{}, errors.New("capability unavailable") + } + return (*c).Info(ctx) +} + +func (a *atomicTriggerCapability) GetState() connectivity.State { + c := a.Load() + if c == nil { + return connectivity.Shutdown + } + if sg, ok := (*c).(StateGetter); ok { + return sg.GetState() + } + return connectivity.State(-1) // unknown +} + +func (a *atomicTriggerCapability) RegisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) (<-chan capabilities.TriggerResponse, error) { + c := a.Load() + if c == nil { + return nil, errors.New("capability unavailable") + } + return (*c).RegisterTrigger(ctx, request) +} + +func (a *atomicTriggerCapability) UnregisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) error { + c := a.Load() + if c == nil { + return errors.New("capability unavailable") + } + return (*c).UnregisterTrigger(ctx, request) +} + +var _ capabilities.ExecutableCapability = &atomicExecuteCapability{} + +type atomicExecuteCapability struct { + atomic.Pointer[capabilities.ExecutableCapability] +} + +func (a *atomicExecuteCapability) Update(c capabilities.BaseCapability) error { + if c == nil { + a.Store(nil) + return nil + } + tc, ok := c.(capabilities.ExecutableCapability) + if !ok { + return errors.New("action does not satisfy ExecutableCapability interface") + } + a.Store(&tc) + return nil +} + +func (a *atomicExecuteCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) { + c := a.Load() + if c == nil { + return capabilities.CapabilityInfo{}, errors.New("capability unavailable") + } + return (*c).Info(ctx) +} + +func (a *atomicExecuteCapability) GetState() connectivity.State { + c := a.Load() + if c == nil { + return connectivity.Shutdown + } + if sg, ok := (*c).(StateGetter); ok { + return sg.GetState() + } + return connectivity.State(-1) // unknown +} + +func (a *atomicExecuteCapability) RegisterToWorkflow(ctx context.Context, request capabilities.RegisterToWorkflowRequest) error { + c := a.Load() + if c == nil { + return errors.New("capability unavailable") + } + return (*c).RegisterToWorkflow(ctx, request) +} + +func (a *atomicExecuteCapability) UnregisterFromWorkflow(ctx context.Context, request capabilities.UnregisterFromWorkflowRequest) error { + c := a.Load() + if c == nil { + return errors.New("capability unavailable") + } + return (*c).UnregisterFromWorkflow(ctx, request) +} + +func (a *atomicExecuteCapability) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { + c := a.Load() + if c == nil { + return capabilities.CapabilityResponse{}, errors.New("capability unavailable") + } + return (*c).Execute(ctx, request) +} + +var _ capabilities.ExecutableAndTriggerCapability = &atomicExecuteAndTriggerCapability{} + +type atomicExecuteAndTriggerCapability struct { + atomic.Pointer[capabilities.ExecutableAndTriggerCapability] +} + +func (a *atomicExecuteAndTriggerCapability) Update(c capabilities.BaseCapability) error { + if c == nil { + a.Store(nil) + return nil + } + tc, ok := c.(capabilities.ExecutableAndTriggerCapability) + if !ok { + return errors.New("target capability does not satisfy ExecutableAndTriggerCapability interface") + } + a.Store(&tc) + return nil +} + +func (a *atomicExecuteAndTriggerCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) { + c := a.Load() + if c == nil { + return capabilities.CapabilityInfo{}, errors.New("capability unavailable") + } + return (*c).Info(ctx) +} + +func (a *atomicExecuteAndTriggerCapability) GetState() connectivity.State { + c := a.Load() + if c == nil { + return connectivity.Shutdown + } + if sg, ok := (*c).(StateGetter); ok { + return sg.GetState() + } + return connectivity.State(-1) // unknown +} + +func (a *atomicExecuteAndTriggerCapability) RegisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) (<-chan capabilities.TriggerResponse, error) { + c := a.Load() + if c == nil { + return nil, errors.New("capability unavailable") + } + return (*c).RegisterTrigger(ctx, request) +} + +func (a *atomicExecuteAndTriggerCapability) UnregisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) error { + c := a.Load() + if c == nil { + return errors.New("capability unavailable") + } + return (*c).UnregisterTrigger(ctx, request) +} + +func (a *atomicExecuteAndTriggerCapability) RegisterToWorkflow(ctx context.Context, request capabilities.RegisterToWorkflowRequest) error { + c := a.Load() + if c == nil { + return errors.New("capability unavailable") + } + return (*c).RegisterToWorkflow(ctx, request) +} + +func (a *atomicExecuteAndTriggerCapability) UnregisterFromWorkflow(ctx context.Context, request capabilities.UnregisterFromWorkflowRequest) error { + c := a.Load() + if c == nil { + return errors.New("capability unavailable") + } + return (*c).UnregisterFromWorkflow(ctx, request) +} + +func (a *atomicExecuteAndTriggerCapability) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) { + c := a.Load() + if c == nil { + return capabilities.CapabilityResponse{}, errors.New("capability unavailable") + } + return (*c).Execute(ctx, request) +} diff --git a/pkg/capabilities/registry/base_test.go b/pkg/capabilities/registry/base_test.go index 4e32004dd..ba8ad26ee 100644 --- a/pkg/capabilities/registry/base_test.go +++ b/pkg/capabilities/registry/base_test.go @@ -50,13 +50,17 @@ func TestRegistry(t *testing.T) { gc, err := r.Get(ctx, id) require.NoError(t, err) + info, err := gc.Info(t.Context()) + require.NoError(t, err) - assert.Equal(t, c, gc) + assert.Equal(t, c.CapabilityInfo, info) cs, err := r.List(ctx) require.NoError(t, err) assert.Len(t, cs, 1) - assert.Equal(t, c, cs[0]) + info, err = cs[0].Info(t.Context()) + require.NoError(t, err) + assert.Equal(t, c.CapabilityInfo, info) } func TestRegistryCompatibleVersions(t *testing.T) { diff --git a/pkg/loop/internal/core/services/capability/capabilities.go b/pkg/loop/internal/core/services/capability/capabilities.go index f01423486..65224424b 100644 --- a/pkg/loop/internal/core/services/capability/capabilities.go +++ b/pkg/loop/internal/core/services/capability/capabilities.go @@ -8,6 +8,7 @@ import ( "sync" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" @@ -22,7 +23,7 @@ type TriggerCapabilityClient struct { *baseCapabilityClient } -func NewTriggerCapabilityClient(brokerExt *net.BrokerExt, conn grpc.ClientConnInterface) capabilities.TriggerCapability { +func NewTriggerCapabilityClient(brokerExt *net.BrokerExt, conn net.ClientConnInterface) capabilities.TriggerCapability { return &TriggerCapabilityClient{ triggerExecutableClient: newTriggerExecutableClient(brokerExt, conn), baseCapabilityClient: newBaseCapabilityClient(brokerExt, conn), @@ -39,7 +40,7 @@ type ExecutableCapability interface { capabilities.BaseCapability } -func NewExecutableCapabilityClient(brokerExt *net.BrokerExt, conn grpc.ClientConnInterface) ExecutableCapability { +func NewExecutableCapabilityClient(brokerExt *net.BrokerExt, conn net.ClientConnInterface) ExecutableCapability { return &ExecutableCapabilityClient{ executableClient: newExecutableClient(brokerExt, conn), baseCapabilityClient: newBaseCapabilityClient(brokerExt, conn), @@ -52,7 +53,7 @@ type CombinedCapabilityClient struct { *triggerExecutableClient } -func NewCombinedCapabilityClient(brokerExt *net.BrokerExt, conn grpc.ClientConnInterface) ExecutableCapability { +func NewCombinedCapabilityClient(brokerExt *net.BrokerExt, conn net.ClientConnInterface) ExecutableCapability { return &CombinedCapabilityClient{ executableClient: newExecutableClient(brokerExt, conn), baseCapabilityClient: newBaseCapabilityClient(brokerExt, conn), @@ -135,14 +136,18 @@ func InfoToReply(info capabilities.CapabilityInfo) *capabilitiespb.CapabilityInf } type baseCapabilityClient struct { + c net.ClientConnInterface grpc capabilitiespb.BaseCapabilityClient *net.BrokerExt } var _ capabilities.BaseCapability = (*baseCapabilityClient)(nil) -func newBaseCapabilityClient(brokerExt *net.BrokerExt, conn grpc.ClientConnInterface) *baseCapabilityClient { - return &baseCapabilityClient{grpc: capabilitiespb.NewBaseCapabilityClient(conn), BrokerExt: brokerExt} +func newBaseCapabilityClient(brokerExt *net.BrokerExt, conn net.ClientConnInterface) *baseCapabilityClient { + return &baseCapabilityClient{c: conn, grpc: capabilitiespb.NewBaseCapabilityClient(conn), BrokerExt: brokerExt} +} +func (c *baseCapabilityClient) GetState() connectivity.State { + return c.c.GetState() } func (c *baseCapabilityClient) Info(ctx context.Context) (capabilities.CapabilityInfo, error) { diff --git a/pkg/loop/internal/core/services/capability/capabilities_registry.go b/pkg/loop/internal/core/services/capability/capabilities_registry.go index a74a7f305..8c9cee12b 100644 --- a/pkg/loop/internal/core/services/capability/capabilities_registry.go +++ b/pkg/loop/internal/core/services/capability/capabilities_registry.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/capabilities" capabilitiespb "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/registry" "github.com/smartcontractkit/chainlink-common/pkg/loop/internal/net" "github.com/smartcontractkit/chainlink-common/pkg/loop/internal/pb" "github.com/smartcontractkit/chainlink-common/pkg/types/core" @@ -606,6 +607,10 @@ func (c *capabilitiesRegistryServer) List(ctx context.Context, _ *emptypb.Empty) return reply, nil } +var _ registry.StateGetter = (*TriggerCapabilityClient)(nil) +var _ registry.StateGetter = (*ExecutableCapabilityClient)(nil) +var _ registry.StateGetter = (*CombinedCapabilityClient)(nil) + func (c *capabilitiesRegistryServer) Add(ctx context.Context, request *pb.AddRequest) (*emptypb.Empty, error) { conn, err := c.Dial(request.CapabilityID) if err != nil { diff --git a/pkg/loop/internal/core/services/capability/capabilities_registry_test.go b/pkg/loop/internal/core/services/capability/capabilities_registry_test.go index f2d37347c..0d6853741 100644 --- a/pkg/loop/internal/core/services/capability/capabilities_registry_test.go +++ b/pkg/loop/internal/core/services/capability/capabilities_registry_test.go @@ -7,13 +7,12 @@ import ( "testing" "github.com/hashicorp/go-plugin" + p2ptypes "github.com/smartcontractkit/libocr/ragep2p/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/grpc" - p2ptypes "github.com/smartcontractkit/libocr/ragep2p/types" - "github.com/smartcontractkit/chainlink-protos/cre/go/values" "github.com/smartcontractkit/chainlink-common/pkg/capabilities" diff --git a/pkg/loop/internal/core/services/capability/standard/standard_capabilities.go b/pkg/loop/internal/core/services/capability/standard/standard_capabilities.go index 2fb6af142..5d6897765 100644 --- a/pkg/loop/internal/core/services/capability/standard/standard_capabilities.go +++ b/pkg/loop/internal/core/services/capability/standard/standard_capabilities.go @@ -62,8 +62,12 @@ func NewStandardCapabilitiesClient(brokerCfg net.BrokerConfig) *StandardCapabili // Reinitialise calls Initialise with cached deps from the previous call, if one was already made. func (c *StandardCapabilitiesClient) Reinitialise(ctx context.Context) error { if c.initializeDeps == nil { + c.Logger.Debug("No dependencies to re-initialise") return nil } + c.CloseAll(c.resources...) + c.resources = nil + c.Logger.Info("Re-initialising dependencies") return c.Initialise(ctx, *c.initializeDeps) } diff --git a/pkg/loop/internal/core/services/oracle/server.go b/pkg/loop/internal/core/services/oracle/server.go index 94f77aec0..958302aa0 100644 --- a/pkg/loop/internal/core/services/oracle/server.go +++ b/pkg/loop/internal/core/services/oracle/server.go @@ -38,7 +38,7 @@ func NewServer(log logger.Logger, impl core.Oracle, broker *net.BrokerExt) (*ser } func (s *server) Close() error { - return nil + return s.impl.Close(context.Background()) } func (s *server) CloseOracle(ctx context.Context, e *emptypb.Empty) (*emptypb.Empty, error) { diff --git a/pkg/loop/internal/core/services/oraclefactory/server.go b/pkg/loop/internal/core/services/oraclefactory/server.go index f87023146..a83b76b8c 100644 --- a/pkg/loop/internal/core/services/oraclefactory/server.go +++ b/pkg/loop/internal/core/services/oraclefactory/server.go @@ -108,7 +108,6 @@ func (s *server) NewOracle(ctx context.Context, req *oraclefactorypb.NewOracleRe return nil, fmt.Errorf("failed to serve new oracle: %w", err) } resources = append(resources, oracleRes) - s.resources = append(s.resources, resources...) return &oraclefactorypb.NewOracleReply{OracleId: oracleID}, nil } diff --git a/pkg/loop/internal/net/client.go b/pkg/loop/internal/net/client.go index 713cf0bfb..380758408 100644 --- a/pkg/loop/internal/net/client.go +++ b/pkg/loop/internal/net/client.go @@ -9,18 +9,30 @@ import ( "github.com/jpillora/backoff" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/loop/internal/pb" ) -var _ grpc.ClientConnInterface = (*AtomicClient)(nil) +var _ ClientConnInterface = (*grpc.ClientConn)(nil) + +type ClientConnInterface interface { + grpc.ClientConnInterface + GetState() connectivity.State +} + +var _ ClientConnInterface = (*AtomicClient)(nil) // An AtomicClient implements [grpc.ClientConnInterface] and is backed by a swappable [*grpc.ClientConn]. type AtomicClient struct { cc atomic.Pointer[grpc.ClientConn] } +func (a *AtomicClient) GetState() connectivity.State { + return a.cc.Load().GetState() +} + func (a *AtomicClient) Store(cc *grpc.ClientConn) { a.cc.Store(cc) } func (a *AtomicClient) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error { @@ -31,7 +43,7 @@ func (a *AtomicClient) NewStream(ctx context.Context, desc *grpc.StreamDesc, met return a.cc.Load().NewStream(ctx, desc, method, opts...) } -var _ grpc.ClientConnInterface = (*clientConn)(nil) +var _ ClientConnInterface = (*clientConn)(nil) // newClientFn returns a new client connection id to dial, and a set of Resource dependencies to close. type newClientFn func(context.Context) (id uint32, deps Resources, err error) @@ -49,6 +61,17 @@ type clientConn struct { cc *grpc.ClientConn } +func (c *clientConn) GetState() connectivity.State { + c.mu.RLock() + cc := c.cc + c.mu.RUnlock() + if cc != nil { + return cc.GetState() + } + // fall back to Shutdown to reflect underlying state + return connectivity.Shutdown +} + func (c *clientConn) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error { c.mu.RLock() cc := c.cc diff --git a/pkg/loop/standard_capabilities.go b/pkg/loop/standard_capabilities.go index df5c86366..14e5abd72 100644 --- a/pkg/loop/standard_capabilities.go +++ b/pkg/loop/standard_capabilities.go @@ -37,13 +37,13 @@ func (p *StandardCapabilitiesLoop) GRPCServer(broker *plugin.GRPCBroker, server return standardcapability.RegisterStandardCapabilitiesServer(server, broker, p.BrokerConfig, p.PluginServer) } -func (p *StandardCapabilitiesLoop) GRPCClient(_ context.Context, broker *plugin.GRPCBroker, conn *grpc.ClientConn) (any, error) { +func (p *StandardCapabilitiesLoop) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, conn *grpc.ClientConn) (any, error) { if p.pluginClient == nil { p.pluginClient = standardcapability.NewStandardCapabilitiesClient(p.BrokerConfig) } p.pluginClient.Refresh(broker, conn) - return StandardCapabilities(p.pluginClient), nil + return StandardCapabilities(p.pluginClient), p.pluginClient.Reinitialise(ctx) } func (p *StandardCapabilitiesLoop) ClientConfig() *plugin.ClientConfig { @@ -73,7 +73,7 @@ func NewStandardCapabilitiesService(lggr logger.Logger, grpcOpts GRPCOpts, cmd f if !ok { return nil, nil, fmt.Errorf("expected StandardCapabilities but got %T", instance) } - return scs, scs, scs.Reinitialise(ctx) + return scs, scs, nil } stopCh := make(chan struct{}) lggr = logger.Named(lggr, "StandardCapabilities")