Skip to content

Commit af49629

Browse files
committed
Introduce Element/ELEMENT_SIZE_IN_BYTES; verify zero padding; tidy decode error
1 parent df2dcef commit af49629

1 file changed

Lines changed: 30 additions & 14 deletions

File tree

fastcrypto-tbls/src/threshold_schnorr/reed_solomon.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
use crate::polynomial::{Eval, MonicLinear, Poly};
55
use crate::threshold_schnorr::S;
66
use crate::types::{to_scalar, ShareIndex};
7-
use fastcrypto::error::FastCryptoError::{
8-
InputLengthWrong, InputTooShort, InvalidInput, TooManyErrors,
9-
};
7+
use fastcrypto::error::FastCryptoError::{InputLengthWrong, InvalidInput, TooManyErrors};
108
use fastcrypto::error::FastCryptoResult;
119
use itertools::Itertools;
1210
use reed_solomon_erasure::galois_16::ReedSolomon;
@@ -130,6 +128,12 @@ impl RSDecoder {
130128
/// A wrapper struct for the Reed-Solomon erasure coding library.
131129
pub struct ErasureCoder(ReedSolomon);
132130

131+
/// An element of `GF(2^16)` as represented by the underlying coder.
132+
type Element = [u8; ELEMENT_SIZE_IN_BYTES];
133+
134+
/// Size in bytes of one `GF(2^16)` element.
135+
const ELEMENT_SIZE_IN_BYTES: usize = 2;
136+
133137
#[derive(Clone, Debug, Serialize, Deserialize)]
134138
#[serde(transparent)]
135139
pub struct Shard(pub(crate) Vec<u8>);
@@ -167,14 +171,16 @@ impl ErasureCoder {
167171
if data.is_empty() {
168172
return Err(InvalidInput);
169173
}
170-
// GF(2^16) elements are pairs of bytes; size each shard to a whole number of pairs.
171-
let shard_size = data.len().div_ceil(2 * self.0.data_shard_count());
172-
let bytes_per_shard = 2 * shard_size;
174+
// Size each shard to a whole number of field elements.
175+
let shard_size = data
176+
.len()
177+
.div_ceil(ELEMENT_SIZE_IN_BYTES * self.0.data_shard_count());
178+
let bytes_per_shard = ELEMENT_SIZE_IN_BYTES * shard_size;
173179
let mut data = data.to_vec();
174180
data.resize(bytes_per_shard * self.0.total_shard_count(), 0);
175-
let mut shards: Vec<Vec<[u8; 2]>> = data
181+
let mut shards: Vec<Vec<Element>> = data
176182
.chunks_exact(bytes_per_shard)
177-
.map(bytes_to_elems)
183+
.map(bytes_to_elements)
178184
.collect::<FastCryptoResult<_>>()?;
179185
self.0.encode(&mut shards).map_err(|_| InvalidInput)?;
180186
Ok(shards
@@ -192,16 +198,16 @@ impl ErasureCoder {
192198
expected_len: usize,
193199
) -> FastCryptoResult<Vec<u8>> {
194200
if shards.len() != self.0.total_shard_count() {
195-
return Err(InputTooShort(self.0.total_shard_count()));
201+
return Err(InputLengthWrong(self.0.total_shard_count()));
196202
}
197203

198204
if shards.iter().filter(|s| s.is_none()).count() > self.0.parity_shard_count() {
199205
return Err(InvalidInput);
200206
}
201207

202-
let mut shards: Vec<Option<Vec<[u8; 2]>>> = shards
208+
let mut shards: Vec<Option<Vec<Element>>> = shards
203209
.into_iter()
204-
.map(|s| s.map(|s| bytes_to_elems(&s.0)).transpose())
210+
.map(|s| s.map(|s| bytes_to_elements(&s.0)).transpose())
205211
.collect::<FastCryptoResult<_>>()?;
206212
self.0.reconstruct(&mut shards).map_err(|_| InvalidInput)?;
207213
let shards = shards
@@ -223,16 +229,26 @@ impl ErasureCoder {
223229
if data.len() < expected_len {
224230
return Err(InvalidInput);
225231
}
232+
// The bytes past `expected_len` are zero-padding inserted by `encode`; reject anything
233+
// that doesn't match.
234+
if data[expected_len..].iter().any(|&b| b != 0) {
235+
return Err(InvalidInput);
236+
}
226237
data.truncate(expected_len);
227238
Ok(data)
228239
}
229240
}
230241

231-
fn bytes_to_elems(bytes: &[u8]) -> FastCryptoResult<Vec<[u8; 2]>> {
232-
if !bytes.len().is_multiple_of(2) {
242+
/// Reinterpret `bytes` as a sequence of [Element]s. Fails with [`InvalidInput`] if the input
243+
/// length is not a multiple of [`ELEMENT_SIZE_IN_BYTES`].
244+
fn bytes_to_elements(bytes: &[u8]) -> FastCryptoResult<Vec<Element>> {
245+
if !bytes.len().is_multiple_of(ELEMENT_SIZE_IN_BYTES) {
233246
return Err(InvalidInput);
234247
}
235-
Ok(bytes.chunks_exact(2).map(|p| [p[0], p[1]]).collect())
248+
Ok(bytes
249+
.chunks_exact(ELEMENT_SIZE_IN_BYTES)
250+
.map(|p| p.try_into().expect("chunk has ELEMENT_SIZE_IN_BYTES bytes"))
251+
.collect())
236252
}
237253

238254
#[cfg(test)]

0 commit comments

Comments
 (0)