diff --git a/pkg/mpc/key_exchange_session.go b/pkg/mpc/key_exchange_session.go index 2065f032..8da1774b 100644 --- a/pkg/mpc/key_exchange_session.go +++ b/pkg/mpc/key_exchange_session.go @@ -32,17 +32,19 @@ type ECDHSession interface { GetReadyPeersCount() int ErrChan() <-chan error Close() error + OnKeyExchangeComplete(callback func()) } type ecdhSession struct { - nodeID string - peerIDs []string - pubSub messaging.PubSub - ecdhSub messaging.Subscription - identityStore identity.Store - privateKey *ecdh.PrivateKey - publicKey *ecdh.PublicKey - errCh chan error + nodeID string + peerIDs []string + pubSub messaging.PubSub + ecdhSub messaging.Subscription + identityStore identity.Store + privateKey *ecdh.PrivateKey + publicKey *ecdh.PublicKey + errCh chan error + onKeyExchangeComplete func() } func NewECDHSession( @@ -51,6 +53,7 @@ func NewECDHSession( pubSub messaging.PubSub, identityStore identity.Store, ) *ecdhSession { + logger.Info("Creating ECDH session", "nodeID", nodeID, "peerIDs", peerIDs, "expectedKeys", len(peerIDs)) return &ecdhSession{ nodeID: nodeID, peerIDs: peerIDs, @@ -72,6 +75,10 @@ func (e *ecdhSession) ErrChan() <-chan error { return e.errCh } +func (e *ecdhSession) OnKeyExchangeComplete(callback func()) { + e.onKeyExchangeComplete = callback +} + func (e *ecdhSession) ListenKeyExchange() error { // Generate an ephemeral ECDH key pair privateKey, err := ecdh.X25519().GenerateKey(rand.Reader) @@ -86,6 +93,7 @@ func (e *ecdhSession) ListenKeyExchange() error { sub, err := e.pubSub.Subscribe(ECDHExchangeTopic, func(natMsg *nats.Msg) { var ecdhMsg types.ECDHMessage if err := json.Unmarshal(natMsg.Data, &ecdhMsg); err != nil { + logger.Error("Failed to unmarshal ECDH message", err) return } @@ -93,8 +101,11 @@ func (e *ecdhSession) ListenKeyExchange() error { return } + logger.Debug("Received ECDH message", "from", ecdhMsg.From, "to", e.nodeID) + //TODO: consider how to avoid replay attack if err := e.identityStore.VerifySignature(&ecdhMsg); err != nil { + logger.Error("ECDH signature verification failed", err, "from", ecdhMsg.From) e.errCh <- err return } @@ -113,7 +124,15 @@ func (e *ecdhSession) ListenKeyExchange() error { // Derive symmetric key using HKDF symmetricKey := e.deriveSymmetricKey(sharedSecret, ecdhMsg.From) e.identityStore.SetSymmetricKey(ecdhMsg.From, symmetricKey) - logger.Debug("ECDH progress", "peer", ecdhMsg.From, "current", e.identityStore.GetSymetricKeyCount()) + + currentKeyCount := e.identityStore.GetSymetricKeyCount() + logger.Debug("ECDH progress", "peer", ecdhMsg.From, "current", currentKeyCount, "expected", len(e.peerIDs)) + + // Check if ECDH exchange is complete and notify callback + if currentKeyCount == len(e.peerIDs) && e.onKeyExchangeComplete != nil { + logger.Info("ECDH key exchange completed successfully", "totalKeys", currentKeyCount) + e.onKeyExchangeComplete() + } }) e.ecdhSub = sub diff --git a/pkg/mpc/registry.go b/pkg/mpc/registry.go index c27b12dd..9c4868aa 100644 --- a/pkg/mpc/registry.go +++ b/pkg/mpc/registry.go @@ -66,7 +66,9 @@ func NewRegistry( pubSub messaging.PubSub, identityStore identity.Store, ) *registry { - ecdhSession := NewECDHSession(nodeID, peerNodeIDs, pubSub, identityStore) + // ECDH session should only exchange keys with other peers, not self + peerIDsExceptSelf := getPeerIDsExceptSelf(nodeID, peerNodeIDs) + ecdhSession := NewECDHSession(nodeID, peerIDsExceptSelf, pubSub, identityStore) mpcThreshold := viper.GetInt("mpc_threshold") if mpcThreshold < 1 { logger.Fatal("mpc_threshold must be greater than 0", nil) @@ -75,7 +77,7 @@ func NewRegistry( reg := ®istry{ consulKV: consulKV, nodeID: nodeID, - peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs), + peerNodeIDs: peerIDsExceptSelf, readyMap: make(map[string]bool), readyCount: 1, // self healthCheck: directMessaging, @@ -85,6 +87,11 @@ func NewRegistry( mpcThreshold: mpcThreshold, } + // Set up callback to check ready state when ECDH completes + ecdhSession.OnKeyExchangeComplete(func() { + reg.checkAndUpdateReadyState() + }) + go reg.consumeECDHErrors() return reg @@ -126,11 +133,28 @@ func (r *registry) registerReadyPairs(peerIDs []string) { r.readyMap[peerID] = true } - if len(peerIDs) == len(r.peerNodeIDs) && !r.ready { + // Check if we should update ready state + r.checkAndUpdateReadyState() +} + +// checkAndUpdateReadyState checks if all conditions are met to mark the registry as ready +func (r *registry) checkAndUpdateReadyState() { + // Count ready peers in readyMap + readyPeersCount := 0 + for _, isReady := range r.readyMap { + if isReady { + readyPeersCount++ + } + } + + // Only mark as ready when both conditions are met: + // 1. All peers are registered in Consul + // 2. ECDH key exchange is complete + if readyPeersCount == len(r.peerNodeIDs) && r.isECDHReady() && !r.ready { r.mu.Lock() r.ready = true r.mu.Unlock() - logger.Info("All peers are ready including ECDH exchange completion") + logger.Info("[READY] All peers are ready including ECDH exchange completion") } } @@ -163,12 +187,15 @@ func (r *registry) Ready() error { } _, err = r.healthCheck.Listen(r.composeHealthCheckTopic(r.nodeID), func(data []byte) { - peerID, isEcdhReady, _ := parseHealthDataSplit(string(data)) + peerID, isEcdhReady, parseErr := parseHealthDataSplit(string(data)) + if parseErr != nil { + logger.Error("Failed to parse health check data", parseErr, "data", string(data)) + return + } logger.Debug("Health check ok", "peerID", peerID, "isEcdhReady", isEcdhReady) if !isEcdhReady { logger.Info("[ECDH exchange retriggerd] not all peers are ready", "peerID", peerID) go r.triggerECDHExchange() - } }) if err != nil {