Skip to content

Commit 95e00fe

Browse files
committed
linreg: Include training loop in module
1 parent 725b420 commit 95e00fe

4 files changed

Lines changed: 54 additions & 83 deletions

File tree

src/emlearn_linreg/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ DIST_DIR := ../../dist/$(ARCH)_$(MPY_ABI_VERSION)
1919
MOD = emlearn_linreg
2020

2121
# Source files (.c or .py)
22-
SRC = linreg.c
22+
SRC = linreg.c linreg.py
2323

2424
# Include to get the rules for compiling and linking the module
2525
include $(MPY_DIR)/py/dynruntime.mk

src/emlearn_linreg/linreg.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
2+
3+
def train(model, X_train, y_train,
4+
max_iterations=100,
5+
tolerance=1e-6,
6+
check_interval=10,
7+
divergence_factor=10.0,
8+
verbose=0,
9+
):
10+
"""
11+
Simple training loop
12+
"""
13+
prev_mse = float('inf')
14+
15+
for iteration in range(max_iterations):
16+
# Perform one gradient descent step
17+
model.step(X_train, y_train)
18+
19+
# Only check progress at intervals
20+
if iteration % check_interval != 0:
21+
continue
22+
23+
# Calculate current MSE
24+
current_mse = model.score_mse(X_train, y_train)
25+
change = abs(prev_mse - current_mse)
26+
27+
if verbose >= 2:
28+
print(f'Iteration {iteration} mse={current_mse}')
29+
30+
# Check convergence
31+
converged = change < tolerance and iteration > check_interval * 2
32+
33+
# Check divergence
34+
diverged = current_mse > prev_mse * divergence_factor or not (current_mse == current_mse) # NaN check
35+
36+
if converged:
37+
if verbose:
38+
print(f"Converged at iteration {iteration}")
39+
break
40+
41+
if diverged:
42+
if verbose:
43+
print(f"Diverged at iteration {iteration}")
44+
break
45+
46+
prev_mse = current_mse
47+

tests/test_linreg.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
import emlearn_linreg
33
import array
44

5-
# Create model: 4 features, alpha=0.1, l1_ratio=0.5, lr=0.01
65
model = emlearn_linreg.new(4, 0.1, 0.5, 0.01)
76

87
# Training data (float32 arrays)
98
X = array.array('f', [1,2,3,4, 2,3,4,5]) # flattened
109
y = array.array('f', [10, 15])
1110

1211
# Train
13-
model.train(X, y, 1000, 1e-6)
12+
emlearn_linreg.train(model, X, y, max_iterations=100, tolerance=1e-6, verbose=0)
1413

1514
# Predict
1615
prediction = model.predict(array.array('f', [1,2,3,4]))

tests/test_linreg_california.py

Lines changed: 5 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13,81 +13,6 @@ def load_npy_as_array(filename, dtype='f'):
1313
return shape, data
1414

1515

16-
def train_with_monitoring(model, X_train, y_train, max_iterations=1000,
17-
tolerance=1e-6, check_interval=10, verbose=True):
18-
"""
19-
Train model with progress monitoring using generator pattern.
20-
21-
Yields: (iteration, mse, change, converged, diverged)
22-
"""
23-
prev_mse = float('inf')
24-
n_features = model.get_n_features()
25-
26-
for iteration in range(max_iterations):
27-
# Perform one gradient descent step
28-
model.step(X_train, y_train)
29-
30-
# Check progress at intervals
31-
if iteration % check_interval == 0:
32-
# Calculate current MSE
33-
current_mse = model.score_mse(X_train, y_train)
34-
change = abs(prev_mse - current_mse)
35-
36-
# Check convergence
37-
converged = change < tolerance and iteration > check_interval * 2
38-
39-
# Check divergence
40-
diverged = current_mse > prev_mse * 10.0 or not (current_mse == current_mse) # NaN check
41-
42-
# Yield progress
43-
yield iteration, current_mse, change, converged, diverged
44-
45-
if converged:
46-
if verbose:
47-
print(f"Converged at iteration {iteration}")
48-
break
49-
50-
if diverged:
51-
if verbose:
52-
print(f"Diverged at iteration {iteration}")
53-
break
54-
55-
prev_mse = current_mse
56-
else:
57-
# Yield minimal info for non-check iterations
58-
yield iteration, None, None, False, False
59-
60-
def train_model(model, X_train, y_train, max_iterations=1000, tolerance=1e-6,
61-
check_interval=10, verbose=True):
62-
"""
63-
Simple training function with progress monitoring.
64-
"""
65-
if verbose:
66-
print(f"Training model for up to {max_iterations} iterations...")
67-
print(f"Checking convergence every {check_interval} iterations")
68-
print("Iter MSE Change Status")
69-
print("-" * 40)
70-
71-
for iter_num, mse, change, converged, diverged in train_with_monitoring(
72-
model, X_train, y_train, max_iterations, tolerance, check_interval, verbose=False
73-
):
74-
75-
if verbose and mse is not None: # Only print on check intervals
76-
status = ""
77-
if converged:
78-
status = "CONVERGED"
79-
elif diverged:
80-
status = "DIVERGED"
81-
82-
print(f"{iter_num:4d} {mse:8.6f} {change:8.6f} {status}")
83-
84-
if converged or diverged:
85-
break
86-
87-
if verbose:
88-
print("Training completed.\n")
89-
90-
9116
def compare_regularization():
9217
"""Compare different regularization settings."""
9318
print("\n=== Regularization Comparison ===")
@@ -98,7 +23,7 @@ def compare_regularization():
9823

9924
n_features = X_shape[1]
10025
n_samples = min(500, y_shape[0])
101-
26+
10227
X_subset = array.array('f', X_data[:n_samples * n_features])
10328
y_subset = array.array('f', y_data[:n_samples])
10429

@@ -108,13 +33,14 @@ def compare_regularization():
10833
("LASSO (L1)", 0.01, 1.0),
10934
("ElasticNet", 0.01, 0.5),
11035
]
111-
36+
11237
for name, alpha, l1_ratio in configs:
11338
print(f"\nTesting {name} (alpha={alpha}, l1_ratio={l1_ratio}):")
11439

11540
model = emlearn_linreg.new(n_features, alpha, l1_ratio, 0.01)
116-
train_model(model, X_subset, y_subset, max_iterations=300, check_interval=50, verbose=False)
117-
41+
emlearn_linreg.train(model, X_subset, y_subset,
42+
max_iterations=2000, check_interval=50, verbose=True)
43+
11844
# Calculate predictions and MSE
11945
mse = model.score_mse(X_subset, y_subset)
12046

@@ -123,7 +49,6 @@ def compare_regularization():
12349

12450
# Count non-zero weights (sparsity)
12551
non_zero = sum(1 for w in weights if abs(w) > 1e-6)
126-
12752
print(f" MSE: {mse:.6f}")
12853
print(f" Non-zero weights: {non_zero}/{n_features}")
12954
print(f" Max weight magnitude: {max(abs(w) for w in weights):.6f}")

0 commit comments

Comments
 (0)