@@ -257,21 +257,25 @@ impl SupModel<Matrix<f64>, Vector<usize>> for DecisionTreeClassifier {
257257}
258258
259259
260- /// Uniquify Vec<f64>, result is sorted
261- fn uniquify ( values : & Vec < f64 > ) -> Vec < f64 > {
260+ /// Uniquify values, then get splitter values, i.e. midpoints of unique values
261+ fn get_splits ( values : & Vec < f64 > ) -> Vec < f64 > {
262+ debug_assert ! ( values. len( ) > 0 , "values can't be empty" ) ;
263+
264+ // ToDo: must avoid repeated sort
262265 let mut values = values. clone ( ) ;
263266 values. sort_by ( |a, b| a. partial_cmp ( b) . unwrap ( ) ) ;
264- values. dedup ( ) ;
265- values
266- }
267267
268- /// Uniquify values, then get splitter values, i.e. midpoints of unique values
269- fn get_splits ( values : & Vec < f64 > ) -> Vec < f64 > {
270- let uniques = uniquify ( values) ;
271- uniques[ ..uniques. len ( ) ] . iter ( )
272- . zip ( uniques[ 1 ..] . iter ( ) )
273- . map ( |( & x, & y) | ( x + y) / 2. )
274- . collect ( )
268+ let mut splits: Vec < f64 > = Vec :: with_capacity ( values. len ( ) ) ;
269+
270+ let mut prev: f64 = unsafe { * values. get_unchecked ( 0 ) } ;
271+ for & v in values. iter ( ) . skip ( 0 ) {
272+ if prev != v {
273+ splits. push ( ( prev + v) / 2. ) ;
274+ prev = v;
275+ }
276+
277+ }
278+ splits
275279}
276280
277281/// Split Vec to left and right, depending on given bool Vec values
@@ -354,19 +358,17 @@ mod tests {
354358
355359 use linalg:: Vector ;
356360
357- use super :: { uniquify, get_splits, split_slice, xlogy, freq, Metrics } ;
358-
359- #[ test]
360- fn test_uniquify ( ) {
361- assert_eq ! ( uniquify( & vec![ 0.1 , 0.2 , 0.1 ] ) , vec![ 0.1 , 0.2 ] ) ;
362- assert_eq ! ( uniquify( & vec![ 0.3 , 0.1 , 0.1 , 0.1 , 0.2 , 0.2 ] ) , vec![ 0.1 , 0.2 , 0.3 ] ) ;
363- }
361+ use super :: { get_splits, split_slice, xlogy, freq, Metrics } ;
364362
365363 #[ test]
366364 fn test_get_splits ( ) {
367365 assert_eq ! ( get_splits( & vec![ 0.1 , 0.2 , 0.1 ] ) , vec![ 0.15000000000000002 ] ) ;
368366 assert_eq ! ( get_splits( & vec![ 0.3 , 0.1 , 0.1 , 0.1 , 0.2 , 0.2 ] ) , vec![ 0.15000000000000002 , 0.25 ] ) ;
369367 assert_eq ! ( get_splits( & vec![ 1. , 3. , 7. , 3. , 7. ] ) , vec![ 2. , 5. ] ) ;
368+ assert_eq ! ( get_splits( & vec![ 0.1 , 0.2 , 0.1 ] ) , vec![ 0.15000000000000002 ] ) ;
369+ assert_eq ! ( get_splits( & vec![ 0.1 , 0.2 , 0.1 , 0.1 ] ) , vec![ 0.15000000000000002 ] ) ;
370+ assert_eq ! ( get_splits( & vec![ -1. , -2. , 1. , -2. ] ) , vec![ -1.5 , 0. ] ) ;
371+ assert_eq ! ( get_splits( & vec![ 0.1 , 0.1 , 0.1 ] ) , vec![ ] ) ;
370372 }
371373
372374 #[ test]
0 commit comments