Skip to content

Commit 184b1dc

Browse files
committed
logreg: Fixup iteration termination handling
1 parent 58c1aec commit 184b1dc

1 file changed

Lines changed: 26 additions & 15 deletions

File tree

src/emlearn_logreg/logreg.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,24 @@ def train(model, X_train, y_train,
5555
score_logits, score_probs = _make_workspace_pair(n_classes)
5656
predict_logits, predict_probs = _make_predict_buffers(n_classes)
5757

58-
prev_loss = float('inf')
58+
prev_loss = None
59+
final_loss = float('inf')
60+
iterations_completed = 0
5961

6062
use_batches = batch_size < n_samples
63+
full_X_view = memoryview(X_train)
64+
full_y_view = memoryview(y_train)
6165
if use_batches:
6266
batch_X = array.array('f', [0.0] * (batch_size * n_features))
6367
batch_y = array.array('f', [0.0] * (batch_size * n_classes))
6468
batch_X_view = memoryview(batch_X)
6569
batch_y_view = memoryview(batch_y)
6670
else:
67-
batch_X_view = memoryview(X_train)
68-
batch_y_view = memoryview(y_train)
71+
batch_X_view = full_X_view
72+
batch_y_view = full_y_view
6973

70-
for iteration in range(max_iterations):
74+
for _ in range(max_iterations):
75+
iterations_completed += 1
7176
if use_batches:
7277
for start in range(0, n_samples, batch_size):
7378
count = min(batch_size, n_samples - start)
@@ -86,33 +91,39 @@ def train(model, X_train, y_train,
8691
else:
8792
model.step(batch_X_view, batch_y_view, logits_buf, probs_buf, bias_buf)
8893

89-
if iteration % check_interval != 0:
94+
if iterations_completed % check_interval != 0:
9095
continue
9196

92-
current_loss = model.score_logloss(batch_X_view, batch_y_view, score_logits, score_probs)
93-
change = abs(prev_loss - current_loss)
97+
current_loss = model.score_logloss(full_X_view, full_y_view, score_logits, score_probs)
98+
final_loss = current_loss
99+
change = float('inf') if prev_loss is None else abs(prev_loss - current_loss)
94100

95101
if verbose >= 2:
96-
print(log_prefix, f'Iteration {iteration} loss={current_loss}')
102+
print(log_prefix, f'Iteration {iterations_completed} loss={current_loss}')
97103

98-
converged = change < tolerance and iteration > check_interval * 2
104+
converged = change < tolerance and iterations_completed > check_interval * 2
99105

100-
if score_limit is not None:
101-
converged = converged or current_loss <= score_limit
106+
if score_limit is not None and current_loss <= score_limit:
107+
converged = True
102108

103-
diverged = (current_loss > prev_loss * divergence_factor) or not (current_loss == current_loss)
109+
diverged = not (current_loss == current_loss)
110+
if not diverged and prev_loss is not None:
111+
diverged = current_loss > prev_loss * divergence_factor
104112

105113
if converged:
106114
if verbose >= 1:
107-
print(log_prefix, f"Converged at iteration {iteration}")
115+
print(log_prefix, f"Converged at iteration {iterations_completed}")
108116
break
109117

110118
if diverged:
111119
if verbose >= 1:
112-
print(log_prefix, f"Diverged at iteration {iteration}")
120+
print(log_prefix, f"Diverged at iteration {iterations_completed}")
113121
break
114122

115123
prev_loss = current_loss
116124

117-
return iteration, prev_loss
125+
if final_loss == float('inf'):
126+
final_loss = model.score_logloss(full_X_view, full_y_view, score_logits, score_probs)
127+
128+
return iterations_completed, final_loss
118129

0 commit comments

Comments
 (0)