Skip to content

Commit 2571a73

Browse files
committed
linreg: Benchmark on ESP32 S3
On California dataset with 16k rows and 8 features, can do one training step in 80 milliseconds
1 parent c229af1 commit 2571a73

2 files changed

Lines changed: 16 additions & 6 deletions

File tree

src/emlearn_linreg/linreg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@ def train(model, X_train, y_train,
5151

5252
prev_mse = current_mse
5353

54+
return iteration, prev_mse
55+

tests/test_linreg_california.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import npyfile
44
import array
55
import gc
6+
import time
67

78
def load_npy_as_array(filename, dtype='f'):
89
"""Load .npy file and convert to MicroPython array."""
@@ -39,7 +40,7 @@ def compare_regularization():
3940

4041
model = emlearn_linreg.new(n_features, alpha, l1_ratio, 0.01)
4142
emlearn_linreg.train(model, X_subset, y_subset,
42-
max_iterations=2000, check_interval=50, verbose=1)
43+
max_iterations=100, check_interval=50, verbose=1)
4344

4445
# Calculate predictions and MSE
4546
mse = model.score_mse(X_subset, y_subset)
@@ -75,11 +76,17 @@ def test_elasticnet_full():
7576
print("Creating ElasticNet model...")
7677
model = emlearn_linreg.new(n_features, 0.001, 0.5, 0.01) # Lower learning rate for stability
7778

79+
train_start = time.ticks_ms()
7880
# Train on full dataset
7981
print("Training on full dataset...")
80-
emlearn_linreg.train(model, X_train_data, y_train_data,
81-
max_iterations=2000, check_interval=50, verbose=1, tolerance=0.001, score_limit=0.60)
82-
82+
stop_iter, stop_mse = emlearn_linreg.train(model,
83+
X_train_data, y_train_data,
84+
max_iterations=2000, check_interval=50,
85+
verbose=2, tolerance=0.001, score_limit=0.60,
86+
)
87+
train_duration = time.ticks_diff(time.ticks_ms(), train_start)
88+
print('Train time (ms)', train_duration, 'per iter', train_duration/stop_iter)
89+
8390
# Get final parameters
8491
weights = array.array('f', [0.0] * n_features)
8592
model.get_weights(weights)
@@ -117,8 +124,9 @@ def main():
117124

118125
try:
119126
# Compare regularization approaches
120-
compare_regularization()
121-
127+
#compare_regularization()
128+
gc.collect()
129+
122130
test_elasticnet_full()
123131
print("\n=== All Tests Completed Successfully! ===")
124132

0 commit comments

Comments
 (0)