Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.

Commit 5265129

Browse files
sinhrksAtheMathmo
authored andcommitted
LinRegressor to use solve rather than inverse (#165)
* Use solve rather than inverse * add trees dataset * update error message
1 parent 34a5417 commit 5265129

6 files changed

Lines changed: 106 additions & 11 deletions

File tree

src/datasets/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use std::fmt::Debug;
22

33
/// Module for iris dataset.
44
pub mod iris;
5+
/// Module for trees dataset.
6+
pub mod trees;
57

68
/// Dataset container
79
#[derive(Clone, Debug)]

src/datasets/trees.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
use rulinalg::matrix::Matrix;
2+
use rulinalg::vector::Vector;
3+
4+
use super::Dataset;
5+
6+
/// Load trees dataset.
7+
///
8+
/// The data set contains a sample of 31 black cherry trees in the
9+
/// Allegheny National Forest, Pennsylvania.
10+
///
11+
/// ## Attribute Information
12+
///
13+
/// ### Data
14+
///
15+
/// ``Matrix<f64>`` contains following columns.
16+
///
17+
/// - diameter (inches)
18+
/// - height (feet)
19+
///
20+
/// ### Target
21+
///
22+
/// ``Vector<f64>`` contains volume (cubic feet) of trees.
23+
///
24+
/// Thomas A. Ryan, Brian L. Joiner, Barbara F. Ryan. (1976).
25+
/// Minitab student handbook. Duxbury Press
26+
pub fn load() -> Dataset<Matrix<f64>, Vector<f64>> {
27+
let data = matrix![8.3, 70.;
28+
8.6, 65.;
29+
8.8, 63.;
30+
10.5, 72.;
31+
10.7, 81.;
32+
10.8, 83.;
33+
11.0, 66.;
34+
11.0, 75.;
35+
11.1, 80.;
36+
11.2, 75.;
37+
11.3, 79.;
38+
11.4, 76.;
39+
11.4, 76.;
40+
11.7, 69.;
41+
12.0, 75.;
42+
12.9, 74.;
43+
12.9, 85.;
44+
13.3, 86.;
45+
13.7, 71.;
46+
13.8, 64.;
47+
14.0, 78.;
48+
14.2, 80.;
49+
14.5, 74.;
50+
16.0, 72.;
51+
16.3, 77.;
52+
17.3, 81.;
53+
17.5, 82.;
54+
17.9, 80.;
55+
18.0, 80.;
56+
18.0, 80.;
57+
20.6, 87.];
58+
let target = vec![10.3, 10.3, 10.2, 16.4, 18.8, 19.7, 15.6, 18.2, 22.6, 19.9,
59+
24.2, 21.0, 21.4, 21.3, 19.1, 22.2, 33.8, 27.4, 25.7, 24.9,
60+
34.5, 31.7, 36.3, 38.3, 42.6, 55.4, 55.7, 58.3, 51.5, 51.0,
61+
77.0];
62+
Dataset{ data: data,
63+
target: Vector::new(target) }
64+
}

src/learning/lin_reg.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,8 @@ impl SupModel<Matrix<f64>, Vector<f64>> for LinRegressor {
8787
let full_inputs = ones.hcat(inputs);
8888

8989
let xt = full_inputs.transpose();
90-
91-
self.parameters =
92-
Some(((&xt * full_inputs).inverse().expect("Could not compute (X_T X) inverse.") *
93-
&xt) * targets);
94-
90+
self.parameters = Some((&xt * full_inputs).solve(&xt * targets)
91+
.expect("Unable to solve linear equation."));
9592
Ok(())
9693
}
9794

tests/datasets.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
extern crate rusty_machine as rm;
22

3-
4-
#[cfg(datasets)]
3+
#[cfg(feature = "datasets")]
54
mod test {
65

7-
use rm::datasets::iris;
6+
use rm::datasets;
87
use rm::linalg::BaseMatrix;
98

109
#[test]
1110
fn test_iris() {
12-
let dt = iris::load_();
11+
let dt = datasets::iris::load();
1312
assert_eq!(dt.data().rows(), 150);
1413
assert_eq!(dt.data().cols(), 4);
1514

1615
assert_eq!(dt.target().size(), 150);
1716
}
18-
}
17+
18+
#[test]
19+
fn test_trees() {
20+
let dt = datasets::trees::load();
21+
assert_eq!(dt.data().rows(), 31);
22+
assert_eq!(dt.data().cols(), 2);
23+
24+
assert_eq!(dt.target().size(), 31);
25+
}
26+
}

tests/learning/lin_reg.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,27 @@ fn test_no_train_predict() {
4747
let inputs = Matrix::new(3, 2, vec![1.0, 2.0, 1.0, 3.0, 1.0, 4.0]);
4848

4949
let _ = lin_mod.predict(&inputs).unwrap();
50+
}
51+
52+
#[cfg(feature = "datasets")]
53+
#[test]
54+
fn test_regression_datasets_trees() {
55+
use rm::datasets::trees;
56+
let trees = trees::load();
57+
58+
let mut lin_mod = LinRegressor::default();
59+
lin_mod.train(&trees.data(), &trees.target()).unwrap();
60+
let params = lin_mod.parameters().unwrap();
61+
assert_eq!(params, &Vector::new(vec![-57.98765891838409, 4.708160503017506, 0.3392512342447438]));
62+
63+
let predicted = lin_mod.predict(&trees.data()).unwrap();
64+
let expected = vec![4.837659653793278, 4.55385163347481, 4.816981265588826, 15.874115228921276,
65+
19.869008437727473, 21.018326956518717, 16.192688074961563, 19.245949183164257,
66+
21.413021404689726, 20.187581283767756, 22.015402271048487, 21.468464618616007,
67+
21.468464618616007, 20.50615412980805, 23.954109686181766, 27.852202904652785,
68+
31.583966481344966, 33.806481916796706, 30.60097760433255, 28.697035014921106,
69+
34.388184394951004, 36.008318964043994, 35.38525970948079, 41.76899799551756,
70+
44.87770231764652, 50.942867757643015, 52.223751092491256, 53.42851282520877,
71+
53.899328875510534, 53.899328875510534, 68.51530482306926];
72+
assert_eq!(predicted, Vector::new(expected));
5073
}

tests/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#[macro_use]
2+
extern crate rulinalg;
13
extern crate rusty_machine as rm;
24
extern crate num as libnum;
35

@@ -12,5 +14,4 @@ pub mod learning {
1214
}
1315
}
1416

15-
#[cfg(datasets)]
1617
pub mod datasets;

0 commit comments

Comments
 (0)