Skip to content

Commit 4c8529f

Browse files
authored
Fix ECDH key exchange race condition in distributed deployments (#111)
* Fix ECDH key exchange race condition in distributed deployments * Fix ECDH session parameter
1 parent c7e0875 commit 4c8529f

2 files changed

Lines changed: 61 additions & 15 deletions

File tree

pkg/mpc/key_exchange_session.go

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,19 @@ type ECDHSession interface {
3232
GetReadyPeersCount() int
3333
ErrChan() <-chan error
3434
Close() error
35+
OnKeyExchangeComplete(callback func())
3536
}
3637

3738
type ecdhSession struct {
38-
nodeID string
39-
peerIDs []string
40-
pubSub messaging.PubSub
41-
ecdhSub messaging.Subscription
42-
identityStore identity.Store
43-
privateKey *ecdh.PrivateKey
44-
publicKey *ecdh.PublicKey
45-
errCh chan error
39+
nodeID string
40+
peerIDs []string
41+
pubSub messaging.PubSub
42+
ecdhSub messaging.Subscription
43+
identityStore identity.Store
44+
privateKey *ecdh.PrivateKey
45+
publicKey *ecdh.PublicKey
46+
errCh chan error
47+
onKeyExchangeComplete func()
4648
}
4749

4850
func NewECDHSession(
@@ -51,6 +53,7 @@ func NewECDHSession(
5153
pubSub messaging.PubSub,
5254
identityStore identity.Store,
5355
) *ecdhSession {
56+
logger.Info("Creating ECDH session", "nodeID", nodeID, "peerIDs", peerIDs, "expectedKeys", len(peerIDs))
5457
return &ecdhSession{
5558
nodeID: nodeID,
5659
peerIDs: peerIDs,
@@ -72,6 +75,10 @@ func (e *ecdhSession) ErrChan() <-chan error {
7275
return e.errCh
7376
}
7477

78+
func (e *ecdhSession) OnKeyExchangeComplete(callback func()) {
79+
e.onKeyExchangeComplete = callback
80+
}
81+
7582
func (e *ecdhSession) ListenKeyExchange() error {
7683
// Generate an ephemeral ECDH key pair
7784
privateKey, err := ecdh.X25519().GenerateKey(rand.Reader)
@@ -86,15 +93,19 @@ func (e *ecdhSession) ListenKeyExchange() error {
8693
sub, err := e.pubSub.Subscribe(ECDHExchangeTopic, func(natMsg *nats.Msg) {
8794
var ecdhMsg types.ECDHMessage
8895
if err := json.Unmarshal(natMsg.Data, &ecdhMsg); err != nil {
96+
logger.Error("Failed to unmarshal ECDH message", err)
8997
return
9098
}
9199

92100
if ecdhMsg.From == e.nodeID {
93101
return
94102
}
95103

104+
logger.Debug("Received ECDH message", "from", ecdhMsg.From, "to", e.nodeID)
105+
96106
//TODO: consider how to avoid replay attack
97107
if err := e.identityStore.VerifySignature(&ecdhMsg); err != nil {
108+
logger.Error("ECDH signature verification failed", err, "from", ecdhMsg.From)
98109
e.errCh <- err
99110
return
100111
}
@@ -113,7 +124,15 @@ func (e *ecdhSession) ListenKeyExchange() error {
113124
// Derive symmetric key using HKDF
114125
symmetricKey := e.deriveSymmetricKey(sharedSecret, ecdhMsg.From)
115126
e.identityStore.SetSymmetricKey(ecdhMsg.From, symmetricKey)
116-
logger.Debug("ECDH progress", "peer", ecdhMsg.From, "current", e.identityStore.GetSymetricKeyCount())
127+
128+
currentKeyCount := e.identityStore.GetSymetricKeyCount()
129+
logger.Debug("ECDH progress", "peer", ecdhMsg.From, "current", currentKeyCount, "expected", len(e.peerIDs))
130+
131+
// Check if ECDH exchange is complete and notify callback
132+
if currentKeyCount == len(e.peerIDs) && e.onKeyExchangeComplete != nil {
133+
logger.Info("ECDH key exchange completed successfully", "totalKeys", currentKeyCount)
134+
e.onKeyExchangeComplete()
135+
}
117136
})
118137

119138
e.ecdhSub = sub

pkg/mpc/registry.go

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ func NewRegistry(
6666
pubSub messaging.PubSub,
6767
identityStore identity.Store,
6868
) *registry {
69-
ecdhSession := NewECDHSession(nodeID, peerNodeIDs, pubSub, identityStore)
69+
// ECDH session should only exchange keys with other peers, not self
70+
peerIDsExceptSelf := getPeerIDsExceptSelf(nodeID, peerNodeIDs)
71+
ecdhSession := NewECDHSession(nodeID, peerIDsExceptSelf, pubSub, identityStore)
7072
mpcThreshold := viper.GetInt("mpc_threshold")
7173
if mpcThreshold < 1 {
7274
logger.Fatal("mpc_threshold must be greater than 0", nil)
@@ -75,7 +77,7 @@ func NewRegistry(
7577
reg := &registry{
7678
consulKV: consulKV,
7779
nodeID: nodeID,
78-
peerNodeIDs: getPeerIDsExceptSelf(nodeID, peerNodeIDs),
80+
peerNodeIDs: peerIDsExceptSelf,
7981
readyMap: make(map[string]bool),
8082
readyCount: 1, // self
8183
healthCheck: directMessaging,
@@ -85,6 +87,11 @@ func NewRegistry(
8587
mpcThreshold: mpcThreshold,
8688
}
8789

90+
// Set up callback to check ready state when ECDH completes
91+
ecdhSession.OnKeyExchangeComplete(func() {
92+
reg.checkAndUpdateReadyState()
93+
})
94+
8895
go reg.consumeECDHErrors()
8996

9097
return reg
@@ -126,11 +133,28 @@ func (r *registry) registerReadyPairs(peerIDs []string) {
126133
r.readyMap[peerID] = true
127134
}
128135

129-
if len(peerIDs) == len(r.peerNodeIDs) && !r.ready {
136+
// Check if we should update ready state
137+
r.checkAndUpdateReadyState()
138+
}
139+
140+
// checkAndUpdateReadyState checks if all conditions are met to mark the registry as ready
141+
func (r *registry) checkAndUpdateReadyState() {
142+
// Count ready peers in readyMap
143+
readyPeersCount := 0
144+
for _, isReady := range r.readyMap {
145+
if isReady {
146+
readyPeersCount++
147+
}
148+
}
149+
150+
// Only mark as ready when both conditions are met:
151+
// 1. All peers are registered in Consul
152+
// 2. ECDH key exchange is complete
153+
if readyPeersCount == len(r.peerNodeIDs) && r.isECDHReady() && !r.ready {
130154
r.mu.Lock()
131155
r.ready = true
132156
r.mu.Unlock()
133-
logger.Info("All peers are ready including ECDH exchange completion")
157+
logger.Info("[READY] All peers are ready including ECDH exchange completion")
134158
}
135159
}
136160

@@ -163,12 +187,15 @@ func (r *registry) Ready() error {
163187
}
164188

165189
_, err = r.healthCheck.Listen(r.composeHealthCheckTopic(r.nodeID), func(data []byte) {
166-
peerID, isEcdhReady, _ := parseHealthDataSplit(string(data))
190+
peerID, isEcdhReady, parseErr := parseHealthDataSplit(string(data))
191+
if parseErr != nil {
192+
logger.Error("Failed to parse health check data", parseErr, "data", string(data))
193+
return
194+
}
167195
logger.Debug("Health check ok", "peerID", peerID, "isEcdhReady", isEcdhReady)
168196
if !isEcdhReady {
169197
logger.Info("[ECDH exchange retriggerd] not all peers are ready", "peerID", peerID)
170198
go r.triggerECDHExchange()
171-
172199
}
173200
})
174201
if err != nil {

0 commit comments

Comments
 (0)