55import gc
66import time
77
8- data_dir = './examples/datasets/california'
8+ data_dir = './examples/datasets/california/ '
99
1010def load_npy (filename ):
1111 shape , data = npyfile .load (filename )
@@ -16,10 +16,10 @@ def test_elasticnet_full():
1616 print ("\n === Full Dataset Test ===" )
1717
1818 # Load full datasets
19- X_train_shape , X_train_data = load_npy ('X_train.npy' )
20- y_train_shape , y_train_data = load_npy ('y_train.npy' )
21- X_test_shape , X_test_data = load_npy ('X_test.npy' )
22- y_test_shape , y_test_data = load_npy ('y_test.npy' )
19+ X_train_shape , X_train_data = load_npy (data_dir + 'X_train.npy' )
20+ y_train_shape , y_train_data = load_npy (data_dir + 'y_train.npy' )
21+ X_test_shape , X_test_data = load_npy (data_dir + 'X_test.npy' )
22+ y_test_shape , y_test_data = load_npy (data_dir + 'y_test.npy' )
2323
2424 n_features = X_train_shape [1 ]
2525 n_train = y_train_shape [0 ]
@@ -31,15 +31,15 @@ def test_elasticnet_full():
3131
3232 # Create model with different hyperparameters for full dataset
3333 # Lower learning rate for stability
34- model = emlearn_linreg .new (n_features , 0.001 , 0.5 , 0.01 )
34+ model = emlearn_linreg .new (n_features , 0.001 , 0.5 , 0.1 )
3535
3636 train_start = time .ticks_ms ()
3737 # Train on full dataset
3838 print ("Training on full dataset..." )
3939 stop_iter , stop_mse = emlearn_linreg .train (model ,
4040 X_train_data , y_train_data ,
41- max_iterations = 2000 , check_interval = 50 ,
42- verbose = 2 , tolerance = 0.001 , score_limit = 0.60 ,
41+ max_iterations = 1000 , check_interval = 50 ,
42+ verbose = 2 , tolerance = 0.0001 , score_limit = 0.60 ,
4343 )
4444 train_duration = time .ticks_diff (time .ticks_ms (), train_start )
4545 print ('Train time (ms)' , train_duration , 'per iter' , train_duration / stop_iter )
@@ -48,8 +48,8 @@ def test_elasticnet_full():
4848 train_mse = model .score_mse (X_train_data , y_train_data )
4949 test_mse = model .score_mse (X_test_data , y_test_data )
5050
51- assert train_mse <= 0.60 , train_mse
52- assert test_mse <= 0.60 , test_mse
51+ assert train_mse <= 0.65 , train_mse
52+ assert test_mse <= 0.65 , test_mse
5353
5454
5555def main ():
0 commit comments