Skip to content

Commit 2c5c731

Browse files
committed
Add example for usage with kms signer
1 parent cd141b1 commit 2c5c731

1 file changed

Lines changed: 167 additions & 0 deletions

File tree

examples/generate/kms/main.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"flag"
6+
"fmt"
7+
"os"
8+
"os/signal"
9+
"sync"
10+
"sync/atomic"
11+
"syscall"
12+
"time"
13+
14+
"github.com/fystack/mpcium/pkg/client"
15+
"github.com/fystack/mpcium/pkg/config"
16+
"github.com/fystack/mpcium/pkg/event"
17+
"github.com/fystack/mpcium/pkg/logger"
18+
"github.com/fystack/mpcium/pkg/types"
19+
"github.com/google/uuid"
20+
"github.com/nats-io/nats.go"
21+
"github.com/spf13/viper"
22+
)
23+
24+
func main() {
25+
const environment = "development"
26+
const awsRegion = "ap-southeast-1"
27+
const kmsKeyID = "48e76117-fd08-4dc0-bd10-b1c7d01de748"
28+
29+
numWallets := flag.Int("n", 1, "Number of wallets to generate")
30+
31+
flag.Parse()
32+
33+
config.InitViperConfig()
34+
logger.Init(environment, false)
35+
36+
// KMS signer only supports P256
37+
38+
natsURL := viper.GetString("nats.url")
39+
natsConn, err := nats.Connect(natsURL)
40+
if err != nil {
41+
logger.Fatal("Failed to connect to NATS", err)
42+
}
43+
defer natsConn.Drain()
44+
defer natsConn.Close()
45+
46+
// For AWS production, use:
47+
kmsSigner, err := client.NewKMSSigner(types.EventInitiatorKeyTypeP256, client.KMSSignerOptions{
48+
Region: awsRegion,
49+
KeyID: kmsKeyID,
50+
EndpointURL: "http://localhost:4566", // LocalStack endpoint
51+
AccessKeyID: "test", // LocalStack dummy credentials
52+
SecretAccessKey: "test", // LocalStack dummy credentials
53+
})
54+
if err != nil {
55+
logger.Fatal("Failed to create KMS signer", err)
56+
}
57+
58+
// Log the public key for verification
59+
pubKey, err := kmsSigner.PublicKey()
60+
if err != nil {
61+
logger.Fatal("Failed to get public key from KMS signer", err)
62+
}
63+
logger.Info("Public key", "key", pubKey)
64+
65+
mpcClient := client.NewMPCClient(client.Options{
66+
NatsConn: natsConn,
67+
Signer: kmsSigner,
68+
})
69+
70+
var walletStartTimes sync.Map
71+
var walletIDs []string
72+
var walletIDsMu sync.Mutex
73+
var wg sync.WaitGroup
74+
var completedCount int32
75+
76+
startAll := time.Now()
77+
78+
// STEP 1: Pre-generate wallet IDs and store start times
79+
for i := 0; i < *numWallets; i++ {
80+
walletID := uuid.New().String()
81+
walletStartTimes.Store(walletID, time.Now())
82+
83+
walletIDsMu.Lock()
84+
walletIDs = append(walletIDs, walletID)
85+
walletIDsMu.Unlock()
86+
}
87+
88+
// STEP 2: Register the result handler AFTER all walletIDs are stored
89+
err = mpcClient.OnWalletCreationResult(func(event event.KeygenResultEvent) {
90+
logger.Info("Received wallet creation result", "event", event)
91+
now := time.Now()
92+
startTimeAny, ok := walletStartTimes.Load(event.WalletID)
93+
if ok {
94+
startTime := startTimeAny.(time.Time)
95+
duration := now.Sub(startTime).Seconds()
96+
accumulated := now.Sub(startAll).Seconds()
97+
countSoFar := atomic.AddInt32(&completedCount, 1)
98+
99+
logger.Info("Wallet created",
100+
"walletID", event.WalletID,
101+
"duration_seconds", fmt.Sprintf("%.3f", duration),
102+
"accumulated_time_seconds", fmt.Sprintf("%.3f", accumulated),
103+
"count_so_far", countSoFar,
104+
)
105+
106+
walletStartTimes.Delete(event.WalletID)
107+
} else {
108+
logger.Warn("Received wallet result but no start time found", "walletID", event.WalletID)
109+
}
110+
wg.Done()
111+
})
112+
if err != nil {
113+
logger.Fatal("Failed to subscribe to wallet-creation results", err)
114+
}
115+
116+
// STEP 3: Create wallets
117+
for _, walletID := range walletIDs {
118+
wg.Add(1) // Add to WaitGroup BEFORE attempting to create wallet
119+
120+
if err := mpcClient.CreateWallet(walletID); err != nil {
121+
logger.Error("CreateWallet failed", err)
122+
walletStartTimes.Delete(walletID)
123+
wg.Done() // Now this is safe since we added 1 above
124+
continue
125+
}
126+
127+
logger.Info("CreateWallet sent, awaiting result...", "walletID", walletID)
128+
}
129+
130+
// Wait until all wallet creations complete
131+
go func() {
132+
wg.Wait()
133+
totalDuration := time.Since(startAll).Seconds()
134+
logger.Info(
135+
"All wallets generated using KMS signer",
136+
"count",
137+
completedCount,
138+
"total_duration_seconds",
139+
fmt.Sprintf("%.3f", totalDuration),
140+
"kms_key_id",
141+
kmsKeyID,
142+
)
143+
144+
// Save wallet IDs to wallets.json
145+
walletIDsMu.Lock()
146+
data, err := json.MarshalIndent(walletIDs, "", " ")
147+
walletIDsMu.Unlock()
148+
if err != nil {
149+
logger.Error("Failed to marshal wallet IDs", err)
150+
} else {
151+
err = os.WriteFile("wallets.json", data, 0600)
152+
if err != nil {
153+
logger.Error("Failed to write wallets.json", err)
154+
} else {
155+
logger.Info("wallets.json written", "count", len(walletIDs))
156+
}
157+
}
158+
os.Exit(0)
159+
}()
160+
161+
// Block on SIGINT/SIGTERM (Ctrl+C etc.)
162+
stop := make(chan os.Signal, 1)
163+
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
164+
<-stop
165+
166+
fmt.Println("Shutting down.")
167+
}

0 commit comments

Comments
 (0)