Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions pkg/mpc/key_exchange_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,19 @@ type ECDHSession interface {
GetReadyPeersCount() int
ErrChan() <-chan error
Close() error
OnKeyExchangeComplete(callback func())
}

type ecdhSession struct {
nodeID string
peerIDs []string
pubSub messaging.PubSub
ecdhSub messaging.Subscription
identityStore identity.Store
privateKey *ecdh.PrivateKey
publicKey *ecdh.PublicKey
errCh chan error
nodeID string
peerIDs []string
pubSub messaging.PubSub
ecdhSub messaging.Subscription
identityStore identity.Store
privateKey *ecdh.PrivateKey
publicKey *ecdh.PublicKey
errCh chan error
onKeyExchangeComplete func()
}

func NewECDHSession(
Expand All @@ -51,6 +53,7 @@ func NewECDHSession(
pubSub messaging.PubSub,
identityStore identity.Store,
) *ecdhSession {
logger.Info("Creating ECDH session", "nodeID", nodeID, "peerIDs", peerIDs, "expectedKeys", len(peerIDs))
return &ecdhSession{
nodeID: nodeID,
peerIDs: peerIDs,
Expand All @@ -72,6 +75,10 @@ func (e *ecdhSession) ErrChan() <-chan error {
return e.errCh
}

func (e *ecdhSession) OnKeyExchangeComplete(callback func()) {
e.onKeyExchangeComplete = callback
}

func (e *ecdhSession) ListenKeyExchange() error {
// Generate an ephemeral ECDH key pair
privateKey, err := ecdh.X25519().GenerateKey(rand.Reader)
Expand All @@ -86,15 +93,19 @@ func (e *ecdhSession) ListenKeyExchange() error {
sub, err := e.pubSub.Subscribe(ECDHExchangeTopic, func(natMsg *nats.Msg) {
var ecdhMsg types.ECDHMessage
if err := json.Unmarshal(natMsg.Data, &ecdhMsg); err != nil {
logger.Error("Failed to unmarshal ECDH message", err)
return
}

if ecdhMsg.From == e.nodeID {
return
}

logger.Debug("Received ECDH message", "from", ecdhMsg.From, "to", e.nodeID)

//TODO: consider how to avoid replay attack
if err := e.identityStore.VerifySignature(&ecdhMsg); err != nil {
logger.Error("ECDH signature verification failed", err, "from", ecdhMsg.From)
e.errCh <- err
return
}
Expand All @@ -113,7 +124,15 @@ func (e *ecdhSession) ListenKeyExchange() error {
// Derive symmetric key using HKDF
symmetricKey := e.deriveSymmetricKey(sharedSecret, ecdhMsg.From)
e.identityStore.SetSymmetricKey(ecdhMsg.From, symmetricKey)
logger.Debug("ECDH progress", "peer", ecdhMsg.From, "current", e.identityStore.GetSymetricKeyCount())

currentKeyCount := e.identityStore.GetSymetricKeyCount()
logger.Debug("ECDH progress", "peer", ecdhMsg.From, "current", currentKeyCount, "expected", len(e.peerIDs))

// Check if ECDH exchange is complete and notify callback
if currentKeyCount == len(e.peerIDs) && e.onKeyExchangeComplete != nil {
logger.Info("ECDH key exchange completed successfully", "totalKeys", currentKeyCount)
e.onKeyExchangeComplete()
}
})

e.ecdhSub = sub
Expand Down
39 changes: 33 additions & 6 deletions pkg/mpc/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ func NewRegistry(
pubSub messaging.PubSub,
identityStore identity.Store,
) *registry {
ecdhSession := NewECDHSession(nodeID, peerNodeIDs, pubSub, identityStore)
// ECDH session should only exchange keys with other peers, not self
peerIDsExceptSelf := getPeerIDsExceptSelf(nodeID, peerNodeIDs)
ecdhSession := NewECDHSession(nodeID, peerIDsExceptSelf, pubSub, identityStore)
mpcThreshold := viper.GetInt("mpc_threshold")
if mpcThreshold < 1 {
logger.Fatal("mpc_threshold must be greater than 0", nil)
Expand All @@ -75,7 +77,7 @@ func NewRegistry(
reg := &registry{
consulKV: consulKV,
nodeID: nodeID,
peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs),
peerNodeIDs: peerIDsExceptSelf,
readyMap: make(map[string]bool),
readyCount: 1, // self
healthCheck: directMessaging,
Expand All @@ -85,6 +87,11 @@ func NewRegistry(
mpcThreshold: mpcThreshold,
}

// Set up callback to check ready state when ECDH completes
ecdhSession.OnKeyExchangeComplete(func() {
reg.checkAndUpdateReadyState()
})

go reg.consumeECDHErrors()

return reg
Expand Down Expand Up @@ -126,11 +133,28 @@ func (r *registry) registerReadyPairs(peerIDs []string) {
r.readyMap[peerID] = true
}

if len(peerIDs) == len(r.peerNodeIDs) && !r.ready {
// Check if we should update ready state
r.checkAndUpdateReadyState()
}

// checkAndUpdateReadyState checks if all conditions are met to mark the registry as ready
func (r *registry) checkAndUpdateReadyState() {
// Count ready peers in readyMap
readyPeersCount := 0
for _, isReady := range r.readyMap {
if isReady {
readyPeersCount++
}
}

// Only mark as ready when both conditions are met:
// 1. All peers are registered in Consul
// 2. ECDH key exchange is complete
if readyPeersCount == len(r.peerNodeIDs) && r.isECDHReady() && !r.ready {
r.mu.Lock()
r.ready = true
r.mu.Unlock()
logger.Info("All peers are ready including ECDH exchange completion")
logger.Info("[READY] All peers are ready including ECDH exchange completion")
}
}

Expand Down Expand Up @@ -163,12 +187,15 @@ func (r *registry) Ready() error {
}

_, err = r.healthCheck.Listen(r.composeHealthCheckTopic(r.nodeID), func(data []byte) {
peerID, isEcdhReady, _ := parseHealthDataSplit(string(data))
peerID, isEcdhReady, parseErr := parseHealthDataSplit(string(data))
if parseErr != nil {
logger.Error("Failed to parse health check data", parseErr, "data", string(data))
return
}
logger.Debug("Health check ok", "peerID", peerID, "isEcdhReady", isEcdhReady)
if !isEcdhReady {
logger.Info("[ECDH exchange retriggerd] not all peers are ready", "peerID", peerID)
go r.triggerECDHExchange()

}
})
if err != nil {
Expand Down
Loading