Skip to content

Commit 5ca83ee

Browse files
committed
linreg: Fixup test
1 parent 9d9b517 commit 5ca83ee

2 files changed

Lines changed: 17 additions & 11 deletions

File tree

src/emlearn_linreg/emlearn_linreg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11

2-
from emlearn_linreg_c import *
2+
# When used as external C module, the .py is the top-level import,
3+
# and we need to merge the native module symbols at import time
4+
# When used as dynamic native modules (.mpy), .py and native code is merged at build time
5+
try:
6+
from emlearn_linreg_c import *
7+
except ImportError as e:
8+
pass
39

410
log_prefix = 'emlearn_linreg:'
511

tests/test_linreg_california.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import gc
66
import time
77

8-
data_dir = './examples/datasets/california'
8+
data_dir = './examples/datasets/california/'
99

1010
def load_npy(filename):
1111
shape, data = npyfile.load(filename)
@@ -16,10 +16,10 @@ def test_elasticnet_full():
1616
print("\n=== Full Dataset Test ===")
1717

1818
# Load full datasets
19-
X_train_shape, X_train_data = load_npy('X_train.npy')
20-
y_train_shape, y_train_data = load_npy('y_train.npy')
21-
X_test_shape, X_test_data = load_npy('X_test.npy')
22-
y_test_shape, y_test_data = load_npy('y_test.npy')
19+
X_train_shape, X_train_data = load_npy(data_dir+'X_train.npy')
20+
y_train_shape, y_train_data = load_npy(data_dir+'y_train.npy')
21+
X_test_shape, X_test_data = load_npy(data_dir+'X_test.npy')
22+
y_test_shape, y_test_data = load_npy(data_dir+'y_test.npy')
2323

2424
n_features = X_train_shape[1]
2525
n_train = y_train_shape[0]
@@ -31,15 +31,15 @@ def test_elasticnet_full():
3131

3232
# Create model with different hyperparameters for full dataset
3333
# Lower learning rate for stability
34-
model = emlearn_linreg.new(n_features, 0.001, 0.5, 0.01)
34+
model = emlearn_linreg.new(n_features, 0.001, 0.5, 0.1)
3535

3636
train_start = time.ticks_ms()
3737
# Train on full dataset
3838
print("Training on full dataset...")
3939
stop_iter, stop_mse = emlearn_linreg.train(model,
4040
X_train_data, y_train_data,
41-
max_iterations=2000, check_interval=50,
42-
verbose=2, tolerance=0.001, score_limit=0.60,
41+
max_iterations=1000, check_interval=50,
42+
verbose=2, tolerance=0.0001, score_limit=0.60,
4343
)
4444
train_duration = time.ticks_diff(time.ticks_ms(), train_start)
4545
print('Train time (ms)', train_duration, 'per iter', train_duration/stop_iter)
@@ -48,8 +48,8 @@ def test_elasticnet_full():
4848
train_mse = model.score_mse(X_train_data, y_train_data)
4949
test_mse = model.score_mse(X_test_data, y_test_data)
5050

51-
assert train_mse <= 0.60, train_mse
52-
assert test_mse <= 0.60, test_mse
51+
assert train_mse <= 0.65, train_mse
52+
assert test_mse <= 0.65, test_mse
5353

5454

5555
def main():

0 commit comments

Comments
 (0)