diff --git a/pkg/loop/internal/net/broker.go b/pkg/loop/internal/net/broker.go index fe671f0ff6..5874272cc4 100644 --- a/pkg/loop/internal/net/broker.go +++ b/pkg/loop/internal/net/broker.go @@ -65,12 +65,19 @@ type BrokerConfig struct { type BrokerExt struct { Broker Broker BrokerConfig + + onRefreshComplete func(ctx context.Context) error + hooksMu sync.RWMutex } // WithName returns a new [*BrokerExt] with Name added to the logger. func (b *BrokerExt) WithName(name string) *BrokerExt { bn := *b bn.Logger = logger.Named(b.Logger, name) + + // Don't share hooks mutex or onRefreshComplete between copies + bn.onRefreshComplete = nil + bn.hooksMu = sync.RWMutex{} return &bn } @@ -147,6 +154,25 @@ func (b *BrokerExt) CloseAll(deps ...Resource) { } } +// SetOnRefreshComplete sets a hook to be called after successful connection refresh. +func (b *BrokerExt) SetOnRefreshComplete(hook func(ctx context.Context) error) { + b.hooksMu.Lock() + defer b.hooksMu.Unlock() + b.onRefreshComplete = hook +} + +// executeOnRefreshComplete executes the refresh completion hook if it exists. +func (b *BrokerExt) executeOnRefreshComplete(ctx context.Context) error { + b.hooksMu.RLock() + hook := b.onRefreshComplete + b.hooksMu.RUnlock() + + if hook != nil { + return hook(ctx) + } + return nil +} + type Resource struct { io.Closer Name string diff --git a/pkg/loop/internal/net/client.go b/pkg/loop/internal/net/client.go index fbe5978f60..786ba38c5a 100644 --- a/pkg/loop/internal/net/client.go +++ b/pkg/loop/internal/net/client.go @@ -128,6 +128,13 @@ func (c *clientConn) refresh(ctx context.Context, orig *grpc.ClientConn) *grpc.C c.CloseAll(c.deps...) return false } + + // Execute refresh completion hook after successful connection but before returning to caller. + if err := c.BrokerExt.executeOnRefreshComplete(ctx); err != nil { + // Don't fail the refresh, but log the error + c.Logger.Errorw("Refresh completion hook failed", "err", err, "conn", c.name) + } + return true } diff --git a/pkg/loop/internal/relayer/pluginprovider/ext/ccipocr3/chainaccessor.go b/pkg/loop/internal/relayer/pluginprovider/ext/ccipocr3/chainaccessor.go index ea402484ca..ee14487fc7 100644 --- a/pkg/loop/internal/relayer/pluginprovider/ext/ccipocr3/chainaccessor.go +++ b/pkg/loop/internal/relayer/pluginprovider/ext/ccipocr3/chainaccessor.go @@ -2,6 +2,8 @@ package ccipocr3 import ( "context" + "fmt" + "sync" "time" "google.golang.org/grpc" @@ -19,13 +21,22 @@ var _ ccipocr3.ChainAccessor = (*chainAccessorClient)(nil) type chainAccessorClient struct { *net.BrokerExt grpc ccipocr3pb.ChainAccessorClient + + // Local persistence for refresh functionality + mu sync.RWMutex + syncedContracts map[string]ccipocr3.UnknownAddress // contractName -> contractAddress } func NewChainAccessorClient(broker *net.BrokerExt, cc grpc.ClientConnInterface) ccipocr3.ChainAccessor { - return &chainAccessorClient{ - BrokerExt: broker, - grpc: ccipocr3pb.NewChainAccessorClient(cc), + client := &chainAccessorClient{ + BrokerExt: broker, + grpc: ccipocr3pb.NewChainAccessorClient(cc), + syncedContracts: make(map[string]ccipocr3.UnknownAddress), } + + broker.SetOnRefreshComplete(client.restoreStateOnRefresh) + + return client } // AllAccessors methods @@ -82,9 +93,61 @@ func (c *chainAccessorClient) Sync(ctx context.Context, contractName string, con ContractName: contractName, ContractAddress: contractAddress, }) + + if err == nil { + // Persist the synced contract locally for client refresh + c.mu.Lock() + c.syncedContracts[contractName] = contractAddress + c.mu.Unlock() + c.Logger.Debugw("Persisted synced contract", "contractName", contractName, "contractAddress", contractAddress) + } + return err } +// restoreStateOnRefresh is called after successful relayer refresh to restore synced contracts. +// +// TODO: right now this only supports re-syncing previously synced contracts. In the future this should support +// re-establishing any arbitrary serializable state. +func (c *chainAccessorClient) restoreStateOnRefresh(ctx context.Context) error { + c.mu.RLock() + contractsToRestore := make(map[string]ccipocr3.UnknownAddress) + for name, addr := range c.syncedContracts { + contractsToRestore[name] = addr + } + c.mu.RUnlock() + + if len(contractsToRestore) == 0 { + c.Logger.Debug("No synced contracts to restore") + return nil + } + + c.Logger.Infow("Restoring synced contracts after refresh", "count", len(contractsToRestore)) + + // Re-sync all previously synced contracts + var restoreErrors []error + for contractName, contractAddress := range contractsToRestore { + if err := c.Sync(ctx, contractName, contractAddress); err != nil { + c.Logger.Errorw("Failed to restore synced contract", + "contractName", contractName, + "contractAddress", contractAddress, + "err", err) + restoreErrors = append(restoreErrors, fmt.Errorf("failed to restore contract %s: %w", contractName, err)) + } else { + c.Logger.Debugw("Successfully restored synced contract", + "contractName", contractName, + "contractAddress", contractAddress) + } + } + + if len(restoreErrors) > 0 { + return fmt.Errorf("failed to restore %d/%d contracts: %v", len(restoreErrors), len(contractsToRestore), restoreErrors) + } + + c.Logger.Infow("Successfully restored all synced contracts", "count", len(contractsToRestore)) + return nil +} + // DestinationAccessor methods func (c *chainAccessorClient) CommitReportsGTETimestamp( ctx context.Context,