@@ -47,4 +47,27 @@ fn test_no_train_predict() {
4747 let inputs = Matrix :: new ( 3 , 2 , vec ! [ 1.0 , 2.0 , 1.0 , 3.0 , 1.0 , 4.0 ] ) ;
4848
4949 let _ = lin_mod. predict ( & inputs) . unwrap ( ) ;
50+ }
51+
52+ #[ cfg( feature = "datasets" ) ]
53+ #[ test]
54+ fn test_regression_datasets_trees ( ) {
55+ use rm:: datasets:: trees;
56+ let trees = trees:: load ( ) ;
57+
58+ let mut lin_mod = LinRegressor :: default ( ) ;
59+ lin_mod. train ( & trees. data ( ) , & trees. target ( ) ) . unwrap ( ) ;
60+ let params = lin_mod. parameters ( ) . unwrap ( ) ;
61+ assert_eq ! ( params, & Vector :: new( vec![ -57.98765891838409 , 4.708160503017506 , 0.3392512342447438 ] ) ) ;
62+
63+ let predicted = lin_mod. predict ( & trees. data ( ) ) . unwrap ( ) ;
64+ let expected = vec ! [ 4.837659653793278 , 4.55385163347481 , 4.816981265588826 , 15.874115228921276 ,
65+ 19.869008437727473 , 21.018326956518717 , 16.192688074961563 , 19.245949183164257 ,
66+ 21.413021404689726 , 20.187581283767756 , 22.015402271048487 , 21.468464618616007 ,
67+ 21.468464618616007 , 20.50615412980805 , 23.954109686181766 , 27.852202904652785 ,
68+ 31.583966481344966 , 33.806481916796706 , 30.60097760433255 , 28.697035014921106 ,
69+ 34.388184394951004 , 36.008318964043994 , 35.38525970948079 , 41.76899799551756 ,
70+ 44.87770231764652 , 50.942867757643015 , 52.223751092491256 , 53.42851282520877 ,
71+ 53.899328875510534 , 53.899328875510534 , 68.51530482306926 ] ;
72+ assert_eq ! ( predicted, Vector :: new( expected) ) ;
5073}
0 commit comments