Skip to content

Commit df2bd34

Browse files
committed
update: add cmp test
1 parent 58feb9f commit df2bd34

1 file changed

Lines changed: 114 additions & 0 deletions

File tree

pkg/mpc/taurus/cmp_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package taurus
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"math/big"
7+
"sync"
8+
"testing"
9+
10+
"github.com/fystack/mpcium/pkg/logger"
11+
"github.com/taurusgroup/multi-party-sig/pkg/party"
12+
)
13+
14+
type cmpTest struct {
15+
parties []*CmpParty
16+
results map[string]chan any
17+
}
18+
19+
func newCmpTest(sid string, ids []party.ID) *cmpTest {
20+
t := &cmpTest{
21+
results: map[string]chan any{
22+
"keygen": make(chan any, len(ids)),
23+
"sign": make(chan any, len(ids)),
24+
"reshare": make(chan any, len(ids)),
25+
},
26+
}
27+
28+
// Create all memory transports first
29+
transports := make([]*Memory, len(ids))
30+
for i, id := range ids {
31+
transports[i] = NewMemoryParty(string(id))
32+
}
33+
34+
// Link all peers together
35+
LinkPeers(transports...)
36+
37+
// Create parties with linked transports
38+
for i, id := range ids {
39+
adapter := NewTaurusNetworkAdapter(sid, id, transports[i], ids)
40+
t.parties = append(t.parties, NewCmpParty(sid, id, ids, 2,
41+
nil, adapter, nil, nil))
42+
}
43+
44+
return t
45+
}
46+
47+
func (t *cmpTest) runAll(fn func(*CmpParty) (any, error), key string) {
48+
var wg sync.WaitGroup
49+
for _, p := range t.parties {
50+
wg.Add(1)
51+
go func(p *CmpParty) {
52+
defer wg.Done()
53+
res, err := fn(p)
54+
if err != nil {
55+
logger.Error("operation failed", err)
56+
return
57+
}
58+
t.results[key] <- res
59+
}(p)
60+
}
61+
wg.Wait()
62+
}
63+
64+
func TestCmpParty(t *testing.T) {
65+
sid := "test-session-123"
66+
ids := []party.ID{"node0", "node1", "node2"}
67+
test := newCmpTest(sid, ids)
68+
69+
// --- Keygen ---
70+
test.runAll(func(p *CmpParty) (any, error) {
71+
return p.Keygen(context.Background())
72+
}, "keygen")
73+
74+
// --- Sign 1 ---
75+
msg := big.NewInt(1)
76+
test.runAll(func(p *CmpParty) (any, error) {
77+
return p.Sign(context.Background(), msg)
78+
}, "sign")
79+
80+
sigs := drain[[]byte](test.results["sign"])
81+
assertAllBytesEqual(t, sigs)
82+
83+
// // --- Reshare ---
84+
// test.runAll(func(p *CmpParty) (any, error) {
85+
// return p.Reshare(context.Background())
86+
// }, "reshare")
87+
88+
// // // --- Sign 2 ---
89+
// msg = big.NewInt(2)
90+
// test.runAll(func(p *CmpParty) (any, error) {
91+
// return p.Sign(context.Background(), msg)
92+
// }, "sign")
93+
}
94+
95+
func drain[T any](ch chan any) []T {
96+
n := len(ch)
97+
out := make([]T, n)
98+
for i := 0; i < n; i++ {
99+
out[i] = (<-ch).(T)
100+
}
101+
return out
102+
}
103+
104+
func assertAllBytesEqual(t *testing.T, vals [][]byte) {
105+
if len(vals) == 0 {
106+
t.Fatal("no values to compare")
107+
}
108+
first := vals[0]
109+
for i, v := range vals[1:] {
110+
if !bytes.Equal(first, v) {
111+
t.Fatalf("byte slices not equal at index %d", i+1)
112+
}
113+
}
114+
}

0 commit comments

Comments
 (0)