Skip to content

Commit fa9a4e1

Browse files
committed
Fixed comments and add some unit tests for ed25519 utils
1 parent 67d5517 commit fa9a4e1

10 files changed

Lines changed: 247 additions & 67 deletions

File tree

cmd/mpcium-cli/generate-initiator.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,17 @@ func generateInitiatorIdentity(ctx context.Context, c *cli.Command) error {
4040
algorithm := c.String("algorithm")
4141

4242
if algorithm == "" {
43-
algorithm = string(types.KeyTypeEd25519)
43+
algorithm = string(types.EventInitiatorKeyTypeEd25519)
4444
}
4545

4646
if !slices.Contains(
47-
[]string{string(types.KeyTypeEd25519), string(types.KeyTypeP256)},
47+
[]string{string(types.EventInitiatorKeyTypeEd25519), string(types.EventInitiatorKeyTypeP256)},
4848
algorithm,
4949
) {
5050
return fmt.Errorf("invalid algorithm: %s. Must be %s or %s",
5151
algorithm,
52-
types.KeyTypeEd25519,
53-
types.KeyTypeP256,
52+
types.EventInitiatorKeyTypeEd25519,
53+
types.EventInitiatorKeyTypeP256,
5454
)
5555
}
5656

@@ -90,9 +90,9 @@ func generateInitiatorIdentity(ctx context.Context, c *cli.Command) error {
9090
var keyData encryption.KeyData
9191
var err error
9292

93-
if algorithm == string(types.KeyTypeEd25519) {
93+
if algorithm == string(types.EventInitiatorKeyTypeEd25519) {
9494
keyData, err = generateEd25519Keys()
95-
} else if algorithm == string(types.KeyTypeP256) {
95+
} else if algorithm == string(types.EventInitiatorKeyTypeP256) {
9696
keyData, err = encryption.GenerateP256Keys()
9797
}
9898

examples/generate/main.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,19 @@ func main() {
3333

3434
algorithm := viper.GetString("event_initiator_algorithm")
3535
if algorithm == "" {
36-
algorithm = string(types.KeyTypeEd25519)
36+
algorithm = string(types.EventInitiatorKeyTypeEd25519)
3737
}
3838

3939
if !slices.Contains(
40-
[]string{string(types.KeyTypeEd25519), string(types.KeyTypeP256)},
40+
[]string{string(types.EventInitiatorKeyTypeEd25519), string(types.EventInitiatorKeyTypeP256)},
4141
algorithm,
4242
) {
4343
logger.Fatal(
4444
fmt.Sprintf(
4545
"invalid algorithm: %s. Must be %s or %s",
4646
algorithm,
47-
types.KeyTypeEd25519,
48-
types.KeyTypeP256,
47+
types.EventInitiatorKeyTypeEd25519,
48+
types.EventInitiatorKeyTypeP256,
4949
),
5050
nil,
5151
)

examples/reshare/main.go

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,20 @@ func main() {
2424

2525
algorithm := viper.GetString("event_initiator_algorithm")
2626
if algorithm == "" {
27-
algorithm = string(types.KeyTypeEd25519)
27+
algorithm = string(types.EventInitiatorKeyTypeEd25519)
2828
}
2929

3030
// Validate algorithm
3131
if !slices.Contains(
32-
[]string{string(types.KeyTypeEd25519), string(types.KeyTypeP256)},
32+
[]string{string(types.EventInitiatorKeyTypeEd25519), string(types.EventInitiatorKeyTypeP256)},
3333
algorithm,
3434
) {
3535
logger.Fatal(
3636
fmt.Sprintf(
3737
"invalid algorithm: %s. Must be %s or %s",
3838
algorithm,
39-
types.KeyTypeEd25519,
40-
types.KeyTypeP256,
39+
types.EventInitiatorKeyTypeEd25519,
40+
types.EventInitiatorKeyTypeP256,
4141
),
4242
nil,
4343
)
@@ -68,16 +68,6 @@ func main() {
6868
if err != nil {
6969
logger.Fatal("Failed to subscribe to OnResharingResult", err)
7070
}
71-
// Determine key type based on algorithm
72-
var keyType types.KeyType
73-
switch algorithm {
74-
case string(types.KeyTypeEd25519):
75-
keyType = types.KeyTypeEd25519
76-
case string(types.KeyTypeP256):
77-
keyType = types.KeyTypeP256
78-
default:
79-
logger.Fatal("Unsupported algorithm", nil)
80-
}
8171

8272
resharingMsg := &types.ResharingMessage{
8373
SessionID: uuid.NewString(),
@@ -88,7 +78,7 @@ func main() {
8878
}, // new peer IDs
8979

9080
NewThreshold: 1, // t+1 <= len(NodeIDs)
91-
KeyType: keyType,
81+
KeyType: types.KeyTypeEd25519,
9282
}
9383
err = mpcClient.Resharing(resharingMsg)
9484
if err != nil {

examples/sign/main.go

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,20 @@ func main() {
2424

2525
algorithm := viper.GetString("event_initiator_algorithm")
2626
if algorithm == "" {
27-
algorithm = string(types.KeyTypeEd25519)
27+
algorithm = string(types.EventInitiatorKeyTypeEd25519)
2828
}
2929

3030
// Validate algorithm
3131
if !slices.Contains(
32-
[]string{string(types.KeyTypeEd25519), string(types.KeyTypeP256)},
32+
[]string{string(types.EventInitiatorKeyTypeEd25519), string(types.EventInitiatorKeyTypeP256)},
3333
algorithm,
3434
) {
3535
logger.Fatal(
3636
fmt.Sprintf(
3737
"invalid algorithm: %s. Must be %s or %s",
3838
algorithm,
39-
types.KeyTypeEd25519,
40-
types.KeyTypeP256,
39+
types.EventInitiatorKeyTypeEd25519,
40+
types.EventInitiatorKeyTypeP256,
4141
),
4242
nil,
4343
)
@@ -60,19 +60,8 @@ func main() {
6060
txID := uuid.New().String()
6161
dummyTx := []byte("deadbeef") // replace with real transaction bytes
6262

63-
// Determine key type based on algorithm
64-
var keyType types.KeyType
65-
switch algorithm {
66-
case string(types.KeyTypeEd25519):
67-
keyType = types.KeyTypeEd25519
68-
case string(types.KeyTypeP256):
69-
keyType = types.KeyTypeP256
70-
default:
71-
logger.Fatal("Unsupported algorithm", nil)
72-
}
73-
7463
txMsg := &types.SignTxMessage{
75-
KeyType: keyType,
64+
KeyType: types.KeyTypeEd25519,
7665
WalletID: "ad24f678-b04b-4149-bcf6-bf9c90df8e63", // Use the generated wallet ID
7766
NetworkInternalCode: "solana-devnet",
7867
TxID: txID,

pkg/client/client.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,8 @@ func NewMPCClient(opts Options) MPCClient {
201201
genKeySuccessQueue: genKeySuccessQueue,
202202
signResultQueue: signResultQueue,
203203
reshareSuccessQueue: reshareSuccessQueue,
204-
// privKey: priv,
205-
// privKeyECDSA: privECDSA,
206-
initiatorPrivKey: initiatorPrivKey,
207-
algorithm: opts.Algorithm,
204+
initiatorPrivKey: initiatorPrivKey,
205+
algorithm: opts.Algorithm,
208206
}
209207
}
210208

pkg/encryption/ed25519.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package encryption
2+
3+
import (
4+
"crypto/ed25519"
5+
"encoding/hex"
6+
"fmt"
7+
)
8+
9+
// ParseEd25519PublicKeyFromHex parses a hex-encoded Ed25519 public key and validates it.
10+
// Returns the public key as []byte and an error if invalid.
11+
func ParseEd25519PublicKeyFromHex(hexKey string) ([]byte, error) {
12+
if hexKey == "" {
13+
return nil, fmt.Errorf("public key hex string is empty")
14+
}
15+
16+
// Decode hex string to bytes
17+
keyBytes, err := hex.DecodeString(hexKey)
18+
if err != nil {
19+
return nil, fmt.Errorf("invalid hex format: %w", err)
20+
}
21+
22+
// Validate the key
23+
if err := ValidateEd25519PublicKey(keyBytes); err != nil {
24+
return nil, err
25+
}
26+
27+
return keyBytes, nil
28+
}
29+
30+
// ValidateEd25519PublicKey validates an existing byte slice as a valid Ed25519 public key
31+
func ValidateEd25519PublicKey(keyBytes []byte) error {
32+
if len(keyBytes) != ed25519.PublicKeySize {
33+
return fmt.Errorf("invalid Ed25519 public key length: expected %d bytes, got %d",
34+
ed25519.PublicKeySize, len(keyBytes))
35+
}
36+
37+
// Create and validate Ed25519 public key
38+
pubKey := ed25519.PublicKey(keyBytes)
39+
40+
// Basic validation - attempt to use the key
41+
// Invalid curve points will cause verification to behave predictably
42+
dummyMsg := []byte("validation_test")
43+
dummySig := make([]byte, ed25519.SignatureSize)
44+
ed25519.Verify(pubKey, dummyMsg, dummySig) // This won't panic on invalid keys
45+
46+
return nil
47+
}

pkg/encryption/ed25519_test.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package encryption
2+
3+
import (
4+
"crypto/ed25519"
5+
"crypto/rand"
6+
"encoding/hex"
7+
"strings"
8+
"testing"
9+
)
10+
11+
var (
12+
// Test data shared across tests
13+
testValidKey ed25519.PublicKey
14+
testValidHex string
15+
testAllZeros = make([]byte, 32)
16+
testAllMax = func() []byte {
17+
b := make([]byte, 32)
18+
for i := range b {
19+
b[i] = 0xFF
20+
}
21+
return b
22+
}()
23+
)
24+
25+
func init() {
26+
// Generate a single valid key for all tests
27+
testValidKey, _, _ = ed25519.GenerateKey(rand.Reader)
28+
testValidHex = hex.EncodeToString(testValidKey)
29+
}
30+
31+
// Helper function to check error expectations
32+
func checkError(t *testing.T, err error, wantError bool, errorMsg string) {
33+
t.Helper()
34+
if wantError {
35+
if err == nil {
36+
t.Errorf("expected error but got none")
37+
} else if errorMsg != "" && !strings.Contains(err.Error(), errorMsg) {
38+
t.Errorf("error = %v, want error containing %v", err, errorMsg)
39+
}
40+
} else if err != nil {
41+
t.Errorf("unexpected error = %v", err)
42+
}
43+
}
44+
45+
func TestParseEd25519PublicKeyFromHex(t *testing.T) {
46+
tests := []struct {
47+
name string
48+
hexKey string
49+
wantError bool
50+
errorMsg string
51+
}{
52+
{"valid hex key", testValidHex, false, ""},
53+
{"empty hex string", "", true, "public key hex string is empty"},
54+
{"invalid hex characters", strings.Repeat("g", 64), true, "invalid hex format"},
55+
{"too short hex string", "abcdef1234567890", true, "invalid Ed25519 public key length: expected 32 bytes, got 8"},
56+
{"too long hex string", strings.Repeat("ab", 40), true, "invalid Ed25519 public key length: expected 32 bytes, got 40"},
57+
{"odd length hex string", "abc", true, "invalid hex format"},
58+
{"all zeros", hex.EncodeToString(testAllZeros), false, ""},
59+
{"all max bytes", hex.EncodeToString(testAllMax), false, ""},
60+
}
61+
62+
for _, tt := range tests {
63+
t.Run(tt.name, func(t *testing.T) {
64+
result, err := ParseEd25519PublicKeyFromHex(tt.hexKey)
65+
66+
checkError(t, err, tt.wantError, tt.errorMsg)
67+
68+
if !tt.wantError {
69+
if result == nil {
70+
t.Errorf("expected non-nil result")
71+
} else if len(result) != ed25519.PublicKeySize {
72+
t.Errorf("result length = %d, want %d", len(result), ed25519.PublicKeySize)
73+
}
74+
} else if result != nil {
75+
t.Errorf("expected nil result on error, got %v", result)
76+
}
77+
})
78+
}
79+
}
80+
81+
func TestValidateEd25519PublicKey(t *testing.T) {
82+
tests := []struct {
83+
name string
84+
keyBytes []byte
85+
wantError bool
86+
errorMsg string
87+
}{
88+
{"valid public key", testValidKey, false, ""},
89+
{"nil key bytes", nil, true, "invalid Ed25519 public key length: expected 32 bytes, got 0"},
90+
{"empty key bytes", []byte{}, true, "invalid Ed25519 public key length: expected 32 bytes, got 0"},
91+
{"too short key", make([]byte, 16), true, "invalid Ed25519 public key length: expected 32 bytes, got 16"},
92+
{"too long key", make([]byte, 64), true, "invalid Ed25519 public key length: expected 32 bytes, got 64"},
93+
{"all zeros", testAllZeros, false, ""},
94+
{"all max bytes", testAllMax, false, ""},
95+
}
96+
97+
for _, tt := range tests {
98+
t.Run(tt.name, func(t *testing.T) {
99+
err := ValidateEd25519PublicKey(tt.keyBytes)
100+
checkError(t, err, tt.wantError, tt.errorMsg)
101+
})
102+
}
103+
}
104+
105+
func TestParseAndValidateIntegration(t *testing.T) {
106+
testKeys := []ed25519.PublicKey{testValidKey}
107+
108+
// Generate a few more keys for testing
109+
for i := 0; i < 3; i++ {
110+
pubKey, _, err := ed25519.GenerateKey(rand.Reader)
111+
if err != nil {
112+
t.Fatalf("Failed to generate test key %d: %v", i, err)
113+
}
114+
testKeys = append(testKeys, pubKey)
115+
}
116+
117+
for i, validPubKey := range testKeys {
118+
validHex := hex.EncodeToString(validPubKey)
119+
120+
parsedKey, err := ParseEd25519PublicKeyFromHex(validHex)
121+
if err != nil {
122+
t.Errorf("ParseEd25519PublicKeyFromHex() failed for key %d: %v", i, err)
123+
continue
124+
}
125+
126+
if err := ValidateEd25519PublicKey(parsedKey); err != nil {
127+
t.Errorf("ValidateEd25519PublicKey() failed for key %d: %v", i, err)
128+
}
129+
130+
if !compareBytes(validPubKey, parsedKey) {
131+
t.Errorf("Key %d: parsed key differs from original", i)
132+
}
133+
}
134+
}
135+
136+
// Helper function to compare byte slices
137+
func compareBytes(a, b []byte) bool {
138+
if len(a) != len(b) {
139+
return false
140+
}
141+
for i, v := range a {
142+
if v != b[i] {
143+
return false
144+
}
145+
}
146+
return true
147+
}
148+
149+
func BenchmarkParseEd25519PublicKeyFromHex(b *testing.B) {
150+
for i := 0; i < b.N; i++ {
151+
_, _ = ParseEd25519PublicKeyFromHex(testValidHex)
152+
}
153+
}
154+
155+
func BenchmarkValidateEd25519PublicKey(b *testing.B) {
156+
for i := 0; i < b.N; i++ {
157+
_ = ValidateEd25519PublicKey(testValidKey)
158+
}
159+
}

0 commit comments

Comments
 (0)