Skip to content

Commit f87446c

Browse files
committed
Refactor ecdh service to become resilient to node disconnect and rejoin
1 parent 46d5c1c commit f87446c

7 files changed

Lines changed: 189 additions & 148 deletions

File tree

cmd/mpcium/main.go

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ func runNode(ctx context.Context, c *cli.Command) error {
129129
if err != nil {
130130
logger.Fatal("Failed to connect to NATS", err)
131131
}
132-
defer natsConn.Close()
133132

134133
pubsub := messaging.NewNATSPubSub(natsConn)
135134
keygenBroker, err := messaging.NewJetStreamBroker(ctx, natsConn, event.KeygenBrokerStream, []string{
@@ -162,7 +161,7 @@ func runNode(ctx context.Context, c *cli.Command) error {
162161
logger.Info("Node is running", "ID", nodeID, "name", nodeName)
163162

164163
peerNodeIDs := GetPeerIDs(peers)
165-
peerRegistry := mpc.NewRegistry(nodeID, peerNodeIDs, consulClient.KV(), directMessaging)
164+
peerRegistry := mpc.NewRegistry(nodeID, peerNodeIDs, consulClient.KV(), directMessaging, pubsub, identityStore)
166165

167166
mpcNode := mpc.NewNode(
168167
nodeID,
@@ -176,9 +175,6 @@ func runNode(ctx context.Context, c *cli.Command) error {
176175
)
177176
defer mpcNode.Close()
178177

179-
// ECDH session for DH key exchange
180-
ecdhSession := mpcNode.GetECDHSession()
181-
182178
eventConsumer := eventconsumer.NewEventConsumer(
183179
mpcNode,
184180
pubsub,
@@ -206,12 +202,7 @@ func runNode(ctx context.Context, c *cli.Command) error {
206202
}
207203
logger.Info("[READY] Node is ready", "nodeID", nodeID)
208204

209-
logger.Info("Waiting for ECDH key exchange to complete...", "nodeID", nodeID)
210-
if err := ecdhSession.WaitForExchangeComplete(); err != nil {
211-
logger.Fatal("ECDH exchange failed", err)
212-
}
213-
214-
logger.Info("ECDH key exchange completed successfully, starting consumers...", "nodeID", nodeID)
205+
logger.Info("Starting consumers", "nodeID", nodeID)
215206
appContext, cancel := context.WithCancel(context.Background())
216207
//Setup signal handling to cancel context on termination signals.
217208
go func() {
@@ -221,6 +212,11 @@ func runNode(ctx context.Context, c *cli.Command) error {
221212
logger.Warn("Shutdown signal received, canceling context...")
222213
cancel()
223214

215+
// Resign from peer registry first (before closing NATS)
216+
if err := peerRegistry.Resign(); err != nil {
217+
logger.Error("Failed to resign from peer registry", err)
218+
}
219+
224220
// Gracefully close consumers
225221
if err := keygenConsumer.Close(); err != nil {
226222
logger.Error("Failed to close keygen consumer", err)
@@ -229,10 +225,6 @@ func runNode(ctx context.Context, c *cli.Command) error {
229225
logger.Error("Failed to close signing consumer", err)
230226
}
231227

232-
if err := ecdhSession.Close(); err != nil {
233-
logger.Error("Failed to close ECDH session", err)
234-
}
235-
236228
err := natsConn.Drain()
237229
if err != nil {
238230
logger.Error("Failed to drain NATS connection", err)
@@ -264,21 +256,6 @@ func runNode(ctx context.Context, c *cli.Command) error {
264256
logger.Info("Signing consumer finished successfully")
265257
}()
266258

267-
go func() {
268-
for {
269-
select {
270-
case <-appContext.Done():
271-
return
272-
case err := <-ecdhSession.ErrChan():
273-
if err != nil {
274-
logger.Error("ECDH session error", err)
275-
errChan <- fmt.Errorf("ecdh session error: %w", err)
276-
return
277-
}
278-
}
279-
}
280-
}()
281-
282259
go func() {
283260
wg.Wait()
284261
logger.Info("All consumers have finished")

pkg/eventconsumer/keygen_consumer.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ func (sc *keygenConsumer) waitForAllPeersReadyToGenKey(ctx context.Context) erro
8383
func (sc *keygenConsumer) Run(ctx context.Context) error {
8484
// Wait for sufficient peers before starting to consume messages
8585
if err := sc.waitForAllPeersReadyToGenKey(ctx); err != nil {
86+
if err == context.Canceled {
87+
return nil
88+
}
8689
return fmt.Errorf("failed to wait for sufficient peers: %w", err)
8790
}
8891

pkg/eventconsumer/sign_consumer.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ func (sc *signingConsumer) waitForSufficientPeers(ctx context.Context) error {
9393
func (sc *signingConsumer) Run(ctx context.Context) error {
9494
// Wait for sufficient peers before starting to consume messages
9595
if err := sc.waitForSufficientPeers(ctx); err != nil {
96+
if err == context.Canceled {
97+
return nil
98+
}
9699
return fmt.Errorf("failed to wait for sufficient peers: %w", err)
97100
}
98101

pkg/identity/identity.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ type Store interface {
4444

4545
SetSymmetricKey(peerID string, key []byte)
4646
GetSymmetricKey(peerID string) ([]byte, error)
47+
RemoveSymmetricKey(peerID string)
48+
GetSymetricKeyCount() int
4749
CheckSymmetricKeyComplete(desired int) bool
4850

4951
EncryptMessage(plaintext []byte, peerID string) ([]byte, error)
@@ -238,18 +240,22 @@ func (s *fileStore) GetSymmetricKey(peerID string) ([]byte, error) {
238240
return nil, fmt.Errorf("SymmetricKey key not found for node ID: %s", peerID)
239241
}
240242

241-
func (s *fileStore) CheckSymmetricKeyComplete(desired int) bool {
243+
func (s *fileStore) RemoveSymmetricKey(peerID string) {
244+
s.mu.Lock()
245+
defer s.mu.Unlock()
246+
delete(s.symmetricKeys, peerID)
247+
}
248+
249+
func (s *fileStore) GetSymetricKeyCount() int {
242250
s.mu.RLock()
243251
defer s.mu.RUnlock()
252+
return len(s.symmetricKeys)
253+
}
244254

245-
completeCount := 0
246-
for _, value := range s.symmetricKeys {
247-
if len(value) > 0 {
248-
completeCount++
249-
}
250-
}
251-
252-
return completeCount == desired
255+
func (s *fileStore) CheckSymmetricKeyComplete(desired int) bool {
256+
s.mu.RLock()
257+
defer s.mu.RUnlock()
258+
return len(s.symmetricKeys) == desired
253259
}
254260

255261
// GetPublicKey retrieves a node's public key by its ID

pkg/mpc/key_exchange_session.go

Lines changed: 26 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ import (
1717

1818
"encoding/json"
1919

20-
"sync"
21-
2220
"github.com/nats-io/nats.go"
2321
)
2422

@@ -30,24 +28,21 @@ const (
3028
type ECDHSession interface {
3129
ListenKeyExchange() error
3230
BroadcastPublicKey() error
33-
WaitForExchangeComplete() error
34-
ResetLocalKeys()
31+
RemovePeer(peerID string)
32+
GetReadyPeersCount() int
3533
ErrChan() <-chan error
3634
Close() error
3735
}
3836

3937
type ecdhSession struct {
40-
nodeID string
41-
peerIDs []string
42-
pubSub messaging.PubSub
43-
ecdhSub messaging.Subscription
44-
identityStore identity.Store
45-
privateKey *ecdh.PrivateKey
46-
publicKey *ecdh.PublicKey
47-
exchangeComplete chan struct{}
48-
errCh chan error
49-
exchangeDone bool
50-
mu sync.RWMutex
38+
nodeID string
39+
peerIDs []string
40+
pubSub messaging.PubSub
41+
ecdhSub messaging.Subscription
42+
identityStore identity.Store
43+
privateKey *ecdh.PrivateKey
44+
publicKey *ecdh.PublicKey
45+
errCh chan error
5146
}
5247

5348
func NewECDHSession(
@@ -57,20 +52,24 @@ func NewECDHSession(
5752
identityStore identity.Store,
5853
) *ecdhSession {
5954
return &ecdhSession{
60-
nodeID: nodeID,
61-
peerIDs: peerIDs,
62-
pubSub: pubSub,
63-
identityStore: identityStore,
64-
exchangeComplete: make(chan struct{}, 1),
65-
errCh: make(chan error, 1),
55+
nodeID: nodeID,
56+
peerIDs: peerIDs,
57+
pubSub: pubSub,
58+
identityStore: identityStore,
59+
errCh: make(chan error, 1),
6660
}
6761
}
6862

69-
func (e *ecdhSession) ResetLocalKeys() {
70-
// Set a specific key to an empty []byte
71-
for _, peerID := range e.peerIDs {
72-
e.identityStore.SetSymmetricKey(peerID, []byte{})
73-
}
63+
func (e *ecdhSession) RemovePeer(peerID string) {
64+
e.identityStore.RemoveSymmetricKey(peerID)
65+
}
66+
67+
func (e *ecdhSession) GetReadyPeersCount() int {
68+
return e.identityStore.GetSymetricKeyCount()
69+
}
70+
71+
func (e *ecdhSession) ErrChan() <-chan error {
72+
return e.errCh
7473
}
7574

7675
func (e *ecdhSession) ListenKeyExchange() error {
@@ -114,21 +113,7 @@ func (e *ecdhSession) ListenKeyExchange() error {
114113
// Derive symmetric key using HKDF
115114
symmetricKey := e.deriveSymmetricKey(sharedSecret, ecdhMsg.From)
116115
e.identityStore.SetSymmetricKey(ecdhMsg.From, symmetricKey)
117-
118-
requiredKeyCount := len(e.peerIDs) - 1
119-
logger.Info("ECDH progress", "peer", ecdhMsg.From, "required", requiredKeyCount)
120-
121-
if e.identityStore.CheckSymmetricKeyComplete(requiredKeyCount) {
122-
logger.Info("Completed ECDH!", "symmetric key counts of peers", requiredKeyCount)
123-
logger.Info("ALL PEERS ARE READY! Starting to accept MPC requests")
124-
125-
e.mu.Lock()
126-
if !e.exchangeDone {
127-
e.exchangeDone = true
128-
e.exchangeComplete <- struct{}{}
129-
}
130-
e.mu.Unlock()
131-
}
116+
logger.Debug("ECDH progress", "peer", ecdhMsg.From, "current", e.identityStore.GetSymetricKeyCount())
132117
})
133118

134119
e.ecdhSub = sub
@@ -138,10 +123,6 @@ func (e *ecdhSession) ListenKeyExchange() error {
138123
return nil
139124
}
140125

141-
func (s *ecdhSession) ErrChan() <-chan error {
142-
return s.errCh
143-
}
144-
145126
func (s *ecdhSession) Close() error {
146127
err := s.ecdhSub.Unsubscribe()
147128
if err != nil {
@@ -173,25 +154,6 @@ func (e *ecdhSession) BroadcastPublicKey() error {
173154
return nil
174155
}
175156

176-
func (e *ecdhSession) WaitForExchangeComplete() error {
177-
e.mu.RLock()
178-
if e.exchangeDone {
179-
e.mu.RUnlock()
180-
return nil
181-
}
182-
e.mu.RUnlock()
183-
timeout := time.After(ECDHExchangeTimeout) // 2 minutes timeout
184-
185-
select {
186-
case <-e.exchangeComplete:
187-
return nil
188-
case err := <-e.errCh:
189-
return err
190-
case <-timeout:
191-
return fmt.Errorf("ECDH exchange timeout!")
192-
}
193-
}
194-
195157
func deriveConsistentInfo(a, b string) []byte {
196158
if a < b {
197159
return []byte(a + b)

pkg/mpc/node.go

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ type Node struct {
3939
identityStore identity.Store
4040

4141
peerRegistry PeerRegistry
42-
ecdhSession ECDHSession
4342
}
4443

4544
func NewNode(
@@ -55,11 +54,6 @@ func NewNode(
5554
start := time.Now()
5655
elapsed := time.Since(start)
5756
logger.Info("Starting new node, preparams is generated successfully!", "elapsed", elapsed.Milliseconds())
58-
// Each node initiates the DH key exchange listener at the beginning and invoke message sending when all peers are ready
59-
dhSession := NewECDHSession(nodeID, peerIDs, pubSub, identityStore)
60-
if err := dhSession.ListenKeyExchange(); err != nil {
61-
logger.Fatal("Failed to listen to DH key exchange", err)
62-
}
6357

6458
node := &Node{
6559
nodeID: nodeID,
@@ -70,22 +64,11 @@ func NewNode(
7064
keyinfoStore: keyinfoStore,
7165
peerRegistry: peerRegistry,
7266
identityStore: identityStore,
73-
ecdhSession: dhSession,
7467
}
7568
node.ecdsaPreParams = node.generatePreParams()
7669

77-
// we define two types of tasks, initTask and resetTask
78-
ecdhTasks := func(isInit bool) {
79-
if isInit {
80-
if err := dhSession.BroadcastPublicKey(); err != nil {
81-
logger.Fatal("DH key broadcast failed", err)
82-
}
83-
} else {
84-
dhSession.ResetLocalKeys()
85-
}
86-
}
87-
88-
go peerRegistry.WatchPeersReady(ecdhTasks)
70+
// Start watching peers - ECDH is now handled by the registry
71+
go peerRegistry.WatchPeersReady()
8972
return node
9073
}
9174

@@ -430,9 +413,6 @@ func (p *Node) Close() {
430413
}
431414
}
432415

433-
func (p *Node) GetECDHSession() ECDHSession {
434-
return p.ecdhSession
435-
}
436416

437417
func (p *Node) generatePreParams() []*keygen.LocalPreParams {
438418
start := time.Now()

0 commit comments

Comments
 (0)