Skip to content

Commit ea7a39a

Browse files
feat(wasm-mps): use message domain separators
Ticket: HSM-396
1 parent 2ac8afb commit ea7a39a

2 files changed

Lines changed: 266 additions & 24 deletions

File tree

packages/wasm-mps/src/lib.rs

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ mod mps {
116116
pub chaincode: [u8; 32],
117117
}
118118

119+
fn rem_prefix(prefix: &str, data: &Vec<u8>) -> Result<Vec<u8>, MpsError> {
120+
Ok(data
121+
.as_slice()
122+
.strip_prefix(prefix.as_bytes())
123+
.ok_or(MpsError::InvalidInput)?
124+
.to_vec())
125+
}
126+
127+
fn add_prefix(prefix: &str, data: &Vec<u8>) -> Vec<u8> {
128+
[prefix.as_bytes(), data.as_slice()].concat()
129+
}
130+
119131
fn internal_dkg_round0_process<G>(
120132
party_id: u8,
121133
decryption_key: &[u8; 32],
@@ -193,7 +205,16 @@ mod mps {
193205
encryption_keys: &[Vec<u8>; 2],
194206
seed: &[u8; 32],
195207
) -> Result<MsgState, MpsError> {
196-
internal_dkg_round0_process::<EdwardsPoint>(party_id, decryption_key, encryption_keys, seed)
208+
let result = internal_dkg_round0_process::<EdwardsPoint>(
209+
party_id,
210+
decryption_key,
211+
encryption_keys,
212+
seed,
213+
)?;
214+
Ok(MsgState {
215+
msg: add_prefix("mps-ed25519-dkg-round1-message$", &result.msg),
216+
state: add_prefix("mps-ed25519-dkg-round1-state$", &result.state),
217+
})
197218
}
198219

199220
fn internal_dkg_round1_process<G>(
@@ -240,7 +261,14 @@ mod mps {
240261
round1_messages: &[Vec<u8>; 2],
241262
state: &[u8],
242263
) -> Result<MsgState, MpsError> {
243-
internal_dkg_round1_process::<EdwardsPoint>(round1_messages, state)
264+
let i0_msg1 = rem_prefix("mps-ed25519-dkg-round1-message$", &round1_messages[0])?;
265+
let i1_msg1 = rem_prefix("mps-ed25519-dkg-round1-message$", &round1_messages[1])?;
266+
let state = rem_prefix("mps-ed25519-dkg-round1-state$", &state.to_vec())?;
267+
let result = internal_dkg_round1_process::<EdwardsPoint>(&[i0_msg1, i1_msg1], &state)?;
268+
Ok(MsgState {
269+
msg: add_prefix("mps-ed25519-dkg-round2-message$", &result.msg),
270+
state: add_prefix("mps-ed25519-dkg-round2-state$", &result.state),
271+
})
244272
}
245273

246274
fn internal_dkg_round2_process<G>(
@@ -277,7 +305,10 @@ mod mps {
277305
round2_messages: &[Vec<u8>; 2],
278306
state: &[u8],
279307
) -> Result<Share, MpsError> {
280-
let share = internal_dkg_round2_process::<EdwardsPoint>(round2_messages, state)?;
308+
let i0_msg2 = rem_prefix("mps-ed25519-dkg-round2-message$", &round2_messages[0])?;
309+
let i1_msg2 = rem_prefix("mps-ed25519-dkg-round2-message$", &round2_messages[1])?;
310+
let state = rem_prefix("mps-ed25519-dkg-round2-state$", &state.to_vec())?;
311+
let share = internal_dkg_round2_process::<EdwardsPoint>(&[i0_msg2, i1_msg2], &state)?;
281312
Ok(Share {
282313
share: bincode::serialize(&share).map_err(|_| MpsError::SerializationError)?,
283314
pk: share.public_key.compress().to_bytes(),
@@ -328,7 +359,12 @@ mod mps {
328359
&mut rand::thread_rng(),
329360
);
330361

331-
internal_dsg_round0_process(p0)
362+
let result = internal_dsg_round0_process(p0)?;
363+
364+
Ok(MsgState {
365+
msg: add_prefix("mps-ed25519-dsg-round1-message$", &result.msg),
366+
state: add_prefix("mps-ed25519-dsg-round1-state$", &result.state),
367+
})
332368
}
333369

334370
fn internal_dsg_round1_process<G>(
@@ -373,7 +409,17 @@ mod mps {
373409
round1_message: &[u8],
374410
state: &[u8],
375411
) -> Result<MsgState, MpsError> {
376-
internal_dsg_round1_process::<EdwardsPoint>(round1_message, state)
412+
let round1_message =
413+
rem_prefix("mps-ed25519-dsg-round1-message$", &round1_message.to_vec())?;
414+
let state = rem_prefix("mps-ed25519-dsg-round1-state$", &state.to_vec())?;
415+
let result = internal_dsg_round1_process::<EdwardsPoint>(
416+
round1_message.as_slice(),
417+
state.as_slice(),
418+
)?;
419+
Ok(MsgState {
420+
msg: add_prefix("mps-ed25519-dsg-round2-message$", &result.msg),
421+
state: add_prefix("mps-ed25519-dsg-round2-state$", &result.state),
422+
})
377423
}
378424

379425
/// Process round 2 of DSG protocol.
@@ -383,13 +429,18 @@ mod mps {
383429
round2_message: &[u8],
384430
state: &[u8],
385431
) -> Result<MsgState, MpsError> {
432+
// Strip prefix
433+
let round2_message =
434+
rem_prefix("mps-ed25519-dsg-round2-message$", &round2_message.to_vec())?;
435+
let state = rem_prefix("mps-ed25519-dsg-round2-state$", &state.to_vec())?;
436+
386437
// Parse state
387438
let state: DsgStateR2<EdwardsPoint> =
388-
bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?;
439+
bincode::deserialize(&state).map_err(|_| MpsError::DeserializationError)?;
389440

390441
// Parse messages
391442
let i0_msg2: SignMsg2<EdwardsPoint> =
392-
bincode::deserialize(round2_message).map_err(|_| MpsError::DeserializationError)?;
443+
bincode::deserialize(&round2_message).map_err(|_| MpsError::DeserializationError)?;
393444
let msgs = vec![i0_msg2, state.msg];
394445

395446
// Process all round2 messages together
@@ -408,8 +459,14 @@ mod mps {
408459
};
409460

410461
Ok(MsgState {
411-
msg: bincode::serialize(&msg3).map_err(|_| MpsError::SerializationError)?,
412-
state: bincode::serialize(&state).map_err(|_| MpsError::SerializationError)?,
462+
msg: add_prefix(
463+
"mps-ed25519-dsg-round3-message$",
464+
&bincode::serialize(&msg3).map_err(|_| MpsError::SerializationError)?,
465+
),
466+
state: add_prefix(
467+
"mps-ed25519-dsg-round3-state$",
468+
&bincode::serialize(&state).map_err(|_| MpsError::SerializationError)?,
469+
),
413470
})
414471
}
415472

@@ -420,13 +477,18 @@ mod mps {
420477
round3_message: &[u8],
421478
state: &[u8],
422479
) -> Result<Vec<u8>, MpsError> {
480+
// Strip prefix
481+
let round3_message =
482+
rem_prefix("mps-ed25519-dsg-round3-message$", &round3_message.to_vec())?;
483+
let state = rem_prefix("mps-ed25519-dsg-round3-state$", &state.to_vec())?;
484+
423485
// Parse state
424486
let state: DsgStateR3<EdwardsPoint> =
425-
bincode::deserialize(state).map_err(|_| MpsError::DeserializationError)?;
487+
bincode::deserialize(&state).map_err(|_| MpsError::DeserializationError)?;
426488

427489
// Parse messages
428490
let i0_msg3: SignMsg3<EdwardsPoint> =
429-
bincode::deserialize(round3_message).map_err(|_| MpsError::DeserializationError)?;
491+
bincode::deserialize(&round3_message).map_err(|_| MpsError::DeserializationError)?;
430492
let msgs = vec![i0_msg3, state.msg];
431493

432494
// Process all round2 messages together

0 commit comments

Comments
 (0)