Skip to content

Commit c229af1

Browse files
committed
linreg: Add back full test
1 parent 95e00fe commit c229af1

2 files changed

Lines changed: 70 additions & 6 deletions

File tree

src/emlearn_linreg/linreg.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11

22

3+
log_prefix = 'emlearn_linreg:'
4+
35
def train(model, X_train, y_train,
46
max_iterations=100,
57
tolerance=1e-6,
68
check_interval=10,
79
divergence_factor=10.0,
10+
score_limit=None,
811
verbose=0,
912
):
1013
"""
@@ -25,22 +28,25 @@ def train(model, X_train, y_train,
2528
change = abs(prev_mse - current_mse)
2629

2730
if verbose >= 2:
28-
print(f'Iteration {iteration} mse={current_mse}')
31+
print(log_prefix, f'Iteration {iteration} mse={current_mse}')
2932

3033
# Check convergence
3134
converged = change < tolerance and iteration > check_interval * 2
3235

36+
if score_limit is not None:
37+
converged = converged or current_mse <= score_limit
38+
3339
# Check divergence
3440
diverged = current_mse > prev_mse * divergence_factor or not (current_mse == current_mse) # NaN check
3541

3642
if converged:
37-
if verbose:
38-
print(f"Converged at iteration {iteration}")
43+
if verbose >= 1:
44+
print(log_prefix, f"Converged at iteration {iteration}")
3945
break
4046

4147
if diverged:
42-
if verbose:
43-
print(f"Diverged at iteration {iteration}")
48+
if verbose >= 1:
49+
print(log_prefix, f"Diverged at iteration {iteration}")
4450
break
4551

4652
prev_mse = current_mse

tests/test_linreg_california.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def compare_regularization():
3939

4040
model = emlearn_linreg.new(n_features, alpha, l1_ratio, 0.01)
4141
emlearn_linreg.train(model, X_subset, y_subset,
42-
max_iterations=2000, check_interval=50, verbose=True)
42+
max_iterations=2000, check_interval=50, verbose=1)
4343

4444
# Calculate predictions and MSE
4545
mse = model.score_mse(X_subset, y_subset)
@@ -53,6 +53,63 @@ def compare_regularization():
5353
print(f" Non-zero weights: {non_zero}/{n_features}")
5454
print(f" Max weight magnitude: {max(abs(w) for w in weights):.6f}")
5555

56+
def test_elasticnet_full():
57+
"""Test with full dataset."""
58+
print("\n=== Full Dataset Test ===")
59+
60+
# Load full datasets
61+
X_train_shape, X_train_data = load_npy_as_array('X_train.npy')
62+
y_train_shape, y_train_data = load_npy_as_array('y_train.npy')
63+
X_test_shape, X_test_data = load_npy_as_array('X_test.npy')
64+
y_test_shape, y_test_data = load_npy_as_array('y_test.npy')
65+
66+
n_features = X_train_shape[1]
67+
n_train = y_train_shape[0]
68+
n_test = y_test_shape[0]
69+
70+
print(f"Train set: {n_train} samples")
71+
print(f"Test set: {n_test} samples")
72+
print(f"Features: {n_features}")
73+
74+
# Create model with different hyperparameters for full dataset
75+
print("Creating ElasticNet model...")
76+
model = emlearn_linreg.new(n_features, 0.001, 0.5, 0.01) # Lower learning rate for stability
77+
78+
# Train on full dataset
79+
print("Training on full dataset...")
80+
emlearn_linreg.train(model, X_train_data, y_train_data,
81+
max_iterations=2000, check_interval=50, verbose=1, tolerance=0.001, score_limit=0.60)
82+
83+
# Get final parameters
84+
weights = array.array('f', [0.0] * n_features)
85+
model.get_weights(weights)
86+
bias = model.get_bias()
87+
88+
print(f"Final weights: {list(weights)}")
89+
print(f"Final bias: {bias}")
90+
91+
# Evaluate on train set
92+
print("Calculating training MSE...")
93+
train_mse = model.score_mse(X_train_data, y_train_data)
94+
print(f"Training MSE: {train_mse}")
95+
96+
# Evaluate on test set
97+
print("Calculating test MSE...")
98+
test_mse = model.score_mse(X_test_data, y_test_data)
99+
print(f"Test MSE: {test_mse}")
100+
101+
# Make some sample predictions
102+
print("\nSample predictions:")
103+
for i in range(min(5, n_test)):
104+
start_idx = i * n_features
105+
end_idx = start_idx + n_features
106+
test_features = array.array('f', X_test_data[start_idx:end_idx])
107+
prediction = model.predict(test_features)
108+
actual = y_test_data[i]
109+
print(f"Sample {i}: predicted={prediction:.3f}, actual={actual:.3f}, error={abs(prediction-actual):.3f}")
110+
111+
return model
112+
56113
def main():
57114
"""Main test function."""
58115
print("ElasticNet MicroPython Module Test")
@@ -62,6 +119,7 @@ def main():
62119
# Compare regularization approaches
63120
compare_regularization()
64121

122+
test_elasticnet_full()
65123
print("\n=== All Tests Completed Successfully! ===")
66124

67125
except Exception as e:

0 commit comments

Comments
 (0)