@@ -55,19 +55,24 @@ def train(model, X_train, y_train,
5555 score_logits , score_probs = _make_workspace_pair (n_classes )
5656 predict_logits , predict_probs = _make_predict_buffers (n_classes )
5757
58- prev_loss = float ('inf' )
58+ prev_loss = None
59+ final_loss = float ('inf' )
60+ iterations_completed = 0
5961
6062 use_batches = batch_size < n_samples
63+ full_X_view = memoryview (X_train )
64+ full_y_view = memoryview (y_train )
6165 if use_batches :
6266 batch_X = array .array ('f' , [0.0 ] * (batch_size * n_features ))
6367 batch_y = array .array ('f' , [0.0 ] * (batch_size * n_classes ))
6468 batch_X_view = memoryview (batch_X )
6569 batch_y_view = memoryview (batch_y )
6670 else :
67- batch_X_view = memoryview ( X_train )
68- batch_y_view = memoryview ( y_train )
71+ batch_X_view = full_X_view
72+ batch_y_view = full_y_view
6973
70- for iteration in range (max_iterations ):
74+ for _ in range (max_iterations ):
75+ iterations_completed += 1
7176 if use_batches :
7277 for start in range (0 , n_samples , batch_size ):
7378 count = min (batch_size , n_samples - start )
@@ -86,33 +91,39 @@ def train(model, X_train, y_train,
8691 else :
8792 model .step (batch_X_view , batch_y_view , logits_buf , probs_buf , bias_buf )
8893
89- if iteration % check_interval != 0 :
94+ if iterations_completed % check_interval != 0 :
9095 continue
9196
92- current_loss = model .score_logloss (batch_X_view , batch_y_view , score_logits , score_probs )
93- change = abs (prev_loss - current_loss )
97+ current_loss = model .score_logloss (full_X_view , full_y_view , score_logits , score_probs )
98+ final_loss = current_loss
99+ change = float ('inf' ) if prev_loss is None else abs (prev_loss - current_loss )
94100
95101 if verbose >= 2 :
96- print (log_prefix , f'Iteration { iteration } loss={ current_loss } ' )
102+ print (log_prefix , f'Iteration { iterations_completed } loss={ current_loss } ' )
97103
98- converged = change < tolerance and iteration > check_interval * 2
104+ converged = change < tolerance and iterations_completed > check_interval * 2
99105
100- if score_limit is not None :
101- converged = converged or current_loss <= score_limit
106+ if score_limit is not None and current_loss <= score_limit :
107+ converged = True
102108
103- diverged = (current_loss > prev_loss * divergence_factor ) or not (current_loss == current_loss )
109+ diverged = not (current_loss == current_loss )
110+ if not diverged and prev_loss is not None :
111+ diverged = current_loss > prev_loss * divergence_factor
104112
105113 if converged :
106114 if verbose >= 1 :
107- print (log_prefix , f"Converged at iteration { iteration } " )
115+ print (log_prefix , f"Converged at iteration { iterations_completed } " )
108116 break
109117
110118 if diverged :
111119 if verbose >= 1 :
112- print (log_prefix , f"Diverged at iteration { iteration } " )
120+ print (log_prefix , f"Diverged at iteration { iterations_completed } " )
113121 break
114122
115123 prev_loss = current_loss
116124
117- return iteration , prev_loss
125+ if final_loss == float ('inf' ):
126+ final_loss = model .score_logloss (full_X_view , full_y_view , score_logits , score_probs )
127+
128+ return iterations_completed , final_loss
118129
0 commit comments