44use crate :: polynomial:: { Eval , MonicLinear , Poly } ;
55use crate :: threshold_schnorr:: S ;
66use crate :: types:: { to_scalar, ShareIndex } ;
7- use fastcrypto:: error:: FastCryptoError :: {
8- InputLengthWrong , InputTooShort , InvalidInput , TooManyErrors ,
9- } ;
7+ use fastcrypto:: error:: FastCryptoError :: { InputLengthWrong , InvalidInput , TooManyErrors } ;
108use fastcrypto:: error:: FastCryptoResult ;
119use itertools:: Itertools ;
1210use reed_solomon_erasure:: galois_16:: ReedSolomon ;
@@ -130,6 +128,12 @@ impl RSDecoder {
130128/// A wrapper struct for the Reed-Solomon erasure coding library.
131129pub struct ErasureCoder ( ReedSolomon ) ;
132130
131+ /// An element of `GF(2^16)` as represented by the underlying coder.
132+ type Element = [ u8 ; ELEMENT_SIZE_IN_BYTES ] ;
133+
134+ /// Size in bytes of one `GF(2^16)` element.
135+ const ELEMENT_SIZE_IN_BYTES : usize = 2 ;
136+
133137#[ derive( Clone , Debug , Serialize , Deserialize ) ]
134138#[ serde( transparent) ]
135139pub struct Shard ( pub ( crate ) Vec < u8 > ) ;
@@ -167,14 +171,16 @@ impl ErasureCoder {
167171 if data. is_empty ( ) {
168172 return Err ( InvalidInput ) ;
169173 }
170- // GF(2^16) elements are pairs of bytes; size each shard to a whole number of pairs.
171- let shard_size = data. len ( ) . div_ceil ( 2 * self . 0 . data_shard_count ( ) ) ;
172- let bytes_per_shard = 2 * shard_size;
174+ // Size each shard to a whole number of field elements.
175+ let shard_size = data
176+ . len ( )
177+ . div_ceil ( ELEMENT_SIZE_IN_BYTES * self . 0 . data_shard_count ( ) ) ;
178+ let bytes_per_shard = ELEMENT_SIZE_IN_BYTES * shard_size;
173179 let mut data = data. to_vec ( ) ;
174180 data. resize ( bytes_per_shard * self . 0 . total_shard_count ( ) , 0 ) ;
175- let mut shards: Vec < Vec < [ u8 ; 2 ] > > = data
181+ let mut shards: Vec < Vec < Element > > = data
176182 . chunks_exact ( bytes_per_shard)
177- . map ( bytes_to_elems )
183+ . map ( bytes_to_elements )
178184 . collect :: < FastCryptoResult < _ > > ( ) ?;
179185 self . 0 . encode ( & mut shards) . map_err ( |_| InvalidInput ) ?;
180186 Ok ( shards
@@ -192,16 +198,16 @@ impl ErasureCoder {
192198 expected_len : usize ,
193199 ) -> FastCryptoResult < Vec < u8 > > {
194200 if shards. len ( ) != self . 0 . total_shard_count ( ) {
195- return Err ( InputTooShort ( self . 0 . total_shard_count ( ) ) ) ;
201+ return Err ( InputLengthWrong ( self . 0 . total_shard_count ( ) ) ) ;
196202 }
197203
198204 if shards. iter ( ) . filter ( |s| s. is_none ( ) ) . count ( ) > self . 0 . parity_shard_count ( ) {
199205 return Err ( InvalidInput ) ;
200206 }
201207
202- let mut shards: Vec < Option < Vec < [ u8 ; 2 ] > > > = shards
208+ let mut shards: Vec < Option < Vec < Element > > > = shards
203209 . into_iter ( )
204- . map ( |s| s. map ( |s| bytes_to_elems ( & s. 0 ) ) . transpose ( ) )
210+ . map ( |s| s. map ( |s| bytes_to_elements ( & s. 0 ) ) . transpose ( ) )
205211 . collect :: < FastCryptoResult < _ > > ( ) ?;
206212 self . 0 . reconstruct ( & mut shards) . map_err ( |_| InvalidInput ) ?;
207213 let shards = shards
@@ -223,16 +229,26 @@ impl ErasureCoder {
223229 if data. len ( ) < expected_len {
224230 return Err ( InvalidInput ) ;
225231 }
232+ // The bytes past `expected_len` are zero-padding inserted by `encode`; reject anything
233+ // that doesn't match.
234+ if data[ expected_len..] . iter ( ) . any ( |& b| b != 0 ) {
235+ return Err ( InvalidInput ) ;
236+ }
226237 data. truncate ( expected_len) ;
227238 Ok ( data)
228239 }
229240}
230241
231- fn bytes_to_elems ( bytes : & [ u8 ] ) -> FastCryptoResult < Vec < [ u8 ; 2 ] > > {
232- if !bytes. len ( ) . is_multiple_of ( 2 ) {
242+ /// Reinterpret `bytes` as a sequence of [Element]s. Fails with [`InvalidInput`] if the input
243+ /// length is not a multiple of [`ELEMENT_SIZE_IN_BYTES`].
244+ fn bytes_to_elements ( bytes : & [ u8 ] ) -> FastCryptoResult < Vec < Element > > {
245+ if !bytes. len ( ) . is_multiple_of ( ELEMENT_SIZE_IN_BYTES ) {
233246 return Err ( InvalidInput ) ;
234247 }
235- Ok ( bytes. chunks_exact ( 2 ) . map ( |p| [ p[ 0 ] , p[ 1 ] ] ) . collect ( ) )
248+ Ok ( bytes
249+ . chunks_exact ( ELEMENT_SIZE_IN_BYTES )
250+ . map ( |p| p. try_into ( ) . expect ( "chunk has ELEMENT_SIZE_IN_BYTES bytes" ) )
251+ . collect ( ) )
236252}
237253
238254#[ cfg( test) ]
0 commit comments