Skip to content

Commit e2510f8

Browse files
committed
logreg: Support mini-batch training with generator
1 parent 184b1dc commit e2510f8

5 files changed

Lines changed: 388 additions & 67 deletions

File tree

src/emlearn_logreg/logreg.py

Lines changed: 126 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,10 @@
33
log_prefix = 'emlearn_logreg:'
44

55

6-
def _make_workspace_triplet(n_classes):
7-
return (
8-
array.array('f', [0.0] * n_classes),
9-
array.array('f', [0.0] * n_classes),
10-
array.array('f', [0.0] * n_classes),
11-
)
6+
def _make_buffer(n):
7+
return array.array('f', [0.0] * n)
128

139

14-
def _make_workspace_pair(n_classes):
15-
return (
16-
array.array('f', [0.0] * n_classes),
17-
array.array('f', [0.0] * n_classes),
18-
)
19-
20-
21-
def _make_predict_buffers(n_classes):
22-
return (
23-
array.array('f', [0.0] * n_classes),
24-
array.array('f', [0.0] * n_classes),
25-
)
2610

2711
def train(model, X_train, y_train,
2812
max_iterations=200,
@@ -31,12 +15,13 @@ def train(model, X_train, y_train,
3115
divergence_factor=2.0,
3216
score_limit=None,
3317
verbose=0,
34-
batch_size=None,
3518
):
36-
"""Mini-batch training loop for logistic regression.
19+
"""Full-dataset training loop for logistic regression."""
20+
if max_iterations <= 0:
21+
raise ValueError('max_iterations must be positive')
22+
if check_interval <= 0:
23+
raise ValueError('check_interval must be positive')
3724

38-
Copies data into a reusable buffer when mini-batching to limit peak memory.
39-
"""
4025
n_features = model.get_n_features()
4126
n_classes = model.get_n_classes()
4227
if len(X_train) % n_features != 0:
@@ -48,53 +33,27 @@ def train(model, X_train, y_train,
4833
if n_samples == 0:
4934
raise ValueError('y_train is empty')
5035

51-
if batch_size is None or batch_size <= 0 or batch_size > n_samples:
52-
batch_size = n_samples
53-
54-
logits_buf, probs_buf, bias_buf = _make_workspace_triplet(n_classes)
55-
score_logits, score_probs = _make_workspace_pair(n_classes)
56-
predict_logits, predict_probs = _make_predict_buffers(n_classes)
36+
logits_buf = _make_buffer(n_classes)
37+
probs_buf = _make_buffer(n_classes)
38+
bias_buf = _make_buffer(n_classes)
39+
score_logits = _make_buffer(n_classes)
40+
score_probs = _make_buffer(n_classes)
5741

5842
prev_loss = None
5943
final_loss = float('inf')
6044
iterations_completed = 0
6145

62-
use_batches = batch_size < n_samples
63-
full_X_view = memoryview(X_train)
64-
full_y_view = memoryview(y_train)
65-
if use_batches:
66-
batch_X = array.array('f', [0.0] * (batch_size * n_features))
67-
batch_y = array.array('f', [0.0] * (batch_size * n_classes))
68-
batch_X_view = memoryview(batch_X)
69-
batch_y_view = memoryview(batch_y)
70-
else:
71-
batch_X_view = full_X_view
72-
batch_y_view = full_y_view
46+
X_view = memoryview(X_train)
47+
y_view = memoryview(y_train)
7348

7449
for _ in range(max_iterations):
7550
iterations_completed += 1
76-
if use_batches:
77-
for start in range(0, n_samples, batch_size):
78-
count = min(batch_size, n_samples - start)
79-
base_feature = start * n_features
80-
base_target = start * n_classes
81-
# Copy features for current batch
82-
end_f = base_feature + count * n_features
83-
batch_X[:count * n_features] = X_train[base_feature:end_f]
84-
# Copy targets
85-
end_t = base_target + count * n_classes
86-
batch_y[:count * n_classes] = y_train[base_target:end_t]
87-
88-
X_slice = batch_X_view[:count * n_features]
89-
y_slice = batch_y_view[:count * n_classes]
90-
model.step(X_slice, y_slice, logits_buf, probs_buf, bias_buf)
91-
else:
92-
model.step(batch_X_view, batch_y_view, logits_buf, probs_buf, bias_buf)
51+
model.step(X_view, y_view, logits_buf, probs_buf, bias_buf)
9352

9453
if iterations_completed % check_interval != 0:
9554
continue
9655

97-
current_loss = model.score_logloss(full_X_view, full_y_view, score_logits, score_probs)
56+
current_loss = model.score_logloss(X_view, y_view, score_logits, score_probs)
9857
final_loss = current_loss
9958
change = float('inf') if prev_loss is None else abs(prev_loss - current_loss)
10059

@@ -123,7 +82,116 @@ def train(model, X_train, y_train,
12382
prev_loss = current_loss
12483

12584
if final_loss == float('inf'):
126-
final_loss = model.score_logloss(full_X_view, full_y_view, score_logits, score_probs)
85+
final_loss = model.score_logloss(X_view, y_view, score_logits, score_probs)
86+
87+
return iterations_completed, final_loss
88+
89+
90+
def train_batches(model,
91+
batch_iter_factory,
92+
max_iterations=200,
93+
tolerance=1e-4,
94+
check_interval=5,
95+
divergence_factor=2.0,
96+
score_limit=None,
97+
verbose=0,
98+
score_batches=None,
99+
):
100+
"""Train logistic regression model using externally provided batches.
101+
102+
batch_iter_factory must be a callable that returns a fresh iterator for each
103+
epoch. Each iterator should yield tuples of (X_batch, y_batch) where both are
104+
float32 arrays compatible with model.step(). y_batch must be one-hot encoded.
105+
106+
score_batches is an optional callable taking the model and returning the
107+
average log-loss over the data (computed however the caller prefers). When
108+
provided, it is used for convergence checking.
109+
"""
110+
if not callable(batch_iter_factory):
111+
raise ValueError('batch_iter_factory must be callable')
112+
if max_iterations <= 0:
113+
raise ValueError('max_iterations must be positive')
114+
if check_interval <= 0:
115+
raise ValueError('check_interval must be positive')
116+
if score_batches is not None and not callable(score_batches):
117+
raise ValueError('score_batches must be callable')
118+
119+
n_features = model.get_n_features()
120+
n_classes = model.get_n_classes()
121+
122+
logits_buf = _make_buffer(n_classes)
123+
probs_buf = _make_buffer(n_classes)
124+
bias_buf = _make_buffer(n_classes)
125+
126+
prev_loss = None
127+
final_loss = float('inf')
128+
iterations_completed = 0
129+
130+
for _ in range(max_iterations):
131+
iterations_completed += 1
132+
batches = batch_iter_factory()
133+
try:
134+
batch_iter = iter(batches)
135+
except TypeError:
136+
raise ValueError('batch iterator must be iterable')
137+
138+
batches_processed = 0
139+
140+
for batch in batch_iter:
141+
batches_processed += 1
142+
try:
143+
X_batch, y_batch = batch
144+
except Exception as exc:
145+
raise ValueError('each batch must unpack into (X_batch, y_batch)') from exc
146+
147+
if len(X_batch) == 0:
148+
continue
149+
if len(X_batch) % n_features != 0:
150+
raise ValueError('X_batch size mismatch with n_features')
151+
n_samples = len(X_batch) // n_features
152+
if len(y_batch) != n_samples * n_classes:
153+
raise ValueError('y_batch must be one-hot encoded (len = n_samples * n_classes)')
154+
155+
model.step(X_batch, y_batch, logits_buf, probs_buf, bias_buf)
156+
157+
if batches_processed == 0:
158+
raise ValueError('batch iterator produced no batches')
159+
160+
if iterations_completed % check_interval != 0:
161+
continue
162+
if score_batches is None:
163+
continue
164+
165+
current_loss = float(score_batches(model))
166+
final_loss = current_loss
167+
change = float('inf') if prev_loss is None else abs(prev_loss - current_loss)
168+
169+
if verbose >= 2:
170+
print(log_prefix, f'Iteration {iterations_completed} loss={current_loss}')
171+
172+
converged = change < tolerance and iterations_completed > check_interval * 2
173+
174+
if score_limit is not None and current_loss <= score_limit:
175+
converged = True
176+
177+
diverged = not (current_loss == current_loss)
178+
if not diverged and prev_loss is not None:
179+
diverged = current_loss > prev_loss * divergence_factor
180+
181+
if converged:
182+
if verbose >= 1:
183+
print(log_prefix, f"Converged at iteration {iterations_completed}")
184+
break
185+
186+
if diverged:
187+
if verbose >= 1:
188+
print(log_prefix, f"Diverged at iteration {iterations_completed}")
189+
break
190+
191+
prev_loss = current_loss
192+
193+
if score_batches is not None and final_loss == float('inf'):
194+
final_loss = float(score_batches(model))
127195

128196
return iterations_completed, final_loss
129197

tests/test_logreg.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,27 @@ def assert_raises_value_error(func, message='Expected ValueError'):
160160
raise AssertionError(message)
161161

162162

163+
def make_batch_factory(X, y, n_features, n_classes, batch_size):
164+
n_samples = len(X) // n_features
165+
166+
def factory():
167+
for start in range(0, n_samples, batch_size):
168+
count = min(batch_size, n_samples - start)
169+
feat_start = start * n_features
170+
feat_end = feat_start + count * n_features
171+
target_start = start * n_classes
172+
target_end = target_start + count * n_classes
173+
yield array.array('f', X[feat_start:feat_end]), array.array('f', y[target_start:target_end])
174+
175+
return factory
176+
177+
163178
def test_logreg_train_and_predict():
164179
X, y = make_dataset()
165180
model = emlearn_logreg.new(2, 2, 0.5, 0.0, 0.0)
166181

167182
emlearn_logreg.train(model, X, y, max_iterations=400, tolerance=1e-5,
168-
check_interval=5, batch_size=2)
183+
check_interval=5)
169184

170185
logits, probs = alloc_predict_buffers(model)
171186
run_predict(model, array.array('f', [1, 1]), logits, probs)
@@ -207,7 +222,7 @@ def test_logreg_train_minibatch_reduces_loss():
207222
initial_loss = model.score_logloss(X, y, logits, probs)
208223

209224
emlearn_logreg.train(model, X, y, max_iterations=600, tolerance=1e-6,
210-
check_interval=10, batch_size=1)
225+
check_interval=10)
211226

212227
final_loss = model.score_logloss(X, y, logits, probs)
213228
assert final_loss < initial_loss * 0.7, (initial_loss, final_loss)
@@ -270,7 +285,7 @@ def test_logreg_handles_ill_conditioned_features():
270285
initial_loss = model.score_logloss(X, y, logits, probs)
271286

272287
emlearn_logreg.train(model, X, y, max_iterations=2000, tolerance=1e-6,
273-
check_interval=100, batch_size=4)
288+
check_interval=100)
274289

275290
final_loss = model.score_logloss(X, y, logits, probs)
276291
assert final_loss < initial_loss, (initial_loss, final_loss)
@@ -289,7 +304,7 @@ def test_logreg_high_dimensional_sparse_case():
289304
initial_loss = model.score_logloss(X, y, logits, probs)
290305

291306
emlearn_logreg.train(model, X, y, max_iterations=600, tolerance=1e-6,
292-
check_interval=30, batch_size=4)
307+
check_interval=30)
293308

294309
final_loss = model.score_logloss(X, y, logits, probs)
295310

@@ -312,6 +327,46 @@ def test_logreg_train_requires_targets():
312327
assert_raises_value_error(lambda: emlearn_logreg.train(model, X, y))
313328

314329

330+
def test_logreg_train_batches_matches_dense_training():
331+
X, y = make_linearly_separable_dataset()
332+
n_features = 2
333+
n_classes = 2
334+
batch_size = 2
335+
model = emlearn_logreg.new(n_features, n_classes, 0.4, 0.01, 0.0)
336+
337+
def batch_factory():
338+
return make_batch_factory(X, y, n_features, n_classes, batch_size)()
339+
340+
def scorer(m):
341+
logits, probs = alloc_predict_buffers(m)
342+
return m.score_logloss(X, y, logits, probs)
343+
344+
emlearn_logreg.train_batches(
345+
model,
346+
batch_factory,
347+
max_iterations=200,
348+
tolerance=1e-6,
349+
check_interval=10,
350+
score_batches=scorer,
351+
score_limit=0.1,
352+
)
353+
354+
logits, probs = alloc_predict_buffers(model)
355+
final_loss = model.score_logloss(X, y, logits, probs)
356+
assert final_loss < 0.1, final_loss
357+
358+
359+
def test_logreg_train_batches_validates_batches():
360+
n_features = 2
361+
n_classes = 2
362+
model = emlearn_logreg.new(n_features, n_classes, 0.2, 0.0, 0.0)
363+
364+
def bad_factory_missing_batches():
365+
return iter(())
366+
367+
assert_raises_value_error(lambda: emlearn_logreg.train_batches(model, bad_factory_missing_batches))
368+
369+
315370
def test_logreg_warm_start_sets_new_weights_and_bias():
316371
X, y = make_dataset()
317372
model = emlearn_logreg.new(2, 2, 0.3, 0.0, 0.0)
@@ -346,7 +401,6 @@ def test_logreg_multiclass_softmax_train_set_accuracy():
346401
max_iterations=1200,
347402
tolerance=1e-6,
348403
check_interval=60,
349-
batch_size=3,
350404
)
351405

352406
for idx in range(len(y) // n_classes):
@@ -373,7 +427,6 @@ def test_logreg_multiclass_softmax_generalization():
373427
max_iterations=1200,
374428
tolerance=1e-6,
375429
check_interval=60,
376-
batch_size=3,
377430
)
378431

379432
test_points = [
@@ -399,5 +452,3 @@ def test_logreg_multiclass_softmax_generalization():
399452
test_logreg_train_validates_dimensions()
400453
test_logreg_train_requires_targets()
401454
test_logreg_warm_start_sets_new_weights_and_bias()
402-
test_logreg_one_vs_rest_classifies_training_samples()
403-
test_logreg_one_vs_rest_generalizes_new_points()

tests/test_logreg_cancer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def test_logreg_real_dataset_binary_classification():
7777
max_iterations=1500,
7878
tolerance=1e-5,
7979
check_interval=25,
80-
batch_size=64,
8180
score_limit=0.28,
8281
)
8382

0 commit comments

Comments
 (0)