1111
1212use core:: ops:: SubAssign ;
1313
14- use super :: WeightedError ;
14+ use super :: WeightError ;
1515use crate :: Distribution ;
1616use alloc:: vec:: Vec ;
1717use rand:: distributions:: uniform:: { SampleBorrow , SampleUniform } ;
@@ -98,15 +98,19 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
9898 WeightedTreeIndex < W >
9999{
100100 /// Creates a new [`WeightedTreeIndex`] from a slice of weights.
101- pub fn new < I > ( weights : I ) -> Result < Self , WeightedError >
101+ ///
102+ /// Error cases:
103+ /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
104+ /// - [`WeightError::Overflow`] when the sum of all weights overflows.
105+ pub fn new < I > ( weights : I ) -> Result < Self , WeightError >
102106 where
103107 I : IntoIterator ,
104108 I :: Item : SampleBorrow < W > ,
105109 {
106110 let mut subtotals: Vec < W > = weights. into_iter ( ) . map ( |x| x. borrow ( ) . clone ( ) ) . collect ( ) ;
107111 for weight in subtotals. iter ( ) {
108- if * weight < W :: ZERO {
109- return Err ( WeightedError :: InvalidWeight ) ;
112+ if ! ( * weight >= W :: ZERO ) {
113+ return Err ( WeightError :: InvalidWeight ) ;
110114 }
111115 }
112116 let n = subtotals. len ( ) ;
@@ -115,7 +119,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
115119 let parent = ( i - 1 ) / 2 ;
116120 subtotals[ parent]
117121 . checked_add_assign ( & w)
118- . map_err ( |( ) | WeightedError :: Overflow ) ?;
122+ . map_err ( |( ) | WeightError :: Overflow ) ?;
119123 }
120124 Ok ( Self { subtotals } )
121125 }
@@ -164,14 +168,18 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
164168 }
165169
166170 /// Appends a new weight at the end.
167- pub fn push ( & mut self , weight : W ) -> Result < ( ) , WeightedError > {
168- if weight < W :: ZERO {
169- return Err ( WeightedError :: InvalidWeight ) ;
171+ ///
172+ /// Error cases:
173+ /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
174+ /// - [`WeightError::Overflow`] when the sum of all weights overflows.
175+ pub fn push ( & mut self , weight : W ) -> Result < ( ) , WeightError > {
176+ if !( weight >= W :: ZERO ) {
177+ return Err ( WeightError :: InvalidWeight ) ;
170178 }
171179 if let Some ( total) = self . subtotals . first ( ) {
172180 let mut total = total. clone ( ) ;
173181 if total. checked_add_assign ( & weight) . is_err ( ) {
174- return Err ( WeightedError :: Overflow ) ;
182+ return Err ( WeightError :: Overflow ) ;
175183 }
176184 }
177185 let mut index = self . len ( ) ;
@@ -184,9 +192,13 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
184192 }
185193
186194 /// Updates the weight at an index.
187- pub fn update ( & mut self , mut index : usize , weight : W ) -> Result < ( ) , WeightedError > {
188- if weight < W :: ZERO {
189- return Err ( WeightedError :: InvalidWeight ) ;
195+ ///
196+ /// Error cases:
197+ /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative.
198+ /// - [`WeightError::Overflow`] when the sum of all weights overflows.
199+ pub fn update ( & mut self , mut index : usize , weight : W ) -> Result < ( ) , WeightError > {
200+ if !( weight >= W :: ZERO ) {
201+ return Err ( WeightError :: InvalidWeight ) ;
190202 }
191203 let old_weight = self . get ( index) ;
192204 if weight > old_weight {
@@ -195,7 +207,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
195207 if let Some ( total) = self . subtotals . first ( ) {
196208 let mut total = total. clone ( ) ;
197209 if total. checked_add_assign ( & difference) . is_err ( ) {
198- return Err ( WeightedError :: Overflow ) ;
210+ return Err ( WeightError :: Overflow ) ;
199211 }
200212 }
201213 self . subtotals [ index]
@@ -235,13 +247,10 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
235247 ///
236248 /// Returns an error if there are no elements or all weights are zero. This
237249 /// is unlike [`Distribution::sample`], which panics in those cases.
238- fn try_sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Result < usize , WeightedError > {
239- if self . subtotals . is_empty ( ) {
240- return Err ( WeightedError :: NoItem ) ;
241- }
242- let total_weight = self . subtotals [ 0 ] . clone ( ) ;
250+ fn try_sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Result < usize , WeightError > {
251+ let total_weight = self . subtotals . first ( ) . cloned ( ) . unwrap_or ( W :: ZERO ) ;
243252 if total_weight == W :: ZERO {
244- return Err ( WeightedError :: AllWeightsZero ) ;
253+ return Err ( WeightError :: InsufficientNonZero ) ;
245254 }
246255 let mut target_weight = rng. gen_range ( W :: ZERO ..total_weight) ;
247256 let mut index = 0 ;
@@ -296,19 +305,19 @@ mod test {
296305 let tree = WeightedTreeIndex :: < f64 > :: new ( & [ ] ) . unwrap ( ) ;
297306 assert_eq ! (
298307 tree. try_sample( & mut rng) . unwrap_err( ) ,
299- WeightedError :: NoItem
308+ WeightError :: InsufficientNonZero
300309 ) ;
301310 }
302311
303312 #[ test]
304313 fn test_overflow_error ( ) {
305314 assert_eq ! (
306315 WeightedTreeIndex :: new( & [ i32 :: MAX , 2 ] ) ,
307- Err ( WeightedError :: Overflow )
316+ Err ( WeightError :: Overflow )
308317 ) ;
309318 let mut tree = WeightedTreeIndex :: new ( & [ i32:: MAX - 2 , 1 ] ) . unwrap ( ) ;
310- assert_eq ! ( tree. push( 3 ) , Err ( WeightedError :: Overflow ) ) ;
311- assert_eq ! ( tree. update( 1 , 4 ) , Err ( WeightedError :: Overflow ) ) ;
319+ assert_eq ! ( tree. push( 3 ) , Err ( WeightError :: Overflow ) ) ;
320+ assert_eq ! ( tree. update( 1 , 4 ) , Err ( WeightError :: Overflow ) ) ;
312321 tree. update ( 1 , 2 ) . unwrap ( ) ;
313322 }
314323
@@ -318,22 +327,22 @@ mod test {
318327 let mut rng = crate :: test:: rng ( 0x9c9fa0b0580a7031 ) ;
319328 assert_eq ! (
320329 tree. try_sample( & mut rng) . unwrap_err( ) ,
321- WeightedError :: AllWeightsZero
330+ WeightError :: InsufficientNonZero
322331 ) ;
323332 }
324333
325334 #[ test]
326335 fn test_invalid_weight_error ( ) {
327336 assert_eq ! (
328337 WeightedTreeIndex :: <i32 >:: new( & [ 1 , -1 ] ) . unwrap_err( ) ,
329- WeightedError :: InvalidWeight
338+ WeightError :: InvalidWeight
330339 ) ;
331340 let mut tree = WeightedTreeIndex :: < i32 > :: new ( & [ ] ) . unwrap ( ) ;
332- assert_eq ! ( tree. push( -1 ) . unwrap_err( ) , WeightedError :: InvalidWeight ) ;
341+ assert_eq ! ( tree. push( -1 ) . unwrap_err( ) , WeightError :: InvalidWeight ) ;
333342 tree. push ( 1 ) . unwrap ( ) ;
334343 assert_eq ! (
335344 tree. update( 0 , -1 ) . unwrap_err( ) ,
336- WeightedError :: InvalidWeight
345+ WeightError :: InvalidWeight
337346 ) ;
338347 }
339348
0 commit comments