Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.

Commit f1e31dc

Browse files
committed
avoid repeated clone
1 parent 99d979e commit f1e31dc

1 file changed

Lines changed: 21 additions & 19 deletions

File tree

src/learning/tree.rs

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -257,21 +257,25 @@ impl SupModel<Matrix<f64>, Vector<usize>> for DecisionTreeClassifier {
257257
}
258258

259259

260-
/// Uniquify Vec<f64>, result is sorted
261-
fn uniquify(values: &Vec<f64>) -> Vec<f64> {
260+
/// Uniquify values, then get splitter values, i.e. midpoints of unique values
261+
fn get_splits(values: &Vec<f64>) -> Vec<f64> {
262+
debug_assert!(values.len() > 0, "values can't be empty");
263+
264+
// ToDo: must avoid repeated sort
262265
let mut values = values.clone();
263266
values.sort_by(|a, b| a.partial_cmp(b).unwrap());
264-
values.dedup();
265-
values
266-
}
267267

268-
/// Uniquify values, then get splitter values, i.e. midpoints of unique values
269-
fn get_splits(values: &Vec<f64>) -> Vec<f64> {
270-
let uniques = uniquify(values);
271-
uniques[..uniques.len()].iter()
272-
.zip(uniques[1..].iter())
273-
.map(|(&x, &y)| (x + y) / 2.)
274-
.collect()
268+
let mut splits: Vec<f64> = Vec::with_capacity(values.len());
269+
270+
let mut prev: f64 = unsafe {*values.get_unchecked(0) };
271+
for &v in values.iter().skip(0) {
272+
if prev != v {
273+
splits.push((prev + v) / 2.);
274+
prev = v;
275+
}
276+
277+
}
278+
splits
275279
}
276280

277281
/// Split Vec to left and right, depending on given bool Vec values
@@ -354,19 +358,17 @@ mod tests {
354358

355359
use linalg::Vector;
356360

357-
use super::{uniquify, get_splits, split_slice, xlogy, freq, Metrics};
358-
359-
#[test]
360-
fn test_uniquify() {
361-
assert_eq!(uniquify(&vec![0.1, 0.2, 0.1]), vec![0.1, 0.2]);
362-
assert_eq!(uniquify(&vec![0.3, 0.1, 0.1, 0.1, 0.2, 0.2]), vec![0.1, 0.2, 0.3]);
363-
}
361+
use super::{get_splits, split_slice, xlogy, freq, Metrics};
364362

365363
#[test]
366364
fn test_get_splits() {
367365
assert_eq!(get_splits(&vec![0.1, 0.2, 0.1]), vec![0.15000000000000002]);
368366
assert_eq!(get_splits(&vec![0.3, 0.1, 0.1, 0.1, 0.2, 0.2]), vec![0.15000000000000002, 0.25]);
369367
assert_eq!(get_splits(&vec![1., 3., 7., 3., 7.]), vec![2., 5.]);
368+
assert_eq!(get_splits(&vec![0.1, 0.2, 0.1]), vec![0.15000000000000002]);
369+
assert_eq!(get_splits(&vec![0.1, 0.2, 0.1, 0.1]), vec![0.15000000000000002]);
370+
assert_eq!(get_splits(&vec![-1., -2., 1., -2.]), vec![-1.5, 0.]);
371+
assert_eq!(get_splits(&vec![0.1, 0.1, 0.1]), vec![]);
370372
}
371373

372374
#[test]

0 commit comments

Comments
 (0)